In [1]:
import pickle 
import pandas as pd 
import numpy as np 
import os 
from glob import glob 
import tqdm 
from PIL import Image 
from tqdm import tqdm 
import wandb
import math

from sklearn.metrics import f1_score,accuracy_score

from src.Dataset import CifarDataset,label_unlabel_load,dataset_load
from src.Models import Model,PiModel
from src.Loss import PiCriterion

import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms

import warnings 
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def make_transform():
    color_jitter = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
    transformer = transforms.Compose([
                                          transforms.RandomApply([color_jitter],p=0.8),
                                          transforms.RandomResizedCrop(32),
                                          transforms.GaussianBlur(kernel_size=int(0.1*32))
                                         ])
    return transformer
 
def make_valid(dataset = 'cifar10'):
    (train_imgs,train_labels),(test_imgs,test_labels) = dataset_load(dataset)
    idx = np.random.choice(np.arange(len(train_imgs)),5000,replace=False)
    valid_set = {'imgs':train_imgs[idx],
        'labels':train_labels[idx]}
    return valid_set 

def train(model,criterion,optimizer,train_loader,cfg,transformer):
    global epoch 
    model.train() 
    tl_loss = [] 
    tu_loss = [] 
    total_loss = [] 
    for batch_img,batch_labels in tqdm(train_loader):
        
        batch_img_1 = transformer(batch_img.type(torch.float32).to(cfg['device']))
        batch_img_2 = transformer(batch_img.type(torch.float32).to(cfg['device']))
        batch_labels = batch_labels.to(cfg['device'])
        
        y_pred_1 = model(batch_img_1,True)
        y_pred_2 = model(batch_img_2,True)
        loss,tl,tu,weight = criterion(y_pred_1,y_pred_2,batch_labels,epoch)
        
        total_loss.append(loss.detach().cpu().numpy())
        tl_loss.append(tl.detach().cpu().numpy())
        tu_loss.append(tu.detach().cpu().numpy())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    return np.mean(total_loss),np.mean(tl_loss),np.mean(tu_loss),weight

def valid(model,test_loader,cfg):
    labels = []
    y_preds = [] 
    model.eval() 
    for batch_imgs,batch_labels in test_loader:
        batch_imgs = batch_imgs.type(torch.float32).to(cfg['device'])
        with torch.no_grad():
            y_pred = model(batch_imgs,False)
        y_pred = torch.argmax(F.softmax(y_pred),dim=1)
        y_pred = y_pred.detach().cpu().numpy()
        
        y_preds.extend(y_pred)
        labels.extend(batch_labels.detach().cpu().numpy())    
    f1 = f1_score(np.array(y_preds),np.array(labels),average='macro')
    auc = accuracy_score(np.array(y_preds),np.array(labels))
    return f1, auc

In [None]:
cfg = {}
cfg['dataset'] = 'cifar10'
cfg['model_name'] = 'resnet18'
cfg['unlabel_ratio'] = 0.6
cfg['batch_size'] = 100 
cfg['device'] = 'cuda:0'
cfg['lr'] = 0.003 
cfg['beta1'] = 0.8
cfg['beta2'] = 0.999 
cfg['epochs'] = 300 
cfg['std'] = 0.15 
cfg['super_only'] = False



train_set,test_set = label_unlabel_load(cfg)
valid_set = make_valid()

train_dataset = CifarDataset(train_set, unlabel=False)
valid_dataset = CifarDataset(valid_set, unlabel=False)
test_dataset  = CifarDataset(test_set,unlabel=False)

train_loader = DataLoader(train_dataset,batch_size=cfg['batch_size'],shuffle=True)
valid_loader = DataLoader(valid_dataset,batch_size=cfg['batch_size'],shuffle=False)
test_loader  = DataLoader(test_dataset,batch_size=cfg['batch_size'],shuffle=False)

transformer = make_transform()

#model = PiModel(device='cuda')
model = Model(cfg['model_name']).to('cuda')
criterion = PiCriterion(cfg)
optimizer = torch.optim.Adam(model.parameters(),lr=cfg['lr'],betas=(cfg['beta1'],cfg['beta2']))



best_epoch = np.inf 
for epoch in range(cfg['epochs']):
    loss,tl_loss,tu_loss,weight =  train(model,criterion,optimizer,train_loader,cfg,transformer)
    f1 , auc = valid(model,valid_loader,cfg)
    print(f'\n Epochs : {epoch}')
    print(f'\n loss : {loss} | tl_loss : {tl_loss} | tu_loss : {tu_loss}')
    print(f'\n valid f1 : {f1}')
    print(f'\n valid auc : {auc}')
    
    if loss < best_epoch:
        torch.save(model,'./Save_models/best.pt')
        best_epoch = loss 
        print(f'model saved | best loss :{best_epoch}')
    '''
    wandb.log({'loss':loss,
               'tl_loss':tl_loss,
               'tu_loss':tu_loss,
               'weight':weight
               })
    '''
    #if  loss > 10000:
     #   model = torch.load('./Save_models/best.pt')
      #  print('Model reloaded')
f1 , auc = valid(model,test_loader,transformer,cfg)
print(f"\n F1 score : {f1} | Auccuracy :")

In [2]:
class CallBack:
    def __init__(self,cfg):
        self.best_loss = np.inf
        self.best_epoch = 0
        self.cfg = cfg 
    
    def model_checkpoint(self,epoch,loss):
        torch.save(model,f"./Save_models/{self.cfg['dir']}/best.pt")
        self.best_loss = loss 
        self.best_epoch = epoch
        print(f'model saved | best loss :{best_epoch}')
        
    def model_reloaded(self):
        model = torch.load(f"./Save_models/{self.cfg['dir']}/best.pt")
        print('Model reloaded')
        return model 
    
    def model_log(self,result_log):
        wandb.log(result_log)
        
        
    def __call__(self,epoch,result_log):
        #log 
        loss = result_log['loss']
        self.model_log(result_log)
        #check point 
        if loss < best_epoch:
            self.model_checkpoint(epoch,loss)
        #reloaded 
        if loss > 100000:
            return self.model_reloaded()
        #early stopping 
        if epoch + 15 > self.best_epoch:
            return True 

In [6]:
for batch_img,batch_labels in tqdm(train_loader):
    
    batch_img_1 = transformer(batch_img.type(torch.float32).to(cfg['device']))
    batch_img_2 = transformer(batch_img.type(torch.float32).to(cfg['device']))
    batch_labels = batch_labels.to(cfg['device'])
    
    y_pred_1 = model(batch_img_1,True)
    y_pred_2 = model(batch_img_2,True)
    loss,tl,tu,weight = criterion(y_pred_1,y_pred_2,batch_labels,epoch)

  0%|          | 0/500 [00:00<?, ?it/s]


In [15]:
label_idx = (batch_labels!=-1).nonzero().flatten()
batch_labels[label_idx].shap

tensor([2, 9, 9, 5, 6, 4, 6, 1, 6, 6, 2, 4, 8, 2, 2, 8, 9, 5, 9, 1, 7, 4, 4, 8,
        0, 9, 0, 6, 1, 5, 0, 9, 7, 5, 8, 8, 8, 1], device='cuda:0')