In [None]:
%reload_ext autoreload
%autoreload 3
import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"


import xarray as xr
import numpy as np
from pathlib import Path
from tqdm.auto import tqdm
import torch
from sklearn.preprocessing import LabelEncoder

from torch import nn
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger


In [None]:
subjects = xr.open_dataset('data/julia2018/timeseries_dosenbach2010.nc5')['subject'].values
subject_encoder = LabelEncoder().fit(subjects)
class_encoder = LabelEncoder().fit([subj[:4] for subj in subjects])

In [None]:
# load time-series and concatenate them along the regions dimension

def get_atlas_name(file_name):
    name = file_name.replace('-','_').replace('_2mm','')
    name = name.replace('difumo_','difumo')
    return name.split('_', 1)[1]

ts_files = {
    get_atlas_name(f.stem): f
    for f in Path('data/julia2018/').glob('*.nc5')
}

X_region = []
y_subject = []
y_class = []

s = 0


for atlas_name, ts_file in tqdm(ts_files.items()):

    ds = xr.open_dataset(ts_file).map(lambda x: x.values)

    # time-series
    ts = torch.tensor(ds['timeseries'].values).to(torch.float32)
    ts = ts.permute(0, 2, 1)  # (n_subjects, n_regions, n_timepoints)
    X_region.append(ts)

X_regions = torch.cat(X_region, dim=1)

# subjects
subjects = subject_encoder.transform(ds['subject'].values)
y_subject = torch.tensor(subjects)  # (n_subjects, n_regions)

# classes
classes = class_encoder.transform(ds['subject'].values)
y_class = torch.tensor(classes)  # (n_subjects, n_regions)

In [None]:


class ACNets(pl.LightningModule):
    def __init__(self, input_size, feature_size, hidden_size, n_timesteps, kernel_size=3):
        super().__init__()
        self.example_input_array = torch.Tensor(32, 948, 125)

        self.feature_extractor = nn.Conv1d(input_size, feature_size, kernel_size)

        self.fc = nn.Linear(hidden_size, 2)

        self.encoder = nn.LSTM(
            input_size=feature_size, hidden_size=hidden_size,
            num_layers=n_timesteps - kernel_size + 1, batch_first=True)

        # decoder
        self.decoder = nn.LSTM(
            input_size=hidden_size, hidden_size=feature_size,
            num_layers=n_timesteps - kernel_size + 1, batch_first=True)

        self.loss_recon = nn.MSELoss()
        self.loss_class = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.feature_extractor(x)
        # out = out.squeeze(1)

        x_enc = x.permute(0, 2, 1)  # permute to batch, time, feature
        
        y_enc, (h_enc, _) = self.encoder(x_enc)
        x_recon, (h_dec, _) = self.decoder(h_enc.permute(1, 0, 2))  # permute to batch, time, feature

        # classification output
        y_cls = self.fc(h_enc[-1, :, :])  # select last state

        return x_enc, x_recon, y_cls

    def training_step(self, batch, batch_idx):
        x,y  = batch
        x_enc, x_recon, y_hat = self(x)
        loss_recon = self.loss_recon(x_recon, x_enc)
        loss_cls = self.loss_class(y_hat, y)
        loss = loss_recon + loss_cls
        accuracy = (y_hat.argmax(dim=1) == y).float().mean()
        self.log('train/loss_recon', loss_recon)
        self.log('train/loss_cls', loss_cls)
        self.log('train/loss', loss)
        self.log('train/accuracy', accuracy)
        return loss

    def validation_step(self, batch, batch_idx):
        x,y  = batch
        x_enc, x_recon, y_hat = self(x)
        loss_recon = self.loss_recon(x_recon, x_enc)
        loss_cls = self.loss_class(y_hat, y)
        loss = loss_recon + loss_cls
        accuracy = (y_hat.argmax(dim=1) == y).float().mean()
        self.log('val/loss_recon', loss_recon)
        self.log('val/loss_cls', loss_cls)
        self.log('val/loss', loss)
        self.log('val/accuracy', accuracy)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

class Julia2018DataModule(pl.LightningDataModule):
    def __init__(self, X, y):
        super().__init__()
        self.X = X
        self.y = y

    def prepare_data(self) -> None:
        self.data = torch.utils.data.TensorDataset(self.X, self.y)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.data, batch_size=32)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.data, batch_size=32)


datamodule = Julia2018DataModule(X_regions, y_class)
model = ACNets(X_regions.shape[1], 64, hidden_size=48, n_timesteps=X_regions.shape[2])
trainer = pl.Trainer(max_epochs=10000, accelerator='cpu', log_every_n_steps=1, logger=True)
trainer.fit(model, datamodule=datamodule)
X_regions.shape