In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import sampler
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as T
import pickle
import numpy as np

In [None]:
from logger import Logger

In [None]:
device = torch.device('cpu')

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
# Set up training pipelines
def train_main(model, optimizer, loader_train, loader_val, epochs=1, model_path=None, early_stop_patience = 0):
    """
    Train the main branch
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Logger object with loss and accuracy data
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    logger = Logger()
    last_loss = float('inf')
    for e in range(epochs):
        num_correct = 0
        num_samples = 0
        total_loss = 0.0
        count = 0
        for t, (x, y) in enumerate(loader_train):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            loss = F.cross_entropy(scores, y)
            total_loss += loss.item()

            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(f"\r[Epoch {e + 1}, Batch {t}] train_loss: {loss.item()}", end='')
            count += 1

        # Conclude Epoch
        train_loss = total_loss / count
        train_acc = float(num_correct) / num_samples
        val_loss, val_acc = evaluate_main(model, loader_val)
        logger.log(train_loss, train_acc, val_loss, val_acc)
        
        with open(model_path.split('.')[0] + '.pkl', 'wb') as output_file:
            pickle.dump(logger, output_file)

        # Early Stopping
        if logger.check_early_stop(early_stop_patience):
            print("[Early Stopped]")
            break
        else:
            if last_loss > val_loss:
                print(f"\r[Epoch {e}] train_acc: {train_acc}, val_acc:{val_acc}, val_loss improved from %.4f to %.4f. Saving model to {model_path}." % (last_loss, val_loss))
                if model_path is not None:
                    torch.save(model.state_dict(), model_path)
            else:
                print(f"\r[Epoch {e}] train_acc: {train_acc}, val_acc:{val_acc}, val_loss did not improve from %.4f" % (last_loss))
            last_loss = val_loss
    return logger, model

In [None]:
def train_both(model, optimizer, loader_train, loader_val, epochs=1, model_path=None, early_stop_patience = 0):
    """
    Train the main and auxillary branch
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Logger object with loss and accuracy data
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    logger = Logger()
    last_loss = float('inf')
    for e in range(epochs):
        num_correct = 0
        num_samples = 0
        running_loss = 0.0
        count = 0
        for t, (x, y) in enumerate(loader_train):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            loss_main = F.cross_entropy(scores[0], y[:, 0])
            loss_auxillary = F.cross_entropy(scores[1], y[:, 1])
            loss = loss_main + loss_auxillary
            running_loss += loss_main.item()

            _, preds = scores[0].max(1)
            num_correct += (preds == y[:, 0]).sum()
            num_samples += preds.size(0)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(f"\r[Epoch {e}, Batch {t}] train_loss: {loss.item()}", end='')
            count += 1

        # Conclude Epoch
        train_loss = running_loss / count
        train_acc = float(num_correct) / num_samples
        val_loss, val_acc = evaluate_both(model, loader_val)
        logger.log(train_loss, train_acc, val_loss, val_acc)

        with open(model_path.split('.')[0] + '.pkl', 'wb') as output_file:
            pickle.dump(logger, output_file)
            
        # Early Stopping
        if logger.check_early_stop(early_stop_patience):
            print("[Early Stopped]")
            break
        else:
            if last_loss > val_loss:
                print(f"\r[Epoch {e}] train_acc: {train_acc}, val_acc:{val_acc}, val_loss improved from %.4f to %.4f. Saving model to {model_path}." % (last_loss, val_loss))
                if model_path is not None:
                    torch.save(model.state_dict(), model_path)
            else:
                print(f"\r[Epoch {e}] train_acc: {train_acc}, val_acc:{val_acc}, val_loss did not improve from %.4f" % (last_loss))
            last_loss = val_loss
    return logger, model

In [None]:
def evaluate_main(model, loader):
    """
    Evaluate main branch accuracy
    Outputs: loss and accuracy
    """
    num_correct = 0
    num_samples = 0
    ave_loss = 0.0
    count = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            loss = F.cross_entropy(scores, y)
            # print(scores.shape)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
            ave_loss += loss.item()
            count += 1
        acc = float(num_correct) / num_samples
        return ave_loss / count, acc

