In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import numpy as np
import pickle
from tqdm.autonotebook import tqdm
import plotly.express as px
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from datasets import get_dataset
from gnn import GNN
from pyvis.network import Network
from PIL import Image

In [None]:
# dataset_name = "MUTAG"
# model_path = "models/MUTAG_model_new.pth"
dataset_name = "Shapes_Ones"
model_path = "models/Shapes_Ones_model.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nn = torch.load(model_path, fix_imports=True, map_location=device)
nn.device = device
nn.to(device)
nn.eval()
dataset = get_dataset(dataset_name)
train_loader = dataset.get_train_loader()
test_loader = dataset.get_test_loader()
print(nn)

In [None]:
def get_masks(nn, loader, threshold=0):
    all_masks = []
    ys = []
    for data in tqdm(loader):
        all_outputs = nn.get_all_layer_outputs(data)
        output_vector = torch.concat([o[1].flatten() for o in all_outputs if "Relu" in o[0] and "Lin" in o[0]])
        mask = output_vector > threshold
        all_masks.append(mask.to(int))
        ys.append(data.y.item())
    return torch.stack(all_masks, dim=0).squeeze().detach().numpy(), np.array(ys)

train_masks, train_ys = get_masks(nn, train_loader)
test_masks, test_ys = get_masks(nn, test_loader)
all_masks, all_ys = np.concatenate([train_masks, test_masks], axis=0), np.concatenate([train_ys, test_ys], axis=0)
train_masks.shape, test_masks.shape, all_masks.shape

In [None]:
out_numel = [output.numel() for name, output in nn.get_all_layer_outputs(next(iter(train_loader))[0]) if "Relu" in name and "Lin" in name]
mask_indices = dict(zip([name for name in nn.layers.keys() if "Relu" in name and "Lin" in name], np.cumsum(out_numel)))
mask_indices

In [None]:
# start, end = mask_indices["conv3_ReLU"], mask_indices["conv4_ReLU"]
# start, end = mask_indices["Lin_0_Relu"], mask_indices["Lin_1_Relu"]
start, end = 0, mask_indices["Lin_1_Relu"]
end-start

In [None]:
unique_train_masks, unique_train_indices, unique_train_inverse, unique_train_counts = np.unique(train_masks[:, start:end], axis=0, return_index=True, return_inverse=True, return_counts=True)
unique_test_masks, unique_test_indices, unique_test_inverse, unique_test_counts =np.unique(test_masks[:, start:end], axis=0, return_index=True, return_inverse=True, return_counts=True) 
unique_masks, unique_indices, unique_inverse, unique_counts =np.unique(all_masks[:, start:end], axis=0, return_index=True, return_inverse=True, return_counts=True) 
len(unique_train_masks), len(unique_test_masks), len(unique_masks)

In [None]:
unique_train_masks

In [None]:
node_size = 800

def get_mask_graph(masks, ys, max_diff = 1):
    unique_masks, unique_index, unique_inverse, unique_counts =np.unique(masks, axis=0, return_index=True, return_inverse=True, return_counts=True) #TODO: Problem here
    print("Unique Masks:", masks.shape)
    print("Unique Masks Shape:", unique_masks.shape)
    print("Unique Index Shape:", unique_index.shape)
    print("Unique Inverse Shape:", unique_inverse.shape)
    print("Unique Counts Shape:", unique_counts.shape)
    G = nx.Graph()
    for i, mask in enumerate(unique_masks):
        mask_tuple = tuple(mask)
        mask_hex = hex(int("".join(mask.astype(str)), 2))
        mask_indices = np.nonzero(unique_inverse == i)[0]
        fig, ax = dataset.draw_graph(data=dataset.data[unique_index[i]])
        fig.canvas.draw()
        img = Image.frombytes(
            "RGBa", fig.canvas.get_width_height(), fig.canvas.buffer_rgba()
        )
        plt.close()
        # img = img.resize((node_size*20, round(((node_size*20)/img.size[0])*img.size[1])))
        class_proportions = {f"class_prop_{c}": np.sum(ys[mask_indices] == c)/unique_counts[i] for c in np.unique(ys)}
        G.add_node(mask_tuple, mask=mask, index=i, mask_indices=mask_indices, count=unique_counts[i], hex=mask_hex, image=img, **class_proportions)
        for other_node in G.nodes:
            bits_different = np.sum(mask ^ other_node)
            if 0 < bits_different <= max_diff:
                G.add_edge(mask_tuple, other_node, bits_different=bits_different, difference=np.nonzero(mask ^ other_node))
            # G.add_edge(mask_tuple, other_node, weight=bits_different)
    return G

G = get_mask_graph(all_masks, all_ys, max_diff=3)
pos = nx.kamada_kawai_layout(G)
fig = plt.figure()
ax=plt.subplot(111)

nx.draw(
    G, pos, edge_color='black', width=1, linewidths=1,
    node_size=node_size, node_color='white', alpha=1,
    ax=ax, edgecolors=(0,0,0,0)#labels={n: d["class_0_prop"] for n, d in G.nodes(data=True)},
)
if "weight" in next(iter(G.edges(data=True)))[2]:
    nx.draw_networkx_edge_labels(
        G, pos,
        edge_labels={(u, v): d['weight'] for u, v, d in G.edges(data=True)},
        font_color='red', ax=ax,
    )

