# 필요한 패키지 불러오기

In [None]:
import os
import numpy as np

from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.models as models

from torch.utils.data import Dataset
from torchvision.transforms import transforms
from torchvision import transforms, datasets

import warnings
warnings.filterwarnings(action='ignore')

# 학습에 필요한 하이퍼파라미터 정의

In [None]:
ARCHITECTURE = 'resnet18'
MLP_HIDDEN_SIZE = 512
PROJECTION_SIZE = 128
NUM_WORKER = 0
N_VIEWS = 2
EPOCHS = 5
BATCH_SIZE = 256
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 1e-4
SEED = 42
NUM_CLASS = 10
LOG_EVERY_N_STEPS = 20
TEMPERATURE = 0.07
GPU_INDEX = 0
DEVICE = 'cuda'

# CIFAR10 Dataset 정의

In [None]:
class CIFAR10Dataset(Dataset):
    def __init__(self, data_dir: str = './data', mode: str = 'train'):
        self.data_dir = data_dir
        self.mode = mode
        # 저장 경로 폴더 없을 때 만들기
        os.makedirs(data_dir, exist_ok=True)
    
    @staticmethod
    def get_transform(size: int=None, s: int=1):
        """
        Return a set of data augmentation transformations 
        as described in the SimCLR paper.
        """
        normalize = transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
                                         std=(0.2023, 0.1994, 0.2010))
        color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
        data_transforms = transforms.Compose(
            [transforms.RandomResizedCrop(size=size),
             transforms.RandomHorizontalFlip(),
             transforms.RandomApply([color_jitter], p=0.8),
             transforms.RandomGrayscale(p=0.2),
             GaussianBlur(kernel_size=int(0.1 * size)),
             transforms.ToTensor(),
             normalize])
        return data_transforms
    
    # For SimCLR 
    def get_pretrain_dataset(self, n_views: int=2):        
        # SimCLR, 지도 학습에 사용되는 학습 데이터
        if self.mode == 'train':
            view_gen = ViewGenerator(base_transform=self.get_transform(size=32), n_views=n_views)
            train_dataset = datasets.CIFAR10(self.data_dir, train=True, transform=view_gen, download=True)
            return train_dataset
        else:
            raise NotImplementedError()
    
    # For Supervised Setting (Transfer Learning)
    def get_dataset(self):        
        # SimCLR, 지도 학습에 사용되는 학습 데이터
        if self.mode == 'train':
            train_dataset = datasets.CIFAR10(self.data_dir,
                                             train=True,
                                             transform=transforms.ToTensor(),
                                             download=True)
            return train_dataset
        elif self.mode == 'test':
            test_dataset = datasets.CIFAR10(self.data_dir,
                                            train=False,
                                            transform=transforms.ToTensor(),
                                            download=False)
            return test_dataset
        else:
            raise NotImplementedError()
            

class GaussianBlur(object):
    """
    blur a single image on CPU
    """
    def __init__(self, kernel_size: int = None):
        radias = kernel_size // 2
        kernel_size = radias * 2 + 1
        self.blur_h = nn.Conv2d(in_channels=3,
                                out_channels=3,
                                kernel_size=(kernel_size, 1),
                                stride=1,
                                padding=0, 
                                bias=False,
                                groups=3)
        self.blur_v = nn.Conv2d(in_channels=3,
                                out_channels=3,
                                kernel_size=(1, kernel_size),
                                stride=1,
                                padding=0, 
                                bias=False,
                                groups=3)
        self.k = kernel_size
        self.r = radias
        
        self.blur = nn.Sequential(
            nn.ReflectionPad2d(radias),
            self.blur_h,
            self.blur_v
        )
        
        self.pil_to_tensor = transforms.ToTensor()
        self.tensor_to_pil = transforms.ToPILImage()
    
    def __call__(self, img):
        img = self.pil_to_tensor(img).unsqueeze(0)
        
        sigma = np.random.uniform(0.1, 2.0)
        x = np.arange(-self.r, self.r + 1)
        x = np.exp(-np.power(x, 2) / (2 * sigma * sigma))
        x = x / x.sum()
        x = torch.from_numpy(x).view(1, -1).repeat(3, 1)
        
        self.blur_h.weight.data.copy_(x.view(3, 1, self.k, 1))
        self.blur_v.weight.data.copy_(x.view(3, 1, 1, self.k))
    
        with torch.no_grad():
            img = self.blur(img)
            img = img.squeeze()
        
        img = self.tensor_to_pil(img)
        
        return img
    
    
