# Initializations

## imports

In [None]:
from torchvision.datasets import MNIST
from torch.utils.data import Dataset,DataLoader
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC

import torch
import torchvision
from torch.autograd import Variable
from collections import Counter

import seaborn as sns
import numpy as np
from collections import Counter
from operator import itemgetter
import matplotlib.pyplot as plt
import random

# Data preparation

## Configs

In [None]:
batch_size_train = 128
batch_size_test = 100
random_seed = 12453211
random_threshold = 0.6 #should be b/w 0 and 1
mode = "original" # Can be symmetric, assymetric or original (default)
partiality =  None # Can be balanced or imbalanced or None

# For Dataset 1 - Balanced dataset -                 - mode = original   and partitality = balanced
# For Dataset 2 - Imbalanced dataset Original MNIST  - mode = original   and partiality = None
# For Dataset 3 - Balanced Symmetric Noise           - mode = Symmetric  and partiality = balanced
# For Dataset 4 - Balanced Assymetric Noise          - mode = Assymetric and partiality = balanced
# For Dataset 5 - Imbalanced Symmetric Noise         - mode = Symmetric  and partiality = imbalanced
# For Dataset 6 - Imbalanced Assymetric Noise        - mode = Assymetric and partiality = imbalanced   

imbalanced_weights = {
    0: 0.3,
    1: 0.3,
    2: 1.0,
    3: 1.0,
    4: 1.0,
    5: 0.3,
    6: 1.0,
    7: 0.3,
    8: 1.0,
    9: 1.0
}

# Creating symmetric noise for 1,2 and 5 as 9,7 and 8
symmetric_noise = {
    0: 0,
    1: 9,
    9: 1,
    2: 7,
    7: 2,
    3: 3,
    4: 4,
    5: 8,
    8: 5,
    6: 6
}

# Creating asymettric noise for 0,3,4 and 8
asymmetric_noise = {
    0: 0,
    1: 1,
    2: 2,
    3: 4,
    4: 8,
    5: 5,
    6: 6,
    7: 7,
    8: 3,
    9: 0
}

torch.manual_seed(random_seed)
np.random.seed(random_seed)

## Symmetric and Assymetric Noise

In [None]:
if mode == "symmetric" or mode == "assymetric":
    train_set = torchvision.datasets.MNIST(
        '.',
        train=True,
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307,), (0.3081,))
        ]),
        target_transform = lambda y: 
        (y if random.random() > random_threshold else symmetric_noise[y]) 
        if mode == "symmetric" else 
        (y if random.random() > random_threshold else asymmetric_noise[y])
    )

if mode == "original":
    train_set = torchvision.datasets.MNIST(
        '.',
        train=True,
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307,), (0.3081,))
        ])
    )

test_set = torchvision.datasets.MNIST(
    '.',
    train=False, 
    download=True,
    transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
        (0.1307,), (0.3081,))
    ])
)

train_set, val_set = torch.utils.data.random_split(train_set, [50000, 10000])

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size_train,
    shuffle=True
)

valid_loader = torch.utils.data.DataLoader(
    val_set,
    batch_size=batch_size_train,
    shuffle=True,
)

test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=batch_size_test,
    shuffle=False
)

## Balanced and Imbalanced data

## To numpy data

In [None]:
def data_loader_to_numpy(data_loader):
    result_x = []
    result_y = []
    for x, y in data_loader:
        result_x.append(x.numpy())
        result_y.append(y.numpy())
        
    return np.concatenate(result_x, axis=0), np.concatenate(result_y, axis=0)
    
train_x, train_y = data_loader_to_numpy(train_loader)
test_x, test_y = data_loader_to_numpy(test_loader)
valid_x, valid_y = data_loader_to_numpy(valid_loader)

print(train_x.shape, test_x.shape, train_y.shape, test_y.shape)

In [None]:
def make_imbalanced(ds_x, ds_y, imbalanced_weights=imbalanced_weights):
    class_partition = {k:[] for k in range(10)}

    for x, y in zip(ds_x, ds_y):
        class_partition[y].append((x, y))

    for i in range(10):
        idxs = np.random.randint(0, len(class_partition[i]), int(imbalanced_weights[i]*len(class_partition[i])))
        class_partition[i] = [class_partition[i][j] for j in idxs]
        print(f"class {i}: size={len(class_partition[i])}")

    imbalanced_train = []

    for partition in class_partition.values():
        imbalanced_train.extend(partition)

    np.random.shuffle(imbalanced_train)
    imbalanced_train_x, imbalanced_train_y = zip(*imbalanced_train)
    
    return imbalanced_train_x, imbalanced_train_y

