In [1]:
import os 
import warnings 
warnings.filterwarnings('ignore')
import numpy as np 
import pandas as pd 
import cv2 

import matplotlib.pyplot as plt 

from PIL import Image

import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.autograd import Variable
from torch.utils.data import Dataset 
from torch.utils.data import DataLoader
import torchvision 
from torchvision import transforms
from torchvision.datasets import STL10
import timm

#데이터 로더 
from Dataset import prepare_dataloader

#어그먼테이션 
from utils import augmenter

#model 
from utils import ResnetEncoder

# Load Dataset 

In [19]:
class Custom_Dset(Dataset):
    def __init__(self,dataset):
        super().__init__
        self.dataset = dataset 
        self.augmenter = self.__augmenter__() 

    def __len__(self):
        return len(self.dataset)

    def __augmenter__(self):
        augmentation = transforms.Compose([
            transforms.ToTensor()
        ])
        return augmentation

    def __getitem__(self,idx):
        img,label = self.dataset[idx]
        img = self.augmenter(img)

        return img, label 
        
def prepare_dataloader(batch_size):
    root = './Data'
    #train - label 
    label_train_stl10 = STL10(root=root,split='train')
    label_train_Dset = Custom_Dset(label_train_stl10)
    label_train_loader = DataLoader(label_train_Dset,batch_size=batch_size,shuffle=True)

    #train-unlabelled 
    unlabel_train_stl10 = STL10(root=root,split='unlabeled')
    unlabel_train_Dset = Custom_Dset(unlabel_train_stl10)
    unlabel_train_loader = DataLoader(unlabel_train_Dset,batch_size=batch_size,shuffle=True)

    #test 
    test_st10 = STL10(root=root,split='test')
    test_dset = Custom_Dset(test_st10)
    test_loader = DataLoader(test_dset,batch_size=batch_size,shuffle=False)

    print(len(label_train_stl10), len(unlabel_train_stl10),len(test_st10))
    
    return label_train_loader, unlabel_train_loader , test_loader

# Augmentation 
- SimCLR나 BYOL과 같은 다른 self-supervised 와 달리 NNCLR은 Augmentation에 대해 덜 의존적임 
- nearest-neighbors이 이미 같은 변수에 대해 풍부한 정보를 제공하기 때문 
- NNCLR은 complex augmentation에 덜 의존적이기 때문에 **Random Crop**, **Random Brightness** 만을 사용 

In [20]:
def augmenter(kwargs):
    brightness,_,scale = kwargs.values()
    return transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(size=96,scale=scale),
        transforms.ColorJitter(brightness=brightness)
    ])
    

# Encoder Architecture 

In [29]:
class ResnetEncoder(nn.Module):
    def __init__(self,vector_size=2048):
        super(ResnetEncoder,self).__init__()
        self.resnet50 = res50 = timm.create_model('resnet50',num_classes=vector_size,pretrained=True)

    def forward(self,x):
        x = self.resnet50(x)
        return x     

class ProjectionHead(nn.Module):
    def __init__(self):
        super(ProjectionHead,self).__init__()
        self.fc1 = self.linear_layer()
        self.fc2 = self.linear_layer()
        self.fc3 = nn.Sequential(nn.Linear(in_features=2048,out_features=256),
                                 nn.BatchNorm1d(256))

    def linear_layer(self,out_features=2048):
        Linear_layer = nn.Sequential(
            nn.Linear(in_features=2048,out_features=out_features),
            nn.BatchNorm1d(2048),
            nn.ReLU()
        )
        return Linear_layer 

    def forward(self,x):
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x 

class NNCLR_model(nn.Module):
    def __init__(self,Encoder,ProjectionHead):
        super(NNCLR_model,self).__init__()
        self.encoder = Encoder 
        self.head = ProjectionHead
    
    def forward(self,x):
        x = self.encoder(x)
        x = self.ehad(x)
        return x 

# Nearest Neighbor 

In [22]:
def nearest_neighbour(projections):
    support_similarities = torch.matmul(projections,feature_queue)
    nn_projections = torch.gather(feature_queue,torch.argmax(support_similarities,axis=1),axis=0)
    
    return_value = projections + (nn_projections - projections).detach()

    return return_value 

# Contrastive Loss 

In [23]:
def contrastive_loss(projection_1,projection_2,temperature):
    projection_1 = F.normalize(projection_1,p=2)
    projection_2 = F.normalize(projection_2,p=2)

    similarities_1_2_1 = (torch.matmul(nearest_neighbour(projection_1),projection_2)/temperature)
    similarities_1_2_2 = (torch.matmul(projection_2,nearest_neighbour(projection_1))/temperature)
    similarities_2_1_1 = (torch.matmul(nearest_neighbour(projection_2),projection_1)/temperature)
    similarities_2_1_2 = (torch.matmul(projection_1,self.nearest_neighbour(projection_2))/temperature)

    contrastive_batch_size= projection_1.shape[0]
    contrastive_labels = torch.range(0,contrastive_batch_size)
    loss = nn.CrossEntropyLoss(
        torch.concat(
            [
            contrastive_labels,
            contrastive_labels,
            contrastive_labels,
            contrastive_labels,
            ],
            axis=0),
        torch.concat(
            [
            similarities_1_2_1,
            similarities_1_2_2,
            similarities_2_1_1,
            similarities_2_1_2,
            ],
            axis=0
        ),
    )

    feature_quene.assign(torch.concat([projection_1,feature_quene[:-batch_size]],axis=0))
    return loss 

# Memory Bank 

# Training 

In [4]:
# 하이퍼 파라미터 

shuffle_buffer = 5000 
labelelled_train_images = 5000 
unlabelled_images = 100000

temperature = 0.1 
queue_size = 10000
contrastive_augmenter = {
    "brightness" : 0.5, 
    "name" : "contrastive_augmenter",
    "scale" : (0.2,1.0)
}
classification_augmenter = {
    "brightness": 0.2,
    "name": "classification_augmenter",
    "scale": (0.5, 1.0),
}
input_shape = (96,96,3)
width = 128 
num_epochs = 25 
steps_per_epoch = 200 
batch_size= 4096 
learning_rate = 1e-3 
device = 'cuda:0'

In [32]:
#데이터 로더 로드 
#label_train, unlabel_train, test = prepare_dataloader(batch_size=32)

#Augmentation 로드 
contrastive_augmentation = augmenter(contrastive_augmenter)
classification_augmentation = augmenter(classification_augmenter)

#모델 로드 
encoder = ResnetEncoder()
prj_head = ProjectionHead()
model = NNCLR_model(encoder,prj_head)

criterion = contrastive_loss
optimizer = torch.optim.Adam(lr=learning_rate,params=model.parameters())

feature_queue = Variable(F.normalize(torch.normal(mean=0.0,std=1.0,size=(queue_size,2048)),p=2),requires_grad=False)

for epoch in range(num_epochs):
    encoder.train()
    prj_head.train() 
    #이미지 데이터 로드 
    for (label_img,labels),(unlabel_img,_) in zip(label_train,unlabel_train):
        imgages = torch.concat((label_img,unlabel_img),axis=0)
        augmented_images_1 = contrastive_augmentation(imgages)
        augmented_images_2 = contrastive_augmentation(imgages)
    

        optimizer.zero_grad()

        features_1 = encoder(augmented_images_1)
        features_2 = encoder(augmented_images_2)

        projection_1 = prj_head(features_1)
        projection_2 = prj_head(features_2)

        loss = criterion(projection_1,projection_2)

        loss.backward()
        optimizer.step() 
        break
    break
    


NameError: name 'label_train' is not defined