In [1]:
"Incase you don't have these."
!pip install simclr
!pip install timm
!pip install torchtyping
!pip install gdown

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting simclr
  Downloading simclr-1.0.2-py3-none-any.whl (21 kB)
Installing collected packages: simclr
Successfully installed simclr-1.0.2
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm
  Downloading timm-0.6.11-py3-none-any.whl (548 kB)
[K     |████████████████████████████████| 548 kB 6.0 MB/s 
[?25hCollecting huggingface-hub
  Downloading huggingface_hub-0.11.0-py3-none-any.whl (182 kB)
[K     |████████████████████████████████| 182 kB 37.6 MB/s 
Installing collected packages: huggingface-hub, timm
Successfully installed huggingface-hub-0.11.0 timm-0.6.11
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchtyping
  Downloading torchtyping-0.1.4-py3-none-any.whl (17 kB)
Collecting typeguard>=2.11.1
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB

In [2]:
"Downloading the dog-dataset" 
!gdown --folder https://drive.google.com/drive/folders/1oudus89CoG9_7r3twbIhw2K_pgVr0D0X

Retrieving folder list
Retrieving folder 1lc8u7_LGR9s2QrWtlbNfL0Ti2UhfJj7T train
Retrieving folder 1JdkS9eMaunQtoZ8HpH1wnnNYqBsMk9Vy n02102040
Processing file 1vC-vCzEjUG1_Jvb-J4rV7FfG3w1S7igK ILSVRC2012_val_00000665.JPEG
Processing file 1Ak_Lw6B3tVS-yNfQx1gVWH-3HoXzkZSN ILSVRC2012_val_00001968.JPEG
Processing file 18XRIAErRiNh0aNe1Fjt1xfmRgAobFHHp ILSVRC2012_val_00002294.JPEG
Processing file 1yMMn4Ljb0uafb654x-WXsK7LfWYjz37V ILSVRC2012_val_00002315.JPEG
Processing file 1JQRfqwAa1Jud8IY8bDR7MxDAUaKChGIq ILSVRC2012_val_00004548.JPEG
Processing file 1M2Mh4s8igoz17Idmg75E0gu5ty8llJ3B ILSVRC2012_val_00004553.JPEG
Processing file 1z1Xz4VMW49Ma_oNU999qnESmspCewqcf ILSVRC2012_val_00007568.JPEG
Processing file 1WXHiWi_khf7_aTSN3B_fbY9o78NRGGSO ILSVRC2012_val_00008334.JPEG
Processing file 1jHIDgQSIj0-5P9sfGDXKQLzzrGlANQAY ILSVRC2012_val_00010994.JPEG
Processing file 1vqOeAMQSKbQ1x7EaXqVrd4A9BIAQr3mN ILSVRC2012_val_00012689.JPEG
Processing file 1s_agMj-u3rQhfMS5aYrIMcfRICPuQXOU ILSVRC2012_val_00

In [1]:
import torch
import torchvision

import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms

from simclr.modules.transformations import TransformsSimCLR
from simclr.modules import NT_Xent
from simclr.modules.identity import Identity

from torchtyping import TensorType, patch_typeguard
from torch.utils.data.dataloader import DataLoader
from typing import Callable, Iterator, Dict
from typeguard import typechecked

from matplotlib import pyplot as plt

from timm.data import create_dataset

patch_typeguard()  # use before @typechecked

In [2]:
"For SimCLRv1"
class SimCLR(nn.Module):

    def __init__(self, encoder, projection_dim, n_features):
        super(SimCLR, self).__init__()

        self.encoder = encoder
        self.n_features = n_features

        self.encoder.fc = Identity()
        
        self.projector = nn.Sequential(
            nn.Linear(self.n_features, self.n_features, bias=False, device = DEVICE),
            nn.ReLU(),
            nn.Linear(self.n_features, projection_dim, bias=False, device = DEVICE),
        )

    def forward(self, x_i, x_j):
        
        
        h_i = self.encoder(x_i)
        h_j = self.encoder(x_j)
        
        z_i = self.projector(h_i)
        z_j = self.projector(h_j)
        
        x_i = self.encoder.conv1(x_i)
        x_i = self.encoder.bn1(x_i)
        x_i = self.encoder.relu(x_i)
        x_i = self.encoder.maxpool(x_i)
        
        #Extract the features from intermediate layers
        h_1 = self.encoder.layer1(x_i)
        h_2 = self.encoder.layer2(h_1)
        h_3 = self.encoder.layer3(h_2)
        h_4 = self.encoder.layer4(h_3)
        
        h_1.norm(dim = 1, p = 2)
        h_2.norm(dim = 1, p = 2)
        h_3.norm(dim = 1, p = 2)
        h_4.norm(dim = 1, p = 2)
        
        return h_i, h_j, z_i, z_j, h_1, h_2, h_3, h_4

In [3]:
image_size = 448
batch_size = 10
projection_dim = 128 #128 for imagenette, 64 for CIFAR10

In [4]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
"Code below loads the imagenette data set."
tim_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor()
])

imagenet_ds_test = create_dataset(name = '', root = 'imagenetmini', 
                                    transform = tim_transform)

img_testloader = torch.utils.data.DataLoader(
    imagenet_ds_test,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True,
    num_workers=2,
)

In [6]:
simclr = 'v2' #Available strings: 'v1', 'v2', False
#False is going to download a supervised contrastive learning model

!git clone https://github.com/tonylins/simclr-converter.git v1
!git clone https://github.com/Separius/SimCLRv2-Pytorch.git v2

from v2.resnet import ResNet, ContrastiveHead

class FeatureResNet(ResNet):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def forward(self, x, apply_fc = False):
        h = self.net(x).mean(dim=[2, 3]) 
        if apply_fc:                     
            h = self.fc(h)               
           
        X = []                           
        for i in self.net:               
            x = i(x)                     
            X.append(x)                  
      
        for i in X:                      
            i.norm(dim = 1, p = 2)
        return h, X[1], X[2], X[3], X[4]
    
def get_resnet(depth=50, width_multiplier=1, sk_ratio=0):  # sk_ratio=0.0625 is recommended
    layers = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3]}[depth]
    resnet = FeatureResNet(layers, width_multiplier, sk_ratio)
    return resnet, ContrastiveHead(resnet.channels_out)

