In [None]:
import mne
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx

from models import Brain2Vec
from data import TensorDataset

In [None]:
transform = Brain2Vec(
    n_times=256,
    n_outputs=2,
    loss_fn="ce",
    normalization="micro",
    cross_connections=1,
    signal_transform="raw",
    node_transform="unipolar",
    edge_select="dynamic_gt",
    threshold=0.5,
    aggregator="vector",
    gru_size=128,
).transform

ds_raw = TensorDataset(name="chb_mit_window_1", split="train", method="kfold", folds=5, k=-1, transform=lambda x: x)

raw = mne.io.RawArray(
    data=ds_raw[50]["data"],
    info=mne.create_info(sfreq=256, ch_names=ds_raw[50]["ch_names"])
)

raw.plot(show_scrollbars=False, show_scalebars=False)

In [None]:
ds = TensorDataset(name="chb_mit_window_1", split="train", method="kfold", folds=5, k=-1, transform=Brain2Vec(
    n_times=256,
    n_outputs=2,
    loss_fn="ce",
    normalization="micro",
    cross_connections=1,
    signal_transform="raw",
    node_transform="unipolar",
    edge_select="dynamic_gt",
    threshold=0.5,
    aggregator="vector",
    gru_size=128,
).transform)

fig, ax = plt.subplots(1, 1, figsize=(8, 8))
fig.tight_layout()

for i in range(1):
    g = to_networkx(ds[i+50], node_attrs=["x"], graph_attrs=["n_nodes", "n_graphs", "node_positions"], to_undirected=True)

    x = nx.get_node_attributes(g, "x")
    size = g.graph["n_nodes"]
    length = g.graph["n_graphs"]
    pos = {key: g.graph["node_positions"][key % size][:2] + np.array([int(key / size) * 0.2, int(key / size) * 0.00]) for key, _ in x.items()}

    axi = ax

    # if ds[i].y.item() == 1:
    #     axi.set_facecolor("tomato")

    nx.draw_networkx(
        g,
        ax=axi,
        pos=pos,
        arrows=True,
        node_size=700,
        node_color=[np.linalg.norm(x[key]) for key, _ in pos.items()],
        # connectionstyle="arc3,rad=0.1",
    )
    # nx.draw_networkx_labels(
    #     g,
    #     ax=axi,
    #     pos={key: [val[0]-0.00, val[1]+0.007] for key, val in pos.items()},
    #     labels={key: f"{np.linalg.norm(x[key]):0.2f}" for key, _ in pos.items()}
    # )

fig.show()

In [None]:
ds = TensorDataset(name="chb_mit_window_1", split="train", method="kfold", folds=5, k=-1, transform=Brain2Vec(
    n_times=128,
    n_outputs=2,
    loss_fn="ce",
    normalization="micro",
    cross_connections=1,
    signal_transform="raw",
    node_transform="unipolar",
    edge_select="dynamic_gt",
    threshold=0.5,
    aggregator="vector",
    gru_size=128,
).transform)

fig, ax = plt.subplots(1, 1, figsize=(16, 8))
fig.tight_layout()

for i in range(1):
    g = to_networkx(ds[i+50], node_attrs=["x"], graph_attrs=["n_nodes", "n_graphs", "node_positions"], to_undirected=True)

    x = nx.get_node_attributes(g, "x")
    size = g.graph["n_nodes"]
    length = g.graph["n_graphs"]
    pos = {key: g.graph["node_positions"][key % size][:2] + np.array([int(key / size) * 0.2, int(key / size) * 0.00]) for key, _ in x.items()}

    axi = ax

    # if ds[i].y.item() == 1:
    #     axi.set_facecolor("tomato")

    nx.draw_networkx(
        g,
        ax=axi,
        pos=pos,
        arrows=True,
        node_size=700,
        node_color=[np.linalg.norm(x[key]) for key, _ in pos.items()],
        connectionstyle="arc3,rad=0.1",
    )
    # nx.draw_networkx_labels(
    #     g,
    #     ax=axi,
    #     pos={key: [val[0]-0.00, val[1]+0.007] for key, val in pos.items()},
    #     labels={key: f"{np.linalg.norm(x[key]):0.2f}" for key, _ in pos.items()}
    # )

fig.show()

In [None]:
ds = TensorDataset(name="chb_mit_window_1", split="train", method="kfold", folds=5, k=-1, transform=Brain2Vec(
    n_times=64,
    n_outputs=2,
    loss_fn="ce",
    normalization="micro",
    cross_connections=1,
    signal_transform="raw",
    node_transform="unipolar",
    edge_select="dynamic_gt",
    threshold=0.5,
    aggregator="vector",
    gru_size=128,
).transform)

fig, ax = plt.subplots(1, 1, figsize=(30, 8))
fig.tight_layout()

for i in range(1):
    g = to_networkx(ds[i+50], node_attrs=["x"], graph_attrs=["n_nodes", "n_graphs", "node_positions"], to_undirected=True)

    x = nx.get_node_attributes(g, "x")
    size = g.graph["n_nodes"]
    length = g.graph["n_graphs"]
    pos = {key: g.graph["node_positions"][key % size][:2] + np.array([int(key / size) * 0.2, int(key / size) * 0.00]) for key, _ in x.items()}

    axi = ax

    # if ds[i].y.item() == 1:
    #     axi.set_facecolor("tomato")

    nx.draw_networkx(
        g,
        ax=axi,
        pos=pos,
        arrows=True,
        node_size=700,
        node_color=[np.linalg.norm(x[key]) for key, _ in pos.items()],
        connectionstyle="arc3,rad=0.1",
    )
    # nx.draw_networkx_labels(
    #     g,
    #     ax=axi,
    #     pos={key: [val[0]-0.00, val[1]+0.007] for key, val in pos.items()},
    #     labels={key: f"{np.linalg.norm(x[key]):0.2f}" for key, _ in pos.items()}
    # )

