# Bootstrap your own latent(BYOL) Tutorial

- 실습조교: 배진수(wlstn215@korea.ac.kr), 김정인(jungin_kim23@korea.ac.kr), 정진용(jy_jeong@korea.ac.kr)

In [None]:
''' github+colab 교육생분들 '''
# !git clone https://github.com/bogus215/LG-EDUCATION4.git

from IPython.display import Image
Image('./byol.png')

# Colab gpu 연결

## 런타임 -> 런타임유형 변경 -> 하드웨어 가속기 -> GPU

In [None]:
import torch
torch.cuda.is_available()

# 필요한 패키지 불러오기

In [None]:
import os
import numpy as np

from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.models as models

from torch.utils.data import Dataset
from torchvision.transforms import transforms
from torchvision import transforms, datasets

import warnings
warnings.filterwarnings(action='ignore')

# 학습에 필요한 하이퍼파라미터 정의

In [None]:
ARCHITECTURE = 'resnet18'
MLP_HIDDEN_SIZE = 512
PROJECTION_SIZE = 128
NUM_WORKER = 0
EPOCHS = 5
BATCH_SIZE = 256
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 1e-4
SEED = 123
NUM_CLASS = 10
LOG_EVERY_N_STEPS = 20
MOMENTUM = 0.996
GPU_INDEX = 0
DEVICE = 'cuda'

# CIFAR10 Dataset 정의 (+Augmentation)

In [None]:
class CIFAR10Dataset(Dataset):
    def __init__(self, 
                 data_dir: str = './data',
                 mode: str = 'train'):
        
        self.data_dir = data_dir
        self.mode = mode
    
    @staticmethod
    def get_transform(size: int = None,s: int = 1):
        """
        Return a set of data augmentation transformations 
        as described in the SimCLR paper.
        """
        
        normalize = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),std=(0.2023, 0.1994, 0.2010))
        color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s,0.8 * s,0.2 * s) # Randomly change the brightness, contrast, saturation and hue of an image.
        data_transforms = transforms.Compose([
            transforms.RandomResizedCrop(size=size),
            transforms.RandomHorizontalFlip(),
            transforms.RandomApply([color_jitter], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(kernel_size=int(0.1 * size)),
            transforms.ToTensor(),
            normalize,
        ])
        
        return data_transforms
    
    # For BYOL
    def get_pretrain_dataset(self):        
        # BYOL, 지도 학습에 사용되는 학습 데이터
        if self.mode == 'train':
            train_dataset = datasets.CIFAR10(self.data_dir,train=True,
                                             transform=MultiViewGenerator(base_transforms=self.get_transform(size=32)),
                                             download=True)
            
            return train_dataset
        
        else:
            raise NotImplementedError()
    
    # For Supervised Setting (Transfer Learning)
    def get_dataset(self):        
        # BYOL, 지도 학습에 사용되는 학습 데이터
        if self.mode == 'train':
            train_dataset = datasets.CIFAR10(self.data_dir,train=True,
                                             transform=transforms.Compose([transforms.ToTensor(),
                                                                           transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),std=(0.2023, 0.1994, 0.2010))]),
                                             download=True)
            
            return train_dataset
        
        elif self.mode == 'test':
            # 지도 학습 테스트에 사용되는 테스트 데이터
            test_dataset = datasets.CIFAR10(self.data_dir,train=False,
                                            transform=transforms.Compose([transforms.ToTensor(),
                                                                          transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),std=(0.2023, 0.1994, 0.2010))]),
                                            download=True)
            return test_dataset
        
        else:
            raise NotImplementedError()
            

