In [None]:
# Import libraries and helper functions
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as nnf

import matplotlib.pyplot as plt
import random
import pickle
import pandas as pd

from dataprocessing import *
import sys
sys.path.append('..')
from utils import *
from helper_functions import *

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import time
import scipy.spatial.distance
from scipy.spatial.distance import hamming
from scipy.spatial import distance
from sklearn.mixture import GaussianMixture
from scipy.stats import norm
import seaborn as sns

# Global variables
device = 'cuda' if torch.cuda.is_available() else 'cpu'

class Args:
    batchsize = 100
    model = "mobilenet_v2"
    lr = 0.001
    epochs = 200
    dataset = "tinyimagenet"
    max_noise_rate = 0.4
args = Args()

In [None]:
# Generate the dataset containing hard and noisy samples
normalize = transforms.Normalize(
      mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]),
    'val': transforms.Compose([
        transforms.ToTensor(),
    ]),
}
data_dir = './tiny-imagenet-200'
num_classes = 200
trainset = TinyImageNetNoisyDatasetSpectrum(data_dir,transform =data_transforms['train'], normalize=normalize, mode='train', max_noise_rate=args.max_noise_rate)

# subsampling according to class label to make dataset imbalanced
def samples_per_class(label):
    print(label%40)
    if label%40 < 8:
        return 400
    elif label%40 < 16:
        return 200
    elif label%40 < 24:
        return 100
    elif label%40 < 32:
        return 50
    else:
        return 25
    
subset_indices = []
for cls in range(num_classes):
    indices = [i for i, x in enumerate(trainset.original_label_data==cls) if x]
    indices_sampled = random.sample(indices, samples_per_class(cls))
    subset_indices += indices_sampled
    
class_dependent_hardness_dataset = torch.utils.data.Subset(trainset, subset_indices)
trainloader = torch.utils.data.DataLoader(class_dependent_hardness_dataset, batch_size=args.batchsize,
                                             shuffle=True, num_workers=0)

testset = TinyImageNetDataset(data_dir,transform =data_transforms['val'], normalize=normalize, mode='val')
testloader = torch.utils.data.DataLoader(testset, batch_size=args.batchsize,
                                             shuffle=True, num_workers=0)





In [None]:
# Model and optimizer definitions
pretrained = False
net = models.__dict__[args.model](
        aux_logits=False,pretrained=pretrained, num_classes=num_classes) if args.model == 'inception_v3' else models.__dict__[args.model](pretrained=pretrained, num_classes=num_classes)
net = net.to(device)
if device == 'cuda':
    # net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=args.lr,
                      momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)


In [None]:
# Train the model on the hard and noisy dataset
file_name = './results/values/hardness_via_imbalance_{}_{}_epochs_{}_lr_{}_noise_{}_deployment.pkl'.\
    format(args.dataset, args.model, args.epochs, args.lr, args.max_noise_rate)
model_file_name = './checkpoint/hardness_via_imbalance_{}_{}_epochs_{}_lr_{}_noise_{}_deployment.pth'.\
    format(args.dataset, args.model, args.epochs, args.lr, args.max_noise_rate)

with open(file_name, 'wb') as f:
    pickler = pickle.Pickler(f)
    for epoch in range(args.epochs):
        tl, ta, ta5, df_perepoch = train(epoch, net, trainloader, device, criterion, optimizer, model_name=args.model, num_classes=200)
        tel, tea, tea5 = test(epoch, net, testloader, device, criterion)
        scheduler.step()
        pickler.dump([epoch, tl, ta, ta5, df_perepoch, tel, tea, tea5])

    state = {
        'net': net.state_dict(),
        'opt': optimizer.state_dict()
    }
    if not os.path.isdir('checkpoint'):
        os.mkdir('checkpoint')
torch.save(state, model_file_name)