In [None]:
def make_balanced(ds_x, ds_y):
    #data_count = Counter(ds_y)
    #min_key, min_count = min(data_count.items(), key=itemgetter(1))
    return ds_x,ds_y

In [None]:
if partiality == "imbalanced":
    train_x, train_y = make_imbalanced(train_x, train_y)
elif partiality == "balanced":
    train_x, train_y = make_balanced(train_x, train_y)

## Distribution plotter

In [None]:
def distribution_plotter(df):
    train_classes = [label for label in df]
    data_count = Counter(train_classes)
    palette = sns.color_palette("husl")
    plt.figure(figsize=(18,5))
    sns.barplot(x=list(data_count.keys()),y=list(data_count.values()),palette=palette)
    plt.xlabel('{}'.format(mode))

distribution_plotter(train_y)

# Models

## validation functions

### draw confusion matrix

In [None]:
from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt

def conf_mat(y_true, y_pred):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(8,8))
    sns.heatmap(cm, annot=True)



In [None]:
from sklearn.metrics import classification_report

def clf_metrics(y_true, y_pred, n_class=10):
    class_names = [str(i) for i in range(n_class)]
    print(classification_report(y_true, y_pred))
    


## SVM

### preprocessing data

In [None]:
def preprocess(x, y):
    x, y = x.squeeze(), y
    return x.reshape((x.shape[0], -1)), y

train_x, train_y = preprocess(train_x, train_y)
test_x, test_y = preprocess(test_x, test_y)
valid_x, valid_y = preprocess(valid_x, valid_y)

train_x.shape

### model definition

In [None]:
svm = SVC(
    kernel='linear',
    decision_function_shape='ovr',
    random_state=random_seed,
    verbose=True,
) 

svm.fit(train_x, train_y)
y_pred = svm.predict(test_x)

In [None]:
svm.coef_.shape

In [None]:
conf_mat(test_y, y_pred)

### model report

In [None]:
clf_metrics(test_y, y_pred)

# Logistic Regression

## Model definition

In [None]:
input_features = train_x[0].shape[0]
output_features = 10
num_epochs = 1
learning_rate = 0.001

In [None]:
class LogisticRegression(torch.nn.Module):
    def __init__(self, n_input_features, output_features):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(n_input_features, output_features)

    def forward(self, x):
        y_predicted = self.linear(x)
        return y_predicted


model = LogisticRegression(input_features * input_features, output_features)


In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [None]:
for batch in train_loader:
    images,labels = batch
    print(images.shape,labels.shape)

## Training the logistic regression model


In [None]:
iter = 5

for epoch in range(num_epochs):
    for batch_number, (images, labels) in enumerate(train_loader):
        images = images.view(-1, input_features *
                             input_features).requires_grad_()
        labels = labels
         # Clear gradients w.r.t. parameters
        optimizer.zero_grad()

        # Forward pass to get output/logits
        outputs = model(images)
         # Calculate Loss: softmax --> cross entropy loss
        loss = criterion(outputs, labels)
        
        # Getting gradients w.r.t. parameters
        loss.backward()

        # Updating parameters
        optimizer.step()
        
        batch_number += 1

## Testing the logistic regression model

In [None]:
predictions = []
real_classes = []

for images, labels in test_loader:
    # Load images to a Torch Variable
    images = images.view(-1, 28*28).requires_grad_()

    # Forward pass only to get logits/output
    outputs = model(images)

    # Get predictions from the maximum value
    _, predicted = torch.max(outputs.data, 1)
    predicted = predicted.tolist()
    labels = labels.tolist()
    predictions.append(predicted)
    real_classes.append(labels)

predictions = [item for sublist in predictions for item in sublist]
real_classes = [item for sublist in real_classes for item in sublist]

## Confusion matrix and predictions for Logistic regression

In [None]:
conf_mat(real_classes,predictions)
clf_metrics(real_classes,predictions)

## Saving the model

In [None]:
torch.save(model.state_dict(), 'models/logistic.pkl')

