### SimCLR

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# https://zablo.net/blog/post/understanding-implementing-simclr-guide-eli5-pytorch/
class ContrastiveLoss(nn.Module):
    def __init__(self, batch_size, verbose=False, temperature=0.5):
        super().__init__()
        self.verbose = verbose
        self.batch_size = batch_size
        self.register_buffer("temperature", torch.tensor(temperature))
        self.register_buffer("negatives_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float())
        if self.verbose:
            print(f"negatives_mask: {self.negatives_mask}")
            print(f"negatives_mask shape: {self.negatives_mask.shape}")
            
    def forward(self, emb_i, emb_j):
        """
        emb_i and emb_j are batches of embeddings, where corresponding indices are pairs
        z_i, z_j as per SimCLR paper
        """
        z_i = F.normalize(emb_i, dim=1)
        z_j = F.normalize(emb_j, dim=1)

        representations = torch.cat([z_i, z_j], dim=0)
        
        similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
        
        sim_ij = torch.diag(similarity_matrix, self.batch_size)
        sim_ji = torch.diag(similarity_matrix, -self.batch_size)
        
        positives = torch.cat([sim_ij, sim_ji], dim=0)
        
        if self.verbose:
            print(f"representations: {representations.unsqueeze(1).shape}")
            print(f"representations: {representations.unsqueeze(0).shape}")
            print(f"similarity_matrix: {similarity_matrix.shape}")
            print(f"similarity_matrix: {similarity_matrix}")
            print(f"sim_ij: {sim_ij}")
            print(f"sim_ji: {sim_ji}")
            print(f"positives: {positives.shape}")
        
        nominator = torch.exp(positives / self.temperature)
        denominator = self.negatives_mask * torch.exp(similarity_matrix / self.temperature)
    
        loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))
        loss = torch.sum(loss_partial) / (2 * self.batch_size)
        return loss

In [3]:
batch_size = 4
feature_dim = 512
emb_1, emb_2 = torch.rand(batch_size, feature_dim), torch.rand(batch_size, feature_dim)

In [4]:
simclr_loss = ContrastiveLoss(batch_size, verbose=True)

negatives_mask: tensor([[0., 1., 1., 1., 1., 1., 1., 1.],
        [1., 0., 1., 1., 1., 1., 1., 1.],
        [1., 1., 0., 1., 1., 1., 1., 1.],
        [1., 1., 1., 0., 1., 1., 1., 1.],
        [1., 1., 1., 1., 0., 1., 1., 1.],
        [1., 1., 1., 1., 1., 0., 1., 1.],
        [1., 1., 1., 1., 1., 1., 0., 1.],
        [1., 1., 1., 1., 1., 1., 1., 0.]])
negatives_mask shape: torch.Size([8, 8])


In [5]:
loss = simclr_loss(emb_1, emb_2)

representations: torch.Size([8, 1, 512])
representations: torch.Size([1, 8, 512])
similarity_matrix: torch.Size([8, 8])
similarity_matrix: tensor([[1.0000, 0.7604, 0.7783, 0.7217, 0.7632, 0.7457, 0.7285, 0.7567],
        [0.7604, 1.0000, 0.7543, 0.7191, 0.7400, 0.7507, 0.7658, 0.7726],
        [0.7783, 0.7543, 1.0000, 0.7375, 0.7657, 0.7597, 0.7595, 0.7669],
        [0.7217, 0.7191, 0.7375, 1.0000, 0.7286, 0.7354, 0.7313, 0.7466],
        [0.7632, 0.7400, 0.7657, 0.7286, 1.0000, 0.7539, 0.7424, 0.7580],
        [0.7457, 0.7507, 0.7597, 0.7354, 0.7539, 1.0000, 0.7514, 0.7562],
        [0.7285, 0.7658, 0.7595, 0.7313, 0.7424, 0.7514, 1.0000, 0.7578],
        [0.7567, 0.7726, 0.7669, 0.7466, 0.7580, 0.7562, 0.7578, 1.0000]])
sim_ij: tensor([0.7632, 0.7507, 0.7595, 0.7466])
sim_ji: tensor([0.7632, 0.7507, 0.7595, 0.7466])
positives: torch.Size([8])


In [6]:
loss

tensor(1.9368)

In [7]:
import timm

