In [None]:
import pickle 
import pandas as pd 
import numpy as np 
import os 
from glob import glob 
import tqdm 
from PIL import Image 
from tqdm import tqdm 
import wandb
import math

from sklearn.metrics import f1_score,accuracy_score

from src.Dataset import CifarDataset,label_unlabel_load,dataset_load
from src.Models import Model,PiModel
from src.Loss import PiCriterion

import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import Dataset 
from torchvision import transforms
import timm 

import warnings 
warnings.filterwarnings('ignore')

# 데이터셋 

In [None]:
def from_pickle_to_img(file,name):
    with open(file, 'rb') as f:
        data = pickle.load(f,encoding='bytes')
    if name == 'cifar10':
        #해당 데이터의 경우 flatten한 상태로 저장이 되어 있기 때문에 이를 image에 맞춰서 변형해줌 
        batch_imgs = data[b'data'].reshape(-1,3,32,32).transpose(0,2,3,1) 
        batch_labels = data[b'labels']
        return batch_imgs, np.array(batch_labels) 
    
    elif name == 'cifar100':
        batch_imgs = data[b'data'].reshape(-1,32,32,3)
        batch_labels = data[b'fine_labels']
        return batch_imgs, np.array(batch_labels) 

def load_cifar10():
    files = sorted(glob('./Dataset/cifar-10-batches-py/*')[1:-1])
    imgs = [] 
    labels = [] 
    for file in files:
        batch_imgs,batch_labels = from_pickle_to_img(file,'cifar10')
        imgs.extend(batch_imgs)
        labels.extend(batch_labels)
    labels = np.array(labels)
    imgs = np.array(imgs)
    return (imgs[:50000],labels[:50000]),(imgs[50000:],labels[50000:])

def load_cifar100():
    files = sorted(glob('./Dataset/cifar-100-python/*'))[-2:]
    train_imgs,train_labels = from_pickle_to_img(files[1],'cifar100')
    test_imgs,test_labels = from_pickle_to_img(files[0],'cifar100')
    return (train_imgs,train_labels),(test_imgs,test_labels)

#데이터셋 로드 후 label - unlabel 데이터 만드는 메소드 
def label_unlabel_load(cfg):
    (train_imgs,train_labels),(test_imgs,test_labels) = dataset_load(cfg['dataset'])
    labels = np.unique(train_labels)
    label = labels[0]
    for label in labels:
        label_idx = (train_labels ==label).nonzero()[0]
        unlabel_idx = np.random.choice(label_idx,int(len(label_idx)*cfg['unlabel_ratio']),replace=False)
        train_labels[unlabel_idx] = -1 
        
        
    train = {'imgs':train_imgs,
            'labels':train_labels}
    test = {'imgs':test_imgs,
            'labels':test_labels}
    return train, test 

class CifarDataset(Dataset):
    def __init__(self,data,unlabel=False,transform=None):
        super(CifarDataset,self).__init__()
        self.transform = transform 
        self.imgs = data['imgs']
        self.labels = data['labels']
        self.unlabel = unlabel 
        self.transform = self.transfrom_init(transform)
        
    def __len__(self):
        return len(self.imgs)
    
    def transfrom_init(self,transform):
        if transform == None:
            return transforms.Compose([transforms.ToTensor()])
        else:
            return transform 

            
    def __getitem__(self,idx):
        if self.unlabel:
            img = self.transform(self.imgs[idx])
            return img
        else:
            img = self.transform(self.imgs[idx])
            label = self.labels[idx]
            return img,label 

# 학습 모듈 

In [None]:
def make_transform():
    color_jitter = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
    transformer = transforms.Compose([
                                          transforms.RandomApply([color_jitter],p=0.8),
                                          transforms.RandomResizedCrop(32),
                                          transforms.GaussianBlur(kernel_size=int(0.1*32))
                                         ])
    return transformer
 
def make_valid(dataset = 'cifar10'):
    (train_imgs,train_labels),(test_imgs,test_labels) = dataset_load(dataset)
    idx = np.random.choice(np.arange(len(train_imgs)),5000,replace=False)
    valid_set = {'imgs':train_imgs[idx],
        'labels':train_labels[idx]}
    return valid_set 

def train(model,criterion,optimizer,train_loader,cfg,transformer):
    global epoch 
    model.train() 
    tl_loss = [] 
    tu_loss = [] 
    total_loss = [] 
    for batch_img,batch_labels in tqdm(train_loader):
        
        batch_img_1 = transformer(batch_img.type(torch.float32).to(cfg['device']))
        batch_img_2 = transformer(batch_img.type(torch.float32).to(cfg['device']))
        batch_labels = batch_labels.to(cfg['device'])
        
        y_pred_1 = model(batch_img_1,True)
        y_pred_2 = model(batch_img_2,True)
        loss,tl,tu,weight = criterion(y_pred_1,y_pred_2,batch_labels,epoch)
        
        total_loss.append(loss.detach().cpu().numpy())
        tl_loss.append(tl.detach().cpu().numpy())
        tu_loss.append(tu.detach().cpu().numpy())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    return np.mean(total_loss),np.mean(tl_loss),np.mean(tu_loss),weight

def valid(model,test_loader,cfg):
    labels = []
    y_preds = [] 
    model.eval() 
    for batch_imgs,batch_labels in test_loader:
        batch_imgs = batch_imgs.type(torch.float32).to(cfg['device'])
        with torch.no_grad():
            y_pred = model(batch_imgs,False)
        y_pred = torch.argmax(F.softmax(y_pred),dim=1)
        y_pred = y_pred.detach().cpu().numpy()
        
        y_preds.extend(y_pred)
        labels.extend(batch_labels.detach().cpu().numpy())    
    f1 = f1_score(np.array(y_preds),np.array(labels),average='macro')
    auc = accuracy_score(np.array(y_preds),np.array(labels))
    return f1, auc

