# SimCLR for 3D medical imaging

SimCLR is a self-supervised framework for visual representation learning using contrastive methods. It learns by creating two augmented views of the same image then maximizing agreement between these augmented views while separating them from other images. Key findings include the importance of strong compositions of data augmentations, a nonlinear projection head that boosts representation quality, and the advantages of large batch sizes. Combined, these elements allow SimCLR to approach or match supervised performance on ImageNet and achieve strong transfer and semi-supervised learning results.

## Key Components

- **Data Augmentations**: here, SimCLR uses random cropping, resizing and Gaussian blur to create diverse views of the same image.
- **Backbone**: Convolutional neural networks, such as 3D-ResNet, are employed to encode augmented images into feature representations.
- **Projection Head**: A multilayer perceptron (MLP) maps features into a space where contrastive loss is applied, enhancing representation quality.
- **Contrastive Loss**: The InfoNCE loss encourages similar pairs to align and dissimilar pairs to diverge.

**Reference**: [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/abs/2002.05709)

In [None]:
# This example requires the following dependencies to be installed:
# pip install -e nidl

from nidl.datasets.openbhb import OpenBHB
from nidl.models import SimCLR
from nidl.volume.transforms import SimCLRTransform
from nidl.data.collate import TwoViewsCollateFunction
from torch.utils.data import DataLoader


dataset = OpenBHB("/neurospin/signatures/bd261576/openBHB", modality="vbm", 
                  split="train")

dataloader = DataLoader(dataset, collate_fn=TwoViewsCollateFunction(
                            SimCLRTransform(input_size=(1, 128, 128, 128))),
                        batch_size=64, shuffle=True, num_workers=10)

simclr = SimCLR(
    encoder="resnet18_3d",
    n_embedding=16, 
    max_epochs=100
)

simclr.fit(dataloader)
