# y-Aware 

y-Aware is a self-supervised model for visual representation learning that uses contrastive learning with auxiliary variables.
It learns by creating two augmented views of the same image using agressive data augmentations (mostly crop and cutout) and then
aligning views with similar auxiliary variables (coming **both** from the same and distinct images) while repelling the representation
of other images. 

## Key components

- **Data augmentations**: y-Aware uses cropping with resizing, cutout, gaussian blur, gaussian noise and flip to create diverse views
- **Kernel**: y-Aware uses a kernel (in the statistical sense, same as KDE) to define a similarity criterion between images based on auxiliary variables
- **Backbone**: 3D-CNN such as ResNet are usually employed to deal with 3D volumes such as MRI
- **Projection head**: a small MLP is used to map features from CNN to a latent space where contrastive loss is applied. Alternatively, it can be viewed as 
  a learnable similarity function between features.
- **Contrastive loss**:  The y-Aware InfoNCE loss encourages similar pairs (as measured by the kernel) to align and dissimilar pairs to repel

**Reference**: [Contrastive Learning with Continuous Proxy Meta-Data for 3D MRI Classification](https://arxiv.org/abs/2106.08808)


In [None]:
# This example requires the following dependencies to be installed:
# pip install -e nidl

from nidl.datasets.openbhb import OpenBHB
from nidl.models import yAware
from nidl.volume.transforms import yAwareTransformStrong
from nidl.data.collate import TwoViewsCollateFunction
from torch.utils.data import DataLoader


dataset = OpenBHB("/neurospin/signatures/bd261576/openBHB", target="age", modality="vbm", 
                  split="train")

dataloader = DataLoader(dataset, collate_fn=TwoViewsCollateFunction(
                            yAwareTransformStrong(input_size=(1, 128, 128, 128))),
                        batch_size=16, shuffle=True, num_workers=10)

yaware = yAware(
    encoder="resnet18_3d",
    n_embedding=16, 
    kernel_kwargs=dict(kernel="gaussian", bandwidth=5.0),
    max_epochs=100
)

yaware.fit(dataloader)
