<a href="https://colab.research.google.com/github/bradley-ray/implementing-deep-learn-papers/blob/master/a_simple_framework_for_contrastive_learning_of_visual_representations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A Simple Framework for Contrastive Learning of Visual Representations
* [paper](https://arxiv.org/abs/2002.05709)
* [weights](https://console.cloud.google.com/storage/browser/simclr-checkpoints/simclrv1?pageState=(%22StorageObjectListTable%22:(%22f%22:%22%255B%255D%22))&prefix=&forceOnObjectsSortingFiltering=false)

SimCLR uses ResNet-50 for the encoder model.

In [None]:
class ResNetBlock(nn.Module):
    def __init__(self, num_channels: int, out_channels: int block_idx: int):
        super().__init__()
        stride = 2 if block_idx == 0 else 1
        self.conv1 = nn.Conv2d(in_channels=num_channels*4, 
                               out_channels=num_channels, kernel_size=1,
                               stride=stride)
        self.conv2 = nn.Conv2d(in_channels=num_channels, 
                               out_channels=num_channels, kernel_size=3)
        self.conv3 = nn.Conv2d(in_channels=num_channels, 
                               out_channels=num_channels*4, kernel_size=1)

    def forward(self, x):
        out = x
        res = out
        out = self.conv1(out)
        out = self.conv2(out)
        out = self.conv3(out)
        out += res
        return out

    def load_weights(self, weights):
        ...

class ResNet50(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels64, 
                              kernel_size=7, stride=2)
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2)
        self.avg_pool = nn.AvgPool2d(kernel_size=7)
        self.conv2_x = nn.Sequntial(*[ResNetBlock(64, i) for i in range(3)])
        self.conv3_x = nn.Sequntial(*[ResNetBlock(128, i) for i in range(4)])
        self.conv4_x = nn.Sequntial(*[ResNetBlock(256, i) for i in range(6)])
        self.conv5_x = nn.Sequntial(*[ResNetBlock(512, i) for i in range(3)])

    def forward(self, x):
        out = self.conv1(x)
        out = self.max_pool(out)
        out = self.conv2_x(out)
        out = self.conv3_x(out)
        out = self.conv4_x(out)
        out = self.conv5_x(out)
        out = self.avg_pool(out)
        return out

    def load_weights(self, weights):
        ...