# Dataset

## Description

## Pre-processing

### Resizing

### Random cropping

### Corrupting

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import sampler
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as T
import pickle
import numpy as np

In [None]:
from logger import Logger

In [None]:
device = torch.device('cpu')

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')


In [None]:
# Prepare dataset
batch_size = 128
train_mean = [107.59252, 103.2752, 106.84143]
train_std = [63.439133, 59.521027, 63.240288]
# Preprocessing
transform = T.Compose([
                T.Normalize(train_mean, train_std)
            ])

#train_set = TensorDataset(torch.load('train_x.pt'), torch.load('train_y.pt'))
#val_set = TensorDataset(torch.load('val_x.pt'), torch.load('val_y.pt'))
#test_set = TensorDataset(torch.load('test_x.pt'), torch.load('test_y.pt'))

train_rotate_set = TensorDataset(torch.load('train_x_rotate.pt'), torch.load('train_y_rotate.pt'))
val_rotate_set = TensorDataset(torch.load('val_x_rotate.pt'), torch.load('val_y_rotate.pt'))
test_rotate_set = TensorDataset(torch.load('test_x_rotate.pt'), torch.load('test_y_rotate.pt'))

#train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
#val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)
#test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

train_rotate_loader = DataLoader(train_rotate_set, batch_size=batch_size, shuffle=True)
val_rotate_loader = DataLoader(val_rotate_set, batch_size=batch_size, shuffle=True)
test_rotate_loader = DataLoader(test_rotate_set, batch_size=batch_size, shuffle=True)

# Training Procedures

## Regular Training

In [None]:
# Set up training pipelines
def train_main(model, optimizer, loader_train, loader_val, epochs=1, model_path=None, early_stop_patience = 0):
    """
    Train the main branch
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Logger object with loss and accuracy data
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    logger = Logger()
    last_loss = float('inf')
    for e in range(epochs):
        num_correct = 0
        num_samples = 0
        total_loss = 0.0
        count = 0
        for t, (x, y) in enumerate(loader_train):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            loss = F.cross_entropy(scores, y)
            total_loss += loss.item()

            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(f"\r[Epoch {e}, Batch {t}] train_loss: {loss.item()}", end='')
            count += 1

        # Conclude Epoch
        train_loss = total_loss / count
        train_acc = float(num_correct) / num_samples
        val_loss, val_acc = evaluate_main(model, loader_val)
        logger.log(train_loss, train_acc, val_loss, val_acc)
        
        with open(model_path.split('.')[0] + '.pkl', 'wb') as output_file:
            pickle.dump(logger, output_file)

        # Early Stopping
        if logger.check_early_stop(early_stop_patience):
            print("[Early Stopped]")
            break
        else:
            if last_loss > val_loss:
                print(f"\r[Epoch {e}] train_acc: {train_acc}, val_acc:{val_acc}, val_loss improved from %.4f to %.4f. Saving model to {model_path}." % (last_loss, val_loss))
                if model_path is not None:
                    torch.save(model.state_dict(), model_path)
            else:
                print(f"\r[Epoch {e}] train_acc: {train_acc}, val_acc:{val_acc}, val_loss did not improve from %.4f" % (last_loss))
            last_loss = val_loss
    return logger

        

In [None]:
def train_both(model, optimizer, loader_train, loader_val, epochs=1, model_path=None, early_stop_patience = 0):
    """
    Train the main and auxillary branch
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Logger object with loss and accuracy data
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    logger = Logger()
    last_loss = float('inf')
    for e in range(epochs):
        num_correct = 0
        num_samples = 0
        running_loss = 0.0
        count = 0
        for t, (x, y) in enumerate(loader_train):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            loss_main = F.cross_entropy(scores[0], y[:, 0])
            loss_auxillary = F.cross_entropy(scores[1], y[:, 1])
            loss = loss_main + loss_auxillary
            running_loss += loss_main.item()

            _, preds = scores[0].max(1)
            num_correct += (preds == y[:, 0]).sum()
            num_samples += preds.size(0)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(f"\r[Epoch {e}, Batch {t}] train_loss: {loss.item()}", end='')
            count += 1

        # Conclude Epoch
        train_loss = running_loss / count
        train_acc = float(num_correct) / num_samples
        val_loss, val_acc = evaluate_both(model, loader_val)
        logger.log(train_loss, train_acc, val_loss, val_acc)

        with open(model_path.split('.')[0] + '.pkl', 'wb') as output_file:
            pickle.dump(logger, output_file)
            
        # Early Stopping
        if logger.check_early_stop(early_stop_patience):
            print("[Early Stopped]")
            break
        else:
            if last_loss > val_loss:
                print(f"\r[Epoch {e}] train_acc: {train_acc}, val_acc:{val_acc}, val_loss improved from %.4f to %.4f. Saving model to {model_path}." % (last_loss, val_loss))
                if model_path is not None:
                    torch.save(model.state_dict(), model_path)
            else:
                print(f"\r[Epoch {e}] train_acc: {train_acc}, val_acc:{val_acc}, val_loss did not improve from %.4f" % (last_loss))
            last_loss = val_loss
    return logger


