<a href="https://colab.research.google.com/github/mishless/heroes-game-planning/blob/master/Rob_ADL_Project_George.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [0]:
import os
import urllib.request

from copy import deepcopy

import numpy as np

from scipy.special import logsumexp

from tqdm.auto import tqdm

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

import torch
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
from torch import nn

from datetime import datetime
from torchvision import models, transforms
from torchvision.datasets import CIFAR10, CIFAR100
from torch.utils.data import DataLoader

import argparse
from copy import deepcopy
from tqdm import trange
import time
from sklearn.metrics import confusion_matrix


# Connect to Google Drive

In [2]:
from google.colab import drive
drive.mount('/content/drive')
os.chdir("/content/drive")
main_path = os.path.join("My Drive", "ADL_Project")
if not os.path.exists(main_path):
    os.mkdir(main_path)
os.chdir(main_path)
os.getcwd()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


'/content/drive/My Drive/ADL_Project'

In [3]:
os.listdir(".")

['Data',
 'mwn',
 'meta_weight_net',
 'meta_model',
 'model',
 'Experiments',
 'CIFAR_10_Uniform noise_0_MWN',
 'CIFAR_10_Uniform noise_0.4_MWN',
 'CIFAR_10_Uniform noise_0.4_MWN_confusion_matrix.png',
 'CIFAR_10_Imbalance_200_MWN']

# Data

## Transformations

In [0]:
normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                                     std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: F.pad(x.unsqueeze(0),
                                      (4, 4, 4, 4), mode='reflect').squeeze()),
    transforms.ToPILImage(),
    transforms.RandomCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize
])

## Prepare data (download)

In [0]:
def get_CIFAR_data(cifar=10, train_transform=None, test_transform=None,
                   train_target_transform=None, test_target_transform=None):
    
    if cifar == 10:
        dataset = CIFAR10
        path = os.path.join("Data", "cifar-10-batches-py")
    elif cifar == 100:
        dataset = CIFAR100
        path = os.path.join("Data", "cifar-100-python")
                
    train_data = dataset(path, train=True,
                         transform=train_transform, 
                         target_transform=train_target_transform, 
                         download= not os.path.exists(path))
        
    test_data = dataset(path, train=False, 
                        transform=test_transform, 
                        target_transform=test_target_transform, 
                        download= not os.path.exists(path))
    
    return train_data, test_data
      

## Creating imbalanced data

In [0]:
def generate_imbalance_data(dataset, 
                            factor=200, 
                            num_meta_per_class=10, 
                            num_of_corrupted=10,
                            seed=123):
    imbalance_factor = factor
    
    num_of_classes = np.unique(dataset.targets).shape[0]
    num_of_total_targets = len(dataset.targets)
    
    indices_dict = {target: np.where(np.array(dataset.targets) == target)[0] 
                    for target in range(num_of_classes)}
    
    meta_indices = []
    train_indices = []
    np.random.seed(seed)
    for target in indices_dict.keys():
        np.random.shuffle(indices_dict[target])
        meta_indices += indices_dict[target][: num_meta_per_class].tolist()
        
        imbalance_size = int((
            indices_dict[target].shape[0] - num_meta_per_class) / float(
            imbalance_factor) ** float(target / (num_of_classes-1) ) )
        
        train_indices +=  indices_dict[target][
            num_meta_per_class: num_meta_per_class + imbalance_size].tolist()
        
    train_data = deepcopy(dataset)
    meta_data = deepcopy(dataset)
    
    train_indices = np.array(train_indices).flatten()
    meta_indices = np.array(meta_indices).flatten()    
    
    train_data.data = train_data.data[train_indices]
    meta_data.data = meta_data.data[meta_indices]
    
    train_data.targets = np.array(train_data.targets)[train_indices]
    meta_data.targets = np.array(meta_data.targets)[meta_indices]
       
    return train_data, meta_data, None

## Uniform noise

