In [1]:
import pickle 
import pandas as pd 
import numpy as np 
import os 
from glob import glob 
import tqdm 
from PIL import Image 
from tqdm import tqdm 
import wandb
import math

from sklearn.metrics import f1_score,accuracy_score

from src.Dataset import CifarDataset,label_unlabel_load,dataset_load
from src.Models import Model,PiModel
from src.Loss import PiCriterion

import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms

import warnings 
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def make_transform():
    color_jitter = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
    transformer = transforms.Compose([
                                          transforms.RandomApply([color_jitter],p=0.8),
                                          transforms.RandomResizedCrop(32),
                                          transforms.GaussianBlur(kernel_size=int(0.1*32))
                                         ])
    return transformer
 
def make_valid(dataset = 'cifar10'):
    (train_imgs,train_labels),(test_imgs,test_labels) = dataset_load(dataset)
    idx = np.random.choice(np.arange(len(train_imgs)),5000,replace=False)
    valid_set = {'imgs':train_imgs[idx],
        'labels':train_labels[idx]}
    return valid_set 

def train(model,criterion,optimizer,train_loader,cfg,transformer,te):
    global epoch
    model.train()     
    outputs = []
    tl_loss = [] 
    tu_loss = []
    total_loss = []  
    for step,(batch_img,batch_labels) in enumerate(tqdm(train_loader)):
        
        batch_img = transformer(batch_img.type(torch.float32).to(cfg['device']))
        batch_labels = batch_labels.to(cfg['device'])
        y_pred = model(batch_img,True)
        z_pred = te.predict(step).to(cfg['device'])
        
        loss,tl,tu,weight = criterion(y_pred,z_pred,batch_labels,epoch)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        outputs.extend(y_pred.detach().cpu().numpy())
        total_loss.append(loss.detach().cpu().numpy())
        tl_loss.append(tl.detach().cpu().numpy())
        tu_loss.append(tu.detach().cpu().numpy())
        
    te.update(outputs)
    return np.mean(total_loss),np.mean(tl_loss),np.mean(tu_loss),weight

def valid(model,test_loader,cfg):
    labels = []
    y_preds = [] 
    model.eval() 
    for batch_imgs,batch_labels in test_loader:
        batch_imgs = batch_imgs.type(torch.float32).to(cfg['device'])
        with torch.no_grad():
            y_pred = model(batch_imgs,False)
        y_pred = torch.argmax(F.softmax(y_pred),dim=1)
        y_pred = y_pred.detach().cpu().numpy()
        
        y_preds.extend(y_pred)
        labels.extend(batch_labels.detach().cpu().numpy())    
    f1 = f1_score(np.array(y_preds),np.array(labels),average='macro')
    auc = accuracy_score(np.array(y_preds),np.array(labels))
    return f1, auc

class CallBack:
    def __init__(self,cfg,wandb=False):
        self.best_loss = np.inf
        self.best_epoch = 0
        self.cfg = cfg 
        self.wandb = wandb
    
    def model_checkpoint(self,model,epoch,loss):
        torch.save(model,f"./Save_models/{self.cfg['dir']}/best.pt")
        self.best_loss = loss 
        self.best_epoch = epoch
        print(f'model saved | best loss :{self.best_epoch}')
        
    def model_reloaded(self):
        model = torch.load(f"./Save_models/{self.cfg['dir']}/best.pt")
        print('Model reloaded')
        return model 
    
    def model_log(self,result_log):
        wandb.log(result_log)
        
        
    def __call__(self,model,epoch,result_log):
        #log 
        loss = result_log['loss']
        if self.wandb == True:
            self.model_log(result_log)
        #check point 
        if loss < self.best_loss:
            self.model_checkpoint(model,epoch,loss)
            print('Model saved')
        #reloaded 
        if loss > 100000:
            print('Model Reloaded')
            return self.model_reloaded()
        #early stopping 
        if epoch - self.cfg['Early_stop'] > self.best_epoch:
            print('Model Early stopped')
            return True 
        
        
class TemporalEnsemble:
    def __init__(self,cfg):
        self.cfg = cfg
        self.Z = self.make_zeros()
        self.alpha = cfg['alpha']
        
    def make_zeros(self):
        if self.cfg['dataset'] == 'cifar10':
            return torch.zeros(50000,10).float()
        elif self.cfg['dataset'] == 'cifar100':
            return torch.zeros(50000,100).float()
        
    def predict(self,i):
        return self.Z[i * self.cfg['batch_size'] : (i+1)*self.cfg['batch_size']]
    
    def update(self,outputs):
        outputs = torch.from_numpy(np.array(outputs))
        self.Z = self.alpha * self.Z + (1. - self.alpha) * outputs

In [None]:
cfg = {}
cfg['dataset'] = 'cifar10'
cfg['model_name'] = 'resnet18'
cfg['unlabel_ratio'] = 0.6
cfg['batch_size'] = 100 
cfg['device'] = 'cuda:0'
cfg['lr'] = 0.003 
cfg['beta1'] = 0.8
cfg['beta2'] = 0.999 
cfg['epochs'] = 150 
cfg['std'] = 0.15 
cfg['super_only'] = False
cfg['Early_stop'] = 50 
cfg['dir'] = 'test'
cfg['alpha'] = 0.6  

train_set,test_set = label_unlabel_load(cfg)
valid_set = make_valid()

train_dataset = CifarDataset(train_set, unlabel=False)
valid_dataset = CifarDataset(valid_set, unlabel=False)
test_dataset  = CifarDataset(test_set,unlabel=False)

train_loader = DataLoader(train_dataset,batch_size=cfg['batch_size'],shuffle=True)
valid_loader = DataLoader(valid_dataset,batch_size=cfg['batch_size'],shuffle=False)
test_loader  = DataLoader(test_dataset,batch_size=cfg['batch_size'],shuffle=False)

transformer = make_transform()

#model = PiModel(device='cuda')
model = Model(cfg['model_name']).to('cuda')
criterion = PiCriterion(cfg)
optimizer = torch.optim.Adam(model.parameters(),lr=cfg['lr'],betas=(cfg['beta1'],cfg['beta2']))
callbacks = CallBack(cfg)
te = TemporalEnsemble(cfg)


best_epoch = np.inf 
for epoch in range(cfg['epochs']):
    loss,tl_loss,tu_loss,weight =  train(model,criterion,optimizer,train_loader,cfg,transformer,te)
    f1 , auc = valid(model,valid_loader,cfg)
    print(f'\n Epochs : {epoch}')
    print(f'\n loss : {loss} | tl_loss : {tl_loss} | tu_loss : {tu_loss}')
    print(f'\n valid f1 : {f1}')
    print(f'\n valid auc : {auc}')
    
    callback = callbacks(model,epoch,{'loss':loss,
                               'tl_loss':tl_loss,
                               'tu_loss':tu_loss,
                               'weight':weight
                                }
                        )
    if type(callback) == True:
        break
    elif callback == None:
        pass
    else:
        model = callback
    
    
f1 , auc = valid(model,test_loader,cfg)
print(f"\n F1 score : {f1} | Auccuracy :")

In [11]:
model  = torch.load('./Save_models/0_resnet18_False_0.0/best.pt')
f1 , auc = valid(model,valid_loader,cfg)
print(f1,auc)

0.7709015967313568 0.7746