if simclr == 'v1':
    !cd v1 && gdown --folder https://drive.google.com/drive/folders/1AlT5YcS1JZFA2JZU2zwoVkdIXFPSnATQ
    !cd v1 && python convert.py ResNet50_2x/model.ckpt-225206 resnet50-2x.pth
    
    from v1.resnet_wider import resnet50x2
    encoder = resnet50x2()
    sd = torch.load('v1/resnet50-2x.pth', map_location=DEVICE)
    encoder.load_state_dict(sd['state_dict'])
    n_features = encoder.fc.in_features  # get dimensions of last fully-connected layer
    model = SimCLR(encoder, projection_dim, n_features).to(DEVICE)
    
elif simclr == 'v2':
    !cd v2 && python download.py r50_2x_sk1
    !cd v2 && python convert.py r50_2x_sk1/model.ckpt-250228 --ema


    model, _= get_resnet(50, width_multiplier = 2, sk_ratio = 0.0625)
    pth_path = 'v2/r50_2x_sk1_ema.pth'
    model.load_state_dict(torch.load(pth_path)['resnet'])
    model.to(DEVICE)
else:
    !cd v2 && python download.py r50_2x_sk1 --simclr_category=supervised
    !cd v2 && python convert.py r50_2x_sk1/model.ckpt-250228 --supervised    
        
    model, _= get_resnet(50, width_multiplier = 2, sk_ratio = 0.0625)
    pth_path = 'v2/r50_2x_sk1.pth'
    model.load_state_dict(torch.load(pth_path)['resnet'])
    model.to(DEVICE)

fatal: destination path 'v1' already exists and is not an empty directory.
fatal: destination path 'v2' already exists and is not an empty directory.

  0%|          | 0/5 [00:00<?, ?it/s]
100%|##########| 5/5 [00:00<?, ?it/s]


In [7]:
def inference(loader, simclr_model, device):
    labels_vector = []
    feature_vector_blocks = [[] for i in range(4)]

    for step, (x, y) in enumerate(loader):
        
        x = x.to(DEVICE)

        # get encoding
        
        with torch.no_grad():
            if simclr == 'v1':
                _, _, _, _, h_1, h_2, h_3, h_4 = simclr_model(x, x)
            else:
                _, h_1, h_2, h_3, h_4 = simclr_model(x)

        h_blocks = [h_1, h_2, h_3, h_4]
        for j, i in enumerate(h_blocks):
            i = i.detach()
            feature_vector_blocks[j].extend(i.cpu().detach().numpy())

        
        
        labels_vector.extend(y.numpy())

        if step % 20 == 0:
            print(f"Step [{step}/{len(loader)}]\t Computing features...")
            

    labels_vector = np.array(labels_vector)
    
    for i in range(len(feature_vector_blocks)):
        feature_vector_blocks[i] = np.array(feature_vector_blocks[i])
    
    return labels_vector, feature_vector_blocks