In [0]:
def generate_noise_data(dataset, 
                        factor=0.4, 
                        num_meta_per_class=10, 
                        num_of_corrupted=10,
                        seed=123):

    noise_factor = factor

    num_of_classes = np.unique(dataset.targets).shape[0]
    num_of_total_targets = len(dataset.targets)

    indices_array = np.array([np.where(np.array(dataset.targets) == target)[0] 
                          for target in range(num_of_classes)])

    temp_indices = np.empty((num_meta_per_class, indices_array.shape[1] - num_meta_per_class))
    meta_indices = []
    np.random.seed(seed)
    for n in range(num_of_classes):
        np.random.shuffle(indices_array[n])
        meta_indices.append(indices_array[n, : num_meta_per_class])
        temp_indices[n] = indices_array[n, num_meta_per_class: ]

    train_data = deepcopy(dataset)
    meta_data = deepcopy(dataset)

    train_indices = temp_indices.flatten().astype(int)
    np.random.shuffle(train_indices)
    meta_indices = np.array(meta_indices).flatten().astype(int)

    train_data.data = train_data.data[train_indices]
    meta_data.data = meta_data.data[meta_indices]

    train_data.targets = np.array(train_data.targets)[train_indices]
    meta_data.targets = np.array(meta_data.targets)[meta_indices]

    corrupted_indices = []
    chosen_indices = [[x for x in range(num_of_classes) if x != t] for t in range(num_of_classes)]
    for i in range(num_of_classes):
      target_indices = np.where(train_data.targets == i)[0]
      for n in target_indices:
        if np.random.random() < factor:
          train_data.targets[n] = np.random.choice(chosen_indices[i], 1)
          corrupted_indices.append(n)

    np.random.shuffle(corrupted_indices)
    mask = np.full(train_data.targets.shape[0], False)
    mask[corrupted_indices[:num_of_corrupted]] = True

    corrupted_data = deepcopy(train_data)

    corrupted_data.data = deepcopy(train_data.data[mask])
    corrupted_data.targets = deepcopy(train_data.targets[mask])
    
    return train_data, meta_data, corrupted_data



## Flip noise

In [0]:
def generate_flip_data(dataset, 
                       factor=0.4, 
                       num_meta_per_class=10, 
                       num_of_corrupted=10,
                       seed=123):

  num_of_classes = np.unique(dataset.targets).shape[0]

  indices_array = np.array([np.where(np.array(dataset.targets) == target)[0] 
                        for target in range(num_of_classes)])

  temp_indices = np.empty((num_meta_per_class, indices_array.shape[1] - num_meta_per_class))
  meta_indices = []
  np.random.seed(seed)
  for n in range(num_of_classes):
      np.random.shuffle(indices_array[n])
      meta_indices.append(indices_array[n, : num_meta_per_class])
      temp_indices[n] = indices_array[n, num_meta_per_class: ]

  train_data = deepcopy(dataset)
  meta_data = deepcopy(dataset)

  train_indices = temp_indices.flatten().astype(int)
  np.random.shuffle(train_indices)
  meta_indices = np.array(meta_indices).flatten().astype(int)

  train_data.data = train_data.data[train_indices]
  meta_data.data = meta_data.data[meta_indices]

  train_data.targets = np.array(train_data.targets)[train_indices]
  meta_data.targets = np.array(meta_data.targets)[meta_indices]


  classes = list(range(num_of_classes))
  np.random.shuffle(classes)

  random_pair_classes = [[i, j] for i, j in zip(
      classes[: int(len(classes)/2)], 
      classes[int(len(classes)/2): ])]

  corrupted_indices = []
  for pair in random_pair_classes:
    target_indices_0 = np.where(train_data.targets == pair[0])[0]
    np.random.shuffle(target_indices_0)
    target_indices_1 = np.where(train_data.targets == pair[1])[0]
    np.random.shuffle(target_indices_1)

    size = int(target_indices_0.shape[0] * factor)

    train_data.targets[target_indices_0][: size] = pair[1]
    train_data.targets[target_indices_1][: size] = pair[0]

    corrupted_indices += target_indices_0.tolist() + target_indices_1.tolist()  

  np.random.shuffle(corrupted_indices)
  mask = np.full(train_data.targets.shape[0], False)
  mask[corrupted_indices[:num_of_corrupted]] = True

  corrupted_data = deepcopy(train_data)

  corrupted_data.data = deepcopy(train_data.data[mask])
  corrupted_data.targets = deepcopy(train_data.targets[mask])
       
  return train_data, meta_data, corrupted_data

# Models - Classes

## MLP

