In [None]:
import pickle 
import pandas as pd 
import numpy as np 
import os 
from glob import glob 
import tqdm 
from PIL import Image 
from tqdm import tqdm 
import wandb
import math

from sklearn.metrics import f1_score,accuracy_score

from src.Dataset import CifarDataset,label_unlabel_load,img_load_all
from src.Models import Model,PiModel
from src.Loss import PiCriterion

import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms

import warnings 
warnings.filterwarnings('ignore')

In [28]:
def make_transform(cfg):
    color_jitter = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
    label_transform = transforms.Compose([
                                          transforms.RandomApply([color_jitter],p=0.8),
                                          transforms.RandomResizedCrop(32),
                                          transforms.GaussianBlur(kernel_size=int(0.1*32))
                                         ])
    return label_transform 
def make_valid(dataset = 'cifar10'):
    (train_imgs,train_labels),(test_imgs,test_labels) = img_load_all(dataset)
    idx = np.random.choice(np.arange(len(train_imgs)),5000,replace=False)
    valid_set = {'imgs':train_imgs[idx],
        'labels':train_labels[idx]}
    return valid_set 

def train(model,criterion,optimizer,train_loader,cfg,transform):
    global epoch 
    model.train() 
    tl_loss = [] 
    tu_loss = [] 
    epoch_loss = [] 
    for batch_img,batch_labels in tqdm(train_loader):
        
        batch_img_1 = transform(batch_img.type(torch.float32).to(cfg['device']))
        batch_img_2 = transform(batch_img.type(torch.float32).to(cfg['device']))
        batch_labels = batch_labels.to(cfg['device'])
        
        y_pred_1 = model(batch_img_1)
        y_pred_2 = model(batch_img_2)
        loss,tl,tu = criterion(y_pred_1,y_pred_2,batch_labels,epoch)
        
        epoch_loss.append(loss.detach().cpu().numpy())
        tl_loss.append(tl.detach().cpu().numpy())
        tu_loss.append(tu.detach().cpu().numpy())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    return np.mean(epoch_loss),np.mean(tl_loss),np.mean(tu_loss)

def valid(model,test_loader,transform,cfg):
    labels = []
    y_preds = [] 
    model.eval() 
    for batch_imgs,batch_labels in test_loader:
        batch_imgs = batch_imgs.type(torch.float32).to(cfg['device'])
        with torch.no_grad():
            y_pred = model(batch_imgs)
        y_pred = torch.argmax(F.softmax(y_pred),dim=1)
        y_pred = y_pred.detach().cpu().numpy()
        
        y_preds.extend(y_pred)
        labels.extend(batch_labels.detach().cpu().numpy())    
    f1 = f1_score(np.array(y_preds),np.array(labels),average='macro')
    auc = accuracy_score(np.array(y_preds),np.array(labels))
    return f1, auc

In [29]:
cfg = {}
cfg['dataset'] = 'cifar10'
cfg['model_name'] = 'resnet18'
cfg['unlabel_ratio'] = 0
cfg['batch_size'] = 100 
cfg['device'] = 'cuda:0'
cfg['lr'] = 0.003 
cfg['beta1'] = 0.8
cfg['beta2'] = 0.999 
cfg['epochs'] = 300 
cfg['std'] = 0.15 



train_set,test_set = label_unlabel_load(cfg)
valid_set = make_valid()
train_dataset = CifarDataset(train_set, unlabel=False)
valid_dataset = CifarDataset(valid_set, unlabel=False)
test_dataset = CifarDataset(test_set,unlabel=False)

train_loader = DataLoader(train_dataset,batch_size=cfg['batch_size'],shuffle=True)
valid_loader = DataLoader(valid_dataset,batch_size=cfg['batch_size'],shuffle=False)
test_loader = DataLoader(test_dataset,batch_size=cfg['batch_size'],shuffle=False)

tl_transform= make_transform(cfg)

#model = PiModel(device='cuda')
model = Model(cfg['model_name']).to('cuda')
criterion = PiCriterion()
optimizer = torch.optim.Adam(model.parameters(),lr=cfg['lr'],betas=(cfg['beta1'],cfg['beta2']))



best_epoch = np.inf 
for epoch in range(cfg['epochs']):
    loss,tl_loss,tu_loss =  train(model,criterion,optimizer,train_loader,cfg,tl_transform)
    f1 , auc = valid(model,valid_loader,tl_transform,cfg)
    print(f'\n Epochs : {epoch}')
    print(f'\n loss : {loss} | tl_loss : {tl_loss} | tu_loss : {tu_loss}')
    print(f'\n test f1 : {f1}')
    print(f'\n test auc : {auc}')
    
    if loss < best_epoch:
        torch.save(model,'./Save_models/best.pt')
        best_epoch = loss 
        print(f'model saved | best loss :{best_epoch}')

    #if  loss > 10000:
     #   model = torch.load('./Save_models/best.pt')
      #  print('Model reloaded')

100%|██████████| 500/500 [00:17<00:00, 28.68it/s]



 Epochs : 0

 loss : 1.9194456338882446 | tl_loss : 1.9194456338882446 | tu_loss : 0.0

 test f1 : 0.36534949682177015

 test auc : 0.4058
model saved | best loss :1.9194456338882446


 23%|██▎       | 113/500 [00:03<00:13, 28.39it/s]


KeyboardInterrupt: 

: 