In [4]:
import pickle
import os
import numpy as np
import gc

#import mne
#mne.set_log_level(verbose='WARNING')

import torch
from sklearn.metrics import mean_absolute_error, mean_squared_error,r2_score,explained_variance_score
import torch.nn as nn
#import pandas as pd
import warnings
#warnings.simplefilter(action='ignore', category=FutureWarning)

In [9]:
class LearnGraph(nn.Module):
    
    def __init__(self,):
        super(LearnGraph, self).__init__()
        ##Adj Matrix to learn
        self.adj_matrix = torch.nn.Parameter(torch.ones(76,76))
        ##You can also initialize using a pre-built graph, with knowledge domain or a spatial graph
        #self.adj_matrix = torch.nn.Parameter(torch.load('graph_adj_init.pt'))   
        
    def forward(self,x,channels_to_keep,channels_to_remove):
        ## Solve the reconstruction problem using the closed form
        adj = self.adj_matrix
        D = torch.diag(adj.sum(dim=0))
        L = D - adj
        L_e = torch.stack([L[x,channels_to_remove] for x in channels_to_remove]).float()
        L_e_inv = torch.linalg.inv(L_e)
        L_eeb = torch.stack([L[x,channels_to_keep] for x in channels_to_remove]).float()
        prod = torch.matmul(L_e_inv,L_eeb)
        pred_se = -torch.einsum('ef,bft->bet',prod,x)
        return pred_se

In [6]:
dset_name_path = 'Schirrmeister'
path = '/users/local/eeg_datasets/'

# Load the dataset you want to learn from
arr = torch.from_numpy(np.load(os.path.join(path,dset_name_path+'/'+dset_name_path+'.npy'))).float()

# You can try to see if the code works with a random array
# arr = torch.rand((6000,70,1000))

In [22]:
print('NxChannelsxTimesteps')
print(arr.shape) 

NxChannelsxTimesteps
torch.Size([13484, 76, 2001])


In [18]:
lr = 0.1
n_epoch = 2000
n_iter = 10
n_samples = 5000 # number of sample to use to solve reconstruction problem in parallel (related to how many RAM you have)
device = 'cuda' if torch.cuda.is_available() else cpu
model = LearnGraph().to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=lr)  

In [23]:
print("R²: ")
for j in range(n_epoch):
    list_r2 = [] 
    for i in range(n_iter):
        idx = sorted(list(np.random.choice(np.arange(arr.shape[0]), n_samples, replace=False))) # randomly select samples 
        K = 38 
        channels_to_remove = sorted(list(np.random.choice(np.arange(76), K, replace=False))) # randomly select K electrodes to remove 
        channels_to_keep = [x for x in np.arange(76) if x not in channels_to_remove]
        seb = arr[idx][:,channels_to_keep]  # observed part of the signal
        se = arr[idx][:,channels_to_remove] # missing part of the signal
        data, target = seb.to(device), se.to(device)
        optimizer.zero_grad()
        interpo = model(data,channels_to_keep,channels_to_remove) # interpolate using the closed form
        true_channel = target
        residu = (true_channel - interpo)**2 
        r2 = 1 - (residu.sum())/(((true_channel - true_channel.mean())**2).sum())
        loss = 1 - r2
        loss.backward()
        optimizer.step()
    print("{:3d}:{:.3f}".format(j,1-loss.detach().to('cpu').numpy()),end = ' ')
        

R²: 
  0:0.569   1:0.676   2:0.716   3:0.775   4:0.775   5:0.817   6:0.802   7:0.822   8:0.809   9:0.796  10:0.830  11:0.774  12:0.786  13:0.764  14:0.828  15:0.830  16:0.837  17:0.857  18:0.858  19:0.836  20:0.864  21:0.854  22:0.867  23:0.864  24:0.870  25:0.863  26:0.855  27:0.879  28:0.883  29:0.886  30:0.846  31:0.795  32:0.808  33:0.826  34:0.836  35:0.845  36:0.879  37:0.881  38:0.850  39:0.873  40:0.852  41:0.878 

KeyboardInterrupt: 

In [27]:
dset_name_path = 'Schirrmeister'

with open(os.path.join(path,dset_name_path+'/Schirrmeister_channels.pkl'), "rb") as input_file:
    channels = pickle.load(input_file)    
print(channels)

['AF3', 'AF4', 'AF7', 'AF8', 'AFF5h', 'AFF6h', 'AFz', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'CCP3h', 'CCP4h', 'CCP5h', 'CCP6h', 'CP1', 'CP2', 'CP3', 'CP4', 'CP5', 'CP6', 'CPz', 'Cz', 'F1', 'F2', 'F3', 'F4', 'F5', 'F6', 'F7', 'F8', 'FC1', 'FC2', 'FC3', 'FC4', 'FC5', 'FC6', 'FCC3h', 'FCC4h', 'FCC5h', 'FCC6h', 'FCz', 'FT7', 'FT8', 'Fp1', 'Fp2', 'Fpz', 'Fz', 'Iz', 'O1', 'O2', 'Oz', 'P1', 'P10', 'P2', 'P3', 'P4', 'P5', 'P6', 'P7', 'P8', 'P9', 'PO3', 'PO4', 'PO5', 'PO6', 'PO7', 'PO8', 'POz', 'Pz', 'T7', 'T8', 'TP7', 'TP8']