In [0]:
class MLP(nn.Module):

    def __init__(self, input_dim=1, hidden_dim=100, output_dim=1, 
                 activation_hidden=torch.relu, activation_output=torch.sigmoid,
                 initialization=nn.init.kaiming_normal_,
                 mu=None, std=None, bias=True):
        super(MLP, self).__init__()

        self.activation_hidden = activation_hidden
        self.activation_output = activation_output
        
        self.first_layer = nn.Linear(input_dim, hidden_dim, bias=bias)
        self.second_layer = nn.Linear(hidden_dim, output_dim, bias=bias)

        if initialization == nn.init.kaiming_normal_:
            initialization(self.first_layer.weight)
            initialization(self.second_layer.weight)
        elif initialization == nn.init.normal_:
            mu = 0.0 if mu is None else mu
            std = 1.0 if std is None else std
            initialization(self.first_layer.weight, mu, std)
            initialization(self.second_layer.weight, mu, std)
            
    def forward(self, x):
        self.first_layer_output = self.activation_hidden(self.first_layer(x))
        return self.activation_output(self.second_layer(self.first_layer_output))


## Resnet-32

In [0]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
      

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.fc = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


def resnet32():
    return ResNet(BasicBlock, num_blocks=[3,4,5,3])

## Wide-Resnet-28-10

# Experiment

## Meta-training

In [0]:
def meta_training(args, train_loader, meta_loader, model, meta_model, meta_weight_net, optimizers, loss_functions):
  
  def normalize_weights(weights, args):
    sum_weights = torch.sum(weights)
    denom = sum_weights if sum_weights != 0 else args.tau
    return (weights / denom).to(args.cuda)
 
  train_loss = 0
  train_loss_weighted = 0
  meta_loss = 0
  train_predictions = []
  train_targets = []
  
  model.train()

  #for (x, y), (x_meta, y_meta) in zip(train_loader, meta_loader):
  
  for enum, (x, y) in tqdm(enumerate(train_loader)):
#    if enum == 3:
#      break
    x_meta, y_meta = next(iter(meta_loader))
#    print('batch nr: ',enum)
#    if args.cuda == "cuda:0":
#      torch.cuda.empty_cache()
#    x.float(), y.float(), x_meta.float(), y_meta.float()
    x, y, x_meta, y_meta = x.to(args.cuda), y.to(args.cuda), x_meta.to(args.cuda), y_meta.to(args.cuda)

    # forward pass for meta-model using input data
    y_pred = meta_model(x)
    loss = loss_functions['model'](y_pred, y)

    # forward pass for meta-weight-net
    weights = meta_weight_net(loss.reshape(-1,1))
    
    # normalize weights and get weighted loss
    normalized_weights = normalize_weights(weights, args)
    weighted_loss = torch.mean(loss * normalized_weights)

    # backward pass for meta-model
    meta_model.zero_grad()
    weighted_loss.backward()
    
    # update parameters of meta-model
    optimizers['meta_model'].step()
    

    # ----- STEP 6 ----- #
    # forward pass for meta model using meta data
    y_meta_pred = meta_model(x_meta)
    loss_meta = torch.mean(loss_functions['model'](y_meta_pred, y_meta))

    meta_loss += loss_meta
    
    # backward pass for meta weight net
    meta_weight_net.zero_grad()
    loss_meta.backward()
    
    # update parameters of meta weight net
    optimizers['meta_weight_net'].step()
    

    # ----- STEP 7 ----- #
    # forward pass for model using input data
    y_pred = model(x)

    train_predictions.append(torch.argmax(y_pred,dim=1))
    train_targets.append(y)
    
    loss = loss_functions['model'](y_pred, y)

    # forward pass for updated meta-weight-net
    weights = meta_weight_net(loss.reshape(-1,1))
    
    # normalize weights 
    normalized_weights = normalize_weights(weights, args)
    
    # compute weighted loss
    weighted_loss = torch.mean(loss * normalized_weights)
    
    train_loss += torch.mean(loss)
    train_loss_weighted += weighted_loss.item()
    
    # backward pass for model
    model.zero_grad()
    weighted_loss.backward()
    
    # update parameters of model
    optimizers['model'].step()
    #meta_model = deepcopy(model)
    meta_model.load_state_dict(model.state_dict())
    
  train_loss /= len(train_loader)          # old comment: loss function already averages over batch size?
  train_loss_weighted /= len(train_loader) # old comment: loss function already averages over batch size?
  meta_loss /= len(train_loader)
  
  return model, meta_model, meta_weight_net, train_loss, train_loss_weighted, meta_loss, train_predictions, train_targets

## Evaluation

