<a href="https://colab.research.google.com/github/eisbetterthanpi/vision/blob/main/vicreg_tut.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# https://arxiv.org/pdf/2105.04906.pdf
# https://github.com/facebookresearch/vicreg


In [1]:
# @title augmentations
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# https://github.com/facebookresearch/vicreg/blob/main/augmentations.py

from PIL import ImageOps, ImageFilter
import numpy as np
import torchvision.transforms as transforms
from torchvision.transforms import InterpolationMode

class GaussianBlur(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if np.random.rand() < self.p:
            sigma = np.random.rand() * 1.9 + 0.1
            # return img.filter(ImageFilter.GaussianBlur(sigma))
            return transforms.GaussianBlur(kernel_size=5, sigma=sigma)(img)
        else:
            return img

class Solarization(object):
    def __init__(self, p):
        self.p = p
    def __call__(self, img):
        if np.random.rand() < self.p:
            return ImageOps.solarize(img)
        else:
            return img


class TrainTransform(object):
    def __init__(self):
        self.transform = transforms.Compose([
                # transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC),#224
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], p=0.8,),
                transforms.RandomGrayscale(p=0.2),
                GaussianBlur(p=1.0),
                Solarization(p=0.0),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])
        self.transform_prime = transforms.Compose([
                # transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC),#224
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomApply([transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], p=0.8,),
                transforms.RandomGrayscale(p=0.2),
                GaussianBlur(p=0.1),
                Solarization(p=0.2),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ])

    def __call__(self, sample):
        # print("sample.shape",sample.shape)
        # sample=torch.squeeze(sample)
        # sample=transforms.ToPILImage()(sample)
        # sample = torch.vmap(transforms.ToPILImage(),sample)
        x1 = self.transform(sample)
        x2 = self.transform_prime(sample)
        # x2 = transforms.ToTensor()(sample)
        return x1, x2


In [2]:
# @title data
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor #, Lambda, Compose
import matplotlib.pyplot as plt
# https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html

# training_data = datasets.FashionMNIST(root="data", train=True, download=True,transform=ToTensor(),)
# test_data = datasets.FashionMNIST(root="data", train=False, download=True,transform=ToTensor(),)

# need rgb imgs?
# training_data = datasets.CIFAR10(root="data", train=True, download=True,transform=transforms.Compose([TrainTransform(), ToTensor()]),)
training_data = datasets.CIFAR10(root="data", train=True, download=True,transform=TrainTransform(),)
# test_data = datasets.CIFAR10(root="data", train=False, download=True, transform=transforms.Compose([transforms.RandomResizedCrop(32, interpolation=InterpolationMode.BICUBIC), ToTensor()]),)
test_data = datasets.CIFAR10(root="data", train=False, download=True,transform=ToTensor(),)


batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
ctraining_data = datasets.CIFAR10(root="data", train=True, download=True,transform=ToTensor(),)
ctrain_dataloader = DataLoader(ctraining_data, batch_size=batch_size)


dataiter = iter(test_dataloader)
x, labels = dataiter.next() # images, labels
# print(labels)
# print(y.shape)

import matplotlib.pyplot as plt
def imshow(img):
    # img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

import torchvision
# imshow(torchvision.utils.make_grid(x))
# imshow(torchvision.utils.make_grid(y))


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/cifar-10-python.tar.gz to data
Files already downloaded and verified
Files already downloaded and verified


In [9]:
# @title vicreg
import torch.nn.functional as F

