In [1]:
#https://github.com/MaxLikesMath/Barlow-Twins-Pytorch/tree/main

In [2]:
import torch
import torch.nn as nn
'''
Implementation of Barlow Twins (https://arxiv.org/abs/2103.03230), adapted for ease of use for experiments from
https://github.com/facebookresearch/barlowtwins, with some modifications using code from 
https://github.com/lucidrains/byol-pytorch
'''

'\nImplementation of Barlow Twins (https://arxiv.org/abs/2103.03230), adapted for ease of use for experiments from\nhttps://github.com/facebookresearch/barlowtwins, with some modifications using code from \nhttps://github.com/lucidrains/byol-pytorch\n'

In [3]:


def flatten(t):
    return t.reshape(t.shape[0], -1)

class NetWrapper(nn.Module):
    # from https://github.com/lucidrains/byol-pytorch
    def __init__(self, net, layer = -2):
        super().__init__()
        self.net = net
        self.layer = layer

        self.hidden = None
        self.hook_registered = False

    def _find_layer(self):
        if type(self.layer) == str:
            modules = dict([*self.net.named_modules()])
            return modules.get(self.layer, None)
        elif type(self.layer) == int:
            children = [*self.net.children()]
            return children[self.layer]
        return None

    def _hook(self, _, __, output):
        self.hidden = flatten(output)

    def _register_hook(self):
        layer = self._find_layer()
        assert layer is not None, f'hidden layer ({self.layer}) not found'
        handle = layer.register_forward_hook(self._hook)
        self.hook_registered = True

    def get_representation(self, x):
        if self.layer == -1:
            return self.net(x)

        if not self.hook_registered:
            self._register_hook()

        _ = self.net(x)
        hidden = self.hidden
        self.hidden = None
        assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
        return hidden

    def forward(self, x):
        representation = self.get_representation(x)

        return representation



def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()