class ViewGenerator(object):
    """
    Take two random crops of one image as the query and key.
    """
    
    def __init__(self,
                 base_transform,
                 n_views: int = 2):
        
        self.base_transform = base_transform
        self.n_views = n_views
    
    def __call__(self, x):
        return [self.base_transform(x) for i in range(self.n_views)]

# 학습 네트워크 정의

In [None]:
class ResNet(nn.Module):
    def __init__(self, base_model: str = None):
        super(ResNet, self).__init__()
        
        self.resnet_dict = {
            'resnet18': models.resnet18(pretrained=False),
            'resnet50': models.resnet50(pretrained=False)
        }
        
        resnet = self.get_basemodel(base_model)
        
        # Define CNN encoder
        self.encoder = nn.Sequential(
            *list(resnet.children())[:-1])
        
        # Define MLP Projection
        self.projection = MLPHead(in_channels=resnet.fc.in_features,
                                  mlp_hidden_size=MLP_HIDDEN_SIZE,
                                  projection_size=PROJECTION_SIZE)
    
    def forward(self, x: torch.Tensor):
        h = self.encoder(x)
        h = h.view(h.shape[0], h.shape[1])
        
        return self.projection(h)
    
    def get_basemodel(self, model_name: str = None):
 
        model = self.resnet_dict[model_name]

        return model
    

class MLPHead(nn.Module):
    def __init__(self, 
                 in_channels: int,
                 mlp_hidden_size: int,
                 projection_size):
        super(MLPHead, self).__init__()
        
        self.mlp = nn.Sequential(
            nn.Linear(in_channels, mlp_hidden_size), 
            nn.ReLU(inplace=True),
            nn.Linear(mlp_hidden_size, projection_size)
        )
    
    def forward(self, x):
        
        return self.mlp(x)

# 평가 지표 및 SimCLR 학습 정의

In [None]:
def accuracy(output: torch.FloatTensor, 
             target: torch.LongTensor, 
             topk: tuple = (1, )):
    """
    Computes the accuracy over the k top predictions
    for the specified values of k.
    """
    
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        
        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        
        return res
        

class SimCLR(object):
    def __init__(self, **kwargs):
        self.model = kwargs['model'].to(DEVICE)
        self.optimizer = kwargs['optimizer']
        self.scheduler = kwargs['scheduler']
        self.criterion = nn.CrossEntropyLoss().to(DEVICE)
    
    def infonce(self, 
                features: torch.FloatTensor,
                n_views: int = 2):
        
        # 0, 1, ..., BATCH_SIZE
        labels = torch.cat([torch.arange(BATCH_SIZE) for i in range(n_views)], dim=0)
        
        # 동일한 값 기준으로 0, 1 행렬 표현
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels.to(DEVICE)
        
        # Normalize Feature Vector
        features = F.normalize(features, dim=1)
        
        # 정규화된 벡터 간 행렬 곱
        similarity_matrix = torch.matmul(features, features.T)
        
        # labels matrix, similarity matrix에서 main diagonal 지우기
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(DEVICE)
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
        
        # Select and combine multiple positives
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)
        
        # Select only the negatives
        negatives = similarity_matrix[~labels.bool()].view(labels.shape[0], -1)
        
        logits = torch.cat([positives, negatives], dim=1)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(DEVICE)
        
        logits = logits / TEMPERATURE
        
        return logits, labels
    
    def train(self, train_loader):
        
        n_iter = 0
        for epoch in range(EPOCHS):
            
            train_loss = 0
            for images, _ in tqdm(train_loader):
                images = torch.cat(images, dim=0)  # 2 * Batch Size
                images = images.to(DEVICE)
                
                features = self.model(images)
                logits, labels = self.infonce(features=features)
                loss = self.criterion(logits, labels)
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                train_loss += loss
                
                if n_iter % LOG_EVERY_N_STEPS == 0:
                    top1, top5 = accuracy(logits, labels, topk=(1, 5))
                    print(f'Loss: {loss} \n'
                          f'ACC/Top1: {top1[0]} \n'
                          f'ACC/Top5: {top5[0]} \n'
                          f'Learning Rate: {self.scheduler.get_lr()[0]}')
                
                n_iter += 1
            
            train_loss /= (len(train_loader))
            
            # Warmup for the first 10 epochs
            if epoch >= 10:
                self.scheduler.step()
                
            print('=' * 30)
            print(f'Epoch: {epoch + 1} \n'
                  f'Loss: {train_loss} \n'
                  f'Top1 Accuracy: {top1[0]} \n'
                  f'Top5 Accuracy: {top5[0]} \n'
                  f'Learning Rate: {self.scheduler.get_lr()[0]}')
            
        return self.model               