class GaussianBlur(object):
    """
    blur a single image on CPU
    """
    def __init__(self, kernel_size: int = None):
        radias = kernel_size // 2
        kernel_size = radias * 2 + 1
        self.blur_h = nn.Conv2d(in_channels=3,
                                out_channels=3,
                                kernel_size=(kernel_size, 1),
                                stride=1,
                                padding=0, 
                                bias=False,
                                groups=3)
        self.blur_v = nn.Conv2d(in_channels=3,
                                out_channels=3,
                                kernel_size=(1, kernel_size),
                                stride=1,
                                padding=0, 
                                bias=False,
                                groups=3)
        self.k = kernel_size
        self.r = radias
        
        self.blur = nn.Sequential(
            nn.ReflectionPad2d(radias),
            self.blur_h,
            self.blur_v
        )
        
        self.pil_to_tensor = transforms.ToTensor()
        self.tensor_to_pil = transforms.ToPILImage()
    
    def __call__(self, img):
        img = self.pil_to_tensor(img).unsqueeze(0)
        
        sigma = np.random.uniform(0.1, 2.0)
        x = np.arange(-self.r, self.r + 1)
        x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
        x = x / x.sum()
        x = torch.from_numpy(x).view(1, -1).repeat(3, 1)
        
        self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
        self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))
    
        with torch.no_grad():
            img = self.blur(img)
            img = img.squeeze()
        
        img = self.tensor_to_pil(img)
        
        return img
    
    
class MultiViewGenerator(object):
    """
    Take two random crops of one image as the query and key.
    """
    
    def __init__(self, base_transforms):
        self.transforms = base_transforms
    
    def __call__(self, x):
        output1 = self.transforms(x)
        output2 = self.transforms(x)
        
        output = [output1, output2]
        
        return output

# 학습 네트워크 정의

In [None]:
class ResNet(nn.Module):
    def __init__(self, base_model: str = None):
        super(ResNet, self).__init__()
        
        self.resnet_dict = {
            'resnet18': models.resnet18(pretrained=True),
            'resnet50': models.resnet50(pretrained=True)
        }
        
        resnet = self.get_basemodel(base_model)
        
        # Define CNN encoder
        self.encoder = nn.Sequential(
            *list(resnet.children())[:-1])
        
        # Define MLP Projection
        self.projection = MLPHead(in_channels=resnet.fc.in_features,
                                  mlp_hidden_size=MLP_HIDDEN_SIZE,
                                  projection_size=PROJECTION_SIZE)
    
    def forward(self, x: torch.Tensor):
        h = self.encoder(x)
        h = h.view(h.shape[0], -1)
        
        return self.projection(h)
    
    def get_basemodel(self, model_name: str = None):
 
        model = self.resnet_dict[model_name]

        return model
    

class MLPHead(nn.Module):
    def __init__(self, 
                 in_channels: int,
                 mlp_hidden_size: int,
                 projection_size):
        super(MLPHead, self).__init__()
        
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, mlp_hidden_size), 
            nn.BatchNorm1d(mlp_hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(mlp_hidden_size, projection_size)
        )
    
    def forward(self, x):
        
        return self.mlp(x)

# 평가 지표 및 BYOL 학습 정의

In [None]:
def accuracy(output: torch.FloatTensor, 
             target: torch.LongTensor, 
             topk: tuple = (1, )):
    """
    Computes the accuracy over the k top predictions
    for the specified values of k.
    """
    
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        
        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        
        return res
    