# SV Implementation

## Imports

In [None]:
import numpy as np
import keras
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.models import Model
from keras.layers import Input, Conv2D, Dense, MaxPooling2D, Dropout, Flatten, Activation, BatchNormalization
import numpy as np
import keras.backend as K
from keras.callbacks import Callback, LearningRateScheduler

## Configs

In [None]:
alpha = 1.0
beta = 1.0

## LoggerCallback

In [None]:
class LoggerCallback(Callback):
    """
    Log train/val loss and acc into file for later plots.
    """
    def __init__(self, model, X_train, y_train, y_train_clean, X_test, y_test, dataset,
                 model_name, noise_ratio, asym, epochs, alpha, beta):
        super(LoggerCallback, self).__init__()
        self.model = model
        self.X_train = X_train
        self.y_train = y_train
        self.y_train_clean = y_train_clean
        self.X_test = X_test
        self.y_test = y_test
        self.n_class = y_train.shape[1]
        self.dataset = dataset
        self.model_name = model_name
        self.noise_ratio = noise_ratio
        self.asym = asym
        self.epochs = epochs
        self.alpha = alpha
        self.beta = beta

        self.train_loss = []
        self.test_loss = []
        self.train_acc = []
        self.test_acc = []
        self.train_loss_class = [None]*self.n_class
        self.train_acc_class = [None]*self.n_class

        # the followings are used to estimate LID
        self.lid_k = 20
        self.lid_subset = 128
        self.lids = []

        # complexity - Critical Sample Ratio (csr)
        self.csr_subset = 500
        self.csr_batchsize = 100
        self.csrs = []

    def on_epoch_end(self, epoch, logs={}):
        tr_acc = logs.get('acc')
        tr_loss = logs.get('loss')
        val_loss = logs.get('val_loss')
        val_acc = logs.get('val_acc')

        self.train_loss.append(tr_loss)
        self.test_loss.append(val_loss)
        self.train_acc.append(tr_acc)
        self.test_acc.append(val_acc)

        print('ALL acc:', self.test_acc)

        if self.asym:
            file_name = 'log/asym_loss_%s_%s_%s.npy' % \
                        (self.model_name, self.dataset, self.noise_ratio)
            np.save(file_name, np.stack((np.array(self.train_loss), np.array(self.test_loss))))
            file_name = 'log/asym_acc_%s_%s_%s.npy' % \
                        (self.model_name, self.dataset, self.noise_ratio)
            np.save(file_name, np.stack((np.array(self.train_acc), np.array(self.test_acc))))
            file_name = 'log/asym_class_loss_%s_%s_%s.npy' % \
                        (self.model_name, self.dataset, self.noise_ratio)
            np.save(file_name, np.array(self.train_loss_class))
            file_name = 'log/asym_class_acc_%s_%s_%s.npy' % \
                        (self.model_name, self.dataset, self.noise_ratio)
            np.save(file_name, np.array(self.train_acc_class))
        else:
            file_name = 'log/loss_%s_%s_%s_%s.npy' % \
                        (self.model_name, self.dataset, self.noise_ratio, self.alpha)
            np.save(file_name, np.stack((np.array(self.train_loss), np.array(self.test_loss))))
            file_name = 'log/acc_%s_%s_%s_%s.npy' % \
                        (self.model_name, self.dataset, self.noise_ratio, self.alpha)
            np.save(file_name, np.stack((np.array(self.train_acc), np.array(self.test_acc))))

        return

class SGDLearningRateTracker(Callback):
    def __init__(self, model):
        super(SGDLearningRateTracker, self).__init__()
        self.model = model

    def on_epoch_begin(self, epoch, logs={}):
        init_lr = float(K.get_value(self.model.optimizer.lr))
        decay = float(K.get_value(self.model.optimizer.decay))
        iterations = float(K.get_value(self.model.optimizer.iterations))
        lr = init_lr * (1. / (1. + decay * iterations))
        print('init lr: %.4f, current lr: %.4f, decay: %.4f, iterations: %s' % (init_lr, lr, decay, iterations))