fig.show()

In [None]:
ds = TensorDataset(name="chb_mit_window_1", split="train", method="kfold", folds=5, k=-1, transform=Brain2Vec(
    n_times=64,
    n_outputs=2,
    loss_fn="ce",
    normalization="micro",
    cross_connections=2,
    signal_transform="raw",
    node_transform="unipolar",
    edge_select="dynamic_gt",
    threshold=0.5,
    aggregator="vector",
    gru_size=128,
).transform)

fig, ax = plt.subplots(1, 1, figsize=(30, 8))
fig.tight_layout()

for i in range(1):
    g = to_networkx(ds[i+50], node_attrs=["x"], graph_attrs=["n_nodes", "n_graphs", "node_positions"], to_undirected=True)

    x = nx.get_node_attributes(g, "x")
    size = g.graph["n_nodes"]
    length = g.graph["n_graphs"]
    pos = {key: g.graph["node_positions"][key % size][:2] + np.array([int(key / size) * 0.2, int(key / size) * 0.00]) for key, _ in x.items()}

    axi = ax

    # if ds[i].y.item() == 1:
    #     axi.set_facecolor("tomato")

    nx.draw_networkx(
        g,
        ax=axi,
        pos=pos,
        arrows=True,
        node_size=700,
        node_color=[np.linalg.norm(x[key]) for key, _ in pos.items()],
        connectionstyle="arc3,rad=0.1",
    )
    # nx.draw_networkx_labels(
    #     g,
    #     ax=axi,
    #     pos={key: [val[0]-0.00, val[1]+0.007] for key, val in pos.items()},
    #     labels={key: f"{np.linalg.norm(x[key]):0.2f}" for key, _ in pos.items()}
    # )

fig.show()

In [None]:
ds = TensorDataset(name="chb_mit_window_1", split="train", method="kfold", folds=5, k=-1, transform=Brain2Vec(
    n_times=64,
    n_outputs=2,
    loss_fn="ce",
    normalization="micro",
    cross_connections=3,
    signal_transform="raw",
    node_transform="unipolar",
    edge_select="dynamic_gt",
    threshold=0.5,
    aggregator="vector",
    gru_size=128,
).transform)

fig, ax = plt.subplots(1, 1, figsize=(30, 8))
fig.tight_layout()

for i in range(1):
    g = to_networkx(ds[i+50], node_attrs=["x"], graph_attrs=["n_nodes", "n_graphs", "node_positions"], to_undirected=True)

    x = nx.get_node_attributes(g, "x")
    size = g.graph["n_nodes"]
    length = g.graph["n_graphs"]
    pos = {key: g.graph["node_positions"][key % size][:2] + np.array([int(key / size) * 0.2, int(key / size) * 0.00]) for key, _ in x.items()}

    axi = ax

    # if ds[i].y.item() == 1:
    #     axi.set_facecolor("tomato")

    nx.draw_networkx(
        g,
        ax=axi,
        pos=pos,
        arrows=True,
        node_size=700,
        node_color=[np.linalg.norm(x[key]) for key, _ in pos.items()],
        connectionstyle="arc3,rad=0.1",
    )
    # nx.draw_networkx_labels(
    #     g,
    #     ax=axi,
    #     pos={key: [val[0]-0.00, val[1]+0.007] for key, val in pos.items()},
    #     labels={key: f"{np.linalg.norm(x[key]):0.2f}" for key, _ in pos.items()}
    # )

fig.show()

In [None]:
ds = TensorDataset(name="chb_mit_window_1", split="train", method="kfold", folds=5, k=-1, transform=Brain2Vec(
    n_times=32,
    n_outputs=2,
    loss_fn="ce",
    normalization="micro",
    cross_connections=1,
    signal_transform="raw",
    node_transform="unipolar",
    edge_select="dynamic_gt",
    threshold=0.5,
    aggregator="vector",
    gru_size=128,
).transform)

fig, ax = plt.subplots(1, 1, figsize=(60, 8))
fig.tight_layout()

for i in range(1):
    g = to_networkx(ds[i+50], node_attrs=["x"], graph_attrs=["n_nodes", "n_graphs", "node_positions"], to_undirected=True)

    x = nx.get_node_attributes(g, "x")
    size = g.graph["n_nodes"]
    length = g.graph["n_graphs"]
    pos = {key: g.graph["node_positions"][key % size][:2] + np.array([int(key / size) * 0.2, int(key / size) * 0.00]) for key, _ in x.items()}

    axi = ax

    # if ds[i].y.item() == 1:
    #     axi.set_facecolor("tomato")

    nx.draw_networkx(
        g,
        ax=axi,
        pos=pos,
        arrows=True,
        node_size=700,
        node_color=[np.linalg.norm(x[key]) for key, _ in pos.items()],
        connectionstyle="arc3,rad=0.1",
    )
    # nx.draw_networkx_labels(
    #     g,
    #     ax=axi,
    #     pos={key: [val[0]-0.00, val[1]+0.007] for key, val in pos.items()},
    #     labels={key: f"{np.linalg.norm(x[key]):0.2f}" for key, _ in pos.items()}
    # )

fig.show()