In [0]:
def evaluate(args, data_loader, model, loss_function):
    
    predictions = []
    targets = []

    evaluation_loss = 0

    with torch.no_grad():

      for x_eval, y_eval in data_loader:

   #       if args.cuda == "cuda:0":
   #         torch.cuda.empty_cache()

          #  x_eval.float(), y_eval.float()
          x_eval, y_eval = x_eval.to(args.cuda), y_eval.to(args.cuda)

          # forward pass
          y_pred = model(x_eval)

          predictions.append(torch.argmax(y_pred,dim=1))
          targets.append(y_eval)

          # compute loss
          loss = loss_function(y_pred, y_eval)
          #print(loss.cpu())
          evaluation_loss += torch.mean(loss)

    evaluation_loss /= len(data_loader)
    return evaluation_loss, predictions, targets

In [0]:
def weight_variation(args, w_previous, corrupted_data_loader, model, meta_weight_net, model_loss_function):
  
  x, y_noisy = next(iter(corrupted_data_loader))
  
#  if args.cuda == 'cuda:0':
#    torch.cuda.empty_cache()
  
  x, y_noisy = x.to(args.cuda), y_noisy.to(args.cuda)
  
  y_pred = model(x)
  loss = model_loss_function(y_pred, y_noisy)
  w = meta_weight_net(loss.reshape(-1,1))
  
  w_variation = w - w_previous
  w_variation_mean = Variable(torch.mean(w_variation)).cpu().numpy()
  w_variation_std = Variable(torch.std(w_variation)).cpu().numpy()
  
  return w, w_variation_mean, w_variation_std
    

## Perform experiment

In [0]:
def perform_experiment(args, train_loader, meta_loader, test_loader, corrupted_data_loader, model, meta_model, meta_weight_net, optimizers, loss_functions):
    
  def compute_accuracy(torch_predictions, torch_targets):
    predictions = Variable(torch.cat(torch_predictions)).cpu().numpy()
    targets = Variable(torch.cat(torch_targets)).cpu().numpy()
    accuracy = np.sum(np.where(targets == predictions,1,0)) / targets.size
    return accuracy
  
  history = {'train_losses': [],
             'train_losses_weighted': [],
             'meta_losses': [],
             'test_losses': [],
             'train_accuracy': [],
             'test_accuracy': [],
             'test_targets': []}
  
  best_test_model_loss = np.inf
  best_model_info = {}
  best_test_model = None
  best_test_model_predictions = None
  best_test_model_targets = None
  best_test_model_accuracy = None
  
  if args.experiment_type == "Uniform noise": # ONLY FOR UNIFORM
    w_previous = 0
    history['weight_variation_means'] = []
    history['weight_variation_stds'] = []
  
  cifar_type = str(args.cifar_type) if args.dataset == 'CIFAR' else ""
  path = os.path.join(args.directory, args.dataset + cifar_type)
  path = args.directory
  
  epochs = trange(args.epochs_dic[args.experiment_type], leave=True)
  for epoch in epochs:
        
    if epoch in args.lr_schedule:
      optimizers['model'].param_groups[0]['lr'] = args.lr_schedule[epoch]
    
    time_start = time.time()
    
    # training
    model, meta_model, meta_weight_net, train_loss, train_loss_weighted, meta_loss, train_predictions, train_targets = meta_training(args, train_loader, meta_loader,  model, meta_model, meta_weight_net, optimizers, loss_functions) #args
    train_accuracy = compute_accuracy(train_predictions, train_targets)
    
    #print(np.unique(torch.cat(train_targets).cpu().numpy(),return_counts=True))

    # evaluation on test set
    test_loss, test_predictions, test_targets = evaluate(args, test_loader, model, loss_functions['model'])
    test_accuracy = compute_accuracy(test_predictions, test_targets)
   
    # weight variation
    if args.experiment_type == "Uniform noise":
      w, w_variation_mean, w_variation_std = weight_variation(args, w_previous, corrupted_data_loader, model, meta_weight_net, loss_functions['model'])
      history['weight_variation_means'].append(w_variation_mean)
      history['weight_variation_stds'].append(w_variation_std)
      w_previous = w
  
    time_for_epoch = time.time() - time_start
    
    def rounding(tensor,decimals):
      return torch.round(tensor * 10**decimals) / (10**decimals)
    
    # update progress bar
    epochs.set_description("Time for epoch: {}, \
Train loss: {}, \
Train loss weighted: {}, \
Meta loss: {}, \
Test loss: {}".format(time_for_epoch
                     ,rounding(train_loss,5)
                     ,round(train_loss_weighted,5)
                     ,rounding(meta_loss,5)
                     ,rounding(test_loss,5)
                     ))
                           
    if test_loss < best_test_model_loss:
      best_test_model = model
      best_test_model_predictions = torch.cat(test_predictions).cpu().numpy()
      best_test_model_targets = torch.cat(test_targets).cpu().numpy()
      best_test_model_accuracy = test_accuracy

    torch.save(model, path + '_current_model')
                           
    # append history
    history['train_losses'].append(train_loss)
    history['train_losses_weighted'].append(train_loss_weighted)
    history['meta_losses'].append(meta_loss)
    history['test_losses'].append(test_loss)
    history['train_accuracy'].append(train_accuracy)
    history['test_accuracy'].append(test_accuracy)

  # saving
  best_model_info['model'] = best_test_model
  best_model_info['model_predictions'] = best_test_model_predictions
  best_model_info['test_targets'] = best_test_model_targets

  torch.save({'history': history, 'model_info': best_model_info}, path + '_best_model_info')
                           
  # create confusion matrix
  cm_counts_test, cm_percent_test = compute_confusion_matrix(true_targets=best_test_model_targets,
                                                             pred_targets=best_test_model_predictions)
  # plot and saves confusion matrix
  plot_confusion_matrix(cm_percent_test,args,path,close=False)
  
  
  # plot weight variation curves
  if args.experiment_type == "Uniform noise": # ONLY FOR UNIFORM
 #   history['weight_variation_means'][0] = None
