In [2]:
%%capture
!pip install mne
!pip install pytorch-lightning

In [5]:
from glob import glob
import scipy.io
import torch.nn as nn
import torch
import numpy as np
import mne

In [6]:
input=torch.randn(3,22,15000)
input.shape

torch.Size([3, 22, 15000])

In [7]:
class Block(nn.Module):
  def __init__(self,inplace):
    super().__init__()
    self.conv1=nn.Conv1d(in_channels=inplace,out_channels=32,kernel_size=2,stride=2,padding=0)
    self.conv2=nn.Conv1d(in_channels=inplace,out_channels=32,kernel_size=4,stride=2,padding=1)
    self.conv3=nn.Conv1d(in_channels=inplace,out_channels=32,kernel_size=8,stride=2,padding=3)
    self.relu=nn.ReLU()

  def forward(self,x):
    x1=self.relu(self.conv1(x))
    x2=self.relu(self.conv2(x))
    x3=self.relu(self.conv3(x))
    x=torch.cat([x1,x3,x3],dim=1)
    return x
     

class ChronoNet(nn.Module):
  def __init__(self,channel):
    super().__init__()
    self.block1=Block(channel)
    self.block2=Block(96)
    self.block3=Block(96)
    self.gru1=nn.GRU(input_size=96,hidden_size=32,batch_first=True)
    self.gru2=nn.GRU(input_size=32,hidden_size=32,batch_first=True)
    self.gru3=nn.GRU(input_size=64,hidden_size=32,batch_first=True)
    self.gru4=nn.GRU(input_size=96,hidden_size=32,batch_first=True)
    self.gru_linear=nn.Linear(64,1)
    self.flatten=nn.Flatten()
    self.fc1=nn.Linear(32,1)
    self.relu=nn.ReLU()
  def forward(self,x):
    x=self.block1(x)
    x=self.block2(x)
    x=self.block3(x)
    x=x.permute(0,2,1)
    gru_out1,_=self.gru1(x)
    gru_out2,_=self.gru2(gru_out1)
    gru_out=torch.cat([gru_out1,gru_out2],dim=2)
    gru_out3,_=self.gru3(gru_out)
    gru_out=torch.cat([gru_out1,gru_out2,gru_out3],dim=2)
    #print('gru_out',gru_out.shape)
    linear_out=self.relu(self.gru_linear(gru_out.permute(0,2,1)))
    gru_out4,_=self.gru4(linear_out.permute(0,2,1))
    x=self.flatten(gru_out4)
    x=self.fc1(x)
    return x

In [30]:
input=torch.randn(3,14,512)
input.shape
model=ChronoNet(14)
out=model(input)
out.shape

torch.Size([3, 1])

In [31]:
IDD_data_path='./data/Data/CleanData/CleanData_TDC/Rest'
TDC_data_path='./data/Data/CleanData/CleanData_IDD/Rest'

In [32]:
def convertmat2mne(data):
  ch_names = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4']
  ch_types = ['eeg'] * 14
  sampling_freq=128
  info = mne.create_info(ch_names, ch_types=ch_types, sfreq=sampling_freq)
  info.set_montage('standard_1020')
  data=mne.io.RawArray(data, info)
  data.set_eeg_reference()
  data.filter(l_freq=1,h_freq=30)
  epochs=mne.make_fixed_length_epochs(data,duration=4,overlap=0)
  return epochs.get_data()

In [33]:
%%capture
idd_subject=[]
for idd in glob(IDD_data_path+'/*.mat'):
  data=scipy.io.loadmat(idd)['clean_data']
  data=convertmat2mne(data)
  idd_subject.append(data)

In [34]:
%%capture
tdc_subject=[]
for tdc in glob(TDC_data_path+'/*.mat'):
  data=scipy.io.loadmat(tdc)['clean_data']
  data=convertmat2mne(data)
  tdc_subject.append(data)

In [35]:
len(idd_subject),len(tdc_subject)

(7, 7)

In [36]:
control_epochs_labels=[len(i)*[0] for i in tdc_subject]
patients_epochs_labels=[len(i)*[1] for i in idd_subject]
print(len(control_epochs_labels),len(patients_epochs_labels))

7 7


In [37]:
data_list=tdc_subject+idd_subject
label_list=control_epochs_labels+patients_epochs_labels
groups_list=[[i]*len(j) for i, j in enumerate(data_list)]
print(len(data_list),len(label_list),len(groups_list))

14 14 14


In [38]:
from sklearn.model_selection import GroupKFold,LeaveOneGroupOut
from sklearn.preprocessing import StandardScaler
gkf=GroupKFold()
from sklearn.base import TransformerMixin,BaseEstimator
from sklearn.preprocessing import StandardScaler
#https://stackoverflow.com/questions/50125844/how-to-standard-scale-a-3d-matrix
class StandardScaler3D(BaseEstimator,TransformerMixin):
    #batch, sequence, channels
    def __init__(self):
        self.scaler = StandardScaler()

    def fit(self,X,y=None):
        self.scaler.fit(X.reshape(-1, X.shape[2]))
        return self

    def transform(self,X):
        return self.scaler.transform(X.reshape( -1,X.shape[2])).reshape(X.shape)

