In [14]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import matplotlib.pyplot as plt 
import torch
import torch.utils.data as data
import torch.nn as nn
import torch.nn.functional as F
import torchvision 
import torchvision.transforms as transforms
import kornia as K
import numpy as np
import easydict
import os
import sys
from datetime import datetime


sys.path.append("..")
import mixmatch
import datasets
import transformations as custom_transforms
import utils
import losses
import models
import ramps 
import wide_resnet

# CIFAR 10 reference setting 
args = easydict.EasyDict()
args.train_iterations = 100000
args.K = 2
args.T = 0.5
args.alpha = 0.75
args.lam_u = 75 
args.rampup_length = 100000
args.n_labeled = 250
args.batch_size = 64
args.lr = 0.002 
args.ewa_coef = 0.95
args.device = utils.get_device(1)
args.cifar_root = './CIFAR10'
args.cifar_download = True
args.mean_teacher_coef = None

args.basename = 'foo'
args.call_prefix = '-2'
args.res_path = './'
args.new_log = False

args.log_period = 1000
args.save_period = 10000
args.validation_period = 1000

args.logpath = 'logs/log-' + args.basename + args.call_prefix + '.txt'
args.logpath = os.path.join(args.res_path, args.logpath)

args.model_path = 'models/' + args.basename + args.call_prefix +'-'
args.model_path = os.path.join(args.res_path, args.model_path)



if not os.path.exists(args.logpath) or args.new_log:
    print(f"# Starting at {datetime.now()}",file=open(args.logpath,'w'),flush=True)
else:
    print(f"# Starting at {datetime.now()}",file=open(args.logpath,'a'),flush=True)

print(f"with args:\n" + "\n".join([f"{key} : {value}" for key,value in args.items()]),file=open(args.logpath,'a'),flush=True)
print(f"logpath: {args.logpath}",file=open(args.logpath,'a'),flush=True)
print(f"modelpath: {args.model_path}<name>.pt",file=open(args.logpath,'a'),flush=True)
 
# Datasets and dataloaders

labeled_dataloader, unlabeled_dataloader, validation_dataloader = datasets.get_CIFAR10(args.cifar_root,args.n_labeled,args.batch_size,download=args.cifar_download)
num_classes = 10


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Files already downloaded and verified


In [3]:
# Transforamtion
#k1 = custom_transforms.GaussianNoiseChannelwise((0.0001, 0.0001, 0.0001))
#k2 = K.augmentation.RandomGaussianBlur((3,3),sigma=(1.,1.),p = 0.5)
k3 = K.augmentation.RandomHorizontalFlip(p=0.5)
k4 = K.augmentation.RandomVerticalFlip(p=0.5)
#k5 = K.augmentation.RandomAffine([-5., 5.], [0.1, 0.1], [0.8, 1.2], [0., 0.15])

img_trans = nn.ModuleList([k3,k4])
mask_trans = nn.ModuleList([k3,k4]) # only for segmentation 
invert_trans  = nn.ModuleList([k3,k4])
num_classes = 10 
transform = custom_transforms.MyAugmentation(img_trans,mask_trans,invert_trans)



# Model, optimizer and eval_function
model = wide_resnet.WideResNet(num_classes)
opt = torch.optim.Adam(params=model.parameters(),lr = args.lr)
eval_loss_fn = losses.kl_divergence

# Load previous checkpoint
if args.get('load_path',None) is not None:
    print(f"Loading checkpoint : {args.load_path}",file=open(args.logpath, 'a'), flush=True)
    count,metrics,net,opt,net_args = utils.load_checkpoint(args.device,model,opt,args.load_path) 
    ewa_loss = metrics['train_criterion_ewa'][-1]
else:
    print("Creating new network!",file=open(args.logpath, 'a'), flush=True)
    
    metrics = easydict.EasyDict()
    metrics['train_criterion'] = np.empty(0)
    metrics['train_criterion_ewa'] = np.empty(0)
    metrics['val_loss'] = np.empty(0)
    metrics['val_acc'] = np.empty(0)
    metrics['train_loss'] = np.empty(0)
    metrics['train_acc'] = np.empty(0)
    count = 0
    ewa_loss = 0


print('    Total params: %.2fM' % (sum(p.numel() for p in model.parameters())/1000000.0),file=open(args.logpath, 'a'), flush=True)

# Preparation for training function 

labeled_train_iter = iter(labeled_dataloader)
unlabeled_train_iter = iter(unlabeled_dataloader)

# Use Teacher if desired
if args.mean_teacher_coef:
        mixmatch_clf = models.Mean_Teacher(model,args.mean_teacher_coef)
else:
    mixmatch_clf = model

model.train()
model.to(args.device)

