# BYOL

In this notebook we are going to implement [BYOL: Bootstrap Your Own Latent](https://arxiv.org/pdf/2006.07733.pdf) and compare the results of a classification task before and after pretraining the model with BYOL.

### Data Augmentations

In [None]:
import random
from typing import Callable, Tuple
import torch
import torchvision
from torch import nn, Tensor
from torchvision import transforms as T
from torch.nn import functional as F


class RandomApply(nn.Module):
    def __init__(self, fn: Callable, p: float):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x: Tensor) -> Tensor:
        if random.random() < self.p:
            x = self.fn(x)
        return x


def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
    """
        1. resize images to 'image_size'
        2. RandomApply color jitter
        3. RandomApply grayscale
        4. RandomApply horizon flip
        5. RandomApply gaussian blur with kernel_size(3, 3), sigma=(1.5, 1.5)
        6. RandomApply ResizedCrop to 'image_size'
        7. Normalize
        choosing hyperparameters that are not mentioned is up to you
    """
    return nn.Sequential(
        # your code

        T.Resize((image_size)),
        RandomApply(T.ColorJitter(),0.3),
        RandomApply(T.Grayscale(num_output_channels=3),0.3),#return 3 channel or 1?
        RandomApply(T.RandomHorizontalFlip(p=1),0.3),#I used RandomHorizontalFlip with p=1 so it always happens then I used randomapply
        RandomApply(T.GaussianBlur(kernel_size=(3, 3), sigma=(1.5, 1.5)),0.3),
        RandomApply(T.RandomResizedCrop(image_size),0.3),
        T.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225]),
        )
    )

# Model
We will use ResNet18 as our representation model.

In [None]:
def get_encoder_model():
    resnet = torchvision.models.resnet18()
    # remove last fully-connected layer
    # your code
    layers = list(resnet.children())[:-1]
    return nn.Sequential(*layers)
#get_encoder_model()(torch.randn(1,3,224,224)) #just to ensure it works

### Loss Function
We need to use NormalizedMSELoss as our loss function.
$$NormalizedMSELoss(v_1, v_2) = \Vert \bar{v_1} - \bar{v_2}\Vert_2^2 = 2 - 2.\frac{\langle v_1, v_2 \rangle}{\Vert v_1\Vert_2 \Vert v_2\Vert_2}$$

In [None]:
class NormalizedMSELoss(nn.Module):
    def __init__(self) -> None:
        super(NormalizedMSELoss,self).__init__()

    def forward(self, view1: Tensor, view2: Tensor) -> Tensor:
        # your code
        norm_view1 = torch.sqrt(torch.sum(view1 ** 2, dim=1, keepdim=True))
        norm_view2 = torch.sqrt(torch.sum(view2 ** 2, dim=1, keepdim=True))
        dot_product = view1*view2

        return 2 - 2*torch.sum((dot_product/(norm_view1*norm_view2)),dim=-1)

### MLP
Here you will implement a simple MLP class with one hidden layer with BatchNorm and ReLU activation, and a linear output layer. This class will be used for both the projections and the prediction networks.