In [39]:
import numpy as np
data_array=np.concatenate(data_list)
label_array=np.concatenate(label_list)
group_array=np.concatenate(groups_list)
data_array=np.moveaxis(data_array,1,2)

print(data_array.shape,label_array.shape,group_array.shape)

(420, 512, 14) (420,) (420,)


In [40]:
accuracy=[]
for train_index, val_index in gkf.split(data_array, label_array, groups=group_array):
    train_features,train_labels=data_array[train_index],label_array[train_index]
    val_features,val_labels=data_array[val_index],label_array[val_index]
    scaler=StandardScaler3D()
    train_features=scaler.fit_transform(train_features)
    val_features=scaler.transform(val_features)
    train_features=np.moveaxis(train_features,1,2)
    val_features=np.moveaxis(val_features,1,2)

    break

In [41]:
train_features = torch.Tensor(train_features)
val_features = torch.Tensor(val_features)
train_labels = torch.Tensor(train_labels)
val_labels = torch.Tensor(val_labels)

In [42]:
len(val_features),len(val_labels)

(90, 90)

In [43]:
train_features.shape

torch.Size([330, 14, 512])

In [44]:
from pytorch_lightning import LightningModule,Trainer
import torchmetrics
from torch.utils.data import TensorDataset,DataLoader

In [57]:
from pytorch_lightning import LightningModule, Trainer
import torchmetrics
from torch.utils.data import TensorDataset, DataLoader

class ChronoModel(LightningModule):
    def __init__(self):
        super(ChronoModel, self).__init__()
        self.model = ChronoNet(14)
        self.lr = 1e-3
        self.bs = 12
        self.worker = 2
        self.acc = torchmetrics.Accuracy(task='binary')
        self.criterion = nn.BCEWithLogitsLoss()
        self.training_accuracies = []  # To store accuracies for training
        self.validation_accuracies = []  # To store accuracies for validation

    def forward(self, x):
        x = self.model(x)
        return x

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

    def train_dataloader(self):
        dataset = TensorDataset(train_features, train_labels)
        dataloader = DataLoader(dataset, batch_size=self.bs, num_workers=self.worker, shuffle=True)
        return dataloader

    def training_step(self, batch, batch_idx):
        signal, label = batch
        out = self(signal.float())
        loss = self.criterion(out.flatten(), label.float().flatten())
        acc = self.acc(out.flatten(), label.long().flatten())
        self.training_accuracies.append(acc)  # Store training accuracy
        return {'loss': loss, 'acc': acc}

    def on_train_epoch_end(self):
        # Calculate and log the average accuracy and loss for the epoch
        avg_acc = torch.stack(self.training_accuracies).mean().detach().cpu().numpy().round(2)
        print(f'Train Accuracy: {avg_acc}')
        self.training_accuracies = []  # Reset for next epoch

    def val_dataloader(self):
        dataset = TensorDataset(val_features, val_labels)
        dataloader = DataLoader(dataset, batch_size=self.bs, num_workers=self.worker, shuffle=False)
        return dataloader

    def validation_step(self, batch, batch_idx):
        signal, label = batch
        out = self(signal.float())
        loss = self.criterion(out.flatten(), label.float().flatten())
        acc = self.acc(out.flatten(), label.long().flatten())
        self.validation_accuracies.append(acc)  # Store validation accuracy
        return {'loss': loss, 'acc': acc}

    def on_validation_epoch_end(self):
        # Calculate and log the average accuracy and loss for the epoch
        avg_acc = torch.stack(self.validation_accuracies).mean().detach().cpu().numpy().round(2)
        print(f'Validation Accuracy: {avg_acc}')
        self.validation_accuracies = []  # Reset for next epoch


In [58]:
model=ChronoModel()

In [59]:
trainer=Trainer(max_epochs=1)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [60]:
trainer.fit(model)



  | Name      | Type              | Params | Mode 
--------------------------------------------------------
0 | model     | ChronoNet         | 133 K  | train
1 | acc       | BinaryAccuracy    | 0      | train
2 | criterion | BCEWithLogitsLoss | 0      | train
--------------------------------------------------------
133 K     Trainable params
0         Non-trainable params
133 K     Total params
0.534     Total estimated model params size (MB)
26        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                               | 0/? [00:00<…

C:\Users\iemma\anaconda3\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


Validation Accuracy: 1.0


C:\Users\iemma\anaconda3\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.
C:\Users\iemma\anaconda3\Lib\site-packages\pytorch_lightning\loops\fit_loop.py:298: The number of training batches (28) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                                                                      | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

`Trainer.fit` stopped: `max_epochs=1` reached.


Validation Accuracy: 0.3100000023841858
Train Accuracy: 0.5400000214576721
