In [None]:
%load_ext autoreload
%autoreload 2
# %matplotlib agg

In [None]:
import torch
import numpy as np
from tqdm.autonotebook import tqdm
import plotly.express as px
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from datasets import get_dataset
from gnn import GNN # noqa: F401
from pyvis.network import Network
from PIL import Image
from sklearn.metrics import confusion_matrix
from torch_geometric.datasets.graph_generator import ERGraph

In [None]:
dataset_name = "MUTAG"
model_path = "models/MUTAG_model_new.pth"
# dataset_name = "ENZYMES"
# model_path = "models/ENZYMES_model.pth"
# dataset_name = "Shapes_Ones"
# model_path = "models/Shapes_Ones_model.pth"
# dataset_name = "MNISTSuperpixels"
# model_path = "models/MNISTSuperpixels_model.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
nn = torch.load(model_path, fix_imports=True, map_location=device)
nn.device = device
nn.to(device)
nn.eval()
dataset = get_dataset(dataset_name)
train_loader = dataset.get_train_loader()
test_loader = dataset.get_test_loader()
print(nn)

In [None]:
add_random_class = True

if add_random_class:
    all_data = list(dataset.data)
    for i in tqdm(range(len(all_data)//5)):
        generator = ERGraph(20, 0.2)
        data = generator()
        data.x = torch.eye(dataset.num_node_features)[np.random.randint(0, dataset.num_node_features, (data.num_nodes,))]
        data.y = torch.tensor(dataset.num_classes)
        all_data.append(data)
    dataset.data = all_data
    dataset.num_classes += 1

In [None]:
def get_relu_activations(nn, data, threshold=0):
    all_outputs = nn.get_all_layer_outputs(data)
    output_vector = torch.concat([o[1].flatten() for o in all_outputs if "Relu" in o[0] and "Lin" in o[0]])
    return output_vector.squeeze().detach().numpy()

In [None]:
out_numel = [output.numel() for name, output in nn.get_all_layer_outputs(next(iter(train_loader))[0]) if "Relu" in name and "Lin" in name]
mask_end_indices = dict(zip([name for name in nn.layers.keys() if "Relu" in name and "Lin" in name], np.cumsum(out_numel)))
mask_end_indices

In [None]:
start, end = 0, mask_end_indices["Lin_1_Relu"]
# start, end = 0, mask_end_indices["Lin_0_Relu"]
end-start

In [None]:
a = pd.DataFrame()
a["data"] = list(dataset.data)
a["y"] = a["data"].apply(lambda x: x.y if isinstance(x.y, int) else x.y.item())
a["num_nodes"] = a["data"].apply(lambda x: x.num_nodes)
a["num_edges"] = a["data"].apply(lambda x: x.num_edges)
tqdm.pandas(desc="Gathering ReLU Masks")
a["selected_activations"] = a["data"].progress_apply(lambda x: get_relu_activations(nn, x)[start:end])
a["minimum_activation"] = a["selected_activations"].apply(lambda x: np.min([y for y in x if y != 0]))
a["maximum_activation"] = a["selected_activations"].apply(lambda x: np.max(x))
tqdm.pandas(desc="Getting Outputs")
a["output"] = a["data"].progress_apply(lambda x: nn(x).detach().numpy().flatten())
a["prediction"] = a["output"].apply(lambda x: np.argmax(x))
a["correct"] = a["prediction"] == a["y"]
a[a["y"]==dataset.num_classes-1]["correct"] = 1

In [None]:
def get_normal(nn, mask):
    # lin_layers = [layer for name, layer in nn.layers.items() if "Relu" in name[0] and "Lin" in name[0]]
    # mask = torch.tensor(mask).to(device)
    # mask = mask.view(1, -1)
    ## Get the normal of the nn output on the linear region defined by the mask
    pass

In [None]:
a["minimum_activation"].min()

In [None]:
a.iloc[32]["output"]

In [None]:
px.histogram(a, x="num_nodes", color="y", marginal="box", title="Number of Nodes per Graph")

In [None]:
threshold = 0
a["mask"] = a["selected_activations"].apply(lambda x: (x > threshold).astype(int))

masks = np.stack(a["mask"].tolist())
unique_masks, unique_inverse, unique_counts = np.unique(masks, axis=0, return_index=False, return_inverse=True, return_counts=True) 
print("Masks Shape:", masks.shape)
print("Unique Masks Shape:", unique_masks.shape)
print("Unique Inverse Shape:", unique_inverse.shape)
print("Unique Counts Shape:", unique_counts.shape)

In [None]:
unique_masks[:100]

In [None]:
mask_df = pd.DataFrame()
mask_df["mask"] = list(unique_masks)
mask_df["tuple"] = mask_df["mask"].apply(tuple)
mask_df["count"] = unique_counts.tolist()
mask_df["indices"] = mask_df.index.map(lambda x: np.atleast_1d(np.argwhere(unique_inverse == x).squeeze()))
mask_df["correct_proportion"] = mask_df["indices"].apply(lambda x: np.mean(a["correct"].values[x]).item())
mask_df["ys"] = mask_df["indices"].apply(lambda x: np.atleast_1d(a["y"].values[x].squeeze()))
mask_df["predictions"] = mask_df["indices"].apply(lambda x: np.atleast_1d(a["prediction"].values[x].squeeze()))
mask_df["num_nodes"] = mask_df["indices"].apply(lambda x: np.atleast_1d(a["num_nodes"].values[x].squeeze()))
mask_df["mean_num_nodes"] = mask_df["num_nodes"].apply(lambda x: np.mean(x))
mask_df["std_num_nodes"] = mask_df["num_nodes"].apply(lambda x: np.std(x))
mask_df["num_edges"] = mask_df["indices"].apply(lambda x: np.atleast_1d(a["num_edges"].values[x].squeeze()))
mask_df["mean_num_edges"] = mask_df["num_edges"].apply(lambda x: np.mean(x))
mask_df["std_num_edges"] = mask_df["num_edges"].apply(lambda x: np.std(x))
mask_df["confusion_matrix"] = mask_df.apply(lambda x: confusion_matrix(x["ys"], x["predictions"], labels=np.arange(dataset.num_classes)), axis=1)
mask_df["class_counts"] = mask_df["ys"].apply(lambda x: np.bincount(x, minlength=dataset.num_classes))
mask_df["class_proportions"] = mask_df["ys"].apply(lambda x: np.bincount(x, minlength=dataset.num_classes) / len(x))
mask_df["mask_index"] = mask_df.index

In [None]:
px.histogram(mask_df.explode("num_nodes"), x="num_nodes", color="mask_index")

In [None]:
px.histogram(mask_df.explode("num_edges"), x="num_edges", color="mask_index")

In [None]:
px.histogram(mask_df, x="count", nbins=100, title="Instances Per Mask").show()

In [None]:
px.histogram(mask_df, x="count", nbins=100, title="Instances Per Mask").show()

In [None]:
dataset.draw_graph(data=dataset.data[0])

In [None]:
max_diff = 3
subset = mask_df
num_masks = len(subset)
# subset = mask_df[mask_df["count"]>1]
# num_masks = min(num_masks, len(subset))
# subset = subset.sample(num_masks, random_state=0)

G = nx.Graph()
bar = tqdm(subset.iterrows(), total=num_masks)
for i, row in bar:
    if i < 1000:
        num_examples = min(len(row["indices"]), 3)+1
        num_rows = np.ceil(np.sqrt(num_examples)).astype(int)
        num_cols = num_examples // num_rows
        fig, axs = plt.subplots(num_rows, num_cols, figsize=(10, 10))
        axs = axs.flatten() if num_rows > 1 else [axs]
        for j, ax in enumerate(axs[:-1]):
            ax.axis("equal")
            ax.set_axis_off()
            if j <= num_examples:
                data = a["data"].values[row["indices"][j]]
                dataset.draw_graph(data=data, ax=ax)

        # fig, ax = dataset.draw_graph(data=dataset.data[row["indices"][0][0]])
        
        # fig, ax = plt.subplots()
        # plt.margins(0,0)
        # ax.pie(row["class_proportions"], labeldistance=.6, labels = list(range(dataset.num_classes)))
        # ax.set_box_aspect(1)
        # ax.set_axis_off()
        # fig.tight_layout()
        # plt.tight_layout(pad=0)

        axs[-1].pie(row["class_proportions"], labeldistance=.6, labels = list(range(dataset.num_classes)))
        axs[-1].axis("equal")
        axs[-1].set_axis_off()
        
        fig.canvas.draw()
        img = Image.frombytes(
            "RGBa", fig.canvas.get_width_height(), fig.canvas.buffer_rgba()
        )
        plt.close(fig)
        img.convert("RGB").save(f"images/{i}.png")

    title = f"Mask {i}"
    if len(row['tuple'])<30: 
        title += f"\n{row['tuple']}"
    title += f"\nExample Indices:\n {row['indices'][:5].tolist() if isinstance(row['indices'], np.ndarray) else row['indices'][:5]}"
    # title += f"\nAverage Outputs:\n{a.iloc[row['indices']]['output'].mean(axis=0).round(2)}"
    title += f"\nAverage Number of Nodes: {a.iloc[row['indices']]['num_nodes'].mean():.2f}±{a.iloc[row['indices']]['num_nodes'].std():.2f}"
    title += f"\n{row['count']} sample{'s' if row['count']>1 else ''}\nCorrect Proportion: {row['correct_proportion']:.2f}"
    title += ''.join(f"\nClass {i} Proportion: {p:.2f}" for i, (c, p) in enumerate(zip(row["class_counts"], row["class_proportions"])))
    cm = '\n'.join(str(row) for row in row["confusion_matrix"].tolist())
    title += f"\nConfusion Matrix:\n{cm}"
    label = f"Mask {i}"
    G.add_node(row["tuple"], title=title, label=label, image =f"images/{i}.png", shape="image", size=10*(np.log(row['count'])+3), **{k: str(v) for k,v in row.items()}) #, count=row["count"]
    for other_node in G.nodes:
        difference = np.nonzero(row["mask"] ^ other_node)[0]
        bits_different = len(difference)
        if 0 < bits_different <= max_diff:
            title = f"# Bits Different: {bits_different}\nDifference: {str(difference.tolist())}"
            G.add_edge(row["tuple"], other_node, title=title, label=str(bits_different), value=1/bits_different)
        # G.add_edge(mask_tuple, other_node, weight=bits_different)
    bar.set_postfix({"Nodes": G.number_of_nodes(), "Edges": G.number_of_edges()})
G = nx.relabel_nodes(G, {node: str(node) for node in G.nodes}, copy=False)
print(f"Number of Nodes: {G.number_of_nodes()}\nNumber of Edges: {G.number_of_edges()}")

In [None]:
print(next(iter(G.nodes(data=True)))[1].keys())

In [None]:
np.unique(sorted([len(s) for s in nx.connected_components(G)]), return_counts=True)

In [None]:
if G.number_of_nodes() < 2000:
    nt = Network(height="1000px", width="100%")
    nt.from_nx(G)
    nt.show_buttons()
    nt.repulsion(node_distance=300, central_gravity=0.2, spring_length=200, spring_strength=0.05)
    nt.save_graph('nx.html')