#    history['weight_variation_stds'][0] = None
    plot_weight_variation_curves(np.array(history['weight_variation_means']),
                                 np.array(history['weight_variation_stds']),
                                 args,path,close=False)
     

## Plots

In [0]:
def plot_train_and_meta_loss(train_loss, meta_loss, args, path, close=True):
    cifar_type = args.cifar_type if args.dataset == 'CIFAR' else ""
    plt.figure()
    plt.title('{}{}-{}-{}'.format(args.dataset,cifar_type,args.experiment_type,args.factor))
    plt.ylabel('loss')
    plt.xlabel('epochs')
    plt.plot(train_loss, label='training loss')
    plt.plot(meta_loss, label='meta loss')
    plt.legend(loc='best')
    plt.tight_layout()
#    mng = plt.get_current_fix_manager()
#    mng.full_screen_toggle()
    plt.savefig(path + '_loss_plot.png')
    if close: plt.close() 
      
def plot_weight_variation_curves(weight_variation_means, weight_variation_stds, args, path, close=True):
  
    upper_bound = weight_variation_means+weight_variation_stds
    lower_bound = weight_variation_means-weight_variation_stds
  
    x = np.arange(1,len(weight_variation_means)+1)
    
    cifar_type = args.cifar_type if args.dataset == 'CIFAR' else ""
    plt.figure()
    plt.title('{}{}-{}-{}'.format(args.dataset,cifar_type,args.experiment_type,args.factor))
    plt.ylabel('weights')
    plt.xlabel('epochs')
    plt.plot(x,weight_variation_means.tolist(), label='weight_variation', color='r')
    plt.fill_between(x=x,
                     y1=upper_bound,
                     y2=lower_bound,
                     facecolor='pink')
    plt.legend(loc='best')
    plt.tight_layout()
    plt.savefig(path + '_' + str(args.factor) + '_weight_variation_curve.png')
    if close: plt.close() 
      
def compute_confusion_matrix(true_targets,pred_targets):
    matrix = confusion_matrix(y_true=true_targets.astype(int),
                              y_pred=pred_targets.astype(int))
    return matrix, matrix.astype('float') / np.sum(matrix, axis=1).reshape(-1,1)
 
def plot_confusion_matrix(matrix,args,path,close=True):
    cifar_type = args.cifar_type if args.dataset == 'CIFAR' else ""
    fig, ax = plt.subplots()
    im = ax.imshow(matrix, cmap='Blues', origin='lower')
    plt.colorbar(im)
    ax.set_title('{}-{}{}-{}-{}'.format(args.model_type,args.dataset,cifar_type,args.experiment_type,args.factor))
    ax.set_ylabel('True label')
    ax.set_xlabel('Predicted label')
    ax.set_xticks(np.arange(matrix.shape[1]))
    ax.set_yticks(np.arange(matrix.shape[0]))
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            text = ax.text(j, i, matrix[i, j],ha="center", va="center", color="r")
    plt.tight_layout(pad=0, w_pad=0, h_pad=0)
    plt.get_current_fig_manager().window.showMaximized()
    plt.savefig(path + '_confusion_matrix.png')
    if close: plt.close()

