In [1]:
import pandas as pd
import numpy as np

from spn import LogicalArithmeticSPN
from spn.laspn import sort_dataset, generate_total_significant_mindices, generate_significant_mindices
from spn.laspn import save_all_mindices, read_all_mindices

In [2]:
def read_letter(f, c):
    data = []
    nl = 0
    while True:
        line = f.readline().strip()
        if len(line) <= 2:
            next_c = line
            break
        m = len(line)
        if m > 5:
            raise TypeError("{0}: длина строки больше 5".format(c))
        row = tuple([(1 if x == '*' else 0) for x in line])
        if len(row) != 5:
            raise TypeError("{0}: длина != 5".format(c))
        data.append(row)
        nl += 1
    if nl != 5:
        raise TypeError("{0}: высота != 5".format(c))
    return data, next_c

def flatten_data(data):
    t = ()
    for d in data:
        t = t + d
    return t

def read_letters():
    f = open('data/letters.txt', 'rt', encoding='utf-8')
    Y = []
    X = []
    d = {}
    line = f.readline().strip()
    c = line[0]
    while True:
        if c == '##':
            break
        data, next_c = read_letter(f, c)
        data = flatten_data(data)
        Y.append(c)
        X.append(data)
        c = next_c
    f.close()
    return X, Y

In [3]:
X, Y = read_letters()

In [4]:
X, Y = sort_dataset(X, Y)
print(len(X), len(Y))

26 26


In [5]:
C = list(np.unique(Y))

YY = {}
N = len(X)
for c in C:
    Yc = []
    for y in Y:
        if y == c:
            Yc.append(1)
        else:
            Yc.append(0)
    YY[c] = Yc


In [6]:
for k, Xk in enumerate(X):
    print('%.3i' % k, ''.join(str(x) for x in Xk), Y[k], sum(Xk))

000 0010000100001000010000100 I 5
001 0001000010000101001001100 J 7
002 1000101010001000010000100 Y 7
003 1000010000100001000011111 L 9
004 1000101010001000101010001 X 9
005 1000110001010100101000100 V 9
006 1111100100001000010000100 T 9
007 0111110000100001000001111 C 11
008 1000110001100011000101110 U 11
009 1000110010111001001010001 K 11
010 0111010001100011000101110 O 12
011 0111010001101011001001001 Q 12
012 1000110001101011010101010 W 12
013 1111010001111101000010000 P 12
014 0111110000011100000111110 S 13
015 0111110000100111000101110 G 13
016 1000110001111111000110001 H 13
017 1000111001101011001110001 N 13
018 1000111011101011000110001 M 13
019 1111100010001000100011111 Z 13
020 1111110000111111000010000 F 13
021 0111010001111111000110001 A 14
022 1111010001100011000111110 D 14
023 1111010001111101001010001 R 14
024 1111010001111101000111110 B 16
025 1111110000111111000011111 E 17


In [8]:
h = open("letters5_all_mis.txt", "wt")
MI = []
for mis in generate_total_significant_mindices(X):
    MI.append(mis)
    h.write(repr(mis) + "\n")
h.close()

0 17 [(12,), (2,), (17,)]
1 74 [(8,), (13,), (3,)]
2 70 [(4,), (0,), (6,)]
3 311 [(20,), (24,), (10,)]
4 287 [(16,), (6, 18), (6, 16)]
5 296 [(9,), (11,), (4, 9)]
6 272 [(1,), (0, 7), (2, 4)]
7 1223 [(1, 10), (3, 24), (3, 10)]
8 1194 [(19,), (14,), (0, 14)]
9 1179 [(10, 12), (10, 11), (5, 8)]
10 2045 [(1, 14), (2, 19), (2, 9)]
11 2213 [(14, 24), (9, 24), (12, 14)]
12 1939 [(5, 17), (17, 19), (14, 17)]
13 2426 [(10, 13), (12, 13), (2, 11)]
14 5043 [(13, 19), (13, 23), (11, 19)]
15 3142 [(13, 14), (10, 13, 19), (10, 13, 22)]
16 4466 [(11, 14), (13, 24), (19, 24)]
17 3780 [(6, 15), (18, 19), (6, 19)]
18 2486 [(8, 9), (8, 14), (8, 19)]
19 4914 [(1, 16), (8, 23), (16, 21)]
20 3123 [(3, 14, 20), (1, 14, 20), (0, 2, 14)]
21 7128 [(2, 19, 24), (1, 19, 24), (2, 13, 24)]
22 7047 [(9, 20, 22), (0, 3, 19), (9, 20, 21)]
23 6150 [(1, 13, 18), (13, 18, 20), (1, 11, 18)]
24 33329 [(10, 11, 23), (9, 11, 23), (11, 15, 21)]
25 68626 [(11, 14, 23), (13, 21, 24), (11, 14, 21)]


In [None]:
print(YY['C'])

In [None]:
spn = LogicalArithmeticSPN(total=False)
spn.fit(X, YY['C'], mindices=MI, is_sorted=True)

In [None]:
print(spn.mindices)
print(spn.weights)

In [None]:
spn.evaluate_all(X) == YY['C']

In [None]:
from IPython import display
text = r'\begin{eqnarray}'
text += r"\mathrm{spn} &=& %s\\" % (spn.latex()[1:-1])
text += r'\end{eqnarray}'
display.Latex(text)

In [None]:
import networkx as nx

In [None]:
def plot_spn_digraph(spn):
    g = nx.DiGraph()

    x_list = []
    P_list = []

    edge_labels ={}
    
    m = len(spn.mindices)
    for i in range(m):
        P = 'P'+str(i+1)
        if P not in P_list:
            P_list.append(P)

        mi = spn.mindices[i]
        for t in mi:
            x = 'x'+str(t+1)
            if x not in x_list:
                x_list.append(x)
            key = (x, P)
            g.add_edge(*key)

        S = 'S'
        g.add_edge(P, S)
        key = (P, S)
        g.add_edge(*key)
        edge_labels[key] = spn.weights[i]
        H = 'H'
        g.add_edge(S, H)
        y = 'y'
        g.add_edge(H, y)

    node_colors = []
    for node in g.nodes:
        c = node[0]
        if c == 'x':
            node_colors.append('gray')
        elif c == 'P':  
            node_colors.append('g')
        elif c == 'S':  
            node_colors.append('b')
        elif c == 'H':  
            node_colors.append('m')
        elif c == 'y':  
            node_colors.append('gray')
            
    ax = plt.gca()
    pos = nx.drawing.layout.shell_layout(g, nlist=[x_list, P_list, ['S'], ['H'], ['y']], center=(0,0))
    nx.draw_networkx(g, pos=pos, with_label=True, arrows=True, ax=ax, 
                     node_color=node_colors)
    nx.draw_networkx_edge_labels(g, pos=pos, ax=ax, edge_labels=edge_labels)

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(15,9))
plot_spn_digraph(spn)
plt.savefig("letter5_C.jpg")