# VICReg 

VICReg (Variance-Invariance-Covariance regularization) is a self-supervised model for visual representation learning  without annotations.
It learns by creating two augmented views of the same image using agressive data augmentations (mostly crop as in SimCLR) and then
aligning views. However, it does not use contrastive learning to avoid representation collapse but rather 2 key principles:

1. Variance maximization of each feature in the latent space (maximize expressivity)
2. Covariance minimization across features of the same image (minimize redundancy)

## Key components

- **Data augmentations**: VicReg uses cropping with resizing, gaussian blur, gaussian noise and flip to create diverse views
- **Backbone**: 3D-CNN such as ResNet are usually employed to deal with 3D volumes such as MRI
- **Projection head**: a small MLP is used to map features from CNN to a latent space where VicReg loss is applied. The output dimension is usually much larger than the input dimension (typically 8192 vs 2048), thus it is called "expander".
- **VICReg loss**:  this loss encourages similar pairs to align (invariance), variance of each embedding feature to be maximized (variance) and covariance across features to be minimized (covariance)

**Reference**: [VICReg: Variance-Invariance-Covariance Regularization for Self-Supervised Learning](https://arxiv.org/abs/2105.04906)


In [None]:
# This example requires the following dependencies to be installed:
# pip install -e nidl

from nidl.datasets.openbhb import OpenBHB
from nidl.models import VICReg
from nidl.volume.transforms import SimCLRTransform
from nidl.data.collate import TwoViewsCollateFunction
from torch.utils.data import DataLoader


dataset = OpenBHB("/neurospin/signatures/bd261576/openBHB", modality="vbm", 
                  split="train")

dataloader = DataLoader(dataset, collate_fn=TwoViewsCollateFunction(
                            SimCLRTransform(input_size=(1, 128, 128, 128))),
                        batch_size=64, shuffle=True, num_workers=10)

vicreg = VICReg(
    encoder="resnet18_3d",
    n_embedding=16, 
    max_epochs=100
)

vicreg.fit(dataloader)