def accuracy_plot(args):
    cifar_type = args.cifar_type if args.dataset == 'CIFAR' else ""
    plt.figure()
    plt.title('{}{}-{}-{}'.format(args.dataset,cifar_type,args.experiment_type,args.factor))
    plt.ylabel('loss')
    plt.xlabel('epochs')
    plt.plot(train_loss, label='training loss')
    plt.plot(meta_loss, label='meta loss')
    plt.legend(loc='best')
    plt.tight_layout()
#    mng = plt.get_current_fix_manager()
#    mng.full_screen_toggle()
    plt.savefig(path + '_loss_plot.png')
    if close: plt.close() 
  

In [0]:
class Args():
    pass
    
args = Args()

args.tau = 1e-6
args.momentum = 0.9
args.weight_decay = 5e-4
args.nesterov = True

args.dataset = 'CIFAR' # "Clothing 1M"
args.cifar_type = 10 # 100
args.experiment_type = "Imbalance" # "Uniform noise" # "Flip noise"

args.model_type = "MWN" # Baseline, BaselineFT
args.seed = 123
args.num_meta_per_class = 10
args.num_of_corrupted = 10

args.num_workers = 6
args.pin_memory = True
args.cuda = "cuda:0" if torch.cuda.is_available() else "cpu"

args.model_signature = str(datetime.now())[0:19].replace(':','.')

args.factors_dict = {
    "Imbalance": [200, 100, 50, 20, 10, 1],
    "Uniform noise": [0, 0.4, 0.6],
    "Flip noise": [0, 0.2, 0.4],
}
args.factor = args.factors_dict[args.experiment_type][0]
args.factor = 200

cifar_type = str(args.cifar_type) if args.dataset == 'CIFAR' else ""

args.directory = os.path.join("Experiments",
                              args.experiment_type,
                              args.model_signature, 
                              args.dataset + '_'.join(['', cifar_type, args.experiment_type, str(args.factor), args.model_type])
                             )


args.batch_size_data_dic = {'Imbalance': 100,
                            'Uniform noise': 100,
                            'Flip noise': 100,
                            'Clothing 1M': 32
                            }

args.batch_size_meta_data_dic = {'Imbalance': 100,
                                 'Uniform noise': 100,
                                 'Flip noise': 100,
                                 'Clothing 1M': 32
                                 }
  
args.lr_model_dic = {'Imbalance': {0: 0.1, 80: 0.01, 90: 1e-3},
                     'Uniform noise': {0: 0.1, 36: 0.01, 38: 1e-3},
                     'Flip noise': {0: 0.1, 40: 0.01, 50: 1e-3},
                     'Clothing 1M': {0: 0.01, 5: 0.001}
                     }

args.lr_wnet_dic = {'Imbalance': 1e-5,
                    'Uniform noise': 1e-3,
                    'Flip noise': 1e-3,
                    'Clothing 1M': 1e-3
                    }

args.epochs_dic = {'Imbalance': 100,
                   'Uniform noise': 3, #40,
                   'Flip noise': 60,
                   'Clothing 1M': 10
                   }

args.weight_decay_dic = {'Imbalance': 5e-4,
                         'Uniform noise': 5e-4,
                         'Flip noise': 5e-4,
                         'Clothing 1M': 1e-3
                         }


args.lr_schedule = args.lr_model_dic[args.experiment_type]
args.batch_size = args.batch_size_data_dic[args.experiment_type]


kwargs_dataloader = {'batch_size': args.batch_size,
                     'num_workers': args.num_workers,
                     'pin_memory': args.pin_memory,
                     'shuffle': True}

kwargs_optimizer = {'lr': args.lr_model_dic[args.experiment_type][0],
                    'momentum': args.momentum,
                    'nesterov': True,
                    'weight_decay': args.weight_decay_dic[args.experiment_type]}

kwargs_optimizer_wnet = {'lr': args.lr_wnet_dic[args.experiment_type],
                          'momentum': args.momentum,
                          'nesterov': True,
                          'weight_decay': args.weight_decay_dic[args.experiment_type]}



args.transformations = {'CIFAR': {'train': train_transform, 'test': test_transform},
                        'Clothing 1M': {'train': None, 'test': None}}

args.get_dataset_function_dict = {'CIFAR': get_CIFAR_data, 'Clothing 1M': None}

train_data, test_data = args.get_dataset_function_dict[args.dataset](args.cifar_type,
                                                                 train_transform=args.transformations[args.dataset]['train'],
                                                                 test_transform=args.transformations[args.dataset]['test'])

kwargs_data_functions = {'dataset': train_data,
                         'factor': args.factor,
                         'num_meta_per_class': args.num_meta_per_class,
                         'num_of_corrupted': args.num_of_corrupted,
                         'seed': args.seed}


