In [1]:
import os
import sys
import numpy as np

import torch
import torchvision
import torch.nn as nn
from tqdm.notebook import tqdm
from torch import allclose
from datetime import datetime
import torchvision.transforms as T
from torch.testing import assert_allclose
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision.models as models

import kornia
from kornia import augmentation as K
import kornia.augmentation.functional as F
import kornia.augmentation.random_generator as rg
from torchvision.transforms import functional as tvF

In [2]:
uid = 'byol'
dataset_name = 'stl10'
data_dir = 'dataset'
ckpt_dir = "./ckpt/"+str(datetime.now().strftime('%m%d%H%M%S'))
log_dir = "runs/"+str(datetime.now().strftime('%m%d%H%M%S'))

if not os.path.exists(data_dir):
    os.makedirs(data_dir)
    
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

if not os.path.exists(log_dir):
    os.makedirs(log_dir)
    

In [3]:
# transformations

_MEAN =  [0.5, 0.5, 0.5]
_STD  =  [0.2, 0.2, 0.2]



class InitalTransformation():
    def __init__(self):
        self.transform = T.Compose([
            T.ToTensor(),
            transforms.Normalize(_MEAN,_STD),
        ])

    def __call__(self, x):
        x = self.transform(x)
        return  x


def gpu_transformer(image_size,s=.2):
        
    train_transform = nn.Sequential(

                kornia.augmentation.RandomResizedCrop(image_size,scale=(0.5,1.0)),
                kornia.augmentation.RandomHorizontalFlip(p=0.5),
                kornia.augmentation.ColorJitter(0.8*s,0.8*s,0.8*s,0.2*s,p=0.3),
                kornia.augmentation.RandomGrayscale(p=0.05),)

    test_transform = nn.Sequential(  
                kornia.augmentation.RandomResizedCrop(image_size,scale=(0.5,1.0)),
                kornia.augmentation.RandomHorizontalFlip(p=0.5),
                kornia.augmentation.ColorJitter(0.8*s,0.8*s,0.8*s,0.2*s,p=0.3),
                kornia.augmentation.RandomGrayscale(p=0.05),)

    return train_transform , test_transform
                
def get_clf_train_test_transform(image_size,s=.2):
        
    train_transform = nn.Sequential(
                
                kornia.augmentation.RandomResizedCrop(image_size,scale=(0.5,1.0)),
                kornia.augmentation.RandomHorizontalFlip(p=0.5),
#                 kornia.augmentation.Normalize(CIFAR_MEAN_,CIFAR_STD_),
            )

    test_transform = nn.Sequential(  
                kornia.augmentation.RandomResizedCrop(image_size,scale=(0.5,1.0)),
                kornia.augmentation.RandomHorizontalFlip(p=0.5),
                # kornia.augmentation.RandomGrayscale(p=0.05),
                # kornia.augmentation.Normalize(CIFAR_MEAN_,CIFAR_STD_)
        )

    return train_transform , test_transform

In [4]:
def get_train_test_dataloaders(dataset = "cifar10", data_dir="./dataset", batch_size = 64,num_workers = 4, download=True): 
    
    train_loader = torch.utils.data.DataLoader(
        dataset = torchvision.datasets.CIFAR10(data_dir, train=True, transform=InitalTransformation(), download=download),
        shuffle=True,
        batch_size= batch_size,
        num_workers = num_workers
    )
    

    test_loader = torch.utils.data.DataLoader(
        dataset = torchvision.datasets.CIFAR10(data_dir, train=False, transform=InitalTransformation(), download=download),
        shuffle=True,
        batch_size= batch_size,
        num_workers = num_workers
        )
    return train_loader, test_loader

In [5]:
import copy
from torch import nn
import torchvision.models as models

def loss_fn(q1,q2, z1t,z2t):
    
    l1 = - F.cosine_similarity(q1, z1t.detach(), dim=-1).mean()
    l2 = - F.cosine_similarity(q2, z2t.detach(), dim=-1).mean()
    
    return (l1+l2)/2


class MLPHead(nn.Module):
    def __init__(self, in_dim, hidden_size=4096, projection_size=256):
        super(MLPHead, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, projection_size)
        )

    def forward(self, x):
        return self.net(x)
    


