In [27]:
import xarray as xr
import numpy as np
from pathlib import Path
from tqdm.auto import tqdm
import torch
from sklearn.preprocessing import LabelEncoder

In [82]:
subjects = xr.open_dataset('data/julia2018/timeseries_dosenbach2010.nc5')['subject'].values
subject_encoder = LabelEncoder().fit(subjects)
class_encoder = LabelEncoder().fit([subj[:4] for subj in subjects])

In [137]:
# load time-series and concatenate them along the regions dimension

def get_atlas_name(file_name):
    name = file_name.replace('-','_').replace('_2mm','')
    name = name.replace('difumo_','difumo')
    return name.split('_', 1)[1]

ts_files = {
    get_atlas_name(f.stem): f
    for f in Path('data/julia2018/').glob('*.nc5')
}

X_region = []
y_subject = []
y_class = []

s = 0


for atlas_name, ts_file in tqdm(ts_files.items()):

    ds = xr.open_dataset(ts_file).map(lambda x: x.values)

    # time-series
    ts = torch.tensor(ds['timeseries'].values).to(torch.float32)
    ts = ts.permute(0, 2, 1)  # (n_subjects, n_regions, n_timepoints)
    X_region.append(ts)

    # subjects
    subjects = subject_encoder.transform(ds['subject'].values)
    subjects = torch.tensor(subjects).reshape(-1, 1).repeat(1, ts.shape[1])  # (n_subjects, n_regions)
    y_subject.append(subjects)

    # classes
    classes = class_encoder.transform(ds['subject'].values)
    classes = torch.tensor(classes).reshape(-1, 1).repeat(1, ts.shape[1])  # (n_subjects, n_regions)
    y_class.append(classes)


X_regions = torch.cat(X_region, dim=1)
y_subject = torch.cat(y_subject, dim=1)
y_class = torch.cat(y_class, dim=1)

  0%|          | 0/6 [00:00<?, ?it/s]

In [135]:
X_regions.shape, y_subject.shape, y_class.shape

(torch.Size([32, 948, 124]), torch.Size([192]), torch.Size([192]))

In [134]:
from torch import nn
import pytorch_lightning as pl


class ACNets(pl.LightningModule):
    def __init__(self) -> None:
        super().__init__()

        self.feature_extractor = nn.Conv1d(948, 1, 3)
        self.fc = nn.Linear(122,2)

    def forward(self, x):
        out = self.feature_extractor(x)
        print('[1]', out.shape)
        out = out.squeeze(1)
        print('[2]', out.shape)
        out = self.fc(out)
        out = nn.functional.softmax(out, dim=1)
        return out

    def training_step(self, batch, batch_idx):
        x,y  = batch
        y_hat = self(x)
        print(y.shape, y_hat.shape)

        loss = nn.functional.cross_entropy(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

model = ACNets()

class Julia2018DataModule(pl.LightningDataModule):
    def __init__(self, X, y):
        super().__init__()
        self.X = X
        self.y = y

        self.data = torch.utils.data.TensorDataset(self.X, self.y)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.data)


datamodule = Julia2018DataModule(X_regions, y_subject)

trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, datamodule=datamodule)


AssertionError: Size mismatch between tensors