args.data_function = {'Imbalance': generate_imbalance_data,
                       'Uniform noise': generate_noise_data,
                       'Flip noise': generate_flip_data,
                       'Clothing 1M': None
                       }


data, meta, corrupted_data = args.data_function[args.experiment_type](**kwargs_data_functions)


#cifar_type = str(args.cifar_type) if args.dataset == 'CIFAR' else ""
#data_directory = os.path.join('Data',args.experiment_type + '_' + str(args.factor) \
#                                      + '_' + args.dataset + cifar_type)
#data_saved = {'data': data, 'meta': meta, 'corrupted_data': corrupted_data}
#torch.save(data_saved, data_directory)


In [0]:
for path in ["Experiments",
             os.path.join("Experiments", args.experiment_type),
             os.path.join("Experiments", args.experiment_type, args.model_signature),
             os.path.join("Experiments", args.experiment_type, args.model_signature),
                         args.dataset + '_'.join(['', cifar_type, args.experiment_type, str(args.factor), args.model_type])]:
  if not os.path.exists(path):
      os.mkdir(path)
  




# Run

In [0]:
train_loader = DataLoader(data, **kwargs_dataloader)
meta_loader = DataLoader(meta, **kwargs_dataloader)
test_loader = DataLoader(test_data, **kwargs_dataloader)
corrupted_data_loader = DataLoader(corrupted_data, **kwargs_dataloader) if args.experiment_type != 'Imbalance' else None

meta_weight_net = MLP().to(args.cuda)
model = resnet32().to(args.cuda)
meta_model = resnet32().to(args.cuda)

if torch.cuda.is_available():
  torch.backends.cudnn.benchmark = True

optimizers = {}
optimizers['model'] = torch.optim.SGD(model.parameters(), **kwargs_optimizer)
optimizers['meta_model'] = torch.optim.SGD(meta_model.parameters(), **kwargs_optimizer)
optimizers['meta_weight_net'] = torch.optim.SGD(meta_weight_net.parameters(), **kwargs_optimizer_wnet)

loss_functions = {}
loss_functions['model'] = nn.CrossEntropyLoss(reduction='none').to(args.cuda)


perform_experiment(args,
                   train_loader,
                   meta_loader,
                   test_loader,
                   corrupted_data_loader,
                   model,
                   meta_model,
                   meta_weight_net,
                   optimizers,
                   loss_functions)

  0%|          | 0/100 [00:00<?, ?it/s]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