In [None]:
def evaluate_main(model, loader):
    """
    Evaluate main branch accuracy
    Outputs: loss and accuracy
    """
    num_correct = 0
    num_samples = 0
    ave_loss = 0.0
    count = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            loss = F.cross_entropy(scores, y)
            # print(scores.shape)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
            ave_loss += loss.item()
            count += 1
        acc = float(num_correct) / num_samples
        return ave_loss / count, acc

In [None]:
def evaluate_both(model, loader):
    """
    Evaluate main branch accuracy in model with two predictions
    Outputs: loss and accuracy
    """
    num_correct = 0
    num_samples = 0
    ave_loss = 0.0
    count = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            loss_main = F.cross_entropy(scores[0], y[:, 0])
            # print(scores.shape)
            _, preds = scores[0].max(1)
            num_correct += (preds == y[:, 0]).sum()
            print(f"num_correct: {num_correct}")
            num_samples += preds.size(0)
            print(f"num_samples: {num_samples}")
            ave_loss += loss_main.item()
            count += 1
        acc = float(num_correct) / num_samples
        return ave_loss / count, acc


## Online Training

In [None]:
def ttt_online(model, loader, loader_spinned, optimizer):
    """
    Online TTT with image spinning task
    Outputs: loss and accuracy
    """
    for x, y in loader_spinned:
        x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
        y = y.to(device=device, dtype=torch.long)
        scores = model(x)
        loss_auxillary = F.cross_entropy(scores[1], y[:, 1])
        optimizer.zero_grad()
        loss_auxillary.backward()
        optimizer.step()
    return evaluate_main(model, loader)
    

# Experiment 1: Baseline ResNet18

In [None]:
# Experiment 2: Train a baseline ResNet18: no branch
lr = 1e-3
wd = 1e-4
from models import BaselineResNet
model_base_1 = BaselineResNet(58)
optimizer = optim.Adam(model_base_1.parameters(), lr=lr, weight_decay=wd)
train_main(model_base_1, optimizer, train_loader, val_loader, epochs=25, model_path='model_base_1.pth', early_stop_patience=5)

## Evaluation
### Uncorrupted Images

In [None]:
from models import BaselineResNet
model_path = 'model_base_4.pth'
model_base_4 = BaselineResNet(58)
params = torch.load(model_path)
model_base_4.load_state_dict(params)
model_base_4 = model_base_4.to(device=device)
evaluate_main(model_base_4, test_loader)

### Corrupted Images

# Experiment 2: ResNet18 with Auxillary Branch (No Online Training)

In [None]:
# Experiment 2: Train a ResNet18 with auxillary branch
lr = 1e-3
wd = 1e-5
from models import ResNetTwoBranch
model = ResNetTwoBranch()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
train_both(model, optimizer, train_rotate_loader, train_rotate_loader, epochs=50, model_path='exp_2_model.pth', early_stop_patience=5)

## Evaluate

In [None]:
from models import ResNetTwoBranch
model_path = 'exp_2_model.pth'
model = ResNetTwoBranch()
params = torch.load(model_path)
model.load_state_dict(params)
model = model.to(device=device)
evaluate_both(model, test_rotate_loader)

In [None]:
# Experiment 3: ResNet18 with Auxillary Branch (Online-Trained)

In [None]:
# Experiment 3: Do online training on the auxillary branch, with pre-trained shared and main branch weights from experiment 1.