# Bootstrap Your Own Latent (BYOL)
## By: Karim Zakir
BYOL is a representation learning method. The idea behind this method is that the same image with two different augmentations has the same "essence" and thus should have the same or similar embeddings. This method was originally published in [this paper](https://arxiv.org/abs/2006.07733v3) in January 2020. 

Contrastive learning is a deep learning technique in which we try to learn representations of different objects with the goal that embeddings of similar objects will also be similar. Most contrastive learning architectures/methods require "negative pairs", which make constrastive learning computationally expensive. Unlike those methods, BYOL does not require negative pairs, so it's a lot more efficient!
 
In BYOL, we start off by taking an image and applying two different augmentations on it ($t$ and $t'$). These two images are fed to two different networks called online and target. Let's first focus on the online network. After applying augmentation $t$ on a given image, we feed it into the first part of the online network, which is an encoder, which could be any network that transforms an image to features. frf

In [80]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms

from copy import deepcopy
import math
from tqdm import tqdm

In [3]:
def online_augmentation():
    return transforms.Compose([
        transforms.RandomResizedCrop((224, 224), 
                                     interpolation=transforms.InterpolationMode.BICUBIC),

        transforms.RandomHorizontalFlip(),

        transforms.RandomApply(
            [transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8
        ),

        transforms.RandomApply([transforms.Grayscale(num_output_channels=3)], p=0.2),

        transforms.GaussianBlur((23, 23)),
    ])


def target_augmentation():
    return transforms.Compose([
        transforms.RandomResizedCrop((224, 224),
                                     interpolation=transforms.InterpolationMode.BICUBIC),
        
        transforms.RandomHorizontalFlip(),

        transforms.RandomApply(
            [transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8
        ),

        transforms.RandomApply([transforms.Grayscale(num_output_channels=3)], p=0.2),

        transforms.RandomApply([transforms.GaussianBlur((23, 23))], p=0.1),

        transforms.RandomSolarize(threshold=0.5, p=0.2)
    ])

In [111]:
class BYOL(nn.Module):

    def __init__(self, online_augmentation, target_augmentation, encoder_model):

        super().__init__()

        self.online_augmentation = online_augmentation
        self.target_augmentation = target_augmentation
    
        self.encoder = encoder_model
        self.encoder.fc = nn.Identity()

        self.projector = nn.Sequential(
            nn.Linear(2048, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 256)
        )

        self.predictor = nn.Sequential(
            nn.Linear(256, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 256)
        )

        self.online_network = nn.Sequential(
            self.encoder,
            self.projector,
            self.predictor
        )

        self.target_network = deepcopy(nn.Sequential(
            self.encoder,
            self.projector
        ))

        self.loss_fn = nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.online_network.parameters())
        self.tau_base = 0.996
        self.tau = self.tau_base

    def fit(self, train_loader, val_loader, epochs=1000, verbose=True):

        train_loss = []
        val_loss = []

        for epoch in tqdm(range(epochs)):
            
            self.train(train_loader)

            self.tau += 1 - (1 - self.tau_base) * (math.cos(math.pi*(epoch + 1) / epochs) + 1)/2

            train_loss.append(self.validate(train_loader))
            val_loss.append(self.validate(val_loader))
            print(f"Epoch {epoch}")
            print(f"Train Loss: {train_loss}")
            print(f"Validation Loss: {val_loss}")
        
        return train_loss, val_loss

    def train(self, train_loader):

        self.online_network.train()
        self.target_network.train()

        for batch_X, batch_y in train_loader:
            self.optimizer.zero_grad()

            batch_loss = self.forward(batch_X)

            batch_loss.backward()
            self.optimizer.step()

            # Update target's parameters
            for target_param, online_param in zip(self.online_network.parameters(), self.target_network.parameters()):
                target_param = self.tau * target_param + (1 - self.tau) * online_param
            print("finished batch")


        self.online_network.eval()
        self.target_network.eval()

    def forward(self, batch):
        view = self.online_augmentation(batch)
        view_prime = self.target_augmentation(batch)

        online_output, online_output_prime = self.online_network(view), self.online_network(view_prime)
        online_output, online_output_prime = F.normalize(online_output), F.normalize(online_output_prime)

        with torch.no_grad():
            target_output, target_output_prime = self.target_network(view_prime), self.target_network(view)
            target_output, target_output_prime = F.normalize(target_output), F.normalize(target_output_prime)

        loss = self.loss_fn(online_output, target_output) + self.loss_fn(online_output_prime, target_output_prime)
        return loss
    
    def validate(self, loader):
        loss = 0
        count = 0

        for batch in loader:
            with torch.no_grad():
                loss += self.forward(batch)
            count += len(batch)
        
        return loss / count

In [112]:
byol = BYOL(online_augmentation(), target_augmentation(), torchvision.models.resnet50())

In [98]:
train_labeled = torchvision.datasets.STL10(".", split="train", download=True, transform=transforms.ToTensor())
train_unlabeled = torchvision.datasets.STL10(".", split="unlabeled", download=True, transform=transforms.ToTensor())
test_labeled = torchvision.datasets.STL10(".", split="test", download=True, transform=transforms.ToTensor())

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [77]:
BATCH_SIZE = 128

train_ll = DataLoader(train_labeled, BATCH_SIZE, shuffle=True)
train_ul = DataLoader(train_unlabeled, BATCH_SIZE, shuffle=True)
test_labeled = DataLoader(test_labeled, BATCH_SIZE, shuffle=True)

In [113]:
train_loss, val_loss = byol.fit(train_ul, train_ll, 1000)

  0%|          | 0/1000 [00:35<?, ?it/s]


RuntimeError: [enforce fail at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\c10\core\impl\alloc_cpu.cpp:72] data. DefaultCPUAllocator: not enough memory: you tried to allocate 102760448 bytes.

In [107]:
byol.online_network.state_dict()

OrderedDict([('0.conv1.weight',
              tensor([[[[-0.0061, -0.0107, -0.0030,  ...,  0.0159,  0.0226, -0.0064],
                        [-0.0075,  0.0045,  0.0272,  ...,  0.0459, -0.0291, -0.0284],
                        [-0.0090, -0.0009, -0.0738,  ..., -0.0022,  0.0124,  0.0108],
                        ...,
                        [-0.0042,  0.0178, -0.0204,  ...,  0.0200, -0.0147, -0.0097],
                        [-0.0258,  0.0106, -0.0335,  ..., -0.0050,  0.0155,  0.0050],
                        [ 0.0214, -0.0898,  0.0110,  ...,  0.0368,  0.0060, -0.0151]],
              
                       [[ 0.0228, -0.0120,  0.0484,  ..., -0.0135,  0.0192, -0.0296],
                        [ 0.0007, -0.0210, -0.0023,  ...,  0.0173, -0.0153, -0.0153],
                        [ 0.0313, -0.0180, -0.0275,  ..., -0.0022, -0.0177, -0.0668],
                        ...,
                        [-0.0066, -0.0039,  0.0306,  ..., -0.0219,  0.0544,  0.0458],
                        [-0.0117, 

In [108]:
byol.encoder.state_dict()

OrderedDict([('conv1.weight',
              tensor([[[[-0.0061, -0.0107, -0.0030,  ...,  0.0159,  0.0226, -0.0064],
                        [-0.0075,  0.0045,  0.0272,  ...,  0.0459, -0.0291, -0.0284],
                        [-0.0090, -0.0009, -0.0738,  ..., -0.0022,  0.0124,  0.0108],
                        ...,
                        [-0.0042,  0.0178, -0.0204,  ...,  0.0200, -0.0147, -0.0097],
                        [-0.0258,  0.0106, -0.0335,  ..., -0.0050,  0.0155,  0.0050],
                        [ 0.0214, -0.0898,  0.0110,  ...,  0.0368,  0.0060, -0.0151]],
              
                       [[ 0.0228, -0.0120,  0.0484,  ..., -0.0135,  0.0192, -0.0296],
                        [ 0.0007, -0.0210, -0.0023,  ...,  0.0173, -0.0153, -0.0153],
                        [ 0.0313, -0.0180, -0.0275,  ..., -0.0022, -0.0177, -0.0668],
                        ...,
                        [-0.0066, -0.0039,  0.0306,  ..., -0.0219,  0.0544,  0.0458],
                        [-0.0117, -0