## Precomputed SVM

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from sklearn import svm
X = np.array([[0, 0], [0.5, 0.5], [1, 1], [0.75, 0.75]])
y = [0, 1, 2, 2]
clf = svm.SVC(kernel='precomputed', probability=True)
gram = np.dot(X[:3], X[:3].T)
clf.fit(gram, y[:3])

for x, y in zip(X, y):
    preds = clf.predict_proba(np.dot(x, X[:3].T).reshape(1, -1))[0]
    pred_class = clf.classes_[np.argmin(preds)]
    print('{} -> Expected: {}, Got: {}, Probs: {}'.format(
        ", ".join('{:.2f}'.format(x) for x in x),
        y,
        pred_class,
        ",".join("{:.2f}".format(x) for x in preds))
    )

0.00, 0.00 -> Expected: 0, Got: 0, Probs: 0.24,0.28,0.48
0.50, 0.50 -> Expected: 1, Got: 1, Probs: 0.35,0.30,0.35
1.00, 1.00 -> Expected: 2, Got: 2, Probs: 0.48,0.28,0.24
0.75, 0.75 -> Expected: 2, Got: 2, Probs: 0.41,0.29,0.29


In [86]:
import functools
import wl
import networkx as nx
import matplotlib.pyplot as plt

def get_all_nodes(gs):
    return functools.reduce(lambda acc, x: acc | set(x), gs, set())

def get_wl_args(graphs):
    adjs = [nx.adjacency_matrix(g).toarray() for g in graphs]
    nodes = [g.nodes() for g in graphs]
    return adjs, nodes


g1 = nx.Graph()
g1.add_edge('A', 'B')
g1.add_edge('B', 'C')

g2 = nx.Graph()
g2.add_edge('A', 'B')
g2.add_edge('B', 'C')
g2.add_edge('B', 'D')

g3 = nx.Graph()
g3.add_edge('A', 'B')
g3.add_edge('B', 'C')
g3.add_edge('B', 'D')

g4 = nx.Graph()
g4.add_edge('A', 'B')
g4.add_edge('B', 'C')

g5 = nx.Graph()
g5.add_edge('A', 'B')
g5.add_edge('D', 'C')

all_graphs = (g1, g2, g3, g4, g5)

DEBUG = False
H = 1

all_nodes = get_all_nodes((g1, g2, g3, g4))

adjs, nodes = get_wl_args((g1, g2))
K, phi, label_lookups = wl.WL_compute(ad_list=adjs, node_label=nodes, all_nodes=all_nodes, h = H, DEBUG=DEBUG)
print(K[-1])
adjs, nodes = get_wl_args([g1, g2, g3])
print(K_new[-1])

adjs, nodes = get_wl_args((g1, g2, g1, g2, g3))
K, phi, label_lookups = wl.WL_compute(ad_list=adjs, node_label=nodes, all_nodes=all_nodes, h = H, DEBUG=DEBUG)
print(K[-1])

[[ 6.  3.]
 [ 3.  8.]]
[[ 6.  3.  6.  3.  3.]
 [ 3.  8.  3.  8.  8.]
 [ 6.  3.  6.  3.  3.]
 [ 3.  8.  3.  8.  8.]
 [ 3.  8.  3.  8.  8.]]
[[ 6.  3.  6.  3.  3.]
 [ 3.  8.  3.  8.  8.]
 [ 6.  3.  6.  3.  3.]
 [ 3.  8.  3.  8.  8.]
 [ 3.  8.  3.  8.  8.]]


In [83]:
# Augment matrix from shape (x, y) to (x+1, y+1), filling with zeros
a = np.linspace(0, 5, 6).reshape(3, 2)
b = np.zeros(np.array(a.shape) + 1)
b[:3,:2] = a

print(a)


[[ 0.  1.]
 [ 2.  3.]
 [ 4.  5.]]
[[ 0.  1.  0.]
 [ 2.  3.  0.]
 [ 4.  5.  0.]
 [ 0.  0.  0.]]


In [10]:
a = np.array([0, 1, 2])
b = np.bincount(a)
b, a[b]

(array([1, 1, 1]), array([1, 1, 1]))

In [None]:
FIG_SIZE = (10, 5)

for idx, graph in enumerate(all_graphs):
    plt.figure(figsize = FIG_SIZE)
    pos = nx.circular_layout(graph)
    nx.draw_networkx(graph, pos = pos)
    
    vals = pos.values()
    y_max = max(x[0] for x in vals)
    y_min = min(x[0] for x in vals)
    x_max = max(x[1] for x in vals)
    x_min = min(x[1] for x in vals)
    plt.text(x = x_min + 0.1, y = y_max + 0.1, s = 'g' + str(idx), fontsize = 20)
    print(y_max, y_min)
    plt.show()