def evaluate_both(model, loader):
    """
    Evaluate main branch accuracy in model with two predictions
    Outputs: loss and accuracy
    """
    num_correct = 0
    num_samples = 0
    ave_loss = 0.0
    count = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            loss_main = F.cross_entropy(scores[0], y[:, 0])
            # print(scores.shape)
            _, preds = scores[0].max(1)
            num_correct += (preds == y[:, 0]).sum()
            # print(f"num_correct: {num_correct}")
            num_samples += preds.size(0)
            # print(f"num_samples: {num_samples}")
            ave_loss += loss_main.item()
            count += 1
        acc = float(num_correct) / num_samples
        return ave_loss / count, acc

In [None]:
def evaluate_non_rotate(model, loader):
    """
    Evaluate main branch accuracy in model with two predictions
    Outputs: loss and accuracy
    """
    from random import randrange
    num_correct = 0
    num_samples = 0
    ave_loss = 0.0
    count = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            loss_main = F.cross_entropy(scores[0], y)
            # print(scores.shape)
            _, preds = scores[0].max(1)
            num_correct += (preds == y).sum()
            #print(f"num_correct: {num_correct}")
            num_samples += preds.size(0)
            #print(f"num_samples: {num_samples}")
            ave_loss += loss_main.item()
            count += 1
        acc = float(num_correct) / num_samples
        return ave_loss / count, acc

In [None]:
def ttt(model, loader, loader_spinned, optimizer):
    """
    TTT with image spinning task
    Outputs: loss and accuracy
    """
    model = model.to(device)
    model.train()
    for x, y in loader_spinned:
        x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
        y = y.to(device=device, dtype=torch.long)
        scores = model(x)
        loss_auxillary = F.cross_entropy(scores[1], y[:, 1])
        optimizer.zero_grad()
        loss_auxillary.backward()
        optimizer.step()
    return evaluate_non_rotate(model, loader)

# Experiment 1: Baseline ResNet18
## Training

In [None]:
# Experiment 1: Train a baseline ResNet18: no branch
lr = 1e-3
wd = 1e-4
batch_size = 1024
train_set = TensorDataset(torch.load('train_x.pt'), torch.load('train_y.pt'))
val_set = TensorDataset(torch.load('val_x.pt'), torch.load('val_y.pt'))
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)

from models import ResNetMainBranch
model_base_1 = ResNetMainBranch()
optimizer = optim.Adam(model_base_1.parameters(), lr=lr, weight_decay=wd)
train_main(model_base_1, optimizer, train_loader, val_loader, epochs=30, model_path='model_base_1.pth', early_stop_patience=5)

## Evaluate on Uncorrupted Test Set

In [None]:
# Evaluate
batch_size = 1024
test_set = TensorDataset(torch.load('test_x.pt'), torch.load('test_y.pt'))
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)
from models import ResNetMainBranch
model_path = 'model_base_1.pth'
model_base_1 = ResNetMainBranch()
params = torch.load(model_path)
model_base_1.load_state_dict(params)
model_base_1 = model_base_1.to(device=device)
print('Uncorrupted test set: ', evaluate_main(model_base_1, test_loader))

## Evaluate on Corrupted Test Set

In [None]:
batch_size = 1024
test_gauss_set = TensorDataset(torch.load('test_x_gauss_noise.pt'), torch.load('test_y_corrupted.pt'))
test_gauss_loader = DataLoader(test_gauss_set, batch_size=batch_size, shuffle=True)

test_defocus_set = TensorDataset(torch.load('test_x_defocus_blur.pt'), torch.load('test_y_corrupted.pt'))
test_defocus_loader = DataLoader(test_defocus_set, batch_size=batch_size, shuffle=True)

test_elastic_set = TensorDataset(torch.load('test_x_elastic_transform.pt'), torch.load('test_y_corrupted.pt'))
test_elastic_loader = DataLoader(test_elastic_set, batch_size=batch_size, shuffle=True)

test_motion_set = TensorDataset(torch.load('test_x_motion_blur.pt'), torch.load('test_y_corrupted.pt'))
test_motion_loader = DataLoader(test_motion_set, batch_size=batch_size, shuffle=True)

test_zoom_set = TensorDataset(torch.load('test_x_zoom_blur.pt'), torch.load('test_y_corrupted.pt'))
test_zoom_loader = DataLoader(test_zoom_set, batch_size=batch_size, shuffle=True)

