In [59]:
import dgl
import numpy as np
import networkx as nx

import torch
from torch import nn

In [60]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Creating a dgl graph from connectivity matrix

In [61]:
TEST_PATH = "/Users/h1de0us/uni/mer-eeg-analysis/data/deap_filtered/s01_plv.npy"
connectivity_matrix = np.load(TEST_PATH)
connectivity_matrix.shape

(32, 32, 5)

In [62]:
connectivity_matrix = connectivity_matrix[:, :, -1]
connectivity_matrix

array([[0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.57803585, 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.33041863, 0.56471143, 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.28629317, 0.28253806, 0.27283744, ..., 0.        , 0.        ,
        0.        ],
       [0.34106162, 0.56037416, 0.60070708, ..., 0.27732514, 0.        ,
        0.        ],
       [0.37746005, 0.45468193, 0.43395707, ..., 0.25628626, 0.71649264,
        0.        ]])

In [63]:
threshold = 0.3
connectivity_matrix[connectivity_matrix < threshold] = 0 # remove weak connections
connectivity_matrix += np.rot90(np.fliplr(connectivity_matrix)) # make the matrix symmetric

In [64]:
nx_graph = nx.from_numpy_array(connectivity_matrix)
nx_graph = nx_graph.to_directed()


In [65]:
nx_graph.number_of_edges()

426

In [66]:
dgl_graph = dgl.from_networkx(nx_graph, edge_attrs=['weight'])
dgl_graph

Graph(num_nodes=32, num_edges=426,
      ndata_schemes={}
      edata_schemes={'weight': Scheme(shape=(), dtype=torch.float32)})

### Centrality encoding

In [67]:
n_nodes = dgl_graph.number_of_nodes()
dim_feedforward = 64

centrality_encoding = nn.Embedding(n_nodes, dim_feedforward)

In [68]:
dgl_graph.in_degrees().shape

torch.Size([32])

In [69]:
centrality = centrality_encoding(dgl_graph.in_degrees())
centrality.shape # (n_nodes, dim_feedforward)

torch.Size([32, 64])

### Spatial encoding

In [70]:
n_heads = 4
spatial_encoding = nn.Embedding(n_nodes, n_heads)

In [71]:
spd = dgl.shortest_dist(dgl_graph)
spd.shape

torch.Size([32, 32])

In [72]:
spatial = spatial_encoding(spd)
spatial.shape # (n_nodes, n_nodes, n_heads)

torch.Size([32, 32, 4])

### Edge encoding

In [73]:
spd, paths = dgl.shortest_dist(dgl_graph, return_paths=True)
paths[0, 2] # Each path is a vector that consists of edge IDs with paddings of -1 at the end. (via documentation)

tensor([ 1, -1, -1])

In [74]:
path = paths[0, 2]
path = path[path >= 0]
path

tensor([1])

In [75]:
edge_encoder = nn.Embedding(n_nodes ** 2, n_heads)
edge_features = dgl_graph.edata['weight']

In [76]:
i, j = 12, 24

_, path = dgl.shortest_dist(dgl_graph, i, return_paths=True)
# path is a sequence of nodes, len(path) == max_path 
# -1 is a padding value
path = path[j]
path = path[path >= 0] # remove padding
edge_embeds = edge_encoder(path) # (n_spd, n_heads)
spd_features = edge_features[path] # (n_spd)
result = torch.mean(edge_embeds * spd_features.unsqueeze(-1), dim=0)
result, result.shape

(tensor([ 0.4382,  0.2501, -0.6002,  0.1003], grad_fn=<MeanBackward1>),
 torch.Size([4]))

### Parsing the data

In [77]:
import torch
import os
import pickle
import numpy as np
from tqdm import tqdm

class EEGDataset(torch.utils.data.Dataset):
    def __init__(self, 
                 data_path: str = "../data/deap_filtered",
                 duration : float = 3.0,
                 method : str = "plv",
                 participants_range : tuple = (0, 32),
                 n_trials=40):
        self.n_trials = n_trials

        participants = [file[:-4] for file in os.listdir(data_path) if file.endswith('.dat')][participants_range[0]:participants_range[1]]
        self.paths = []
        self.labels = []

        for participant in tqdm(participants):
            # labels
            labels_for_participant = pickle.load(open(f"{data_path}/{participant}.dat", "rb"), encoding="latin1")["labels"]
            self.labels.append(labels_for_participant)

            # data
            duration_str = str(duration)
            prefix = f"{data_path}/{participant}_{method}_{duration_str}_trial_" # inside the collate_fn we add postfixes for all trials
            self.paths.append(prefix)


    def __getitem__(self, idx):
        # each dataset item is one participant
        # for each participant, we have a list of 40 trials 
        # each trial is a tuple of (connectivity_matrix, label)
        # connectivity_matrix is a 3D numpy array of shape (n_channels, n_channels, n_bands)
        # label is a scalar
        return {
            "path_prefix": self.paths[idx],
            "labels": self.labels[idx],
            "n_trials": self.n_trials
        }


    def __len__(self):
        return len(self.paths)

