In [7]:
%reload_ext autoreload
%autoreload 3
import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


import xarray as xr
import numpy as np
from pathlib import Path
from tqdm.auto import tqdm
import torch
import torch.nn.functional as F

from sklearn.preprocessing import LabelEncoder

from torch import nn
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
from sklearn.model_selection import train_test_split
from dvclive.lightning import DVCLiveLogger



# running via nbconvert
if 'notebooks' in os.getcwd():
    %cd ..

In [8]:
subjects = xr.open_dataset('data/julia2018/timeseries_dosenbach2010.nc5', engine='h5netcdf')['subject'].values
subject_encoder = LabelEncoder().fit(subjects)
class_encoder = LabelEncoder().fit([subj[:4] for subj in subjects])

In [9]:
# load time-series and concatenate them along the regions dimension

def get_atlas_name(file_name):
    name = file_name.replace('-','_').replace('_2mm','')
    name = name.replace('difumo_','difumo')
    return name.split('_', 1)[1]

ts_files = {
    get_atlas_name(f.stem): f
    for f in Path('data/julia2018/').glob('*.nc5')
}

X_region = []
y_subject = []
y_class = []

s = 0


for atlas_name, ts_file in tqdm(ts_files.items()):

    ds = xr.open_dataset(ts_file, engine='h5netcdf').map(lambda x: x.values)

    # time-series
    ts = torch.tensor(ds['timeseries'].values).float()  # (n_subjects, n_timepoints, n_regions)
    X_region.append(ts)

X_region = torch.cat(X_region, dim=-1)

# subjects
subjects = subject_encoder.transform(ds['subject'].values)
y_subject = torch.tensor(subjects)  # (n_subjects, n_regions)

# classes
classes = class_encoder.transform(ds['subject'].values)
y_cls = torch.tensor(classes)  # (n_subjects, n_regions)

X_region = F.normalize(X_region, dim=1)

  0%|          | 0/6 [00:00<?, ?it/s]

In [10]:


class ACNets(pl.LightningModule):
    def __init__(self, n_features, hidden_size, n_subjects):
        super().__init__()
        # self.example_input_array = torch.Tensor(32, 124, 32)
        self.hidden_size = hidden_size
        self.n_subjects = n_subjects
        # self.feature_extractor = nn.Conv1d(n_inputs, n_features, kernel_size)

        self.encoder = nn.RNN(n_features, hidden_size, batch_first=True)
        self.decoder = nn.RNN(hidden_size, n_features, batch_first=True)

        self.fc_decoder = nn.Linear(n_features, n_features)

        # classification head
        self.fc_cls = nn.Linear(hidden_size, 2)
        self.fc_subj_cls = nn.Linear(hidden_size, n_subjects)

    def forward(self, x):
        batch_size = x.size(0)
        n_timepoints = x.size(1)

        # x = self.feature_extractor(x)
        
        y_enc, h_enc = self.encoder(x)
        x_enc = torch.rand(batch_size, n_timepoints, self.hidden_size, device=h_enc.device)

        y_dec, h_dec = self.decoder(x_enc, h_enc)
        x_recon = self.fc_decoder(y_dec)
        
        # classifications
        y_cls = self.fc_cls(h_enc[-1, :, :])  # last hidden state of encoder
        y_subj = self.fc_subj_cls(h_enc[-1, :, :])  # last hidden state of encoder

        return x_recon, y_cls, y_subj

    def training_step(self, batch, batch_idx):
        x, y_cls, y_subj  = batch
        x_recon, y_cls_hat, y_subj_hat = self(x)

        loss_recon = F.mse_loss(x_recon, x)
        loss_cls = F.cross_entropy(y_cls_hat, y_cls)
        loss = loss_recon + loss_cls
        accuracy = (y_cls_hat.argmax(dim=1) == y_cls).float().mean()

        self.log('train/loss_recon', loss_recon)
        self.log('train/loss_cls', loss_cls)
        self.log('train/loss', loss)
        self.log('train/accuracy', accuracy)
        return loss

    def validation_step(self, batch, batch_idx):
        X, y_cls, y_subj  = batch
        X_recon, y_cls_hat, y_subj_hat = self(X)

        loss_recon = F.mse_loss(X_recon, X)
        loss_cls = F.cross_entropy(y_cls_hat, y_cls)
        loss = loss_recon + loss_cls
        accuracy = (y_cls_hat.argmax(dim=1) == y_cls).float().mean()

        self.log('val/loss_recon', loss_recon)
        self.log('val/loss_cls', loss_cls)
        self.log('val/loss', loss, prog_bar=True)
        self.log('val/accuracy', accuracy, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-2)

class Julia2018DataModule(pl.LightningDataModule):
    def __init__(self, X, y_cls, y_subject, train_ratio=0.75, batch_size=8, shuffle=False):
        super().__init__()
        self.X = X
        self.y_cls = y_cls
        self.y_subject = y_subject
        self.train_ratio = train_ratio
        self.batch_size = batch_size
        self.shuffle = shuffle

    def prepare_data(self):
        trn_idx, val_idx = train_test_split(
            torch.arange(0, y_subject.max() + 1),
            train_size=self.train_ratio,
            stratify=self.y_cls)

        self.trn_data = torch.utils.data.TensorDataset(self.X[trn_idx], self.y_cls[trn_idx], self.y_subject[trn_idx])
        self.val_data = torch.utils.data.TensorDataset(self.X[val_idx], self.y_cls[val_idx], self.y_subject[val_idx])

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.trn_data, batch_size=self.batch_size, shuffle=self.shuffle)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_data, batch_size=self.batch_size, shuffle=self.shuffle)


datamodule = Julia2018DataModule(X_region, y_cls, y_subject, batch_size=8, shuffle=False)

n_subjects, n_timepoints, n_features = X_region.shape
hidden_size = n_features

model = ACNets(n_features, hidden_size, n_subjects)

# model = torch.compile(model)
trainer = pl.Trainer(max_epochs=10, accelerator='auto', log_every_n_steps=1,
                     logger=DVCLiveLogger(save_dvc_exp=True, report='md'),
                     num_sanity_val_steps=0)
trainer.fit(model, datamodule=datamodule)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type   | Params
---------------------------------------
0 | encoder     | RNN    | 1.8 M 
1 | decoder     | RNN    | 1.8 M 
2 | fc_decoder  | Linear | 899 K 
3 | fc_cls      | Linear | 1.9 K 
4 | fc_subj_cls | Linear | 30.4 K
---------------------------------------
4.5 M     Trainable params
0         Non-trainable params
4.5 M     Total params
18.122    Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


Updating lock file 'dvc.lock'


INFO:dvc.dvcfile:Updating lock file 'dvc.lock'
	.vscode/settings.json
	.vscode/settings.json
INFO:dvclive:To run with DVC, add this to /home/morteza/workspace/ACNets-MultiHead/dvc.yaml:
stages:
  dvclive:
    cmd: <python my_code_file.py my_args>
    deps:
    - <my_code_file.py>