from models import ResNetMainBranch
model_path = 'model_base_1.pth'
model_base_1 = ResNetMainBranch()
params = torch.load(model_path)
model_base_1.load_state_dict(params)
model_base_1 = model_base_1.to(device=device)

print("Gaussian noise: ", evaluate_main(model_base_1, test_gauss_loader))
print("Defocus blur: ", evaluate_main(model_base_1, test_defocus_loader))
print("Elastic transform: ", evaluate_main(model_base_1, test_elastic_loader))
print("Motion blur: ", evaluate_main(model_base_1, test_motion_loader))
print("Zoom blur: ", evaluate_main(model_base_1, test_zoom_loader))


# Experiment 2: ResNet18 with Auxillary Branch (No Online Training)
## Training

In [None]:
lr = 1e-3
wd = 1e-5
batch_size = 1024
train_mean = [86.69585, 86.342995, 85.84817]
train_std = [74.59906, 74.196365, 73.890495]

train_rotate_set = TensorDataset(torch.load('train_x_rotate.pt'), torch.load('train_y_rotate.pt'))
val_rotate_set = TensorDataset(torch.load('val_x_rotate.pt'), torch.load('val_y_rotate.pt'))

train_rotate_loader = DataLoader(train_rotate_set, batch_size=batch_size, shuffle=True)
val_rotate_loader = DataLoader(val_rotate_set, batch_size=batch_size, shuffle=True)

from models import ResNetTwoBranch
exp_2_model_1 = ResNetTwoBranch()
optimizer = optim.Adam(exp_2_model_1.parameters(), lr=lr, weight_decay=wd)
train_both(exp_2_model_1, optimizer, train_rotate_loader, val_rotate_loader, epochs=50, model_path='exp_2_model_1.pth', early_stop_patience=5)

## Evaluate on Uncorrupted Test Set

In [None]:
batch_size = 1024
test_set = TensorDataset(torch.load('test_x.pt'), torch.load('test_y.pt'))
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

from models import ResNetTwoBranch
model_path = 'exp_2_model_1.pth'
exp_2_model_1 = ResNetTwoBranch()
params = torch.load(model_path)
exp_2_model_1.load_state_dict(params)
exp_2_model_1 = exp_2_model_1.to(device=device)
print('Uncorrupted test set: ', evaluate_non_rotate(exp_2_model_1, test_loader))

## Evaluate on Corrupted Test Set

In [None]:
batch_size = 1024
test_gauss_set = TensorDataset(torch.load('test_x_gauss_noise.pt'), torch.load('test_y_corrupted.pt'))
test_gauss_loader = DataLoader(test_gauss_set, batch_size=batch_size, shuffle=True)

test_defocus_set = TensorDataset(torch.load('test_x_defocus_blur.pt'), torch.load('test_y_corrupted.pt'))
test_defocus_loader = DataLoader(test_defocus_set, batch_size=batch_size, shuffle=True)

test_elastic_set = TensorDataset(torch.load('test_x_elastic_transform.pt'), torch.load('test_y_corrupted.pt'))
test_elastic_loader = DataLoader(test_elastic_set, batch_size=batch_size, shuffle=True)

test_motion_set = TensorDataset(torch.load('test_x_motion_blur.pt'), torch.load('test_y_corrupted.pt'))
test_motion_loader = DataLoader(test_motion_set, batch_size=batch_size, shuffle=True)

test_zoom_set = TensorDataset(torch.load('test_x_zoom_blur.pt'), torch.load('test_y_corrupted.pt'))
test_zoom_loader = DataLoader(test_zoom_set, batch_size=batch_size, shuffle=True)

from models import ResNetTwoBranch
model_path = 'exp_2_model_1.pth'
exp_2_model_1 = ResNetTwoBranch()
params = torch.load(model_path)
exp_2_model_1.load_state_dict(params)
exp_2_model_1 = exp_2_model_1.to(device=device)