def off_diagonal(x):
    # print("off_diagonal",x.shape)
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# https://github.com/facebookresearch/vicreg/blob/main/resnet.py
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        # dim_out=10
        dim_class=10
        dim_exp=128
        self.conv = nn.Sequential( # nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
            # nn.Conv2d(3, 8, 3, 1, 1), nn.BatchNorm2d(8), nn.ReLU(), nn.MaxPool2d(2, 2),
            # nn.Conv2d(8, 16, 5, 1, 2), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(2, 2),
            # nn.Conv2d(16, 16, 7, 1, 3), nn.BatchNorm2d(16), nn.ReLU(), #nn.MaxPool2d(2, 2),
            nn.Conv2d(3, 64, 3, 1, 1), nn.BatchNorm2d(64), nn.ReLU(), #nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 5, 1, 2), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, 7, 1, 3), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2, 2),
        )
        self.lin = nn.Sequential(
            nn.Linear(256 * 8 * 8, 4096), nn.ReLU(),
            nn.Linear(4096, 512), nn.ReLU(),
            nn.Linear(512, 64),
        )

        f=[80,100,128]
        self.exp = nn.Sequential(
            nn.Linear(64, f[0]), nn.BatchNorm1d(f[0]), nn.ReLU(),
            nn.Linear(f[0], f[1]), nn.BatchNorm1d(f[1]), nn.ReLU(),
            nn.Linear(f[1], f[-1])
            )
        self.classifier = nn.Linear(64, dim_class)

    def vicreg(self, x, y): # https://github.com/facebookresearch/vicreg/blob/main/main_vicreg.py
        # invariance loss
        repr_loss = F.mse_loss(x, y)
        
        # x = torch.cat(FullGatherLayer.apply(x), dim=0)
        # y = torch.cat(FullGatherLayer.apply(y), dim=0)
        x = x - x.mean(dim=0)
        y = y - y.mean(dim=0)

        # variance loss
        std_x = torch.sqrt(x.var(dim=0) + 0.0001)
        std_y = torch.sqrt(y.var(dim=0) + 0.0001)
        # print(x.var(dim=0),y.var(dim=0))
        # print(std_x , std_y)
        std_loss = torch.mean(F.relu(1 - std_x)) / 2 + torch.mean(F.relu(1 - std_y)) / 2
        # std_loss=0.02
        # print(torch.mean(F.relu(1 - std_x)) , torch.mean(F.relu(1 - std_y)))

        # # covariance loss
        # cov_x = (x.T @ x) / (self.args.batch_size - 1)
        # cov_y = (y.T @ y) / (self.args.batch_size - 1)
        # cov_loss = off_diagonal(cov_x).pow_(2).sum().div(self.num_features)\
        #  + off_diagonal(cov_y).pow_(2).sum().div(self.num_features)
        # loss = (self.args.sim_coeff * repr_loss + self.args.std_coeff * std_loss + self.args.cov_coeff * cov_loss)

        batch_size=x.size(dim=0)
        num_features=32
        sim_coeff=25.0 # λ / µ?
        std_coeff=25.0
        cov_coeff=1.0 # ν?

        # print("x.dim()",x.dim())
        if x.dim() == 1:
            x = x.view(-1, 1)
        if y.dim() == 1:
            y = y.view(-1, 1)
        x=x.T
        y=y.T
        # print("x",x.shape)
        cov_x = (x.T @ x) / (batch_size - 1)
        cov_y = (y.T @ y) / (batch_size - 1)
        # print("cov_x",cov_x.shape)
        cov_loss = off_diagonal(cov_x).pow_(2).sum().div(num_features)\
         + off_diagonal(cov_y).pow_(2).sum().div(num_features)
        # print("in vicreg",repr_loss , std_loss , cov_loss)
        loss = (sim_coeff * repr_loss + std_coeff * std_loss + cov_coeff * cov_loss)
        return loss
        
    def loss(self, sx,sy):
        sx = self.forward(sx)
        sy = self.forward(sy)
        vx = self.exp(sx)
        vy = self.exp(sy)
        loss = self.vicreg(vx,vy)
        return loss
    def forward(self, x):
        x = self.conv(x)
        # print("forward x",x.shape)
        x = nn.Flatten()(x)
        x = self.lin(x)
        return x
    def classify(self, x):
        x = self.classifier(x)
        return x
# softmax = nn.Softmax(dim=1)
# pred_probab = softmax(logits)
model = NeuralNetwork().to(device) # create an instance and move it to device (cache?)
# print(model)

# LARGE BATCH TRAINING OF CONVOLUTIONAL NETWORKS
# https://arxiv.org/pdf/1708.03888.pdf

# Barlow Twins: Self-Supervised Learning via Redundancy Reduction
# https://arxiv.org/pdf/2103.03230.pdf
# https://github.com/facebookresearch/barlowtwins/blob/main/main.py

# https://arxiv.org/search/?query=vicreg&searchtype=all


Using cuda device


In [20]:

