# Barlow Twins

Barlow twins is a novel architecture inspired by the redundancy reduction principle in the work of neuroscientist H. Barlow. It is conceptually simple and easy to implement an architecture that learns high-quality representations from noisy image data.This model was proposed by Jure Zbontar, Li Jing, Ishan Misra, Yann LeCun, Stéphane Deny on 4th March 2021.

To read about it more, please refer [this](https://analyticsindiamag.com/a-guide-to-barlow-twins-self-supervised-learning-via-redundancy-reduction/) article.

# Usage Code

Pytorch implementation of Barlow twins is made available by Facebook research in their GitHub repository. This model takes a long time to train(7 days on 16 V100 GPUs). Fortunately, pretrained weights of the model on ImageNet are available. Let’s see how to use this model for image classification.


## Loading Pretrained Model

In [1]:
!git clone https://github.com/facebookresearch/barlowtwins.git

Cloning into 'barlowtwins'...
remote: Enumerating objects: 78, done.[K
remote: Counting objects: 100% (78/78), done.[K
remote: Compressing objects: 100% (67/67), done.[K
remote: Total 78 (delta 47), reused 26 (delta 11), pack-reused 0[K
Unpacking objects: 100% (78/78), done.


In [2]:
import torch
model = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')

Downloading: "https://github.com/facebookresearch/barlowtwins/archive/main.zip" to /root/.cache/torch/hub/main.zip
Downloading: "https://dl.fbaipublicfiles.com/barlowtwins/ep1000_bs2048_lrw0.2_lrb0.0048_lambd0.0051/resnet50.pth" to /root/.cache/torch/hub/checkpoints/resnet50.pth


HBox(children=(FloatProgress(value=0.0, max=94355933.0), HTML(value='')))




In [3]:
!mkdir data
!wget --no-check-certificate \
    https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip 

--2021-06-21 10:07:59--  https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 64.233.191.128, 173.194.74.128, 173.194.192.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|64.233.191.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 68606236 (65M) [application/zip]
Saving to: ‘cats_and_dogs_filtered.zip’


2021-06-21 10:08:00 (102 MB/s) - ‘cats_and_dogs_filtered.zip’ saved [68606236/68606236]



In [6]:
import os
import zipfile

local_zip = 'cats_and_dogs_filtered.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('data/')
zip_ref.close()

In [7]:
!ls data/cats_and_dogs_filtered/validation

cats  dogs


## Preprocessing Input Images

In [8]:
from torchvision import datasets, transforms
import torch
val_dataset = datasets.ImageFolder('data/cats_and_dogs_filtered/validation', transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
        ]))
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=32)

In [13]:
val_dataset

Dataset ImageFolder
    Number of datapoints: 1000
    Root location: /content/data/cats_and_dogs_filtered/validation
    StandardTransform
Transform: Compose(
               Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
               CenterCrop(size=(224, 224))
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

## Inference

In [12]:
from tqdm import tqdm_notebook as tqdm
import numpy as np
preds=[]
model.eval()
for images,labels in tqdm(val_loader):
  preds.extend(model(images).detach().numpy())
  print(preds[-1].shape)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  


HBox(children=(FloatProgress(value=0.0, max=32.0), HTML(value='')))

(1000,)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
(1000,)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
(1000,)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
(1000,)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
(1000,)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
(1000,)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
(1000,)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
(1000,)
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0])
(1000,)
tensor([0, 0, 0, 0, 0, 0

In [14]:
with open('data/cats_and_dogs_filtered/labels.txt','r') as f:
  labels=f.readlines()

FileNotFoundError: ignored

In [11]:
label_map={int(i.split()[1])-1 : i.split()[2] for i in labels}
label_map

TypeError: ignored

In [None]:
predicted_labels=list(map(lambda x:label_map[x],np.argmax(np.array(preds),axis=1)))

In [None]:
batch=next(iter(val_loader))

In [None]:
preds=model(batch[0]).detach().numpy()

In [None]:
predicted_labels=list(map(lambda x:label_map[x],np.argmax(np.array(preds),axis=1)))

In [None]:
class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor
unorm = UnNormalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
images=[unorm(i) for i in batch[0]]

In [None]:
import matplotlib.pyplot as plt
fig,axes=plt.subplots(8,4,figsize=(40,20))
for index,i in enumerate(images):
  axes[index%8,index%4].imshow(np.transpose(i,[1,2,0]))
  axes[index%8,index%4].title.set_text(predicted_labels[index])
plt.show()

In [None]:
class BarlowTwins(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.backbone = torchvision.models.resnet50(zero_init_residual=True)
        self.backbone.fc = nn.Identity()

        # projector
        sizes = [2048] + list(map(int, args.projector.split('-')))
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
            layers.append(nn.BatchNorm1d(sizes[i + 1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
        self.projector = nn.Sequential(*layers)

        # normalization layer for the representations z1 and z2
        self.bn = nn.BatchNorm1d(sizes[-1], affine=False)

    def forward(self, y1, y2):
        z1 = self.projector(self.backbone(y1))
        z2 = self.projector(self.backbone(y2))

        # empirical cross-correlation matrix
        c = self.bn(z1).T @ self.bn(z2)

        # use --scale-loss to multiply the loss by a constant factor
        # see the Issues section of the readme
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum().mul(self.args.scale_loss)
        off_diag = off_diagonal(c).pow_(2).sum().mul(self.args.scale_loss)
        loss = on_diag + self.args.lambd * off_diag
        return loss

In [None]:
class Transform:
    def __init__(self):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=1.0),
            Solarization(p=0.0),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.transform_prime = transforms.Compose([
            transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=0.1),
            Solarization(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        y1 = self.transform(x)
        y2 = self.transform_prime(x)
        return y1, y2