def get_features(context_model, test_loader, device):
    test_y, test_blocks = inference(test_loader, context_model, device)
    return test_y, test_blocks 


def create_data_loaders_from_arrays(y_test, block_test, batch_size, B = True):
    test1 = torch.utils.data.TensorDataset(
        torch.from_numpy(block_test[0]), torch.from_numpy(y_test)
    )
    
    test2 = torch.utils.data.TensorDataset(
        torch.from_numpy(block_test[1]), torch.from_numpy(y_test)
    )
    
    test3 = torch.utils.data.TensorDataset(
        torch.from_numpy(block_test[2]), torch.from_numpy(y_test)
    )
    test4 = torch.utils.data.TensorDataset(
        torch.from_numpy(block_test[3]), torch.from_numpy(y_test)
    )
    
    test1loader = torch.utils.data.DataLoader(
    test1,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True,
    num_workers=2,
    )
    test2loader = torch.utils.data.DataLoader(
    test2,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True,
    num_workers=2,
    )
    test3loader = torch.utils.data.DataLoader(
    test3,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True,
    num_workers=2,
    )
    
    test4loader = torch.utils.data.DataLoader(
    test4,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True,
    num_workers=2,
    )
    
    if B == True:
        return test1loader, test2loader, test3loader, test4loader
    else:
        return test1, test2, test3, test4

In [8]:
B = True

print("### Creating features from pre-trained context model ###")
test_y, test_blocks = get_features(model, img_testloader, DEVICE)
block_1, block_2, block_3, block_4 = create_data_loaders_from_arrays(test_y, test_blocks, batch_size, B)

### Creating features from pre-trained context model ###
Step [0/1]	 Computing features...


In [9]:
for step, (images, _) in enumerate(img_testloader):
    if B != True:
        if step != 23:
            continue
    
    print('images.shape:', images.shape)
    image = images.cpu().detach().numpy()
    image = np.moveaxis(image, 1, -1)
    print('image.shape:', image.shape)
    np.save('dog.npy', image)
    break

images.shape: torch.Size([10, 3, 448, 448])
image.shape: (10, 448, 448, 3)


In [10]:
for step, (images, _) in enumerate(block_1):
    if B != True:
        if step != 23:
            continue
    print('images.shape:', images.shape)

    image = images.cpu().detach().numpy()
    image = np.moveaxis(image, 1, -1)
    print('image.shape:', image.shape)

    
    if simclr == 'v1':
        np.save('v1_block1.npy', image)
    elif simclr == 'v2':
        np.save('v2_block1.npy', image)
    else:
        np.save('supervised_block1.npy', image)
        
    break

for step, (images, _) in enumerate(block_2):
    if B != True:
        if step != 23:
            continue
    print('images.shape:', images.shape)
    image = images.cpu().detach().numpy()

    # put color channels as last dimension 
    image = np.moveaxis(image, 1, -1)
    print('image.shape:', image.shape)

    if simclr == 'v1':
        np.save('v1_block2.npy', image)
    elif simclr == 'v2':
        np.save('v2_block2.npy', image)
    else:
        np.save('supervised_block2.npy', image)
    
    break

for step, (images, _) in enumerate(block_3):
    if B != True:
        if step != 23:
            continue
    print('images.shape:', images.shape)
    image = images.cpu().detach().numpy()

    # put color channels as last dimension 
    image = np.moveaxis(image, 1, -1)
    print('image.shape:', image.shape)

    if simclr == 'v1':
        np.save('v1_block3.npy', image)
    elif simclr == 'v2':
        np.save('v2_block3.npy', image)
    else:
        np.save('supervised_block3.npy', image)
    break

for step, (images, _) in enumerate(block_4):
    if B != True:
        if step != 23:
            continue
    print('images.shape:', images.shape)
    image = images.cpu().detach().numpy()

    # put color channels as last dimension 
    image = np.moveaxis(image, 1, -1)
    print('image.shape:', image.shape)

    if simclr == 'v1':
        np.save('v1_block4.npy', image)
    elif simclr == 'v2':
        np.save('v2_block4.npy', image)
    else:
        np.save('supervised_block4.npy', image)

    break

images.shape: torch.Size([10, 512, 112, 112])
image.shape: (10, 112, 112, 512)
images.shape: torch.Size([10, 1024, 56, 56])
image.shape: (10, 56, 56, 1024)
images.shape: torch.Size([10, 2048, 28, 28])
image.shape: (10, 28, 28, 2048)
images.shape: torch.Size([10, 4096, 14, 14])
image.shape: (10, 14, 14, 4096)