class BarlowTwins(nn.Module):
    '''
    Adapted from https://github.com/facebookresearch/barlowtwins for arbitrary backbones, and arbitrary choice of which
    latent representation to use. Designed for models which can fit on a single GPU (though training can be parallelized
    across multiple as with any other model). Support for larger models can be done easily for individual use cases by
    by following PyTorch's model parallelism best practices.
    '''
    def __init__(self, backbone, latent_id, projection_sizes, lambd, scale_factor=1):
        '''

        :param backbone: Model backbone
        :param latent_id: name (or index) of the layer to be fed to the projection MLP
        :param projection_sizes: size of the hidden layers in the projection MLP
        :param lambd: tradeoff function
        :param scale_factor: Factor to scale loss by, default is 1
        '''
        super().__init__()
        self.backbone = backbone
        self.backbone = NetWrapper(self.backbone, latent_id)
        self.lambd = lambd
        self.scale_factor = scale_factor
        # projector
        sizes = projection_sizes
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
            layers.append(nn.BatchNorm1d(sizes[i + 1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
        self.projector = nn.Sequential(*layers)

        # normalization layer for the representations z1 and z2
        self.bn = nn.BatchNorm1d(sizes[-1], affine=False)

    def forward(self, y1, y2):
        z1 = self.backbone(y1)
        z2 = self.backbone(y2)
        z1 = self.projector(z1)
        z2 = self.projector(z2)

        # empirical cross-correlation matrix
        c = torch.mm(self.bn(z1).T, self.bn(z2))
        c.div_(z1.shape[0])


        # use --scale-loss to multiply the loss by a constant factor
        # see the Issues section of the readme
        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).pow_(2).sum()
        loss = self.scale_factor*(on_diag + self.lambd * off_diag)
        return loss

In [4]:

import torchvision


model = torchvision.models.resnet18(zero_init_residual=True)
proj = [512, 512, 512, 512]
twins = BarlowTwins(model, 'avgpool', proj, 0.5)
inp1 = torch.rand(2,3,224,224)
inp2 = torch.rand(2,3,224,224)
outs =twins(inp1, inp2)
#model = model_utils.extract_latent.LatentHook(model, ['avgpool'])
#out, dicti = model(inp1)
print(outs)
#print(model)

tensor(128094.7344, grad_fn=<MulBackward0>)


In [5]:
from torchvision import datasets
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.distributions import Bernoulli

import argparse
import os
import random
import numpy as np

In [6]:
device = 'cuda'

In [25]:
BATCH_SIZE=256
batch_size = BATCH_SIZE

In [14]:
from PIL import Image
from PIL import Image, ImageOps, ImageFilter
import torchvision.transforms as transforms
import random

In [17]:
image_size=224*224

In [18]:
from PIL import Image, ImageOps, ImageFilter
import torchvision.transforms as transforms
import random
'''
#####
Adapted from https://github.com/facebookresearch/barlowtwins
#####
'''


class GaussianBlur(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            sigma = random.random() * 1.9 + 0.1
            return img.filter(ImageFilter.GaussianBlur(sigma))
        else:
            return img


class Solarization(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            return ImageOps.solarize(img)
        else:
            return img


class Transform:
    def __init__(self, transform=None, transform_prime=None):
        '''

        :param transform: Transforms to be applied to first input
        :param transform_prime: transforms to be applied to second
        '''
        if transform == None:
            self.transform = transforms.Compose([
                transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply(
                    [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                            saturation=0.2, hue=0.1)],
                    p=0.8
                ),
                transforms.RandomGrayscale(p=0.2),
                GaussianBlur(p=1.0),
                Solarization(p=0.0),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])
        else:
            self.transform = transform
        if transform_prime == None:

            self.transform_prime = transforms.Compose([
                transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply(
                    [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                            saturation=0.2, hue=0.1)],
                    p=0.8
                ),
                transforms.RandomGrayscale(p=0.2),
                GaussianBlur(p=0.1),
                Solarization(p=0.2),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])
        else:
            self.transform_prime = transform_prime

    def __call__(self, x):
        y1 = self.transform(x)
        y2 = self.transform_prime(x)
        return y1, y2

In [19]:
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

In [42]:
image_size=[28,28]
transform = transforms.Compose([
                transforms.Resize((224,224)),
                transforms.RandomResizedCrop(image_size,
                                            interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply(
                    [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                            saturation=0.2, hue=0.1)],
                    p=0.8
                ),
                transforms.RandomGrayscale(p=0.2),
                GaussianBlur(p=1.0),
                Solarization(p=0.0),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
# --- data loading --- #
train_data = datasets.CIFAR10('./data', train=True, download=True,
                            transform=transform)
test_data = datasets.CIFAR10('./data', train=False,
                           transform=transform)
# pin memory provides improved transfer speed
kwargs = {'num_workers': 1, 'pin_memory': True} if device == 'cuda' else {}

train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=BATCH_SIZE, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=BATCH_SIZE, shuffle=True, **kwargs)

Files already downloaded and verified


In [47]:
transform = transforms.Compose([
                transforms.Resize((224,224)),
                transforms.RandomResizedCrop(image_size,
                                            interpolation=Image.BICUBIC),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply(
                    [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                            saturation=0.2, hue=0.1)],
                    p=0.8
                ),
                transforms.RandomGrayscale(p=0.2),
                GaussianBlur(p=1.0),
                Solarization(p=0.0),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])

#For the transform argument for the dataset, pass in 
# Twins.transform_utils.Transform(transform_1, transform_2)
#If transforms are None, the Imagenet default is used.
dataset = datasets.CIFAR10('./data', train=True, download=True,transform=Transform(transform, transform))

loader = torch.utils.data.DataLoader(dataset,
                                        batch_size=batch_size,
                                        shuffle=True)

Files already downloaded and verified


In [53]:
x1[0].shape

torch.Size([256, 3, 28, 28])

In [54]:
x1[1].shape

torch.Size([256, 3, 28, 28])

In [55]:
import torch
from torchvision import models
import torchvision.transforms as transforms
import torchvision.datasets as dsets

#This is just any generic model
model = torchvision.models.resnet18(zero_init_residual=True)

#Optional: define transformations for your specific dataset.
#Generally, it is best to use the original augmentations in the
#paper, replacing the Imagenet normalization with the normalization
#for your dataset.

#Make the BT instance, passing the model, the latent rep layer id,
# hidden units for the projection MLP, the tradeoff factor,
# and the loss scale.
model = torchvision.models.resnet18(zero_init_residual=True)

learner = BarlowTwins(model, 'avgpool', [512,1024, 1024, 1024],
                      3.9e-3, 1)

optimizer = torch.optim.Adam(learner.parameters(), lr=0.001)

#Single training epoch
for batch_idx, ((x1,x2), _) in enumerate(loader):
    print(batch_idx)
    loss = learner(x1, x2)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