# Transform from data coordinates (scaled between xlim and ylim) to display coordinates
tr_figure = ax.transData.transform
# Transform from display to figure coordinates
tr_axes = fig.transFigure.inverted().transform

# Select the size of the image (relative to the X axis)
icon_size = (ax.get_xlim()[1] - ax.get_xlim()[0]) * 0.05
icon_center = icon_size / 2.0

# Add the respective image to each node
for n in G.nodes:
    xf, yf = tr_figure(pos[n])
    xa, ya = tr_axes((xf, yf))
    # get overlapped axes and plot icon
    a = plt.axes([xa - icon_center, ya - icon_center, icon_size, icon_size])
    a.imshow(G.nodes[n]["image"])
    a.axis("off")
    
plt.axis('off')
plt.show()

In [None]:
nt = Network(height="1000px", width="100%")
for i, (node, data) in enumerate(G.nodes(data=True)):
    title = f"{node} \nCount: {data['count']}"
    for c, p in data.items():
        if "class_prop" in c:
            title += f"\nProportion of Class {c[-1]}: {p*100:.0f}%"
    title += f"\nSome Graph Indices: {data['mask_indices'][:10]}"
    label = " / ".join([f"{p*100:.0f}%" for c, p in data.items() if "class_prop" in c])
    count_str = f"Count: {data['count']}"
    label += "\n" + " "*((len(label)-len(count_str))//2) + count_str
    G.nodes[node]["image"].convert("RGB").save(f"images/{i}.png")
    nt.add_node(str(node), title=title, label=label , shape='image', image =f"images/{i}.png")
for u, v, data in G.edges(data=True):
    # nt.add_edge(str(u), str(v), title= f"# Bits Different: {str(data['bits_different'])} \nDifference: {str(data['difference'][0].tolist())}")
    nt.add_edge(str(u), str(v), title= f"# Bits Different: {str(data['bits_different'])} \nDifference: {str(data['difference'][0].tolist())}", value=1/data['bits_different'])
nt.show_buttons()
nt.save_graph('nx.html')

In [None]:

for class_index in range(dataset.num_classes):
    records = []
    for test_data, test_mask in tqdm(zip([d for d in test_loader if d.y == class_index], test_masks[:, start:end])):
        prediction = torch.argmax(nn(test_data).squeeze()).item()
        test_mask = test_mask.astype(int)
        mask_dists_from_train = np.sum(np.abs(test_mask-unique_train_masks), axis=1)
        min_l1_idx = np.argmin(mask_dists_from_train)
        min_l1 = mask_dists_from_train[min_l1_idx]
        deviations = np.argwhere(np.argwhere(test_mask != unique_train_masks[min_l1_idx]))
        records.append({"Class": test_data.y.item(), "Predicted Class": prediction, "Min Train Mask L1": min_l1, "Min Train Mask L1 Index": min_l1_idx, "Deviations": deviations})
    df = pd.DataFrame(records)

    print(df[df["Predicted Class"] == class_index]["Min Train Mask L1"].mean(), df[df["Predicted Class"] != class_index]["Min Train Mask L1"].mean())
    print(df[df["Predicted Class"] == class_index]["Min Train Mask L1"].mean()/len(test_mask), df[df["Predicted Class"] != class_index]["Min Train Mask L1"].mean()/len(test_mask))
    px.histogram(df, x="Min Train Mask L1", color="Predicted Class", marginal="rug").show()

In [None]:
# utm = unique_train_masks.copy()

# def create_tree(m, rows=None):
#     if rows is None: 
#         rows = np.arange(m.shape[0])
#     if rows.size == 1 or len(np.unique(m[rows], axis=0)) == 1:
#         return m[rows][0].astype(int)
#     split_index = np.argmin(np.abs(m[rows].sum(axis=0) - len(rows) // 2))
#     branch_0_rows = np.argwhere(m[rows, split_index] == 0).squeeze()
#     branch_1_rows = np.argwhere(m[rows, split_index] == 1).squeeze()
#     assert(branch_0_rows.size+branch_1_rows.size == rows.size)
#     # print(m.shape, branch_0_rows.shape, branch_1_rows.shape)
#     # print(branch_0_rows)
#     # return
#     if branch_0_rows.size == 0 and branch_1_rows.size == 0:
#         print("Oops", m[rows, split_index])

#     if branch_0_rows.size == 0:
#         return {(split_index, 1): create_tree(m, branch_1_rows)}
#     elif branch_1_rows.size == 0:
#         return {(split_index, 0): create_tree(m, branch_0_rows)}
#     else:
#         return {(split_index, 0): create_tree(m, branch_0_rows), (split_index, 1): create_tree(m, branch_1_rows)}

# tree = create_tree(utm)

# def print_tree(t, d=0, depth=0):
#     size = 0
#     if isinstance(t, dict):
#         for k in t.keys():
#             print(" "*d+str(k)+": "+str(t[k]))
#             sub_depth, sub_size = print_tree(t[k], d+2, depth+1)
#             size += sub_size
#         return sub_depth+1, size
#     else:
#         print(t)
#         return 1, 1
# print_tree(tree)