In [None]:
import matplotlib

import numpy as np
import networkx as nx
import graph_tool.all as gt

from aann.dataset import Dataset
from aann.utils.graph import get_model_graph, nx2gt

from aann.models import SimpleModel
from aann.utils.image import cv2pil, scale_image

In [None]:
def show_image(dataset_item: dict):
    image = dataset_item["img_features"].numpy()
    image = (image * 128).astype("uint8")
    image = image.reshape(8, 8)
    display(cv2pil(scale_image(image, (256, 256))))

In [None]:
def show_graph(gt_graph: gt.Graph):
    state = gt.minimize_nested_blockmodel_dl(
        gt_graph,
        state_args=dict(clabel=gt_graph.vp.layer)
    )

    _ = state.draw(
        bg_color="black",
        # vertex_text=gt_graph.vp.neuron,
        vertex_fill_color=gt_graph.vp.layer,
        output_size=(768, 768),
        beta=0.5,
        rel_order=gt_graph.vp.neuron,
        vcmap=(matplotlib.cm.Set3, 1.0),
        edge_gradient=[],
        edge_color=gt.prop_to_size(gt_graph.ep.weight, power=2),
        ecmap=(matplotlib.cm.magma, 1.0),
    )

In [None]:
def get_node_map_by_layer(nx_graph: nx.DiGraph, layer: int) -> dict:
    node_map = {
        data["neuron"]: node for node, data in nx_graph.nodes(data=True)
        if data["layer"] == layer
    }

    return node_map

In [None]:
dataset = Dataset()
dataset.load()

In [None]:
model = SimpleModel(
    num_in_features=dataset.num_in_features,
    num_classes=dataset.num_classes,
)

In [None]:
# model.train(train_dp_dataset=dataset.train)
# model.save()

In [None]:
model.load()
model.model

In [None]:
nx_graph = get_model_graph(model, min_weight=0.6)

print(f"nodes: {nx_graph.number_of_nodes()}")
print(f"edges: {nx_graph.number_of_edges()}")

In [None]:
gt_graph = nx2gt(nx_graph)

In [None]:
show_graph(gt_graph)

In [None]:
# sorted(data["weight"] for *_, data in nx_graph.edges(data=True))

In [None]:
l1_node_map = get_node_map_by_layer(nx_graph, 0)
l2_node_map = get_node_map_by_layer(nx_graph, 1)
l2_node_map

In [None]:
train_iter = iter(dataset.train)

In [None]:
dataset_item = next(train_iter)
pred, confs = model.dataset_item_predict(dataset_item)
show_image(dataset_item)

print(f"class:      {dataset_item['y']}")
print(f"pred_class: {pred}")
print(f"confs:      {confs.astype('float16')}")

image = dataset_item["img_features"].numpy().astype("uint8")
non_zero_idx = np.nonzero(image)[0]

l1_active_nodes = [l1_node_map[idx] for idx in non_zero_idx]
l2_path_cnt = {
    num_class: sum(
        len(list(nx.all_simple_paths(nx_graph, l1_node, class_node)))
        for l1_node in l1_active_nodes
    ) for num_class, class_node in l2_node_map.items()
}

l2_path_cnt