# A Simple Framework for Contrastive Learning of Visual Representations
* [paper](https://arxiv.org/abs/2002.05709)
* [weights](https://console.cloud.google.com/storage/browser/simclr-checkpoints/simclrv1?pageState=(%22StorageObjectListTable%22:(%22f%22:%22%255B%255D%22))&prefix=&forceOnObjectsSortingFiltering=false)

In [1]:
import torch
import torch.nn as nn

SimCLR uses ResNet-50 for the encoder model.

In [34]:
from torchvision.models import resnet

class TrainHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.ff1 = nn.Linear(2048, 2048)
        self.ff2 = nn.Linear(2048, 2048)

    def forward(self, x):
        out = torch.relu(self.ff1(x))
        out = self.ff2(out)
        return out

class LinearClassifierHead(nn.Module):
    def __init__(self):
        ...

class Encoder(resnet.ResNet):
    def __init__(self, block, layers):
        super().__init__(block, layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = torch.flatten(self.avgpool(out), 1, -1)
        return out

class SimCLR(nn.Module):
    def __init__(self, head: nn.Module=None):
        super().__init__()
        self.encoder = Encoder(resnet.Bottleneck, [3,4,6,3])
        self.head = head

    def forward(self, x):
        out = self.encoder(x)
        if self.head != None:
            out = self.head(out)
        return out

In [35]:
model = SimCLR(TrainHead())
model.eval()

img = torch.normal(0, 1, (1, 3, 224, 224))
model(img).shape

torch.Size([1, 2048])