In [78]:
dataset = EEGDataset()

100%|██████████| 32/32 [00:04<00:00,  6.74it/s]


In [142]:
def collate_fn(participant): # participant is actually the whole batch, as there are 40 trials for each participant
    matrices = []
    labels = []

    if isinstance(participant, list):
        participant = participant[0]
    for trial in range(participant["n_trials"]):
        path = participant["path_prefix"] + str(trial) + ".npy"
        connectivity_matrix = torch.from_numpy(np.load(path)) # matrices for all bands
        matrices.extend(torch.split(connectivity_matrix, 1, dim=-1)) # split along the last axis
        labels.extend([participant["labels"][trial] for b in range(connectivity_matrix.shape[-1])]) # replicate the label for each band

    return {
        "matrices": torch.stack(matrices, dim=0).squeeze(-1),
        "labels": torch.as_tensor(labels) # n_trials * b_bands,  (valence, arousal, dominance, liking)
    }

In [143]:
collate_fn(dataset[0])["labels"].shape #
# 160 = n_trials * n_bands
# 32 = n_channels
# shape = n_trials * n_bands, 4

torch.Size([160, 4])

In [81]:
# create datasets
train_dataset = EEGDataset(participants_range=(1, 32))
val_dataset = EEGDataset(participants_range=(33, 40))

# create dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)


100%|██████████| 31/31 [00:03<00:00,  8.92it/s]
0it [00:00, ?it/s]


In [83]:
# train loop

import torch.optim as optim
from torch.nn import MSELoss

optimizer = optim.Adam(model.parameters(), lr=3e-4)
criterion = MSELoss()

In [88]:
def process_batch(model, 
                  batch, 
                  optimizer,
                  criterion,
                  device,
                  is_train=True):
    """
    :param model: torch.nn.Module, current model
    :param batch: dict, a batch that is being processed
    """
    for tensor_for_gpu in batch.keys():
        batch[tensor_for_gpu] = batch[tensor_for_gpu].to(device)
    predictions, attention = model(batch)
    if type(predictions) is dict:
        batch.update(predictions)
    else:
        batch["predictions"] = predictions
        batch["attention"] = attention

    optimizer.zero_grad()
    loss = criterion(batch["predictions"], batch["labels"])
    if is_train:
        loss.backward()
        optimizer.step()
    batch["loss"] = loss

    return batch



def train_epoch(epoch, model, optimizer, criterion, train_loader, device, log_step=5):
    """
    Training logic for an epoch

    :param epoch: Integer, current training epoch.
    :return: A log that contains average loss and metric in this epoch.
    """
    model.train()
    # DEBUG
    # torch.autograd.set_detect_anomaly(True)
    for batch_idx, batch in enumerate(
            tqdm(train_loader, desc="train")):
        try:
            batch = process_batch(model,
                                batch,
                                optimizer,
                                criterion,
                                device
                                )
        except RuntimeError as e:
            if "out of memory" in str(e):
                print("OOM on batch. Skipping batch.")
                for p in model.parameters():
                    if p.grad is not None:
                        del p.grad  # free some memory
                torch.cuda.empty_cache()
                continue
            else:
                raise e
        print( # plot loss after every participant
            "Train Epoch: {} {} Loss: {:.6f}".format(
            epoch, batch_idx, batch["loss"].item()
        ))

def evaluate_epoch(model, criterion, val_loader, device):
    """
    Evaluate after training an epoch

    :return: A log that contains information about the evaluation
    """
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch_idx, batch in enumerate(
                tqdm(val_loader, desc="val")):
            batch = process_batch(model,
                                batch,
                                optimizer,
                                criterion,
                                device,
                                is_train=False
                                )
            val_loss += batch["loss"].item()
    val_loss /= len(val_loader)
    print("Val set: Average loss: {:.4f}".format(val_loss))
    return {
        "val_loss": val_loss
    }

In [167]:
one_batch_dataset = EEGDataset(participants_range=(1, 2))
train_loader = torch.utils.data.DataLoader(one_batch_dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

from src.model import GraphormerModel

model = GraphormerModel(
    n_nodes = 32,
    n_layers = 1,
    n_heads = 4,
    embed_dim = 8,
    dim_feedforward=16,

)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
for batch in train_loader:
    predictions, attention = model(batch)
    print(predictions.shape)

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00,  1.83it/s]


torch.Size([160, 4])


In [137]:
# training the model

n_epochs = 10

from src.model import GraphormerModel

model = GraphormerModel(
    n_nodes = 32,
    n_layers = 1,
    n_heads = 4,
    embed_dim = 64,
    dim_feedforward=32,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

for epoch in range(n_epochs):
    train_epoch(epoch, model, optimizer, criterion, train_loader, device)
    evaluate_epoch(model, criterion, val_loader, device)

cpu


train:   0%|          | 0/31 [00:08<?, ?it/s]


KeyboardInterrupt: 