class BYOL(nn.Module):
    def __init__(self, backbone=None,**kwargs):
        super().__init__()
        if backbone is None:
            backbone = models.resnet50(pretrained=False)
        encoder = torch.nn.Sequential(*list(backbone.children())[:-1])
        projector = MLPHead(in_dim=backbone.fc.in_features)
        
        self.online_encoder = nn.Sequential(
            encoder,
            projector)
        
        self.target_encoder = copy.deepcopy(self.online_encoder)
        self.online_predictor = MLPHead(in_dim=256,hidden_size=1024, projection_size=256)
        


    @torch.no_grad()
    def update_moving_average(self, global_step, max_steps):
        
        tau = 1 - base_ema * (cos(pi*global_step/max_steps)+1)/2 
        
        for online, target in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            target.data = tau * target.data + (1 - tau) * online.data     
    
    def forward(self,x1,x2):
        
        z1 = self.online_encoder(x1)
        z2 = self.online_encoder(x2)
        
        q1 = self.online_predictor(z1)
        q2 = self.online_predictor(z2)
        
        with torch.no_grad():
            z1_t = self.target_encoder(x1)
            z2_t = self.target_encoder(x2)
       
        loss = loss_fn(q1, q2, z1_t, z2_t)
        
        return {"loss": loss}


In [6]:
if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
    device = torch.device("cuda")
    # torch.cuda.set_device(device_id)
else:
    dtype = torch.FloatTensor
    device = torch.device("cpu")
    
print(device)


cuda


In [10]:
weight_decay = 1.5e-6
warmup_epochs =  10
warmup_lr = 0
momentum = 0.9
lr =  0.3
final_lr =  0
epochs = 8
stop_at_epoch = 100
batch_size = 64
knn_monitor = False
knn_interval = 5
knn_k = 200
image_size = (32,32)

In [11]:
train_loader, test_loader = get_train_test_dataloaders(batch_size=batch_size)
train_transform,test_transform = gpu_transformer(image_size)

Files already downloaded and verified
Files already downloaded and verified


In [12]:
from lr_scheduler import LR_Scheduler
from lars import LARS

loss_ls = []
acc_ls = []

model = BYOL().to(device)


optimizer = LARS(model.named_modules(), lr=lr*batch_size/256, momentum=momentum, weight_decay=weight_decay)

scheduler = LR_Scheduler(optimizer, warmup_epochs, warmup_lr*batch_size/256,
                         epochs, lr*batch_size/256, final_lr*batch_size/256, 
                        len(train_loader),
                        constant_predictor_lr=True
                        )


min_loss = np.inf 
accuracy = 0

# start training 
global_progress = tqdm(range(0, epochs), desc=f'Training')
data_dict = {"loss": 100}
for epoch in global_progress:
    model.train()   
    local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{epochs}')
    
    for idx, (image, label) in enumerate(local_progress):
        image = image.to(device)
        aug_image = train_transform(image)
        model.zero_grad()
        model_loss = model.forward(image.to(device, non_blocking=True), aug_image.to(device, non_blocking=True))

        data_dict['loss'] = loss.item() 
        
        loss_ls.append(loss.item())
        loss.backward()
        
        optimizer.step()
        scheduler.step()
        
        data_dict.update({'lr': scheduler.get_last_lr()})
        local_progress.set_postfix(data_dict)
        logger.update_scalers(data_dict)
    
    current_loss = data_dict['loss']
    
#     if epoch % knn_interval == 0: 
#         accuracy = knn_monitor(model.backbone, memory_loader, test_loader, 'gpu', hide_progress=True) 
#         data_dict['accuracy'] = accuracy
#         acc_ls.append(accuracy)

    global_progress.set_postfix(data_dict)
    logger.update_scalers(data_dict)
    
    model_path = os.path.join(ckpt_dir, f"{uid}_{datetime.now().strftime('%m%d%H%M%S')}.pth")

    if min_loss > current_loss:
        min_loss = current_loss
        
        torch.save({
        'epoch':epoch+1,
        'state_dict': model.state_dict() }, model_path)
        print(f'Model saved at: {model_path}')

Training:   0%|          | 0/8 [00:00<?, ?it/s]

Epoch 0/8:   0%|          | 0/782 [00:00<?, ?it/s]

RuntimeError: mat1 dim 1 must match mat2 dim 0