## 0) The Necessary Imports

In [None]:
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [None]:
import random
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image # PIL is a library to process images
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm

import torchvision.datasets as datasets

import os
import time
import yaml

from google.colab import drive

In [None]:
drive.mount("/content/gdrive", force_remount=True) #Was used when training in COLLAB

Mounted at /content/gdrive


In [None]:
!unzip gdrive/My\ Drive/proj/unlabeled_data.zip #Was used to load the training data in COLLAB, not needed after the first run in an environment

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: unlabeled_data/506940.PNG  
  inflating: unlabeled_data/506941.PNG  
  inflating: unlabeled_data/506942.PNG  
  inflating: unlabeled_data/506943.PNG  
  inflating: unlabeled_data/506944.PNG  
  inflating: unlabeled_data/506945.PNG  
  inflating: unlabeled_data/506946.PNG  
  inflating: unlabeled_data/506947.PNG  
  inflating: unlabeled_data/506948.PNG  
  inflating: unlabeled_data/506949.PNG  
  inflating: unlabeled_data/506950.PNG  
  inflating: unlabeled_data/506951.PNG  
  inflating: unlabeled_data/506952.PNG  
  inflating: unlabeled_data/506953.PNG  
  inflating: unlabeled_data/506954.PNG  
  inflating: unlabeled_data/506955.PNG  
  inflating: unlabeled_data/506956.PNG  
  inflating: unlabeled_data/506957.PNG  
  inflating: unlabeled_data/506958.PNG  
  inflating: unlabeled_data/506959.PNG  
  inflating: unlabeled_data/506960.PNG  
  inflating: unlabeled_data/506961.PNG  
  inflating: unlabeled_data/50696

In [None]:
os.rename('unlabeled_data/512000.PNG', 'unlabeled_data/511999.PNG')
#os.rename('/content/unlabeled_data/512000.PNG', '/content/unlabeled_data/511999.PNG')

## 1) Training a ResNet50 backbone using a slightly modified public github repo

The version used in the project can be seen as a folder in the submission. main_vicreg was changed to allow training on our dataset, and hubconf.py was changed to be extra sure we are not using any pre-trained weights.