# 모델 

In [None]:
class Model(nn.Module):
    def __init__(self,model_name='ssl_resnet50',dataset_name='cifar10'):
        super(Model,self).__init__()
        self.model_name = model_name 
        self.encoder = self.pretrained_encoder(model_name)
        self.linear = self.output_layer(dataset_name)
        
    
    def pretrained_encoder(self,model_name):
        res = timm.create_model(model_name,pretrained=True)
        encoder = nn.Sequential(*(list(res.children())[:-1]))
        return encoder 
    
    def output_layer(self,dataset_name):
        in_features = list(self.encoder[-2][-1].children())[-3].out_channels
        if dataset_name == 'cifar10':
            return nn.Linear(in_features = in_features,out_features= 10)
        else:
            return nn.Linear(in_features = in_features,out_features= 100)
        
    def forward(self,x,_):
        x = self.encoder(x)
        x = self.linear(x)
        return x 
    
class GaussianNoise(nn.Module):
    
    def __init__(self, batch_size, input_shape=(1, 32, 32), std=0.05,device='cpu'):
        super(GaussianNoise, self).__init__()
        self.shape = (batch_size,) + input_shape
        self.noise = Variable(torch.zeros(self.shape)).to(device)
        self.std = std
        
        
    def forward(self, x):
        self.noise.data.normal_(0, std=self.std)
        return x + self.noise
    
    
class PiModel(nn.Module):
    def __init__(self,num_labels=10,batch_size=100,std=0.15,device='cpu'):
        super(PiModel,self).__init__()
        self.noise = GaussianNoise(batch_size,std=std,device=device)
        self.conv1 = self.conv_block(3,128).to(device)
        self.conv2 = self.conv_block(128,256).to(device)
        self.conv3 = self.conv3_block().to(device)
        self.linear = nn.Linear(128,num_labels).to(device)
        
        
    def conv_block(self,input_channel,num_filters):
        return nn.Sequential(
                                nn.Conv2d(input_channel,num_filters,3,1,1),
                                nn.LeakyReLU(0.1),
                                nn.Conv2d(num_filters,num_filters,3,1,1),
                                nn.LeakyReLU(0.1),
                                nn.Conv2d(num_filters,num_filters,3,1,1),
                                nn.LeakyReLU(0.1),
                                nn.MaxPool2d(2,2),
                                nn.Dropout(0.5)                    
                             )
    def conv3_block (self):
        return nn.Sequential(
                              nn.Conv2d(256,512,3,1,0),
                              nn.LeakyReLU(0.1),
                              nn.Conv2d(512,256,1,1),
                              nn.LeakyReLU(0.1),
                              nn.Conv2d(256,128,1,1),
                              nn.LeakyReLU(0.1)

        )
    def forward(self,x,train):
        if train:
            x = self.noise(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = F.avg_pool2d(x, x.size()[2:]).squeeze()
        x = self.linear(x)
        return x 

# 학습 

In [None]:
cfg = {}
cfg['dataset'] = 'cifar10'
cfg['model_name'] = 'resnet18'
cfg['unlabel_ratio'] = 0.6
cfg['batch_size'] = 100 
cfg['device'] = 'cuda:0'
cfg['lr'] = 0.003 
cfg['beta1'] = 0.8
cfg['beta2'] = 0.999 
cfg['epochs'] = 300 
cfg['std'] = 0.15 
cfg['super_only'] = False



train_set,test_set = label_unlabel_load(cfg)
valid_set = make_valid()

train_dataset = CifarDataset(train_set, unlabel=False)
valid_dataset = CifarDataset(valid_set, unlabel=False)
test_dataset  = CifarDataset(test_set,unlabel=False)

train_loader = DataLoader(train_dataset,batch_size=cfg['batch_size'],shuffle=True)
valid_loader = DataLoader(valid_dataset,batch_size=cfg['batch_size'],shuffle=False)
test_loader  = DataLoader(test_dataset,batch_size=cfg['batch_size'],shuffle=False)

transformer = make_transform()

#model = PiModel(device='cuda')
model = Model(cfg['model_name']).to('cuda')
criterion = PiCriterion(cfg)
optimizer = torch.optim.Adam(model.parameters(),lr=cfg['lr'],betas=(cfg['beta1'],cfg['beta2']))



best_epoch = np.inf 
for epoch in range(cfg['epochs']):
    loss,tl_loss,tu_loss,weight =  train(model,criterion,optimizer,train_loader,cfg,transformer)
    f1 , auc = valid(model,valid_loader,cfg)
    print(f'\n Epochs : {epoch}')
    print(f'\n loss : {loss} | tl_loss : {tl_loss} | tu_loss : {tu_loss}')
    print(f'\n valid f1 : {f1}')
    print(f'\n valid auc : {auc}')
    
    if loss < best_epoch:
        torch.save(model,'./Save_models/best.pt')
        best_epoch = loss 
        print(f'model saved | best loss :{best_epoch}')
    '''
    wandb.log({'loss':loss,
               'tl_loss':tl_loss,
               'tu_loss':tu_loss,
               'weight':weight
               })
    '''
    #if  loss > 10000:
     #   model = torch.load('./Save_models/best.pt')
      #  print('Model reloaded')
f1 , auc = valid(model,test_loader,transformer,cfg)
print(f"\n F1 score : {f1} | Auccuracy :")