class BYOL(object):
    def __init__(self, **kwargs):
        self.online_network = kwargs['model'].to(DEVICE)
        self.target_network = kwargs['model'].to(DEVICE)
        
        # Only Online Network
        self.predictor = kwargs['predictor'].to(DEVICE)
        self.optimizer = kwargs['optimizer']
        
    @torch.no_grad()
    def update_target_network_parameters(self):
        """
        Momentum Update of the Target Encoder
        """
        for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()):
            param_k.data = param_k.data * MOMENTUM + param_q.data * (1. - MOMENTUM)
            
    @staticmethod
    def byolloss(x: torch.FloatTensor, y: torch.FloatTensor):
        x = F.normalize(x, dim=1)
        y = F.normalize(y, dim=1)
        
        return 2 - 2 * (x * y).sum(dim=-1)
    
    def initialize_target_network(self):
        """
        Init Momentum Network as Target Encoder
        """
        for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()):
            param_k.data.copy_(param_q.data)  # Initialize
            param_k.requires_grad = False  # Not Update by Gradient
            
    def update(self, 
               batch_view_1: torch.FloatTensor, 
               batch_view_2: torch.FloatTensor):
        
        # Compute Online Feature
        predictions_from_view_1 = self.predictor(self.online_network(batch_view_1))
        predictions_from_view_2 = self.predictor(self.online_network(batch_view_2))
        
        # Compute Target Feature
        with torch.no_grad():
            targets_to_view_2 = self.target_network(batch_view_1)
            targets_to_view_1 = self.target_network(batch_view_2)
        
        loss = self.byolloss(predictions_from_view_1, targets_to_view_2.detach())
        loss += self.byolloss(predictions_from_view_2, targets_to_view_1.detach())
        
        return loss.mean()        
            
    def train(self, train_loader):
        
        n_iter = 0
        
        # Initialize Target Network
        self.initialize_target_network()
        for epoch in range(EPOCHS):
        
            train_loss = 0
            for (batch_view_1, batch_view_2), _ in tqdm(train_loader):
                batch_view_1 = batch_view_1.to(DEVICE)
                batch_view_2 = batch_view_2.to(DEVICE)
                
                loss = self.update(batch_view_1=batch_view_1, batch_view_2=batch_view_2)
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                # Update Target Encoder
                self.update_target_network_parameters()
                train_loss += loss
                
                if n_iter % LOG_EVERY_N_STEPS == 0:
                    print(f'Loss: {loss} \n')
                
                n_iter += 1
            
            train_loss /= (len(train_loader))
            
            print('=' * 30)
            print(f'Epoch: {epoch + 1} \n'
                  f'Loss: {train_loss} \n')
            
        return self.online_network               

# BYOL 코드 실행

In [None]:
def byol(ckpt=None):
    
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    
    # Define Dataset, Dataloader
    dataset = CIFAR10Dataset(mode='train')
    
    pretrain_dataset = dataset.get_pretrain_dataset()
    pretrain_loader = torch.utils.data.DataLoader(
        dataset=pretrain_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKER,
        drop_last=True)
    
    # Define Model, predictor, optimizer
    model = ResNet(base_model=ARCHITECTURE)
    if ckpt is not None:
        model.encoder.load_state_dict(ckpt)
    
    predictor = MLPHead(
        in_channels=PROJECTION_SIZE,
        mlp_hidden_size=PROJECTION_SIZE,
        projection_size=PROJECTION_SIZE)
    
    # Optimizing Encoder, predictor
    optim_params = list(model.parameters()) + list(predictor.parameters())
    optimizer = torch.optim.Adam(optim_params,lr=LEARNING_RATE, 
                                 weight_decay=WEIGHT_DECAY)

    with torch.cuda.device(GPU_INDEX):
        byol = BYOL(
            model=model,
            predictor=predictor,
            optimizer=optimizer)
        
        pretrained_model = byol.train(train_loader=pretrain_loader)
    
    return pretrained_model


if __name__ == '__main__':
    pretrained_model = byol()

# BYOL로 사전 학습한 인코더를 사용하여 지도 학습 수행

