## Fit yAware using age as auxiliary variable

For efficiency, we will fit y-Aware on VBM-ROI modality using age as auxiliary variable. We will then plot the latent space and compare with y-Aware fit on FreeSurfer Destrieux-ROI modality.

In [None]:
import matplotlib.pyplot as plt
import nibabel
import numpy as np
import pandas as pd
import seaborn
import torch
from nilearn import datasets, plotting
from scipy.stats import pearsonr
from sklearn.linear_model import LinearRegression
from sklearn.manifold import TSNE
from sklearn.neighbors import KNeighborsRegressor
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

from nidl.datasets import OpenBHB
from nidl.estimators.ssl import YAwareContrastiveLearning
from nidl.transforms import MultiViewsTransform
from nidl.volume.backbones import resnet18
from nidl.volume.transforms.augmentation.spatial import RandomErasing
from nidl.volume.transforms.preprocessing.spatial import CropOrPad


In [2]:
root_dir = "/home_local/bd261576/openBHB"

### Build scorer for age

Since y-Aware is an embedding model and we don't know the true latent factors of the data, we use a KNN regression to test whether age is decodable from the embedding. 

In [3]:
def prediction_score(z, y):
    def _knn_fitting(x, y):
        knn_model = KNeighborsRegressor()
        knn_model.fit(x, y)
        return knn_model.score(x, y), knn_model.predict(x)
    return _knn_fitting(z, y)

### Load VBM modality and fit the y-Aware estimator

In [None]:
ssl_transforms = MultiViewsTransform(
    Compose(
        [RandomErasing(num_iterations=5),
        CropOrPad(128)]
    ), n_views=2
)
test_transform = CropOrPad(128)

# Load the data
dataset_ssl_train = OpenBHB(root_dir, target="age", modality="vbm",
                        split="train", transforms=ssl_transforms)
dataset_train = OpenBHB(root_dir, target="age", modality="vbm",
                        split="train", transforms=test_transform)
dataset_test = OpenBHB(root_dir, target="age", modality="vbm",
                       split="val", transforms=test_transform)

train_ssl_dataloader = DataLoader(dataset_ssl_train,
                              batch_size=128,
                              num_workers=10,
                              shuffle=True)

train_dataloader = DataLoader(dataset_train,
                              batch_size=128,
                              num_workers=10,
                              shuffle=False)

test_dataloader = DataLoader(dataset_test,
                             batch_size=128,
                             num_workers=10,
                             shuffle=False)

In [None]:
sigma = 5
yaware = YAwareContrastiveLearning(
    encoder=resnet18(),
    projection_head=None,
    bandwidth=sigma**2
)