Time for epoch: 143.80621027946472, Train loss: 1.1875399351119995, Train loss weighted: 0.01191, Meta loss: 3.5492098331451416, Test loss: 3.335899829864502:   1%|          | 1/100 [02:24<3:57:43, 144.08s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Time for epoch: 141.6654007434845, Train loss: 0.8935399651527405, Train loss weighted: 0.00896, Meta loss: 3.166529893875122, Test loss: 3.122619867324829:   2%|▏         | 2/100 [04:46<3:54:18, 143.45s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Time for epoch: 141.84407782554626, Train loss: 0.8090800046920776, Train loss weighted: 0.00811, Meta loss: 3.0084099769592285, Test loss: 2.962479829788208:   3%|▎         | 3/100 [07:08<3:51:17, 143.06s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Time for epoch: 142.60373425483704, Train loss: 0.7378799915313721, Train loss weighted: 0.0074, Meta loss: 2.9455599784851074, Test loss: 2.9718098640441895:   4%|▍         | 4/100 [09:31<3:48:48, 143.01s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Time for epoch: 142.62056756019592, Train loss: 0.7081699967384338, Train loss weighted: 0.0071, Meta loss: 2.8955700397491455, Test loss: 2.8141398429870605:   5%|▌         | 5/100 [11:53<3:46:22, 142.98s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Time for epoch: 142.95960569381714, Train loss: 0.6564399600028992, Train loss weighted: 0.00658, Meta loss: 2.8811798095703125, Test loss: 2.7838199138641357:   6%|▌         | 6/100 [14:17<3:44:07, 143.06s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Time for epoch: 143.09150552749634, Train loss: 0.6150699853897095, Train loss weighted: 0.00616, Meta loss: 2.8269898891448975, Test loss: 2.7498600482940674:   7%|▋         | 7/100 [16:40<3:41:52, 143.15s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Time for epoch: 143.8964638710022, Train loss: 0.5797100067138672, Train loss weighted: 0.00581, Meta loss: 2.828739881515503, Test loss: 2.7627599239349365:   8%|▊         | 8/100 [19:04<3:39:59, 143.47s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Time for epoch: 143.96771812438965, Train loss: 0.553380012512207, Train loss weighted: 0.00555, Meta loss: 2.73799991607666, Test loss: 2.7177999019622803:   9%|▉         | 9/100 [21:29<3:37:56, 143.69s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Time for epoch: 144.48211812973022, Train loss: 0.5212000012397766, Train loss weighted: 0.00522, Meta loss: 2.744349956512451, Test loss: 2.649250030517578:  10%|█         | 10/100 [23:53<3:36:02, 144.02s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Time for epoch: 144.18522024154663, Train loss: 0.4995799958705902, Train loss weighted: 0.00501, Meta loss: 2.6456298828125, Test loss: 2.777289867401123:  11%|█         | 11/100 [26:18<3:33:49, 144.15s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Time for epoch: 143.3168032169342, Train loss: 0.47224998474121094, Train loss weighted: 0.00474, Meta loss: 2.6336898803710938, Test loss: 2.7408499717712402:  12%|█▏        | 12/100 [28:41<3:31:11, 144.00s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Time for epoch: 144.6319179534912, Train loss: 0.45965999364852905, Train loss weighted: 0.00461, Meta loss: 2.612839937210083, Test loss: 2.568039894104004:  13%|█▎        | 13/100 [31:06<3:29:12, 144.28s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Time for epoch: 146.90901398658752, Train loss: 0.4330099821090698, Train loss weighted: 0.00434, Meta loss: 2.5831098556518555, Test loss: 2.615769863128662:  14%|█▍        | 14/100 [33:34<3:28:04, 145.16s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Time for epoch: 146.66781067848206, Train loss: 0.4293999969959259, Train loss weighted: 0.0043, Meta loss: 2.549099922180176, Test loss: 2.4760499000549316:  15%|█▌        | 15/100 [36:01<3:26:24, 145.70s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Time for epoch: 147.38180446624756, Train loss: 0.39860999584198, Train loss weighted: 0.004, Meta loss: 2.5232598781585693, Test loss: 2.5995800495147705:  16%|█▌        | 16/100 [38:28<3:24:50, 146.32s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))




Time for epoch: 147.64171075820923, Train loss: 0.3958899974822998, Train loss weighted: 0.00397, Meta loss: 2.514240026473999, Test loss: 2.5803699493408203:  17%|█▋        | 17/100 [40:56<3:23:04, 146.80s/it]

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

In [0]:
corrupted_data.targets

# Main

In [0]:
torch.cuda.memory_allocated()

In [0]:
torch.cuda.max_memory_allocated()

In [0]:
def compute_confusion_matrix(true_targets,pred_targets):
    matrix = confusion_matrix(y_true=true_targets.astype(int),
                              y_pred=pred_targets.astype(int))
    return matrix, matrix.astype('float') / np.sum(matrix, axis=1).reshape(-1,1)
 
def plot_confusion_matrix(matrix,args,path,close=True):
    cifar_type = args.cifar_type if args.dataset == 'CIFAR' else ""
    fig, ax = plt.subplots()
    im = ax.imshow(matrix, cmap='Blues', origin='lower')
    plt.colorbar(im)
    ax.set_title('{}-{}{}-{}-{}'.format(args.model_type,args.dataset,cifar_type,args.experiment_type,args.factor))
    ax.set_ylabel('True label')
    ax.set_xlabel('Predicted label')
    ax.set_xticks(np.arange(matrix.shape[1]))
    ax.set_yticks(np.arange(matrix.shape[0]))
    for i in range(matrix.shape[0]):
        for j in range(matrix.shape[1]):
            text = ax.text(j, i, matrix[i, j],ha="center", va="center", color="r")
    #mng = plt.get_current_fig_manager()
    #mng.full_screen_toggle()
    plt.tight_layout()
    plt.savefig(path + '_confusion_matrix.png')
    if close: plt.close()
      

targets = np.array([0,1,2,3,4,5,6,7,8,9,2,3,4,5,6])
pred    = np.array([0,1,2,3,5,4,3,6,7,8,9,1,2,3,6])

cm, cm_rel = compute_confusion_matrix(targets, pred)
plot_confusion_matrix(cm_rel,args,path,close=False)

In [0]:
!nvidia-smi

In [0]:




plot_weight_variation_curves([0.05,0.07,0.09],[0.01,0.02,0.021],args,args.directory,close=False)