print("Gaussian noise: ", evaluate_non_rotate(exp_2_model_1, test_gauss_loader))
print("Defocus blur: ", evaluate_non_rotate(exp_2_model_1, test_defocus_loader))
print("Elastic transform: ", evaluate_non_rotate(exp_2_model_1, test_elastic_loader))
print("Motion blur: ", evaluate_non_rotate(exp_2_model_1, test_motion_loader))
print("Zoom blur: ", evaluate_non_rotate(exp_2_model_1, test_zoom_loader))

# Experiment 3: ResNet18 with Auxillary Branch (TTT)
## Prepare Model

In [None]:
def initialize_model_part3(model_path):
    lr = 1e-5
    wd = 1e-8
    from models import ResNetTwoBranch
    exp_3_model_1 = ResNetTwoBranch()
    params = torch.load(model_path)
    exp_3_model_1.load_state_dict(params)
    optimizer = optim.Adam(exp_3_model_1.parameters(), lr=lr, weight_decay=wd)
    return exp_3_model_1, optimizer

## TTT on Uncorrupted Test Set

In [None]:
batch_size = 1024
test_set = TensorDataset(torch.load('test_x.pt'), torch.load('test_y.pt'))
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)
exp_3_model_1, optimizer = initialize_model('exp_2_model_1.pth')
print('Uncorrupted test set: ', ttt(exp_3_model_1, test_loader, test_rotate_loader, optimizer))

## TTT on Corrupted Test Set

In [None]:
batch_size = 1024
test_gauss_set = TensorDataset(torch.load('test_x_gauss_noise.pt'), torch.load('test_y_corrupted.pt'))
test_gauss_loader = DataLoader(test_gauss_set, batch_size=batch_size, shuffle=True)
test_gauss_rotate_set = TensorDataset(torch.load('test_x_rotate_gauss_noise.pt'), torch.load('test_y_rotate_corrupted.pt'))
test_gauss_rotate_loader = DataLoader(test_gauss_rotate_set, batch_size=batch_size, shuffle=True)