In [None]:
def SVModel(input_tensor=None, input_shape = (28, 28, 1), num_classes=10):
    if input_tensor is None:
        img_input = Input(shape=input_shape)
    else:
        if not K.is_keras_tensor(input_shape):
            img_input = Input(tensor=input_tensor, shape=input_shape)
        else:
            img_input = input_tensor

    x = Conv2D(32, (3, 3), padding='same', kernel_initializer="he_normal", name='conv1')(img_input)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='pool1')(x)

    x = Conv2D(64, (3, 3), padding='same', kernel_initializer="he_normal", name='conv2')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='pool2')(x)

    x = Flatten()(x)

    x = Dense(128, kernel_initializer="he_normal", name='fc1')(x)
    x = BatchNormalization()(x)
    x = Activation('relu', name='lid')(x)
    # x = Dropout(0.2)(x)

    x = Dense(num_classes, kernel_initializer="he_normal")(x)
    x = Activation(tf.nn.softmax)(x)

    model = Model(img_input, x)
    return model

#model_svm = SVModel()

## Utilities

In [None]:
import numpy as np
from keras.callbacks import LearningRateScheduler

# Set random seed
np.random.seed(123)

def other_class(n_classes, current_class):
    """
    Returns a list of class indices excluding the class indexed by class_ind
    :param nb_classes: number of classes in the task
    :param class_ind: the class index to be omitted
    :return: one random class that != class_ind
    """
    if current_class < 0 or current_class >= n_classes:
        error_str = "class_ind must be within the range (0, nb_classes - 1)"
        raise ValueError(error_str)

    other_class_list = list(range(n_classes))
    other_class_list.remove(current_class)
    other_class = np.random.choice(other_class_list)
    return other_class

def get_lr_scheduler(dataset):
    """
    customerized learning rate decay for training with clean labels.
     For efficientcy purpose we use large lr for noisy data.
    :param dataset: 
    :param noise_ratio:
    :return: 
    """
    if dataset in ['mnist']:
        def scheduler(epoch):
            if epoch > 30:
                return 0.001
            elif epoch > 10:
                return 0.01
            else:
                return 0.1
        return LearningRateScheduler(scheduler)
    elif dataset in ['cifar-10']:
        def scheduler(epoch):
            if epoch > 80:
                return 0.0001
            elif epoch > 40:
                return 0.001
            else:
                return 0.01
        return LearningRateScheduler(scheduler)
    elif dataset in ['cifar-100']:
        def scheduler(epoch):
            if epoch > 120:
                return 0.001
            elif epoch > 80:
                return 0.01
            else:
                return 0.1
        return LearningRateScheduler(scheduler)

## Symmetric cross entropy

In [None]:
def symmetric_cross_entropy(alpha, beta):
    def loss(y_true, y_pred):
        y_true_1 = y_true
        y_pred_1 = y_pred

        y_true_2 = y_true
        y_pred_2 = y_pred

        y_pred_1 = tf.clip_by_value(y_pred_1, 1e-7, 1.0)
        y_true_2 = tf.clip_by_value(y_true_2, 1e-4, 1.0)

        return alpha*tf.reduce_mean(-tf.reduce_sum(y_true_1 * tf.log(y_pred_1), axis = -1)) + beta*tf.reduce_mean(-tf.reduce_sum(y_pred_2 * tf.log(y_true_2), axis = -1))
    return loss


## Create noisy symmetric and assymetric data

In [None]:
import os
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from numpy.testing import assert_array_almost_equal

# Set random seed
np.random.seed(123)

NUM_CLASSES = {'mnist': 10, 'svhn': 10, 'cifar-10': 10, 'cifar-100': 100}

def build_for_cifar100(size, noise):
    """ random flip between two random classes.
    """
    assert(noise >= 0.) and (noise <= 1.)

    P = np.eye(size)
    cls1, cls2 = np.random.choice(range(size), size=2, replace=False)
    P[cls1, cls2] = noise
    P[cls2, cls1] = noise
    P[cls1, cls1] = 1.0 - noise
    P[cls2, cls2] = 1.0 - noise

    assert_array_almost_equal(P.sum(axis=1), 1, 1)
    return P

def multiclass_noisify(y, P, random_state=0):
    """ Flip classes according to transition probability matrix T.
    It expects a number between 0 and the number of classes - 1.
    """

    assert P.shape[0] == P.shape[1]
    assert np.max(y) < P.shape[0]

    # row stochastic matrix
    assert_array_almost_equal(P.sum(axis=1), np.ones(P.shape[1]))
    assert (P >= 0.0).all()

    m = y.shape[0]
    new_y = y.copy()
    flipper = np.random.RandomState(random_state)

    for idx in np.arange(m):
        i = y[idx]
        # draw a vector with only an 1
        flipped = flipper.multinomial(1, P[i, :], 1)[0]
        new_y[idx] = np.where(flipped == 1)[0]

    return new_y