class SimCLR_Encoder(nn.Module):
    def __init__(self, backbone_name, in_dim, out_dim, pretrained=False):
        super().__init__()
        # simclr backbone
        self.simclr_backbone = timm.create_model(backbone_name, pretrained=pretrained)
        # simclr projector
        in_features = self.simclr_backbone.fc.in_features
        self.simclr_backbone.fc = nn.Identity()
        self.simclr_projector = nn.Sequential(
            nn.Linear(in_features, in_dim, bias=True),
            nn.BatchNorm1d(in_dim),
            nn.ReLU(),
            nn.Linear(in_dim, in_dim, bias=True),
            nn.BatchNorm1d(in_dim),
            nn.ReLU(),
            nn.Linear(in_dim, out_dim, bias=False),
        )
    
    def forward(self, x):
        x = self.simclr_backbone(x)
        out = self.simclr_projector(x)
        return out

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
model = SimCLR_Encoder('resnet18', 512, 128)

In [9]:
model

SimCLR_Encoder(
  (simclr_backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (drop_block): Identity()
        (act1): ReLU(inplace=True)
        (aa): Identity()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)

In [10]:
from tqdm import tqdm
import requests
import gzip
import os
import numpy as np
def download_mnist(url,file_dict=None):
    if file_dict is not None:
        mnist_data=list()
        try:
            for i, key in enumerate(file_dict.keys()):    
                fname = file_dict[key]
                url = os.path.join(url_root,fname)                
                isExist = os.path.exists(fname)
                if not isExist:
                    response = requests.get(url, stream=True)
                    fsize=len(response.content)
                    print(url)
                    with open(fname, 'wb') as fout:
                        for data in tqdm(response.iter_content(), desc =fname, total=fsize):
                            fout.write(data)
                
                with gzip.open(fname, "rb") as f_in:                
                    if fname.find('idx3') != -1:        
                        mnist_data.append(np.frombuffer(f_in.read(), np.uint8, offset=16).reshape(-1, 28, 28)) #if images        
                    else:                               
                        mnist_data.append(np.frombuffer(f_in.read(), np.uint8, offset=8))  #if labels
            #return mnist_data in a list format ==> [[train_images], [train_labels], [test_images], [test_labels]] 
            return mnist_data
        except Exception as e:
            print("Something went wrong:", e)
    else:
        print("file_dict cannot be None")

In [11]:
url_root = 'http://yann.lecun.com/exdb/mnist'
file_dict={
    'train_images':'train-images-idx3-ubyte.gz',
    'train_labels':'train-labels-idx1-ubyte.gz',
    'test_images':'t10k-images-idx3-ubyte.gz',
    'test_labels':'t10k-labels-idx1-ubyte.gz'
}
dataset= download_mnist(url_root,file_dict)

In [12]:
train_images=dataset[0]
#train_labels=dataset[1]
val_images=dataset[2]
#test_labels=dataset[3]

In [13]:
from PIL import ImageOps, ImageFilter, Image

In [14]:
class GaussianBlur(object):
    def __init__(self, p, sigma_min, sigma_max):
        self.p = p
        self.sigma_min, self.sigma_max = sigma_min, sigma_max
    def __call__(self, img):
        if np.random.rand() < self.p:
            sigma = np.random.rand() * (self.sigma_max - self.sigma_min)+ self.sigma_min
            return img.filter(ImageFilter.GaussianBlur(sigma))
        else:
            return img

In [15]:
from torchvision import transforms
size = 28
data_transform = transforms.Compose([transforms.RandomResizedCrop(size=size),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.RandomApply(
                                         [
                                             transforms.ColorJitter(
                                                 brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=0, hue=0,
                                             )
                                         ],
                                         p=0.8,
                                     ),
                                     transforms.RandomGrayscale(p=0.2),
                                     GaussianBlur(p=1.0, sigma_min=0.1, sigma_max=2.0),
                                     transforms.ToTensor()])

In [16]:
from torch.utils.data import Dataset, DataLoader

class MNISTCustomDataset(Dataset):
    def __init__(self, images, transform=None, transform_p=None):
        self.images = images
        self.transform = transform
        self.transform_p = transform_p
    def __getitem__(self, idx):
        image = Image.fromarray(self.images[idx]).convert("RGB")
        image_1 = self.transform(image)
        image_2 = self.transform_p(image)
        return image_1, image_2
    def __len__(self):
        return self.images.shape[0]

In [17]:
train_dataset = MNISTCustomDataset(train_images, transform=data_transform, transform_p=data_transform)
val_dataset = MNISTCustomDataset(val_images, transform=data_transform, transform_p=data_transform)

In [18]:
batch_size =  64
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=2, pin_memory=True,
                          shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=2, pin_memory=True,
                          shuffle=True, drop_last=True)