In [None]:
class Supervised(object):
    def __init__(self, **kwargs):
        self.encoder = kwargs['encoder'].to(DEVICE)
        self.classifier = kwargs['classifier'].to(DEVICE)
        
        self.optimizer = kwargs['optimizer']
        self.criterion = nn.CrossEntropyLoss()
        
    def train_test(self, train_loader, test_loader):
        
        for epoch in range(EPOCHS):
            
            # Train
            top1_train_acc, train_loss = 0, 0
            for i, (images, targets) in enumerate(train_loader):
                images = images.to(DEVICE)
                targets = targets.to(DEVICE)
                
                # Freeze Encoder Parameters
                with torch.no_grad():
                    features = self.encoder(images)
                    features = features.squeeze()
                    
                logits = self.classifier(features)
                loss = self.criterion(logits, targets)
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                top1 = accuracy(logits, targets, topk=(1, ))
                top1_train_acc += top1[0]
                
                train_loss += loss
            
            top1_train_acc /= (i + 1)
            train_loss /= (i + 1)
            
            # Test
            top1_test_acc = 0
            top5_test_acc = 0
            for i, (images, targets) in enumerate(test_loader):
                images = images.to(DEVICE)
                targets = targets.to(DEVICE)
                
                with torch.no_grad():
                    features = self.encoder(images)
                    features = features.squeeze()
                    logits = self.classifier(features)
                
                top1, top5 = accuracy(logits, targets, topk=(1, 5))
                top1_test_acc += top1[0]
                top5_test_acc += top5[0]
            
            top1_test_acc /= (i + 1)
            top5_test_acc /= (i + 1)

            print(f'Epoch: {epoch + 1} \n'
                  f'Loss: {train_loss} \n'
                  f'Top1 Train Accuracy: {top1_train_acc.item()} \n'
                  f'Top1 Test Accuracy: {top1_test_acc.item()} \n'
                  f'Top5 Test Accuracy: {top5_test_acc.item()} \n')
            
        return self.encoder, self.classifier

# 지도 학습 실행

In [None]:
def supervised():
    
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    
    # Define Dataset, Dataloader
    train_dataset = CIFAR10Dataset(mode='train')
    test_dataset = CIFAR10Dataset(mode='test')
    
    train_dataset = train_dataset.get_dataset()
    test_dataset = test_dataset.get_dataset()
    
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKER,
        drop_last=False)
    
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKER,
        drop_last=False)
    
    # Define Model, optimizer, linear classifier
    encoder = pretrained_model.encoder
    classifier = nn.Linear(512, NUM_CLASS)
    
    optimizer = torch.optim.Adam(
        classifier.parameters(), 
        lr=LEARNING_RATE, 
        weight_decay=WEIGHT_DECAY)
    
    with torch.cuda.device(GPU_INDEX):
        supervised = Supervised(
            encoder=encoder,
            classifier=classifier,
            optimizer=optimizer)
        
        supervised_encoder, supervised_classifier = supervised.train_test(train_loader=train_loader,test_loader=test_loader)
    
    return supervised_encoder, supervised_classifier


if __name__ == '__main__':
    supervised_model = supervised()

# TSNE로 데이터 시각화 해보기

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

EPOCHS=10
pretrained_model = byol(ckpt=supervised_model[0].state_dict())

actual = []
deep_features = []

feature_extractor = pretrained_model.encoder
feature_extractor.eval()

# Define Dataset, Dataloader
test_dataset = CIFAR10Dataset(mode='test')
test_dataset = test_dataset.get_dataset()
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=BATCH_SIZE,
                                          shuffle=False,num_workers=NUM_WORKER,
                                          drop_last=False)
with torch.no_grad():
    for data in tqdm(test_loader):

        images, labels = data[0].to(DEVICE), data[1].to(DEVICE)
        features = feature_extractor(images).squeeze()
        deep_features.append(features.cpu().numpy())
        actual += labels.cpu().numpy().tolist()

tsne = TSNE(n_components=2, random_state=0)
cluster = tsne.fit_transform(np.concatenate(deep_features))
actual = np.array(actual)

In [None]:
plt.figure(figsize=(10, 10))
cifar = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
for i, label in zip(range(10), cifar):
    idx = np.where(actual == i)
    plt.scatter(cluster[idx, 0], cluster[idx, 1], marker='.', label=label)

plt.legend()
plt.show()