In [None]:
class MLP(nn.Module):
    def __init__(self, input_dim: int, projection_dim: int = 256, hidden_dim: int = 4096) -> None:
        super(MLP,self).__init__()

        # your code
        self.ly1 = nn.Linear(input_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, projection_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.flatter = nn.Flatten()

    def forward(self, x: Tensor) -> Tensor:
        # your code
        #print(x.shape)
        x = self.flatter(x)
        #print(x.shape)
        x = F.relu(self.bn1(self.ly1(x)))
        x = self.out(x)
        #print('mlp',x.shape)
        return x

### Encoder + Projector Network
This is the network structure that is shared between online and target networks. It consists of our encoder model, followed by a projection MLP.

In [None]:
class EncoderProjecter(nn.Module):
    def __init__(self,
                 encoder: nn.Module,
                 hidden_dim: int = 4096,
                 projection_out_dim: int = 256
                 ) -> None:
        super(EncoderProjecter, self).__init__()

        # your code
        self.encoder = encoder
        self.projector = MLP(512)
    def forward(self, x: Tensor) -> Tensor:
        # your code
        x = self.encoder(x)
        x = self.projector(x)
        #print(x.shape)
        return x

## BYOL

In [None]:
import copy

class BYOL(nn.Module):
    def __init__(self,
                 model: nn.Module,
                 hidden_dim: int = 4096,
                 projection_out_dim: int = 256,
                 target_decay: float = 0.99

                ) -> None:
        super(BYOL, self).__init__()

        # your code
        self.online_network = nn.Sequential(model,MLP(512))# encoder + projector
        self.online_predictor = MLP(256)

        self.target_network =  nn.Sequential(model,MLP(512))  # init with copy of parameters of online network
        # set target_network's weights to be untrainable
        #self.target_network.load_state_dict(copy.deepcopy(self.online_network.state_dict()))
        for parameter_t,parameter_o in zip(self.target_network.parameters(),self.online_network.parameters()):
          parameter_t.data = copy.deepcopy(parameter_o.data)
          parameter_t.requires_grad = False

        self.tau = target_decay

        self.target_network.eval()
        self.loss_function = NormalizedMSELoss()


    @torch.no_grad()
    def soft_update_target_network(self) -> None:
        # your code
        for parameter_t,parameter_o in zip(self.target_network.parameters(),self.online_network.parameters()):
          parameter_t.data = parameter_o.data*self.tau + parameter_t.data*(1. - self.tau) #This is the moving average


    def forward(self, view) -> Tuple[Tensor]:
        # return online projection and target projection of view
        # your code
        online_projection = self.online_network(view)
        #online_prediction = self.online_predictor(online_projection)#This block asks just for online projection
        #target projection
        with torch.no_grad():
          target_projection = self.target_network(view)

        return online_projection,target_projection

    def loss(self, view1, view2):
        # compute loss once for (online_prediction1, target_projection2) and once for (online_prediction2, target_projection1).
        # then return the mean.
        # your code
        online_projection_view1,target_projection_view1 = self.forward(view1)
        #print('check',online_projection_view1.shape,target_projection_view1.shape)
        #print(online_projection_view1.shape)
        online_prediction_view1 = self.online_predictor(online_projection_view1)
        loss = self.loss_function(online_prediction_view1,target_projection_view1)

        online_projection_view2,target_projection_view2 = self.forward(view2)
        online_prediction_view2 = self.online_predictor(online_projection_view2)
        loss += self.loss_function(online_prediction_view2,target_projection_view2)


        return torch.mean(loss)

# STL10 Datasets

We need 3 separate datasets from STL10 for this experiment:
1. `"train"` -- Contains only labeled training images. Used for supervised training.
2. `"train+unlabeled"` -- Contains training images, plus a large number of unlabelled images.  Used for self-supervised learning with BYOL.
3. `"test"` -- Labeled test images.  We use it both as a validation set, and for computing the final model accuracy.

In [None]:
from torchvision.datasets import STL10
from torchvision.transforms import ToTensor

TRAIN_DATASET = STL10(root="data", split="train", download=True, transform=ToTensor())
TRAIN_UNLABELED_DATASET = STL10(root="data", split="train+unlabeled", download=True, transform=ToTensor())
TEST_DATASET = STL10(root="data", split="test", download=True, transform=ToTensor())

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to data/stl10_binary.tar.gz


100%|██████████| 2640397119/2640397119 [01:26<00:00, 30380405.95it/s]


Extracting data/stl10_binary.tar.gz to data
Files already downloaded and verified
Files already downloaded and verified


Create dataloaders:

In [None]:
# your code
from torch.utils.data import DataLoader
train_dataloader = DataLoader(dataset=TRAIN_DATASET,batch_size=1024,shuffle=True)
train_unlabeld_dataloader = DataLoader(dataset=TRAIN_UNLABELED_DATASET,batch_size=1024,num_workers=2,shuffle=True)
test_dataloader = DataLoader(dataset=TEST_DATASET,batch_size=1024,shuffle=True)

# Supervised Training without BYOL

First create a classifier model by simply adding a linear layer on top of the encoder model. Then train the model using the labeled training set. Performance should be pretty good already.

In [None]:
encoder = get_encoder_model()
# your code
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

classifier = nn.Linear(512,10)
model = nn.Sequential(encoder,torch.nn.Flatten() ,classifier).to(DEVICE)

In [None]:
from tqdm import tqdm
epochs = 150
lr = 1e-2

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

for epoch in tqdm(range(epochs)):
    total_loss = 0.
    for batch, target in train_dataloader:
        batch = batch.to(DEVICE)
        target = target.to(DEVICE)

        optimizer.zero_grad()
        output = model(batch)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    if epoch % 20 == 0:
        eval_loss = 0.
        correct = 0
        with torch.no_grad():
            for batch, target in test_dataloader:
                batch = batch.to(DEVICE)
                target = target.to(DEVICE)

                output = model(batch)
                eval_loss += criterion(output, target).item()

                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        eval_loss /= len(test_dataloader.dataset)
        accuracy = correct / len(test_dataloader.dataset)
        print(f"Epoch {epoch}: Train loss = {total_loss:.4f}, Test loss = {eval_loss:.4f}, Test accuracy = {accuracy:.4f}")

  1%|          | 1/150 [00:17<44:17, 17.84s/it]

Epoch 0: Train loss = 17.9206, Test loss = 0.0027, Test accuracy = 0.1429


 14%|█▍        | 21/150 [02:12<15:45,  7.33s/it]

Epoch 20: Train loss = 6.1037, Test loss = 0.0013, Test accuracy = 0.4979


 27%|██▋       | 41/150 [04:08<13:03,  7.19s/it]

Epoch 40: Train loss = 1.0280, Test loss = 0.0020, Test accuracy = 0.5643


 41%|████      | 61/150 [06:06<10:48,  7.29s/it]

Epoch 60: Train loss = 0.1171, Test loss = 0.0029, Test accuracy = 0.5933


 54%|█████▍    | 81/150 [08:03<08:15,  7.19s/it]

Epoch 80: Train loss = 0.0041, Test loss = 0.0034, Test accuracy = 0.5897


 67%|██████▋   | 101/150 [10:00<05:51,  7.17s/it]

Epoch 100: Train loss = 0.0124, Test loss = 0.0035, Test accuracy = 0.5916


 81%|████████  | 121/150 [11:57<03:30,  7.27s/it]

Epoch 120: Train loss = 0.1291, Test loss = 0.0030, Test accuracy = 0.5797


 94%|█████████▍| 141/150 [13:53<01:04,  7.15s/it]

Epoch 140: Train loss = 0.0004, Test loss = 0.0033, Test accuracy = 0.5951


100%|██████████| 150/150 [14:43<00:00,  5.89s/it]


### Self-Supervised Training with BYOL

Now perform the self-supervised training. This is the most computationally intensive part of the script.

In [None]:
epochs=20
lr = 1e-2

model_ = BYOL(get_encoder_model()).to(DEVICE)
optimizer = torch.optim.Adam(model_.parameters(),lr=lr)

augmenter = default_augmentation()
for epoch in tqdm(range(epochs)):
  total_loss = 0.
  for data, _ in  train_unlabeld_dataloader:

    data = data.to(DEVICE)
    view1 = augmenter(data)
    view2 = augmenter(data)

    optimizer.zero_grad()

    loss = model_.loss(view1,view1)

    loss.backward()
    optimizer.step()
    model_.soft_update_target_network()

    total_loss += loss.item()
  #model.eval()
  eval_loss = 0.
  correct = 0
  with torch.no_grad():
      for batch, target in test_dataloader:
          batch = batch.to(DEVICE)
          target = target.to(DEVICE)

          output = model(batch)
          eval_loss += criterion(output, target).item()

          pred = output.argmax(dim=1, keepdim=True)
          correct += pred.eq(target.view_as(pred)).sum().item()

  eval_loss /= len(test_dataloader.dataset)
  accuracy = correct / len(test_dataloader.dataset)

  # Print training and evaluation metrics
  if epoch % 2 == 0:
      print(f"Epoch {epoch}: Train loss = {total_loss:.4f}, Test loss = {eval_loss:.4f}, Test accuracy = {accuracy:.4f}")

  5%|▌         | 1/20 [06:03<1:55:08, 363.60s/it]

Epoch 0: Train loss = 138.6696, Test loss = 0.0032, Test accuracy = 0.5976


 15%|█▌        | 3/20 [18:11<1:43:04, 363.77s/it]

Epoch 2: Train loss = 3.7918, Test loss = 0.0033, Test accuracy = 0.5975


 25%|██▌       | 5/20 [30:18<1:30:56, 363.77s/it]

Epoch 4: Train loss = 0.2558, Test loss = 0.0033, Test accuracy = 0.5981


 35%|███▌      | 7/20 [42:25<1:18:44, 363.41s/it]

Epoch 6: Train loss = 0.1170, Test loss = 0.0033, Test accuracy = 0.5975


 45%|████▌     | 9/20 [54:30<1:06:32, 362.97s/it]

Epoch 8: Train loss = 0.0697, Test loss = 0.0033, Test accuracy = 0.5954


 55%|█████▌    | 11/20 [1:06:36<54:27, 363.01s/it]  

Epoch 10: Train loss = 0.0617, Test loss = 0.0033, Test accuracy = 0.5969


 65%|██████▌   | 13/20 [1:18:41<42:19, 362.74s/it]

Epoch 12: Train loss = 0.0471, Test loss = 0.0032, Test accuracy = 0.5995


 75%|███████▌  | 15/20 [1:30:46<30:13, 362.74s/it]

Epoch 14: Train loss = 0.0537, Test loss = 0.0033, Test accuracy = 0.5960


 85%|████████▌ | 17/20 [1:42:52<18:07, 362.66s/it]

Epoch 16: Train loss = 0.0647, Test loss = 0.0033, Test accuracy = 0.5971


 95%|█████████▌| 19/20 [1:54:58<06:02, 362.86s/it]

Epoch 18: Train loss = 0.0535, Test loss = 0.0033, Test accuracy = 0.5965


100%|██████████| 20/20 [2:01:01<00:00, 363.09s/it]


### Supervised Training Again

Extract the encoder network's state dictionary from BYOL, and load it into our ResNet18 model before starting training.  Then run supervised training, and watch the accuracy improve from last time!

In [None]:
byol_res= copy.deepcopy(list(model_.children())[2][:-1])
byol_res_state = byol_res.state_dict()

#There was 0. before each key name so we need to clean it before loading it
byol_res_state_edited = {}
for key in byol_res_state.keys():
    byol_res_state_edited[key[2:]] = byol_res_state[key]

encoder_= get_encoder_model()
encoder_.load_state_dict(byol_res_state_edited)

<All keys matched successfully>

In [None]:
classifier = nn.Linear(512,10)
model_last = nn.Sequential(encoder_,torch.nn.Flatten() ,classifier).to(DEVICE)
epochs=150
lr = 1e-2

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_last.parameters(),lr=lr)

