# Demo notebook for running the proposed method

In [None]:
import sys
sys.path.append("..")
import numpy as np
import random
import pickle
import time
from tqdm.notebook import tqdm
import torch
import os
import datetime
import random
import math
import pandas as pd
import models
import shutil
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from collections import Counter
import warnings
# Our files
import train
import utils
import dataset

plt.ion()
plt.show()
%load_ext autoreload
%autoreload 2

run = 0
torch.manual_seed(run)
torch.cuda.manual_seed_all(run)
np.random.seed(run)
random.seed(run)

# Set input parameters

In [None]:
dataset_name = "cifar10"  # other options "cifar100"
data_path = '../datasets/'  # where the datasets have been saved
noisy_rate = 0.4  # noise ratio, only applies to cifar; other values [0.2, 0.4, 0.6, 0.8]
version = 0  # 3 versions of corrupted labels have been generated randomly [0, 1, 2]
# Load corrupted labels
corrupted_targets = np.load(
    f"../train_data/{dataset_name}_corrupted_{noisy_rate}_{version}.npy")
# for validation purposes, load ground truth
true_targets = np.load(f"../train_data/{dataset_name}_true.npy")
model_name = "resnet18"  # other options : "moco"
loss_name = "nfl_rce"  # other options :"elr", "ce"
arch = 'resnet18'  # model type for encoder
use_validation = False
use_protype = False
classification_arch = 'linear'  # other options "multilayer"
device = utils.get_device()  # device cpu/gpu
# params is a dictionary placeholder for all input parameters
params = utils.get_params(dataset_name,
                          img_size=32,
                          batch_size_classif=256,
                          batch_size_representation=128,
                          num_workers_classif=3,
                          num_workers_representation=3,
                          nb_classes=10,
                          model_name=model_name,
                          use_protype=use_protype,
                          use_validation=use_validation,
                          arch=arch,
                          data_path=data_path,
                          classification_arch=classification_arch)

# unique name for the set of experiments
exp = f"demo_{dataset_name}_{noisy_rate}_{model_name}_{loss_name}_{version}"

# BLOCK 1: contrastive pretraing and supervised learning

### 1.1 Unsupervised pretraining

In [None]:
# placeholder for input data
label_data = {
    "labels": corrupted_targets,
    "true_targets": true_targets,
    "sup_sample_ids": None,  # use no samples with supervised loss
    "unsup_sample_ids": []  # exclude no samples from unsupervised loss
}

params['augmentation']['representation_train'] = 'moco'
params['augmentation']['representation_train_strong'] = 'moco'
params['adam'] = True
name = f"{exp}_representation_unsup"

model, criterion_sup = utils.get_model_and_criterion(dataset_name,
                                                     loss_name=loss_name,
                                                     model=model_name,
                                                     p=params)

result = train.representation_training(
    model,
    checkpoint_path_file=None,
    label_data=label_data,
    p=params,
    epochs=2,
    checkpoint_folder=f"../data/models/{dataset_name}",
    name=name)
with open(f'../data/results/{name}.pickle', 'wb') as handle:
    pickle.dump(result, handle, protocol=pickle.HIGHEST_PROTOCOL)

results = [result]

### 1.2 Supervised training

In [None]:
# placeholder for input data
label_data = {"labels": corrupted_targets, "true_targets": true_targets}
name = f"{exp}_supervised_full"

result = train.supervised_training(
    model,
    checkpoint_path_file=results[0]["checkpoint_path_file"],
    loss_name=loss_name,
    label_data=label_data,
    p=params,
    epochs=2,  # nb epochs to train entire model
    checkpoint_folder=f"../data/models/{dataset_name}",
    dataset_name=dataset_name,
    name=name,
    nb_epochs_output_training=2,  # nb epochs to train only classification head
    finetune_lr=False)
with open(f'../data/results/{name}.pickle', 'wb') as handle:
    pickle.dump(result, handle, protocol=pickle.HIGHEST_PROTOCOL)
results.append(result)

# BLOCK 2: Method improvements
### 2.1 GMM training

In [None]:
results = [result]
# placeholder for input data
label_data = {"labels": corrupted_targets, "true_targets": true_targets}
# pass pseudo labels in fields "pred_train" and "pred_val"
pretrain_result = results[0]
if "pred_train" in pretrain_result["model_output"]:
    label_data["pred_train"] = pretrain_result["model_output"]["pred_train"]
if "pred_val" in pretrain_result["model_output"]:
    label_data["pred_val"] = pretrain_result["model_output"]["pred_val"]

name = f"{exp}_gmm"