(Reference https://github.com/facebookresearch/vicreg)

This was explicitely stated to be allowed, as explified by the thread: https://campuswire.com/c/G55A3869E/feed/230

Code to download the original library:


```
%%bash

git clone https://github.com/facebookresearch/vicreg
```

Then, the following command is used to train a backbone with VICReg.

(the backbone in submission was trained for 3 epochs with base learning rate of 0.3. The output here reflects further experimentation which was not used in our last model due to time constraints)

In [None]:
!python -m torch.distributed.launch --nproc_per_node=1 /content/vicreg/main_vicreg.py --data-dir /content/unlabeled_data --exp-dir /content/exp --arch resnet50 --epochs 10 --batch-size 256 --base-lr 0.5

and will be removed in future. Use torchrun.
Note that --use_env is set by default in torchrun.
If your script expects `--local_rank` argument to be set, please
change it to read from `os.environ['LOCAL_RANK']` instead. See 
https://pytorch.org/docs/stable/distributed.html#launch-utility for 
further instructions

| distributed init (rank 0): env://, gpu 0
Namespace(arch='resnet50', base_lr=0.5, batch_size=256, cov_coeff=1.0, data_dir=PosixPath('/content/unlabeled_data'), device='cuda', dist_backend='nccl', dist_url='env://', epochs=10, exp_dir=PosixPath('/content/exp'), gpu=0, local_rank=0, log_freq_time=60, mlp='8192-8192-8192', num_workers=10, rank=0, sim_coeff=25.0, std_coeff=25.0, wd=1e-06, world_size=1)
/content/vicreg/main_vicreg.py --local_rank=0 --data-dir /content/unlabeled_data --exp-dir /content/exp --arch resnet50 --epochs 10 --batch-size 256 --base-lr 0.5
{"epoch": 0, "step": 41, "loss": 31.309356689453125, "time": 60, "lr": 0.001025}
{"epoch": 0, "step": 98, "loss": 26.1

# 2) Custom Model Building and Training

(Not used in final version of the backbone, but was used in earlier versions with lower precision)

### 2.1>1) Directly use classes from eval.py here to ensure compatibility

In [None]:
class UnlabeledDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform):
        r"""
        Args:
            root: Location of the dataset folder, usually it is /unlabeled
            transform: the transform you want to applied to the images.
        """
        self.transform = transform
        self.image_dir = root
        self.num_images = len(os.listdir(self.image_dir))

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        # the idx of unlabeled image is from 0
        with open(os.path.join(self.image_dir, f"{idx}.PNG"), "rb") as f: 
            img = Image.open(f).convert("RGB")

        return self.transform(img)

In [None]:
device = torch.device("cuda") 
print(device)

cuda


### 2.1.2) Make the necessary definitions

In [None]:
UNLABELED_DATASET_PATH = '/content/unlabeled_data' 
BATCH_SIZE = 64

"""
def unlabeled_transform(x):
    return torchvision.transforms.functional.to_tensor(x)
"""

unlabeled_dataset = UnlabeledDataset(
    root=UNLABELED_DATASET_PATH,
    transform=transforms, #aug.TrainTransform()
)

unlabeled_loader = torch.utils.data.DataLoader(
    unlabeled_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
)

### 2.2) Defining data agumentation and loss functions

In [None]:
import torch
import torchvision.transforms as T

# Taken from online page, source: https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/transformations/simclr.py
# + changed to match the augemntation described in VICReg paper better
class Augment:
   """
   A stochastic data augmentation module
   Transforms any given data example randomly
   resulting in two correlated views of the same example,
   denoted x ̃i and x ̃j, which we consider as a positive pair.
   """

   def __init__(self, img_size, s=1):
       color_jitter = T.ColorJitter(
           0.4 * s, 0.4 * s, 0.2 * s, 0.1 * s
       )
       # 10% of the image
       blur = T.GaussianBlur(23) #T.GaussianBlur((3, 3), (0.1, 2.0))

       self.train_transform = torch.nn.Sequential(
           T.RandomResizedCrop(size=img_size),
           T.RandomHorizontalFlip(p=0.5),  # with 0.5 probability
           T.RandomApply([color_jitter], p=0.8),
           T.RandomGrayscale(p=0.2),
           T.RandomApply([blur], p=0.5),
           #T.RandomSolarize(threshold=192.0, p=0.1),
           T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
       )

   def __call__(self, x):
       return (self.train_transform(x), self.train_transform(x))

'\nimport torch\nimport torchvision.transforms as T\n\n# Taken from online page, will find source for reference\n# + changed with the augemntation described in VICReg paper\nclass Augment:\n   """\n   A stochastic data augmentation module\n   Transforms any given data example randomly\n   resulting in two correlated views of the same example,\n   denoted x ̃i and x ̃j, which we consider as a positive pair.\n   """\n\n   def __init__(self, img_size, s=1):\n       color_jitter = T.ColorJitter(\n           0.4 * s, 0.4 * s, 0.2 * s, 0.1 * s\n       )\n       # 10% of the image\n       blur = T.GaussianBlur(23) #T.GaussianBlur((3, 3), (0.1, 2.0))\n\n       self.train_transform = torch.nn.Sequential(\n           T.RandomResizedCrop(size=img_size),\n           T.RandomHorizontalFlip(p=0.5),  # with 0.5 probability\n           T.RandomApply([color_jitter], p=0.8),\n           T.RandomGrayscale(p=0.2),\n           T.RandomApply([blur], p=0.5),\n           #T.RandomSolarize(threshold=192.0, p=0

In [None]:
#The VICReg loss function
#Used source: https://github.com/untitled-ai/self_supervised/blob/7a43c4ae2c2d42cb68e688c8fec86948c9547e72/moco.py
#Constants from paper: https://openreview.net/pdf?id=xm6YD62D1Ub
def VICReg_loss(z_a, z_b):
        assert z_a.shape == z_b.shape and len(z_a.shape) == 2
        
        variance_loss_epsilon = 0.0001
        
        # invariance loss
        loss_inv = F.mse_loss(z_a, z_b)

        # variance loss
        std_z_a = torch.sqrt(z_a.var(dim=0) + variance_loss_epsilon)
        std_z_b = torch.sqrt(z_b.var(dim=0) + variance_loss_epsilon)
        loss_v_a = torch.mean(F.relu(1 - std_z_a))
        loss_v_b = torch.mean(F.relu(1 - std_z_b))
        loss_var = loss_v_a + loss_v_b

        # covariance loss
        N, D = z_a.shape
        z_a = z_a - z_a.mean(dim=0)
        z_b = z_b - z_b.mean(dim=0)
        cov_z_a = ((z_a.T @ z_a) / (N - 1)).square()  # DxD
        cov_z_b = ((z_b.T @ z_b) / (N - 1)).square()  # DxD
        loss_c_a = (cov_z_a.sum() - cov_z_a.diagonal().sum()) / D
        loss_c_b = (cov_z_b.sum() - cov_z_b.diagonal().sum()) / D
        loss_cov = loss_c_a + loss_c_b

        weighted_inv = loss_inv * 25 #invariance_loss_weight
        weighted_var = loss_var * 25 #variance_loss_weight
        weighted_cov = loss_cov * 1 #covariance_loss_weight

        loss = weighted_inv + weighted_var + weighted_cov

        return loss

### 2.3) Building The Self-Supervised Architecture

Chosen regularization method: Information Maximization

In [None]:
class SSL_NN(nn.Module):
    def __init__(self, backbone, poller, representation_size, expanded_size):
        super().__init__()
        self.representation_size = representation_size
        self.layer_size = expanded_size
        self.backbone = backbone
        self.poller = poller
        self.expander = nn.Sequential(  nn.Linear(in_features=self.representation_size, out_features=self.layer_size),
                                        nn.BatchNorm1d(self.layer_size),
                                        nn.ReLU(),
                                        nn.Linear(in_features=self.layer_size, out_features=self.layer_size),
                                        nn.BatchNorm1d(self.layer_size),
                                        nn.ReLU(),
                                        nn.Linear(in_features=self.layer_size, out_features=self.layer_size),
                                        nn.BatchNorm1d(self.layer_size))

    def forward(self, x):
        x = self.backbone(x)
        #x = self.poller(x)
        x = torch.squeeze(x).view(-1, self.representation_size)
        x = self.expander(x)
        return x

In [None]:
resnet_model = torchvision.models.resnet50(pretrained=False)



In [None]:
encoder = nn.Sequential(*list(resnet_model.children())[:8]) #(does not includes avg pooling) #Output dimension: 2048

In [None]:
avg_pool = list(resnet_model.children())[8]

Now, we also need an expander for our architecture:

In [None]:
#Using VICReg for regularization
#References: https://generallyintelligent.ai/open-source/2022-04-21-vicreg/
#https://www.arxiv-vanity.com/papers/2105.04906/
#Min paper: https://arxiv.org/pdf/2105.04906.pdf

repr_size = 2048
expanded_size = 8192
ssl_model = SSL_NN(encoder, avg_pool, repr_size, expanded_size)

In [None]:
#ssl_model

### 2.4) Training the backbone network 

In [None]:
from collections import defaultdict, deque

augmenter = Augment(224)

def get_backbone_loss(model, batch, criterion, device):
    # Implement forward pass and loss calculation for one batch.
    # Remember to move the batch to device.
    # 
    # Return loss for the batch (Tensor)
    data = batch
    data = data.to(device)
    
    X1, X2 = augmenter(data)
    #To see the inouts side by side:
    #plt.imshow(  X1[0].cpu().permute(1, 2, 0),   )
    #plt.show()
    #plt.imshow(  X2[0].cpu().permute(1, 2, 0)  )
    #plt.show()
    Z1 = model(X1)
    Z2 = model(X2)
    batch_loss = criterion(Z1, Z2)

    return batch_loss

def step(loss, optimizer):
    # Implement backward pass and update.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

from tqdm.notebook import tqdm
    
def backbone_train(model, data_loader, criterion, optimizer, device, N_EPOCHS = 5):
    model.to(device)
    
    train_losses = []
    
    pbar = tqdm(range(N_EPOCHS))
    for i in pbar:
        model.train()
        
        total_train_loss = 0
        
        i = 0
        for batch in tqdm(data_loader):
            try:
              loss = get_backbone_loss(model, batch, criterion, device)
              step(loss, optimizer)
              total_train_loss += loss.item()
              i += 1
              #print("Batch: ", i, " | Loss:" , loss.item())
              if(i % 20 == 0):
                print("Batch: ", i, " | Loss:" , loss.item())
            except:
                print("An exception occurred in Batch: ", i)
        
        mean_train_loss = total_train_loss / 8000 #hardcoded number of batches, may change by batch size
        print(mean_train_loss)
        train_losses.append(mean_train_loss)
        pbar.set_postfix({'train_loss': mean_train_loss})#
        
    optimizer.zero_grad() #leave gradients at zero after training, just as a safety

In [None]:
criterion = VICReg_loss
optimizer = torch.optim.SGD(ssl_model.parameters(), lr=0.0001, momentum=0.9)

backbone_train(ssl_model, unlabeled_loader, criterion,optimizer, device, 5)

In [None]:
#ssl_model

In [None]:
torch.save(ssl_model, '/content/gdrive/MyDrive/proj/ssl_model_new2.pt')

In [None]:
backbone = ssl_model.backbone

In [None]:
torch.save(backbone, '/content/gdrive/proj/backbone_model.pt')

In [None]:
#backbone

In [None]:
#ssl_model2 = torch.load("/content/gdrive/MyDrive/proj/ssl_model_new.pt")

In [None]:
#criterion = VICReg_loss
#optimizer = torch.optim.SGD(ssl_model2.parameters(), lr=0.00001, momentum=0.9)

#backbone_train(ssl_model2, unlabeled_loader, criterion, optimizer, device, 5)