# SimCLR 코드 실행

In [None]:
def main():
    
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    
    # Define Dataset, Dataloader
    dataset = CIFAR10Dataset(mode='train')
    
    train_dataset = dataset.get_pretrain_dataset(n_views=N_VIEWS)
    
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKER,
        drop_last=True)
    
    # Define Model, optimizer, scheduler
    model = ResNet(
        base_model=ARCHITECTURE)
    
    optimizer = torch.optim.Adam(
        model.parameters(), 
        lr=LEARNING_RATE, 
        weight_decay=WEIGHT_DECAY)
    
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer,
        T_max=len(train_loader),
        eta_min=0,
        last_epoch=-1)
    
    with torch.cuda.device(GPU_INDEX):
        simclr = SimCLR(
            model=model,
            optimizer=optimizer,
            scheduler=scheduler)
        
        pretrained_model = simclr.train(train_loader=train_loader)
    
    return pretrained_model


if __name__ == '__main__':
    pretrained_model = main()

# SimCLR로 사전 학습한 인코더를 사용하여 지도 학습 수행(전이 학습 개념)

In [None]:
class Supervised(object):
    def __init__(self, **kwargs):
        self.encoder = kwargs['encoder'].to(DEVICE)
        self.classifier = kwargs['classifier'].to(DEVICE)
        self.optimizer = kwargs['optimizer']
        self.criterion = nn.CrossEntropyLoss().to(DEVICE)

    def train_test(self, train_loader, test_loader):
        
        for epoch in range(EPOCHS):
            
            # Train
            top1_train_acc, train_loss = 0, 0
            for i, (images, targets) in enumerate(train_loader):
                images = images.to(DEVICE)
                targets = targets.to(DEVICE)
                
                # Freeze Encoder Parameters
                with torch.no_grad():
                    logits = self.encoder(images)
                    logits = logits.squeeze()
                
                logits = self.classifier(logits)
                loss = self.criterion(logits, targets)
                
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                top1 = accuracy(logits, targets, topk=(1, ))
                top1_train_acc += top1[0]
                
                train_loss += loss
            
            top1_train_acc /= (i + 1)
            train_loss /= (i + 1)
            
            # Test
            top1_test_acc = 0
            top5_test_acc = 0
            for i, (images, targets) in enumerate(test_loader):
                images = images.to(DEVICE)
                targets = targets.to(DEVICE)
                
                with torch.no_grad():
                    logits = self.encoder(images)
                    logits = logits.squeeze()
                    logits = self.classifier(logits)
                
                top1, top5 = accuracy(logits, targets, topk=(1, 5))
                top1_test_acc += top1[0]
                top5_test_acc += top5[0]
            
            top1_test_acc /= (i + 1)
            top5_test_acc /= (i + 1)

            print(f'Epoch: {epoch + 1} \n'
                  f'Loss: {train_loss} \n'
                  f'Top1 Train Accuracy: {top1_train_acc.item()} \n'
                  f'Top1 Test Accuracy: {top1_test_acc.item()} \n'
                  f'Top5 Test Accuracy: {top5_test_acc.item()} \n')
            
        return self.encoder, self.classifier

# 지도 학습 실행

In [None]:
def main():
    
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    
    # Define Dataset, Dataloader
    train_dataset = CIFAR10Dataset(mode='train')
    test_dataset = CIFAR10Dataset(mode='test')
    
    train_dataset = train_dataset.get_dataset()
    test_dataset = test_dataset.get_dataset()
    
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKER,
        drop_last=False)
    
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKER,
        drop_last=False)
    
    # Define encoder, optimizer, linear classifier
    encoder = pretrained_model.encoder
    
    classifier = nn.Linear(512, NUM_CLASS)

    optimizer = torch.optim.Adam(
        classifier.parameters(), 
        lr=LEARNING_RATE, 
        weight_decay=WEIGHT_DECAY)
    
    with torch.cuda.device(GPU_INDEX):
        supervised = Supervised(
            encoder=encoder,
            classifier=classifier,
            optimizer=optimizer)
        
        supervised_encoder, supervised_classifier = supervised.train_test(train_loader=train_loader,
                                                                          test_loader=test_loader)
    
    return supervised_encoder, supervised_classifier


if __name__ == '__main__':
    supervised_model = main()