In [1]:
import pickle 
import pandas as pd 
import numpy as np 
import os 
from glob import glob 
import tqdm 
from PIL import Image 
from tqdm import tqdm 
import wandb
import math

from sklearn.metrics import f1_score,accuracy_score

from src.Dataset import CifarDataset,label_unlabel_load,dataset_load
from src.Models import Model,PiModel
from src.Loss import PiCriterion

import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms

import warnings 
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
def make_transform():
    color_jitter = transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1)
    transformer = transforms.Compose([
                                          transforms.RandomApply([color_jitter],p=0.8),
                                          transforms.RandomResizedCrop(32),
                                          transforms.GaussianBlur(kernel_size=int(0.1*32))
                                         ])
    return transformer
 
def make_valid(dataset = 'cifar10'):
    (train_imgs,train_labels),(test_imgs,test_labels) = dataset_load(dataset)
    idx = np.random.choice(np.arange(len(train_imgs)),5000,replace=False)
    valid_set = {'imgs':train_imgs[idx],
        'labels':train_labels[idx]}
    return valid_set 

def train(model,criterion,optimizer,train_loader,cfg,transformer):
    global epoch 
    model.train() 
    tl_loss = [] 
    tu_loss = [] 
    total_loss = [] 
    for batch_img,batch_labels in tqdm(train_loader):
        
        batch_img_1 = transformer(batch_img.type(torch.float32).to(cfg['device']))
        batch_img_2 = transformer(batch_img.type(torch.float32).to(cfg['device']))
        batch_labels = batch_labels.to(cfg['device'])
        
        y_pred_1 = model(batch_img_1,True)
        y_pred_2 = model(batch_img_2,True)
        loss,tl,tu,weight = criterion(y_pred_1,y_pred_2,batch_labels,epoch)
        
        total_loss.append(loss.detach().cpu().numpy())
        tl_loss.append(tl.detach().cpu().numpy())
        tu_loss.append(tu.detach().cpu().numpy())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    return np.mean(total_loss),np.mean(tl_loss),np.mean(tu_loss),weight

def valid(model,test_loader,cfg):
    labels = []
    y_preds = [] 
    model.eval() 
    for batch_imgs,batch_labels in test_loader:
        batch_imgs = batch_imgs.type(torch.float32).to(cfg['device'])
        with torch.no_grad():
            y_pred = model(batch_imgs,False)
        y_pred = torch.argmax(F.softmax(y_pred),dim=1)
        y_pred = y_pred.detach().cpu().numpy()
        
        y_preds.extend(y_pred)
        labels.extend(batch_labels.detach().cpu().numpy())    
    f1 = f1_score(np.array(y_preds),np.array(labels),average='macro')
    auc = accuracy_score(np.array(y_preds),np.array(labels))
    return f1, auc