In [19]:
from tqdm import trange
def train(model, optimizer, criterion, num_epochs, batch_size, train_loader, val_loader):
    train_losses, val_losses = [], []
    for epoch in trange(num_epochs):
        # train model
        model.train()
        accum_train_loss = 0
        for (x_1, x_2) in train_loader:
            out_1, out_2 = model(x_1.cuda()).squeeze(), model(x_2.cuda()).squeeze()
            loss = criterion(out_1, out_2)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            accum_train_loss += loss.item()
        train_losses.append(accum_train_loss / len(train_loader))

        # eval model
        model.eval()
        accum_val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                x1, x2 = batch[0], batch[1]
                out_1, out_2 = model(x_1.cuda()).squeeze(), model(x_2.cuda()).squeeze()
                loss = criterion(out_1, out_2)
                accum_val_loss += loss.item()
            val_losses.append(accum_val_loss / len(val_loader))
            
        print(f"epoch {epoch+1} Train Loss: {accum_train_loss / len(train_loader)}, Validation Loss: {accum_val_loss / len(val_loader)}")
            
    return train_losses, val_losses

In [21]:
from torch.optim import Adam
model = SimCLR_Encoder('resnet18', 512, 128).cuda()
optimizer = Adam(model.parameters(), lr=1e-3)
criterion = ContrastiveLoss(batch_size).cuda()
num_epochs = 10
train(model, optimizer, criterion, num_epochs, batch_size, train_loader, val_loader)

 10%|█         | 1/10 [00:37<05:40, 37.87s/it]

epoch 1 Train Loss: 4.430632861726694, Validation Loss: 4.167326927185059


 20%|██        | 2/10 [01:15<05:01, 37.75s/it]

epoch 2 Train Loss: 4.047993498652569, Validation Loss: 3.7923147678375244


 30%|███       | 3/10 [01:53<04:24, 37.73s/it]

epoch 3 Train Loss: 3.894592859956218, Validation Loss: 3.635878086090088


 40%|████      | 4/10 [02:30<03:46, 37.71s/it]

epoch 4 Train Loss: 3.791337649366772, Validation Loss: 3.7594809532165527


 50%|█████     | 5/10 [03:08<03:08, 37.71s/it]

epoch 5 Train Loss: 3.7187681450024486, Validation Loss: 3.6511712074279785


 60%|██████    | 6/10 [03:46<02:30, 37.69s/it]

epoch 6 Train Loss: 3.6734036498543037, Validation Loss: 3.6855881214141846


 70%|███████   | 7/10 [04:23<01:52, 37.63s/it]

epoch 7 Train Loss: 3.631831737312716, Validation Loss: 3.5930111408233643


 80%|████████  | 8/10 [05:01<01:15, 37.66s/it]

epoch 8 Train Loss: 3.6076426801203154, Validation Loss: 3.5604779720306396


 90%|█████████ | 9/10 [05:39<00:37, 37.62s/it]

epoch 9 Train Loss: 3.5817133818008604, Validation Loss: 3.567700147628784


100%|██████████| 10/10 [06:16<00:00, 37.68s/it]

epoch 10 Train Loss: 3.5671773470834838, Validation Loss: 3.5599286556243896





([4.430632861726694,
  4.047993498652569,
  3.894592859956218,
  3.791337649366772,
  3.7187681450024486,
  3.6734036498543037,
  3.631831737312716,
  3.6076426801203154,
  3.5817133818008604,
  3.5671773470834838],
 [4.167326927185059,
  3.7923147678375244,
  3.635878086090088,
  3.7594809532165527,
  3.6511712074279785,
  3.6855881214141846,
  3.5930111408233643,
  3.5604779720306396,
  3.567700147628784,
  3.5599286556243896])