In [1]:
import torch as pt
import torch.nn as T
import torch_geometric.nn as G

In [37]:
def aggregate_fn(aggregator, x, n_nodes, n_graphs, batch):
    if aggregator == "vector":
        return G.pool.global_mean_pool(x, batch).unsqueeze(-1)

    if aggregator == "sequence":
        # batch = [0[3*5], 1[2*4], 2[1*3]]
        # batch_size = 3
        # n_nodes = [3,2,1]
        # n_graphs = [5,4,3]
        # batch_old = [0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2] = (26)
        # batch_new = [0,0,0,1,1,1,2,2,2,3,3,3,4,4,4, 5,5,6,6,7,7,8,8, 9,10,11] = (26)
        # Caution: this implementation is highly optimized and complex
        batch = pt.repeat_interleave(pt.arange(n_graphs.sum()), pt.repeat_interleave(n_nodes, n_graphs))
        x = G.pool.global_mean_pool(x, batch)
        splits = pt.tensor_split(x, pt.cumsum(n_graphs, dim=0)[:-1])
        return T.utils.rnn.pad_sequence(splits, batch_first=True)

In [38]:
x = pt.rand(30, 256)
n_nodes = pt.tensor([3,3])
n_graphs = pt.tensor([5,5])
batch = pt.tensor([0,0,0,0,0,0,0,0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1])

x_vec = aggregate_fn("vector", x, n_nodes, n_graphs, batch)
x_seq = aggregate_fn("sequence", x, n_nodes, n_graphs, batch)

In [48]:
gru = T.GRU(input_size=256, hidden_size=128, num_layers=3, bidirectional=True, batch_first=True, dropout=0.3)

x, h = gru(x_seq)

x.shape

torch.Size([2, 5, 256])

In [31]:
x_seq.shape

torch.Size([2, 5, 256])

In [None]:
import mne
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx

from models import Brain2Vec, Brain2Seq
from data import TensorDataset

In [None]:
transform = Brain2Vec(
        n_times=256,
        n_outputs=1,
        loss_fn="ce",
        signal_transform="raw",
        node_transform="unipolar",
        edge_select="dynamic_lt",
        threshold=-0.5,
        gru_size=128
).transform

ds_raw = TensorDataset(name="chb_mit_window_1", split="train", method="kfold", folds=5, k=-1, transform=lambda x: x)
ds_data = TensorDataset(name="chb_mit_window_1", split="train", method="kfold", folds=5, k=-1, transform=transform)

idx = 60
for i in range(25):
    g = to_networkx(ds_data[i+idx], node_attrs=["x"], graph_attrs=["node_positions"], to_undirected=True)

    x = nx.get_node_attributes(g, "x")
    pos = {key: g.graph["node_positions"][key][:2] for key, _ in x.items()}

    raw = mne.io.RawArray(
        data=ds_raw[i+idx]["data"],
        info=mne.create_info(sfreq=256, ch_names=ds_raw[i+idx]["ch_names"])
    )
    
    raw.plot(bgcolor=("tomato" if ds_raw[i+idx]["labels"][0] == 1 else "white"))

    fig = plt.figure(figsize=(12, 12), facecolor=("tomato" if ds_raw[i+idx]["labels"][0] == 1 else "white"))
    nx.draw_networkx(
        g,
        pos=pos,
        node_size=700,
        node_color=[np.linalg.norm(x[key]) for key, _ in pos.items()]
    )
    nx.draw_networkx_labels(
        g,
        pos={key: [val[0]-0.00, val[1]+0.007] for key, val in pos.items()},
        labels={key: f"{np.linalg.norm(x[key]):0.2f}" for key, _ in pos.items()}
    )
    fig.show()

In [None]:
ds = TensorDataset(name="chb_mit_window_5", split="train", method="kfold", folds=5, k=-1, transform=Brain2Vec(
        n_times=256,
        n_outputs=1,
        loss_fn="ce",
        signal_transform="wavelet",
        node_transform="unipolar",
        edge_select="dynamic",
        threshold=0.5,
        gru_size=128,
).transform)