In [3]:
wandb.init(project='BA_SSL',
           name = 'Try2')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcrimama-[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
cfg = {}
cfg['dataset'] = 'cifar10'
cfg['model_name'] = 'resnet18'
cfg['unlabel_ratio'] = 0
cfg['batch_size'] = 100 
cfg['device'] = 'cuda:0'
cfg['lr'] = 0.003 
cfg['beta1'] = 0.8
cfg['beta2'] = 0.999 
cfg['epochs'] = 300 
cfg['std'] = 0.15 



train_set,test_set = label_unlabel_load(cfg)
valid_set = make_valid()

train_dataset = CifarDataset(train_set, unlabel=False)
valid_dataset = CifarDataset(valid_set, unlabel=False)
test_dataset  = CifarDataset(test_set,unlabel=False)

train_loader = DataLoader(train_dataset,batch_size=cfg['batch_size'],shuffle=True)
valid_loader = DataLoader(valid_dataset,batch_size=cfg['batch_size'],shuffle=False)
test_loader  = DataLoader(test_dataset,batch_size=cfg['batch_size'],shuffle=False)

transformer = make_transform()

#model = PiModel(device='cuda')
model = Model(cfg['model_name']).to('cuda')
criterion = PiCriterion(cfg)
optimizer = torch.optim.Adam(model.parameters(),lr=cfg['lr'],betas=(cfg['beta1'],cfg['beta2']))



best_epoch = np.inf 
for epoch in range(cfg['epochs']):
    loss,tl_loss,tu_loss,weight =  train(model,criterion,optimizer,train_loader,cfg,transformer)
    f1 , auc = valid(model,valid_loader,cfg)
    print(f'\n Epochs : {epoch}')
    print(f'\n loss : {loss} | tl_loss : {tl_loss} | tu_loss : {tu_loss}')
    print(f'\n valid f1 : {f1}')
    print(f'\n valid auc : {auc}')
    
    if loss < best_epoch:
        torch.save(model,'./Save_models/best.pt')
        best_epoch = loss 
        print(f'model saved | best loss :{best_epoch}')
    wandb.log({'loss':loss,
               'tl_loss':tl_loss,
               'tu_loss':tu_loss,
               'weight':weight
               })
    
    #if  loss > 10000:
     #   model = torch.load('./Save_models/best.pt')
      #  print('Model reloaded')
f1 , auc = valid(model,test_loader,transformer,cfg)
print(f"\n F1 score : {f1} | Auccuracy :")

100%|██████████| 500/500 [00:17<00:00, 28.87it/s]



 Epochs : 0

 loss : 1.9154711961746216 | tl_loss : 1.9154711961746216 | tu_loss : 0.0

 valid f1 : 0.2952153543952718

 valid auc : 0.3502
model saved | best loss :1.9154711961746216


100%|██████████| 500/500 [00:17<00:00, 28.32it/s]



 Epochs : 1

 loss : 1.7168867588043213 | tl_loss : 1.6206587553024292 | tu_loss : 0.09622790664434433

 valid f1 : 0.3855233803068166

 valid auc : 0.445
model saved | best loss :1.7168867588043213


100%|██████████| 500/500 [00:17<00:00, 28.18it/s]



 Epochs : 2

 loss : 1.5330302715301514 | tl_loss : 1.4223341941833496 | tu_loss : 0.110695980489254

 valid f1 : 0.4793518422329397

 valid auc : 0.52
model saved | best loss :1.5330302715301514


100%|██████████| 500/500 [00:17<00:00, 28.06it/s]



 Epochs : 3

 loss : 1.4351662397384644 | tl_loss : 1.3105820417404175 | tu_loss : 0.12458419799804688

 valid f1 : 0.38286537821779826

 valid auc : 0.4292
model saved | best loss :1.4351662397384644


100%|██████████| 500/500 [00:17<00:00, 28.01it/s]



 Epochs : 4

 loss : 1.407201886177063 | tl_loss : 1.2669321298599243 | tu_loss : 0.14026974141597748

 valid f1 : 0.44498737115896453

 valid auc : 0.4958
model saved | best loss :1.407201886177063


100%|██████████| 500/500 [00:17<00:00, 27.91it/s]



 Epochs : 5

 loss : 1.372666597366333 | tl_loss : 1.218932867050171 | tu_loss : 0.1537337750196457

 valid f1 : 0.5330347321997548

 valid auc : 0.568
model saved | best loss :1.372666597366333


100%|██████████| 500/500 [00:17<00:00, 28.04it/s]



 Epochs : 6

 loss : 1.3100656270980835 | tl_loss : 1.1463325023651123 | tu_loss : 0.16373297572135925

 valid f1 : 0.6239629592958174

 valid auc : 0.642
model saved | best loss :1.3100656270980835


100%|██████████| 500/500 [00:17<00:00, 28.05it/s]



 Epochs : 7

 loss : 1.3090547323226929 | tl_loss : 1.137319803237915 | tu_loss : 0.1717347949743271

 valid f1 : 0.5935173103555849

 valid auc : 0.6152
model saved | best loss :1.3090547323226929


100%|██████████| 500/500 [00:17<00:00, 27.97it/s]



 Epochs : 8

 loss : 1.2967170476913452 | tl_loss : 1.1098322868347168 | tu_loss : 0.18688462674617767

 valid f1 : 0.6120155710460701

 valid auc : 0.6306


KeyboardInterrupt: 