def get_data(dataset='mnist', noise_ratio=0, asym=False, random_shuffle=False):
    """
    Get training images with specified ratio of syn/ayn label noise
    """
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = X_train.reshape(-1, 28, 28, 1)
    X_test = X_test.reshape(-1, 28, 28, 1)
    X_train = X_train / 255.0
    X_test = X_test / 255.0

    X_train = X_train.astype('float32')
    X_test = X_test.astype('float32')

    y_train_clean = np.copy(y_train)
    # generate random noisy labels
    if noise_ratio > 0:
        if asym:
            data_file = "data/asym_%s_train_labels_%s.npy" % (dataset, noise_ratio)
            if dataset == 'cifar-100':
                P_file = "data/asym_%s_P_value_%s.npy" % (dataset, noise_ratio)
        else:
            data_file = "data/%s_train_labels_%s.npy" % (dataset, noise_ratio)
        if os.path.isfile(data_file):
            y_train = np.load(data_file)
            if dataset == 'cifar-100' and asym:
                P = np.load(P_file)
        else:
            if asym:
                if dataset == 'mnist':
                    # 1 < - 7, 2 -> 7, 3 -> 8, 5 <-> 6
                    source_class = [7, 2, 3, 5, 6]
                    target_class = [1, 7, 8, 6, 5]
                elif dataset == 'cifar-10':
                    # automobile < - truck, bird -> airplane, cat <-> dog, deer -> horse
                    source_class = [9, 2, 3, 5, 4]
                    target_class = [1, 0, 5, 3, 7]

                elif dataset == 'cifar-100':
                        P = np.eye(NUM_CLASSES[dataset])
                        n = noise_ratio/100.0
                        nb_superclasses = 20
                        nb_subclasses = 5

                        if n > 0.0:
                            for i in np.arange(nb_superclasses):
                                init, end = i * nb_subclasses, (i+1) * nb_subclasses
                                P[init:end, init:end] = build_for_cifar100(nb_subclasses, n)

                            y_train_noisy = multiclass_noisify(y_train, P=P,
                                                               random_state=0)
                            actual_noise = (y_train_noisy != y_train).mean()
                            assert actual_noise > 0.0
                            y_train = y_train_noisy
                        np.save(P_file, P)

                else:
                    print('Asymmetric noise is not supported now for dataset: %s' % dataset)
                    return

                if dataset == 'mnist' or dataset == 'cifar-10':
                    for s, t in zip(source_class, target_class):
                        cls_idx = np.where(y_train_clean == s)[0]
                        n_noisy = int(noise_ratio * cls_idx.shape[0] / 100)
                        noisy_sample_index = np.random.choice(cls_idx, n_noisy, replace=False)
                        y_train[noisy_sample_index] = t

            else:
                n_samples = y_train.shape[0]
                n_noisy = int(noise_ratio * n_samples / 100)
                class_index = [np.where(y_train_clean == i)[0] for i in range(NUM_CLASSES[dataset])]
                class_noisy = int(n_noisy / NUM_CLASSES[dataset])

                noisy_idx = []
                for d in range(NUM_CLASSES[dataset]):
                    noisy_class_index = np.random.choice(class_index[d], class_noisy, replace=False)
                    noisy_idx.extend(noisy_class_index)

                for i in noisy_idx:
                    y_train[i] = other_class(n_classes=NUM_CLASSES[dataset], current_class=y_train[i])
            np.save(data_file, y_train)


        # print statistics
        print("Print noisy label generation statistics:")
        for i in range(NUM_CLASSES[dataset]):
            n_noisy = np.sum(y_train == i)
            print("Noisy class %s, has %s samples." % (i, n_noisy))

    if random_shuffle:
        # random shuffle
        idx_perm = np.random.permutation(X_train.shape[0])
        X_train, y_train, y_train_clean = X_train[idx_perm], y_train[idx_perm], y_train_clean[idx_perm]

    # one-hot-encode the labels
    y_train_clean = np_utils.to_categorical(y_train_clean, NUM_CLASSES[dataset])
    y_train = np_utils.to_categorical(y_train, NUM_CLASSES[dataset])
    y_test = np_utils.to_categorical(y_test, NUM_CLASSES[dataset])

    print("X_train:", X_train.shape)
    print("y_train:", y_train.shape)
    print("X_test:", X_test.shape)
    print("y_test", y_test.shape)

    return X_train, y_train, y_train_clean, X_test, y_test