# Iterate over index iterator until the desired number of iteration is achived
while count < args.train_iterations:

    if count == 0: # for first time
        ls,acc = wide_resnet.evaluate(model,eval_loss_fn,validation_dataloader,args.device)
        metrics['val_loss'] = np.append(metrics['val_loss'],ls.detach().cpu().numpy())
        metrics['val_acc'] = np.append(metrics['val_acc'],acc.detach().cpu().numpy())

        #
        ls,acc = wide_resnet.evaluate(model,eval_loss_fn,labeled_dataloader,args.device)
        metrics['train_loss'] = np.append(metrics['train_loss'],ls.detach().cpu().numpy())
        metrics['train_acc'] = np.append(metrics['train_acc'],acc.detach().cpu().numpy())
    
    # Iterate over the end if necessary (Can be used with different sizes of dataloaders)
    try:
        data_l, labels = next(labeled_train_iter)
    except:
        labeled_train_iter = iter(labeled_dataloader)
        data_l, labels = next(labeled_train_iter)


    try:
        data_u = next(unlabeled_train_iter)
    except:
        unlabeled_train_iter = iter(unlabeled_dataloader)
        data_u = next(unlabeled_train_iter)

    data_l = data_l.to(args.device)
    labels = labels.to(args.device)
    data_u = data_u.to(args.device)

    # Corner case (batches with different sizes, namely for iregular last batch)
    critical_count = None
    
    current_batch_size = min(data_l.shape[0],data_u.shape[0])
    
    data_l = data_l[:current_batch_size]
    labels = labels[:current_batch_size]
    data_u = data_u[:current_batch_size]

    with torch.no_grad():
        model.eval()
        l_batch,u_batch = mixmatch.mixmatch(labeled_batch=data_l,
                                                labels=labels,
                                                unlabeled_batch=data_u,
                                                clf=mixmatch_clf,
                                                augumentation=transform,
                                                K=args.K,
                                                T=args.T,
                                                alpha=args.alpha
                                                )
        
    x = torch.cat([l_batch[0],u_batch[0]],dim=0)
    targets_l,targets_u = l_batch[1],u_batch[1]

    # Interleave labeled and unlabeled samples between batches to obtain correct batchnorm calculation
    x_splitted = list(torch.split(x, current_batch_size))
    x_splitted = mixmatch.interleave(x_splitted, current_batch_size)
    
    # Forward 
    model.train() 
    logits = [model(x_splitted[0])]
    for x in x_splitted[1:]:
        logits.append(model(x))

    # put interleaved samples back
    logits = mixmatch.interleave(logits, current_batch_size)
    logits_l = logits[0]
    logits_u = torch.cat(logits[1:], dim=0)

    # Loss 
    # TODO: Deal with mask of valid regions of transformed images (fo affine transformation) -> remove black parts
    loss_supervised = losses.soft_cross_entropy(logits_l,targets_l,reduction='mean')
    loss_unsupervised = losses.mse_softmax(logits_u,targets_u,reduction='mean')

    # Lx = -torch.mean(torch.sum(F.log_softmax(logits_l, dim=1) * targets_l, dim=1))
    # Lu = torch.mean((torch.softmax(logits_u, dim=1) - targets_u)**2)
    # print(f"{loss_supervised=:.2f},{Lx=:.2f}")
    # print(f"{loss_unsupervised=:.2f}{Lu=:.2f}")

    lam_u = ramps.linear_rampup(current = count, rampup_length = args.rampup_length) * args.lam_u
    loss = loss_supervised + lam_u * loss_unsupervised

    # SGD
    opt.zero_grad()
    loss.backward()
    opt.step()
    if args.mean_teacher_coef: 
        mixmatch_clf.update_weights(model)

    # Ewa loss
    if (count == 0 and ewa_loss == 0):
        ewa_loss = loss        
    else:
        ewa_loss = args.ewa_coef * ewa_loss + (1-args.ewa_coef) * loss
    
    # Save loss (every time):
    metrics['train_criterion'] = np.append(metrics['train_criterion'],loss.detach().cpu().numpy())
    metrics['train_criterrion_ewa'] = np.append(metrics['train_criterion'],ewa_loss.detach().cpu().numpy())

    
    # Compute validation metrics if validation period 
    if (count % args.validation_period == args.validation_period-1) or (count == args.train_iterations-1): # for first time
        ls,acc = wide_resnet.evaluate(model,eval_loss_fn,validation_dataloader,args.device)
        metrics['val_loss'] = np.append(metrics['val_loss'],ls.detach().cpu().numpy())
        metrics['val_acc'] = np.append(metrics['val_acc'],acc.detach().cpu().numpy())

        #
        ls,acc = wide_resnet.evaluate(model,eval_loss_fn,labeled_dataloader,args.device)
        metrics['train_loss'] = np.append(metrics['train_loss'],ls.detach().cpu().numpy())
        metrics['train_acc'] = np.append(metrics['train_acc'],acc.detach().cpu().numpy())


    # Print log if log period
    if (count % args.log_period == args.log_period-1) or (count == args.train_iterations-1):
        strtoprint = f"batch iteration: {str(count)} "+  \
                     f"ewa loss: {ewa_loss:.2f} " + \
                     f"val loss: {metrics['val_loss'][-1]:.2f} " + \
                     f"val acc: {metrics['val_acc'][-1]:.2f} " + \
                     f"train loss: {metrics['train_loss'][-1]:.2f} " + \
                     f"train acc:  {metrics['train_acc'][-1]:.2f} "
        print(strtoprint, file=open(args.logpath, 'a'), flush=True)

    # Save checkpoint if save_period
    if (count % args.save_period == args.save_period-1) or (count == args.train_iterations-1):
        m_path = args.model_path + f"e{count}" + '-m.pt'
        print(f'# Saving Model : {m_path}', file=open(args.logpath, 'a'), flush=True)
        utils.save_checkpoint(count,metrics,model,opt,args,m_path)
    
    count += 1

    Total params: 1.47M


KeyboardInterrupt: 