for epoch in tqdm(range(epochs)):
    # Training loop
    total_loss = 0.
    for batch, target in train_dataloader:
        batch = batch.to(DEVICE)
        target = target.to(DEVICE)

        optimizer.zero_grad()
        output = model_last(batch)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Evaluation loop
    #model_last.eval()
    eval_loss = 0.
    correct = 0
    with torch.no_grad():
        for batch, target in test_dataloader:
            batch = batch.to(DEVICE)
            target = target.to(DEVICE)

            output = model_last(batch)
            eval_loss += criterion(output, target).item()

            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    eval_loss /= len(test_dataloader.dataset)
    accuracy = correct / len(test_dataloader.dataset)

    # Print training and evaluation metrics
    if epoch % 20 == 0:
        print(f"Epoch {epoch}: Train loss = {total_loss:.4f}, Test loss = {eval_loss:.4f}, Test accuracy = {accuracy:.4f}")

  1%|          | 1/150 [00:11<27:47, 11.19s/it]

Epoch 0: Train loss = 18.3458, Test loss = 0.0025, Test accuracy = 0.1000


 14%|█▍        | 21/150 [03:51<23:32, 10.95s/it]

Epoch 20: Train loss = 5.5304, Test loss = 0.0014, Test accuracy = 0.4998


 27%|██▋       | 41/150 [07:33<20:13, 11.13s/it]

Epoch 40: Train loss = 0.6332, Test loss = 0.0023, Test accuracy = 0.5539


 41%|████      | 61/150 [11:14<16:28, 11.11s/it]

Epoch 60: Train loss = 0.0309, Test loss = 0.0031, Test accuracy = 0.5697


 54%|█████▍    | 81/150 [14:56<12:44, 11.08s/it]

Epoch 80: Train loss = 0.0004, Test loss = 0.0031, Test accuracy = 0.5863


 67%|██████▋   | 101/150 [18:37<09:01, 11.05s/it]

Epoch 100: Train loss = 0.0002, Test loss = 0.0032, Test accuracy = 0.5885


 81%|████████  | 121/150 [22:17<05:19, 11.00s/it]

Epoch 120: Train loss = 0.0002, Test loss = 0.0032, Test accuracy = 0.5893


 94%|█████████▍| 141/150 [25:57<01:38, 10.92s/it]

Epoch 140: Train loss = 0.0001, Test loss = 0.0033, Test accuracy = 0.5859


100%|██████████| 150/150 [27:36<00:00, 11.04s/it]