In [None]:
X_train, Y_train, y_train_clean, X_test, y_test = get_data(dataset='mnist', noise_ratio=40)
Y_train = np.argmax(Y_train, axis=1)
(_, Y_clean_train), (_, Y_clean_test) = mnist.load_data()
clean_selected = np.argwhere(Y_train == Y_clean_train).reshape((-1,))
noisy_selected = np.argwhere(Y_train != Y_clean_train).reshape((-1,))
print("#correct labels: %s, #incorrect labels: %s" % (len(clean_selected), len(noisy_selected)))

## Model definition

In [None]:
def train(dataset='mnist', model_name='sl', batch_size=128, epochs=50, noise_ratio=0, asym=False, alpha = 1.0, beta = 1.0):
    """
    Train one model with data augmentation: random padding+cropping and horizontal flip
    :param dataset: 
    :param model_name:
    :param batch_size: 
    :param epochs: 
    :param noise_ratio: 
    :return: 
    """
    print('Dataset: %s, model: %s, batch: %s, epochs: %s, noise ratio: %s%%, asymmetric: %s, alpha: %s, beta: %s' %
          (dataset, model_name, batch_size, epochs, noise_ratio, asym, alpha, beta))

    # load data
    X_train, y_train, y_train_clean, X_test, y_test = get_data(dataset, noise_ratio, asym=asym, random_shuffle=False)
    n_images = X_train.shape[0]
    image_shape = X_train.shape[1:]
    num_classes = y_train.shape[1]
    print("n_images", n_images, "num_classes", num_classes, "image_shape:", image_shape)
    
    # define P for forward and backward loss
    P = np.eye(num_classes)
    
    # load model
    model = SVModel(input_tensor=None, input_shape=image_shape, num_classes=num_classes)
    # model.summary()

    optimizer = tf.keras.optimizers.SGD(lr=0.1, decay=1e-4, momentum=0.9)    

    # create loss
    loss = symmetric_cross_entropy(alpha,beta)

    # model
    model.compile(
        loss=loss,
        optimizer=optimizer,
        metrics=['accuracy']
    )

    if asym:
        model_save_file = "model/asym_%s_%s_%s.{epoch:02d}.hdf5" % (model_name, dataset, noise_ratio)
    else:
        model_save_file = "model/%s_%s_%s.{epoch:02d}.hdf5" % (model_name, dataset, noise_ratio)


    ## do real-time updates using callbakcs
    callbacks = []

    if model_name == 'sl':
        cp_callback = tf.keras.callbacks.ModelCheckpoint(model_save_file,
                                      monitor='val_loss',
                                      verbose=0,
                                      save_best_only=False,
                                      save_weights_only=True,
                                      period=1)
        callbacks.append(cp_callback)
    else:
        cp_callback = tf.keras.callbacks.ModelCheckpoint(model_save_file,
                                      monitor='val_loss',
                                      verbose=0,
                                      save_best_only=False,
                                      save_weights_only=True,
                                      period=1)
        callbacks.append(cp_callback)

    # learning rate scheduler if use sgd
    lr_scheduler = get_lr_scheduler(dataset)
    callbacks.append(lr_scheduler)

    callbacks.append(SGDLearningRateTracker(model))

    # acc, loss, lid
    log_callback = LoggerCallback(model, X_train, y_train, y_train_clean, X_test, y_test, dataset, model_name, noise_ratio, asym, epochs, alpha, beta)
    callbacks.append(log_callback)

    # data augmentation
    datagen = tf.keras.preprocessing.image.ImageDataGenerator()
    datagen.fit(X_train)

    # train model
    model.fit_generator(datagen.flow(X_train, y_train, batch_size=batch_size),
                        steps_per_epoch=len(X_train) / batch_size, epochs=epochs,
                        validation_data=(X_test, y_test),
                        verbose=1,
                        callbacks=callbacks)

train()