test_defocus_set = TensorDataset(torch.load('test_x_defocus_blur.pt'), torch.load('test_y_corrupted.pt'))
test_defocus_loader = DataLoader(test_defocus_set, batch_size=batch_size, shuffle=True)
test_defocus_rotate_set = TensorDataset(torch.load('test_x_rotate_defocus_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
test_defocus_rotate_loader = DataLoader(test_defocus_rotate_set, batch_size=batch_size, shuffle=True)

test_elastic_set = TensorDataset(torch.load('test_x_elastic_transform.pt'), torch.load('test_y_corrupted.pt'))
test_elastic_loader = DataLoader(test_elastic_set, batch_size=batch_size, shuffle=True)
test_elastic_rotate_set = TensorDataset(torch.load('test_x_rotate_elastic_transform.pt'), torch.load('test_y_rotate_corrupted.pt'))
test_elastic_rotate_loader = DataLoader(test_elastic_rotate_set, batch_size=batch_size, shuffle=True)

test_motion_set = TensorDataset(torch.load('test_x_motion_blur.pt'), torch.load('test_y_corrupted.pt'))
test_motion_loader = DataLoader(test_motion_set, batch_size=batch_size, shuffle=True)
test_motion_rotate_set = TensorDataset(torch.load('test_x_rotate_motion_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
test_motion_rotate_loader = DataLoader(test_motion_rotate_set, batch_size=batch_size, shuffle=True)

test_zoom_set = TensorDataset(torch.load('test_x_zoom_blur.pt'), torch.load('test_y_corrupted.pt'))
test_zoom_loader = DataLoader(test_zoom_set, batch_size=batch_size, shuffle=True)
test_zoom_rotate_set = TensorDataset(torch.load('test_x_rotate_zoom_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
test_zoom_rotate_loader = DataLoader(test_zoom_rotate_set, batch_size=batch_size, shuffle=True)

test_loaders = [(test_gauss_loader, test_gauss_rotate_loader), 
                (test_defocus_loader, test_defocus_rotate_loader),
                (test_elastic_loader, test_elastic_rotate_loader),
                (test_motion_loader, test_motion_rotate_loader),
                (test_zoom_loader, test_zoom_rotate_loader)]
test_names = ["Gaussian noise", "Defocus blur", "Elastic transform", 'Motion blur', 'Zoom blur']

for i in len(test_loaders):
    exp_3_model_1, optimizer = initialize_model_part3('exp_2_model_1.pth')
    print(test_names[i], ': ', ttt(exp_3_model_1, test_loaders[i][0], test_loaders[i][1], optimizer))

# Experiment 4: ResNet18 with Auxillary Branch (TTT Online)
## Prepare Model

In [None]:
def initialize_model_part4(model_path):
    lr = 1e-5
    wd = 1e-8
    from models import ResNetTwoBranch
    exp_4_model_1 = ResNetTwoBranch()
    params = torch.load(model_path)
    exp_4_model_1.load_state_dict(params)
    optimizer = optim.Adam(exp_4_model_1.parameters(), lr=lr, weight_decay=wd)
    return exp_4_model_1, optimizer

## TTT on Uncorrupted Test Set

In [None]:
batch_size = 4
test_set = TensorDataset(torch.load('test_x.pt'), torch.load('test_y.pt'))
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
exp_3_model_1, optimizer = initialize_model('exp_2_model_1.pth')
print('Uncorrupted test set: ', ttt(exp_3_model_1, test_loader, test_rotate_loader, optimizer))

## TTT on Corrupted Test Set

In [None]:
batch_size = 4
test_gauss_set = TensorDataset(torch.load('test_x_gauss_noise.pt'), torch.load('test_y_corrupted.pt'))
test_gauss_loader = DataLoader(test_gauss_set, batch_size=batch_size, shuffle=False)
test_gauss_rotate_set = TensorDataset(torch.load('test_x_rotate_gauss_noise.pt'), torch.load('test_y_rotate_corrupted.pt'))
test_gauss_rotate_loader = DataLoader(test_gauss_rotate_set, batch_size=batch_size, shuffle=False)

test_defocus_set = TensorDataset(torch.load('test_x_defocus_blur.pt'), torch.load('test_y_corrupted.pt'))
test_defocus_loader = DataLoader(test_defocus_set, batch_size=batch_size, shuffle=False)
test_defocus_rotate_set = TensorDataset(torch.load('test_x_rotate_defocus_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
test_defocus_rotate_loader = DataLoader(test_defocus_rotate_set, batch_size=batch_size, shuffle=False)

test_elastic_set = TensorDataset(torch.load('test_x_elastic_transform.pt'), torch.load('test_y_corrupted.pt'))
test_elastic_loader = DataLoader(test_elastic_set, batch_size=batch_size, shuffle=False)
test_elastic_rotate_set = TensorDataset(torch.load('test_x_rotate_elastic_transform.pt'), torch.load('test_y_rotate_corrupted.pt'))
test_elastic_rotate_loader = DataLoader(test_elastic_rotate_set, batch_size=batch_size, shuffle=False)

test_motion_set = TensorDataset(torch.load('test_x_motion_blur.pt'), torch.load('test_y_corrupted.pt'))
test_motion_loader = DataLoader(test_motion_set, batch_size=batch_size, shuffle=False)
test_motion_rotate_set = TensorDataset(torch.load('test_x_rotate_motion_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
test_motion_rotate_loader = DataLoader(test_motion_rotate_set, batch_size=batch_size, shuffle=False)

test_zoom_set = TensorDataset(torch.load('test_x_zoom_blur.pt'), torch.load('test_y_corrupted.pt'))
test_zoom_loader = DataLoader(test_zoom_set, batch_size=batch_size, shuffle=False)
test_zoom_rotate_set = TensorDataset(torch.load('test_x_rotate_zoom_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
test_zoom_rotate_loader = DataLoader(test_zoom_rotate_set, batch_size=batch_size, shuffle=False)

test_loaders = [(test_gauss_loader, test_gauss_rotate_loader), 
                (test_defocus_loader, test_defocus_rotate_loader),
                (test_elastic_loader, test_elastic_rotate_loader),
                (test_motion_loader, test_motion_rotate_loader),
                (test_zoom_loader, test_zoom_rotate_loader)]
test_names = ["Gaussian noise", "Defocus blur", "Elastic transform", 'Motion blur', 'Zoom blur']

for i in len(test_loaders):
    exp_4_model_1, optimizer = initialize_model_part4('exp_2_model_1.pth')
    print(test_names[i], ': ', ttt(exp_4_model_1, test_loaders[i][0], test_loaders[i][1], optimizer))