X = torch.rand(64, 3, 32, 32, device=device)
logits = model(X)
print(logits.shape)
print(logits[0])
# print(logits[0].argmax(1))
pred_probab = nn.Softmax(dim=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

torch.Size([64, 64])
tensor([-0.1402, -0.0820, -0.0451,  0.0304,  0.0086,  0.0953,  0.0503, -0.1340,
         0.0013, -0.0722,  0.1491,  0.1574,  0.0774, -0.0917, -0.0467, -0.1056,
        -0.0304, -0.0396,  0.0885,  0.0470, -0.0372,  0.0181,  0.0521, -0.0003,
         0.0947, -0.1287, -0.1436,  0.0934,  0.1214, -0.0337,  0.1340,  0.0584,
         0.0443, -0.0786, -0.2327,  0.1081, -0.0696,  0.1807,  0.0284,  0.0042,
        -0.0128, -0.0948, -0.1110, -0.1459, -0.0806,  0.0351,  0.0168,  0.0287,
         0.0673,  0.0133, -0.0546, -0.0468, -0.1250, -0.0520, -0.0145, -0.0196,
        -0.0129,  0.0471, -0.0644, -0.0025, -0.0216,  0.0376, -0.0811,  0.0220],
       device='cuda:0', grad_fn=<SelectBackward0>)
Predicted class: tensor([37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37,
        37, 37, 37, 37, 37, 37, 37, 11, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37,
        37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37, 37,
        37, 37, 37, 37, 37, 37

In [7]:
# @title train test function
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    # for batch, (x, y) in enumerate(dataloader):
    for batch, ((x,y), labels) in enumerate(dataloader):
        sx, sy = x.to(device), y.to(device)
        # print("sx sy",sx.shape,sy.shape)
        # pred = model(sx)
        # loss = loss_fn(pred, sy)
        loss = model.loss(sx,sy)
        optimizer.zero_grad() # reset gradients of model parameters, to prevent double-counting
        loss.backward() # Backpropagate gradients
        optimizer.step() # adjust the parameters by the gradients
        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(x)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def train_classifier(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (x, y) in enumerate(dataloader):
    # for batch, ((x,y), labels) in enumerate(dataloader):
        sx, sy = x.to(device), y.to(device)
        sx = model(sx)
        pred = model.classify(sx)
        loss = loss_fn(pred, sy)
        optimizer.zero_grad() # reset gradients of model parameters, to prevent double-counting
        loss.backward() # Backpropagate gradients
        optimizer.step() # adjust the parameters by the gradients
        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(x)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            x, y = X.to(device), y.to(device)
            sx = model(x)
            pred = model.classify(sx)
            loss = loss_fn(pred, y)
            # predicted, actual = classes[pred[0].argmax(0)], classes[y]
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")


In [11]:
#  0.1020,  0.0527,  0.0185,  0.0295, -0.0470, -0.0641,  0.0206, -0.1019
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
coptimizer = torch.optim.SGD(model.classifier.parameters(), lr=1e-3)

epochs = 5 #5 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    train_classifier(ctrain_dataloader, model, loss_fn, coptimizer)
    test(test_dataloader, model, loss_fn)
print("Done!")
torch.save(model.state_dict(), "model.pth")

# model = NeuralNetwork().to(device)
# model.load_state_dict(torch.load("model.pth"))


Epoch 1
-------------------------------
loss: 15.119040  [    0/50000]
loss: 14.850348  [ 6400/50000]
loss: 13.976443  [12800/50000]
loss: 15.651359  [19200/50000]
loss: 14.788509  [25600/50000]
loss: 14.581607  [32000/50000]
loss: 15.632578  [38400/50000]
loss: 14.567408  [44800/50000]
loss: 2.269583  [    0/50000]
loss: 2.296555  [ 6400/50000]
loss: 2.238818  [12800/50000]
loss: 2.267452  [19200/50000]
loss: 2.258715  [25600/50000]
loss: 2.259788  [32000/50000]
loss: 2.282304  [38400/50000]
loss: 2.229334  [44800/50000]
Test Error: 
 Accuracy: 17.2%, Avg loss: 2.266839 

Epoch 2
-------------------------------
loss: 14.898645  [    0/50000]
loss: 14.611539  [ 6400/50000]
loss: 14.496476  [12800/50000]
loss: 14.202031  [19200/50000]
loss: 14.261537  [25600/50000]
loss: 15.205136  [32000/50000]
loss: 16.486755  [38400/50000]
loss: 14.219546  [44800/50000]
loss: 2.266939  [    0/50000]
loss: 2.288191  [ 6400/50000]
loss: 2.225401  [12800/50000]
loss: 2.258638  [19200/50000]
loss: 2.2492

In [None]:
# @title save
from google.colab import drive
drive.mount('/content/gdrive')
PATH="/content/gdrive/MyDrive/torch_save/" # for saving to google drive
name='vicreg_tut.pth'
# PATH="/content/" # for saving on colab only
# name='model.pth'

torch.save(model.state_dict(), PATH+name)

model = NeuralNetwork().to(device)
model.load_state_dict(torch.load(PATH+name))


Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


<All keys matched successfully>

In [None]:

def test():
    model.eval()
    n_correct = 0
    n_samples = 0
    with torch.no_grad():
        n_class_correct = [0 for i in range(10)]
        n_class_samples = [0 for i in range(10)]
        for images, labels in test_dataloader:
            images = images.to(device)
            labels = labels.to(device)
            # outputs = model(images)
            sx = model(images)
            outputs = model.classify(sx)
            # max returns (value ,index)
            _, predicted = torch.max(outputs, 1)
            n_samples += labels.size(0)
            n_correct += (predicted == labels).sum().item()
            for i in range(batch_size):
                print(len(labels))
                label = labels[i]
                pred = predicted[i]
                if (label == pred):
                    n_class_correct[label] += 1
                n_class_samples[label] += 1
        acc = 100.0 * n_correct / n_samples
        print(f'Accuracy of the network: {acc} %')
        for i in range(10):
            acc = 100.0 * n_class_correct[i] / n_class_samples[i]
            print(f'Accuracy of {classes[i]}: {acc} %')

test()

In [None]:
classes = ["T-shirt/top","Trouser","Pullover","Dress","Coat","Sandal","Shirt","Sneaker","Bag","Ankle boot",]
model.eval()
import random
n=random.randint(0,1000)
print(n)
x, y = test_data[n]
# print(x)
with torch.no_grad():
    # pred = model(x.to(device))
    # print(pred)
    # predicted, actual = classes[pred[0].argmax(0)], classes[y]

    # x, y = x.to(device), y.to(device)
    x = x.to(device)
    sx = model(x)
    pred = model.classify(sx)
    pred = torch.argmax(pred, dim=1).item()
    print(pred)
    print(y)
    # predicted, actual = classes[pred[0].argmax(0)], classes[y]
    # print(f'Predicted: "{predicted}", Actual: "{actual}"')