result = train.supervised_training(
    model,
    checkpoint_path_file=results[-1]["checkpoint_path_file"],
    loss_name=loss_name,
    label_data=label_data,
    p=params,
    epochs=0,  # nb epochs to train entire model
    checkpoint_folder=f"../data/models/{dataset_name}",
    dataset_name=dataset_name,
    name=name,
    nb_epochs_output_training=2,  # nb epochs to train only classification head
    finetune_lr=True)
with open(f'../data/results/{name}.pickle', 'wb') as handle:
    pickle.dump(result, handle, protocol=pickle.HIGHEST_PROTOCOL)
results.append(result)

### 2.2 Supervised representation training

In [None]:
# placeholder for input data
label_data = {
    "labels": corrupted_targets,
    "true_targets": true_targets,
    "sup_sample_ids": None, 
    "unsup_sample_ids": []  
}
# pass pseudo labels in fields "pred_train" and "pred_val"
pretrain_result = results[0]
if "pred_train" in pretrain_result["model_output"]:
    label_data["pred_train"] = pretrain_result["model_output"]["pred_train"]
if "pred_val" in pretrain_result["model_output"]:
    label_data["pred_val"] = pretrain_result["model_output"]["pred_val"]

label_data["weights"] = torch.FloatTensor(results[-1]["weights"]).to(device)
#clip data to avoid nan loss
label_data["weights"][torch.where(label_data["weights"] < 0.01)[0]] = 0.01
params['augmentation']['representation_train'] = 'moco'
params['augmentation']['representation_train_strong'] = 'moco'
params['adam'] = True
name = f"{exp}_representation_sup"

model, criterion_sup = utils.get_model_and_criterion(dataset_name,
                                                     loss_name=loss_name,
                                                     model=model_name,
                                                     p=params)

result = train.representation_training(
    model,
    checkpoint_path_file=results[0]["checkpoint_path_file"],
    label_data=label_data,
    p=params,
    epochs=2,  # nb epochs to train entire model
    checkpoint_folder=f"../data/models/{dataset_name}",
    name=name)
with open(f'../data/results/{name}.pickle', 'wb') as handle:
    pickle.dump(result, handle, protocol=pickle.HIGHEST_PROTOCOL)

results.append(result)

### 2.3 Supervised training

In [None]:
# placeholder for input data
label_data = {"labels": corrupted_targets, "true_targets": true_targets}
# pass pseudo labels in fields "pred_train" and "pred_val"
pretrain_result = results[0]
if "pred_train" in pretrain_result["model_output"]:
    label_data["pred_train"] = pretrain_result["model_output"]["pred_train"]
if "pred_val" in pretrain_result["model_output"]:
    label_data["pred_val"] = pretrain_result["model_output"]["pred_val"]
params['adam'] = False
name = f"{exp}_sup_final"

model, criterion_sup = utils.get_model_and_criterion(dataset_name,
                                                     loss_name=loss_name,
                                                     model=model_name,
                                                     p=params)
result = train.supervised_training(
    model,
    checkpoint_path_file=results[-1]["checkpoint_path_file"],
    loss_name=loss_name,
    label_data=label_data,
    p=params,
    epochs=2,  # nb epochs to train entire model
    checkpoint_folder=f"../data/models/{dataset_name}",
    dataset_name=dataset_name,
    name=name,
    nb_epochs_output_training=2,  # nb epochs to train only classification head
    finetune_lr=False)
with open(f'../data/results/{name}.pickle', 'wb') as handle:
    pickle.dump(result, handle, protocol=pickle.HIGHEST_PROTOCOL)

results.append(result)

## Baseline experiment to get the score of the orginal loss without contrastive learning

In [None]:
name = f"{exp}_baseline_corrupted"
# placeholder for input data
label_data = {"labels": corrupted_targets, "true_targets": true_targets}
# Instantiate model and supervised loss
model, criterion_sup = utils.get_model_and_criterion(dataset_name,
                                                     loss_name=loss_name,
                                                     model=model_name,
                                                     p=params)

result = train.supervised_training(
    model,
    checkpoint_path_file=None,
    loss_name=loss_name,
    label_data=label_data,
    p=params,
    epochs=2,  # nb epochs to train entire model
    checkpoint_folder=f"../data/models/{dataset_name}",
    dataset_name=dataset_name,
    name=name,
    nb_epochs_output_training=0,  # nb epochs to train only classification head
    finetune_lr=False)
with open(f'../data/results/{name}.pickle', 'wb') as handle:
    pickle.dump(result, handle, protocol=pickle.HIGHEST_PROTOCOL)