fig, ax = plt.subplots(3, 3, figsize=(36, 36))
fig.tight_layout()

for i in range(9):
    g = to_networkx(ds[i+50], node_attrs=["x"], graph_attrs=["node_positions"], to_undirected=True)

    x = nx.get_node_attributes(g, "x")
    pos = {key: g.graph["node_positions"][key][:2] for key, _ in x.items()}

    axi = ax[int(i / len(ax))][i % len(ax)]

    if ds[i].y.item() == 1:
        axi.set_facecolor("tomato")

    nx.draw_networkx(
        g,
        ax=axi,
        pos=pos,
        node_size=700,
        node_color=[np.linalg.norm(x[key]) for key, _ in pos.items()]
    )
    nx.draw_networkx_labels(
        g,
        ax=axi,
        pos={key: [val[0]-0.00, val[1]+0.007] for key, val in pos.items()},
        labels={key: f"{np.linalg.norm(x[key]):0.2f}" for key, _ in pos.items()}
    )

fig.show()

In [None]:
ds = TensorDataset(name="chb_mit_window_5", split="train", method="kfold", folds=5, k=-1, transform=Brain2Seq(
        n_times=256,
        n_outputs=1,
        loss_fn="ce",
        signal_transform="wavelet",
        node_transform="unipolar",
        edge_select="dynamic_lt",
        threshold=0.5,
        gru_size=128,
).transform)

fig, ax = plt.subplots(5, 1, figsize=(70, 70))
fig.tight_layout()

for i in range(5):
    g = to_networkx(ds[i+50], node_attrs=["x"], graph_attrs=["n_nodes", "n_graphs", "node_positions"], to_undirected=True)

    x = nx.get_node_attributes(g, "x")
    size = g.graph["n_nodes"]
    length = g.graph["n_graphs"]
    pos = {key: g.graph["node_positions"][key % size][:2] + np.array([int(key / size) * 0.2, int(key / size) * 0.00]) for key, _ in x.items()}

    axi = ax[i]

    if ds[i].y.item() == 1:
        axi.set_facecolor("tomato")

    nx.draw_networkx(
        g,
        ax=axi,
        pos=pos,
        node_size=700,
        node_color=[np.linalg.norm(x[key]) for key, _ in pos.items()]
    )
    nx.draw_networkx_labels(
        g,
        ax=axi,
        pos={key: [val[0]-0.00, val[1]+0.007] for key, val in pos.items()},
        labels={key: f"{np.linalg.norm(x[key]):0.2f}" for key, _ in pos.items()}
    )

fig.show()

In [None]:
ds = TensorDataset(name="chb_mit_window_30", split="train", method="kfold", folds=5, k=-1, transform=Brain2Seq(
        n_times=256,
        n_outputs=1,
        loss_fn="ce",
        signal_transform="wavelet",
        node_transform="unipolar",
        edge_select="dynamic",
        threshold=0.5,
).transform)

fig, ax = plt.subplots(5, 1, figsize=(420, 70))
fig.tight_layout()

for i in range(5):
    g = to_networkx(ds[i+50], node_attrs=["x"], graph_attrs=["graph_size", "graph_length", "node_positions"], to_undirected=True)

    x = nx.get_node_attributes(g, "x")
    size = g.graph["graph_size"]
    length = g.graph["graph_length"]
    pos = {key: g.graph["node_positions"][key % size][:2] + np.array([int(key / size) * 0.2, int(key / size) * 0.00]) for key, _ in x.items()}

    axi = ax[i]

    if ds[i].y.item() == 1:
        axi.set_facecolor("tomato")

    nx.draw_networkx(
        g,
        ax=axi,
        pos=pos,
        node_size=700,
        node_color=[np.linalg.norm(x[key]) for key, _ in pos.items()]
    )
    nx.draw_networkx_labels(
        g,
        ax=axi,
        pos={key: [val[0]-0.00, val[1]+0.007] for key, val in pos.items()},
        labels={key: f"{np.linalg.norm(x[key]):0.2f}" for key, _ in pos.items()}
    )

fig.show()