# A Simple Framework for Contrastive Learning of Visual Representations
* [paper](https://arxiv.org/abs/2002.05709)
* [weights](https://console.cloud.google.com/storage/browser/simclr-checkpoints/simclrv1?pageState=(%22StorageObjectListTable%22:(%22f%22:%22%255B%255D%22))&prefix=&forceOnObjectsSortingFiltering=false)

## Notes for weights
- pytorch: downsameple.0 is conv1x1, downsample.1 is batchnorm
- 3,6,4,3 is number of blocks for each conv size
- total conv layers = 53 = $(3*3) + (6*3) + (4*3) + (3*3) + 4 + 1$
- resnet will have extra layers because not using a lot them (stop after avgpool)
- tf uses '/' while pytorch uses '.'
- tf uses 'batch_normalization_{num}/moving_...' (eww) while pytorch uses 'bn{num}.running...' for
- tf uses numbering from 0..52 (again, eww) while pytorch is based on layers, so need working conversion

In [1]:
import torch
import torch.nn as nn

## Getting and Converting Weights

In [229]:
import tensorflow as tf
# convert simclr tf_weights to pytorch_weights

def to_pytorch(tf_weights, model_dict):
    torch_weights = dict()
    for layer in tf_weights:
        if 'Momentum' in layer:
            continue

        weights = torch.from_numpy(tf_weights[layer])
        layer = layer.replace('/', '.')
        layer = layer.replace('base_model', 'encoder')
        layer = layer.replace('kernel', 'weight')
        layer = layer.replace('moving', 'running')
        layer = layer.replace('gamma', 'weight')
        layer = layer.replace('beta', 'bias')
        layer = layer.replace('variance', 'var')
        layer = layer.replace('batch_normalization', 'bn')
        layer = layer.replace('conv2d', 'conv')
        layer = layer.replace('head_supervised.linear_layer.dense', 'head.ff1')
        beg, mat, end = layer.partition('_')
        if end != '' and end[0].isdigit():
            b, e = beg.split('.')
            num, end_ = end.split('.')
            num = int(num)
            layer_idx = 'layer'
            # set up some variables
            if num <= 10:   # layer 1 (10 = 3blocks * 3convs + 1)
                layer_idx += '1'
            elif num <= 23: # layer 2 (29 = prev + 4blocks * 3convs + 1)
                num -= 10
                layer_idx += '2'
            elif num <= 42: # layer 3 (42 = prev + 6blocks * 3convs + 1)
                num -= 23
                layer_idx += '3'
            else:           # layer 4 (52 = prev + 3blocks * 3convs + 1)
                num -= 42
                layer_idx += '4'

            # change layer name accordingly
            if num == 1 and 'conv' in layer:
                layer = '.'.join([b, layer_idx, '0', 'downsample.0', end_])
            elif num == 1 and 'bn' in layer:
                layer = '.'.join([b, layer_idx, '0', 'downsample.1', end_])
            else:
                num -= 2
                block, idx  = divmod(num, 3)
                layer = '.'.join([b, layer_idx, str(block), ('conv' if 'conv' in layer else 'bn') + str(idx+1), end_])
            
        elif 'conv.' in layer:
            layer = layer.replace('conv.', 'conv1.')
        elif 'bn.' in layer:
            layer = layer.replace('bn.', 'bn1.')

        # permuting weights
        if 'conv' in layer or 'downsample.0' in layer:
            weights = weights.permute(3,2,0,1)
        elif 'ff' in layer:
            weights = weights.T


        # print(layer)
        assert layer in model_dict and layer not in torch_weights
        torch_weights[layer] = weights
    return torch_weights

**TODO**: probably should turn this into a python function

In [None]:
# finetune-10%, 1x width
!rm -r 1x
!gsutil -m cp -r \
  "gs://simclr-checkpoints/simclrv1/finetune_10pct/1x" \
  .

In [None]:
# finetune-100%, 1x width
!rm -r 1x
!gsutil -m cp -r \
  "gs://simclr-checkpoints/simclrv1/finetune_100pct/1x" \
  .

In [230]:
# pretrained weights, 1x width
!rm -r 1x
!gsutil -m cp -r \
  "gs://simclr-checkpoints/simclrv1/pretrain/1x" \
  .

Copying gs://simclr-checkpoints/simclrv1/pretrain/1x/checkpoint...
/ [0/9 files][    0.0 B/213.3 MiB]   0% Done                                    Copying gs://simclr-checkpoints/simclrv1/pretrain/1x/graph.pbtxt...
/ [0/9 files][    0.0 B/213.3 MiB]   0% Done                                    Copying gs://simclr-checkpoints/simclrv1/pretrain/1x/hub/saved_model.pb...
/ [0/9 files][    0.0 B/213.3 MiB]   0% Done                                    Copying gs://simclr-checkpoints/simclrv1/pretrain/1x/hub/variables/variables.index...
/ [0/9 files][    0.0 B/213.3 MiB]   0% Done                                    Copying gs://simclr-checkpoints/simclrv1/pretrain/1x/model.ckpt-225206.data-00000-of-00001...
/ [0/9 files][    0.0 B/213.3 MiB]   0% Done                                    Copying gs://simclr-checkpoints/simclrv1/pretrain/1x/model.ckpt-225206.index...
/ [0/9 files][    0.0 B/213.3 MiB]   0% Done                                    Copying gs://simclr-checkpoints/simclrv1/pre

SimCLR uses ResNet-50 for the encoder model.

In [231]:
from torchvision.models import resnet

class TrainHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.ff1 = nn.Linear(2048, 2048)
        self.ff2 = nn.Linear(2048, 2048)

    def forward(self, x):
        out = torch.relu(self.ff1(x))
        out = self.ff2(out)
        return out

class LinearClassifierHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.ff1 = nn.Linear(2048, 1000)

    def forward(self, x):
        out = self.ff1(x)
        return out

class Encoder(resnet.ResNet):
    def __init__(self, block, layers):
        super().__init__(block, layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = torch.flatten(self.avgpool(out), 1, -1)
        return out

class SimCLR(nn.Module):
    def __init__(self, head: nn.Module=None):
        super().__init__()
        self.encoder = Encoder(resnet.Bottleneck, [3,4,6,3])
        self.head = head

    def forward(self, x):
        out = self.encoder(x)
        if self.head != None:
            out = self.head(out)
        return out

    def load_pretrained_weights(self):
        params = self.state_dict()
        vars = map(lambda x: x[0], tf.train.list_variables('1x'))
        
        reader = tf.train.load_checkpoint('1x')
        layers = {layer: reader.get_tensor(layer)\
                for layer in vars if layer != 'global_step'}
        
        loaded_params = to_pytorch(layers, params)
        for layer in params:
            if layer not in loaded_params:
                loaded_params[layer] = torch.zeros_like(params[layer])
        
        self.load_state_dict(loaded_params)


In [232]:
# model = SimCLR(TrainHead())
model = SimCLR(LinearClassifierHead())
model.load_pretrained_weights()
model.eval()

img = torch.normal(0, 1, (1, 3, 224, 224))
model(img).shape

torch.Size([1, 1000])