In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import sampler
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as T
import pickle
import numpy as np

In [2]:
from logger import Logger

In [3]:
device = torch.device('cpu')

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [4]:
# Set up training pipelines
def train_main(model, optimizer, loader_train, loader_val, epochs=1, model_path=None, early_stop_patience = 0):
    """
    Train the main branch
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Logger object with loss and accuracy data
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    logger = Logger()
    last_loss = float('inf')
    for e in range(epochs):
        num_correct = 0
        num_samples = 0
        total_loss = 0.0
        count = 0
        for t, (x, y) in enumerate(loader_train):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            loss = F.cross_entropy(scores, y)
            total_loss += loss.item()

            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(f"\r[Epoch {e + 1}, Batch {t}] train_loss: {loss.item()}", end='')
            count += 1

        # Conclude Epoch
        train_loss = total_loss / count
        train_acc = float(num_correct) / num_samples
        val_loss, val_acc = evaluate_main(model, loader_val)
        logger.log(train_loss, train_acc, val_loss, val_acc)
        
        with open(model_path.split('.')[0] + '.pkl', 'wb') as output_file:
            pickle.dump(logger, output_file)

        # Early Stopping
        if logger.check_early_stop(early_stop_patience):
            print("[Early Stopped]")
            break
        else:
            if last_loss > val_loss:
                print(f"\r[Epoch {e}] train_acc: {train_acc}, val_acc:{val_acc}, val_loss improved from %.4f to %.4f. Saving model to {model_path}." % (last_loss, val_loss))
                if model_path is not None:
                    torch.save(model.state_dict(), model_path)
            else:
                print(f"\r[Epoch {e}] train_acc: {train_acc}, val_acc:{val_acc}, val_loss did not improve from %.4f" % (last_loss))
            last_loss = val_loss
    return logger, model

In [5]:
def train_both(model, optimizer, loader_train, loader_val, epochs=1, model_path=None, early_stop_patience = 0):
    """
    Train the main and auxillary branch
    Inputs:
    - model: A PyTorch Module giving the model to train.
    - optimizer: An Optimizer object we will use to train the model
    - epochs: (Optional) A Python integer giving the number of epochs to train for
    
    Returns: Logger object with loss and accuracy data
    """
    model = model.to(device=device)  # move the model parameters to CPU/GPU
    logger = Logger()
    last_loss = float('inf')
    for e in range(epochs):
        num_correct = 0
        num_samples = 0
        running_loss = 0.0
        count = 0
        for t, (x, y) in enumerate(loader_train):
            model.train()  # put model to training mode
            x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)

            scores = model(x)
            loss_main = F.cross_entropy(scores[0], y[:, 0])
            loss_auxillary = F.cross_entropy(scores[1], y[:, 1])
            loss = loss_main + loss_auxillary
            running_loss += loss_main.item()

            _, preds = scores[0].max(1)
            num_correct += (preds == y[:, 0]).sum()
            num_samples += preds.size(0)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print(f"\r[Epoch {e}, Batch {t}] train_loss: {loss.item()}", end='')
            count += 1

        # Conclude Epoch
        train_loss = running_loss / count
        train_acc = float(num_correct) / num_samples
        val_loss, val_acc = evaluate_both(model, loader_val)
        logger.log(train_loss, train_acc, val_loss, val_acc)

        with open(model_path.split('.')[0] + '.pkl', 'wb') as output_file:
            pickle.dump(logger, output_file)
            
        # Early Stopping
        if logger.check_early_stop(early_stop_patience):
            print("[Early Stopped]")
            break
        else:
            if last_loss > val_loss:
                print(f"\r[Epoch {e}] train_acc: {train_acc}, val_acc:{val_acc}, val_loss improved from %.4f to %.4f. Saving model to {model_path}." % (last_loss, val_loss))
                if model_path is not None:
                    torch.save(model.state_dict(), model_path)
            else:
                print(f"\r[Epoch {e}] train_acc: {train_acc}, val_acc:{val_acc}, val_loss did not improve from %.4f" % (last_loss))
            last_loss = val_loss
    return logger, model

In [6]:
def evaluate_main(model, loader):
    """
    Evaluate main branch accuracy
    Outputs: loss and accuracy
    """
    num_correct = 0
    num_samples = 0
    ave_loss = 0.0
    count = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            loss = F.cross_entropy(scores, y)
            # print(scores.shape)
            _, preds = scores.max(1)
            num_correct += (preds == y).sum()
            num_samples += preds.size(0)
            ave_loss += loss.item()
            count += 1
        acc = float(num_correct) / num_samples
        return ave_loss / count, acc

def evaluate_both(model, loader):
    """
    Evaluate main branch accuracy in model with two predictions
    Outputs: loss and accuracy
    """
    num_correct = 0
    num_samples = 0
    ave_loss = 0.0
    count = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            loss_main = F.cross_entropy(scores[0], y[:, 0])
            # print(scores.shape)
            _, preds = scores[0].max(1)
            num_correct += (preds == y[:, 0]).sum()
            # print(f"num_correct: {num_correct}")
            num_samples += preds.size(0)
            # print(f"num_samples: {num_samples}")
            ave_loss += loss_main.item()
            count += 1
        acc = float(num_correct) / num_samples
        return ave_loss / count, acc

In [7]:
def evaluate_non_rotate(model, loader):
    """
    Evaluate main branch accuracy in model with two predictions
    Outputs: loss and accuracy
    """
    from random import randrange
    num_correct = 0
    num_samples = 0
    ave_loss = 0.0
    count = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
            y = y.to(device=device, dtype=torch.long)
            scores = model(x)
            loss_main = F.cross_entropy(scores[0], y)
            # print(scores.shape)
            _, preds = scores[0].max(1)
            num_correct += (preds == y).sum()
            #print(f"num_correct: {num_correct}")
            num_samples += preds.size(0)
            #print(f"num_samples: {num_samples}")
            ave_loss += loss_main.item()
            count += 1
        acc = float(num_correct) / num_samples
        return ave_loss / count, acc

In [8]:
import copy

def ttt(model, loader, loader_spinned, optimizer):
    """
    TTT with image spinning task
    Outputs: loss and accuracy
    """
    num_correct = 0
    num_samples = 0
    model = model.to(device)
    model.train()
    original_model = copy.deepcopy(model.state_dict())
    for x, y in loader_spinned:
        x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
        y = y.to(device=device, dtype=torch.long)

        for i in range(4): # 4 step gradient descent
            scores = model(x)
            loss_auxillary = F.cross_entropy(scores[1], y[:, 1])
            optimizer.zero_grad()
            loss_auxillary.backward()
            optimizer.step()

        with torch.no_grad():
            scores = model(x)
            _, preds = scores[0].max(1)
            num_correct += (y[preds == y[:, 0], 1] == 0).sum()
            num_samples += (y[:, 1] == 0).sum()

        model.load_state_dict(original_model)
    acc = float(num_correct) / num_samples
    return acc

def ttt_online(model, loader, loader_spinned, optimizer):
    """
    TTT online with image spinning task
    Outputs: loss and accuracy
    """
    num_correct = 0
    num_samples = 0
    model = model.to(device)
    model.train()
    for x, y in loader_spinned:
        x = x.to(device=device, dtype=torch.float32)  # move to device, e.g. GPU
        y = y.to(device=device, dtype=torch.long)
        scores = model(x)
        loss_auxillary = F.cross_entropy(scores[1], y[:, 1])
        optimizer.zero_grad()
        loss_auxillary.backward()
        optimizer.step()

        with torch.no_grad():
            scores = model(x)
            _, preds = scores[0].max(1)
            num_correct += (y[preds == y[:, 0], 1] == 0).sum()
            num_samples += (y[:, 1] == 0).sum()
    acc = float(num_correct) / num_samples
    return acc

# Experiment 1: Baseline ResNet18
## Training

In [9]:
# Experiment 1: Train a baseline ResNet18: no branch
lr = 1e-3
wd = 1e-4
batch_size = 1024
train_set = TensorDataset(torch.load('train_x.pt'), torch.load('train_y.pt'))
val_set = TensorDataset(torch.load('val_x.pt'), torch.load('val_y.pt'))
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)

from models import ResNetMainBranch
model_base_1 = ResNetMainBranch()
optimizer = optim.Adam(model_base_1.parameters(), lr=lr, weight_decay=wd)
train_main(model_base_1, optimizer, train_loader, val_loader, epochs=30, model_path='model_base_1.pth', early_stop_patience=5)

[Epoch 0] train_acc: 0.11511731701096659, val_acc:0.2923096543808188, val_loss improved from inf to 2.7940. Saving model to model_base_1.pth.
[Epoch 1] train_acc: 0.4346467737821984, val_acc:0.6152276495344982, val_loss improved from 2.7940 to 1.4163. Saving model to model_base_1.pth.
[Epoch 2] train_acc: 0.7363236419280795, val_acc:0.7913531437316669, val_loss improved from 1.4163 to 0.7206. Saving model to model_base_1.pth.
[Epoch 3] train_acc: 0.8678589645498598, val_acc:0.871700038260426, val_loss improved from 0.7206 to 0.4909. Saving model to model_base_1.pth.
[Epoch 4] train_acc: 0.9217674062739097, val_acc:0.8895549037112613, val_loss improved from 0.4909 to 0.4047. Saving model to model_base_1.pth.
[Epoch 5] train_acc: 0.9497577148686559, val_acc:0.8922331335288867, val_loss improved from 0.4047 to 0.3886. Saving model to model_base_1.pth.
[Epoch 6] train_acc: 0.9659844427441979, val_acc:0.9043489350848106, val_loss did not improve from 0.3886
[Epoch 7] train_acc: 0.9736674317

(<logger.Logger at 0x7fca6ad4de20>,
 ResNetMainBranch(
   (resnet): BaselineResNet(
     (feature_extractor): ResNet(
       (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
       (bn1): Norm_Layer(
         (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
       )
       (relu): ReLU(inplace=True)
       (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
       (layer1): Sequential(
         (0): BasicBlock(
           (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
           (bn1): Norm_Layer(
             (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
           )
           (relu): ReLU(inplace=True)
           (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
           (bn2): Norm_Layer(
             (group_norm): GroupNorm(32, 64, eps=1e-05, affine=True)
           )
         )
         (1): BasicBlock(
           (

## Evaluate on Uncorrupted Test Set

In [10]:
# Evaluate
batch_size = 1024
test_set = TensorDataset(torch.load('test_x.pt'), torch.load('test_y.pt'))
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)
from models import ResNetMainBranch
model_path = 'model_base_1.pth'
model_base_1 = ResNetMainBranch()
params = torch.load(model_path)
model_base_1.load_state_dict(params)
model_base_1 = model_base_1.to(device=device)
print('Uncorrupted test set: ', evaluate_main(model_base_1, test_loader))

Uncorrupted test set:  (0.6300153686450078, 0.8663499604117181)


## Evaluate on Corrupted Test Set

In [10]:
batch_size = 1024
test_gauss_set = TensorDataset(torch.load('test_x_gauss_noise.pt'), torch.load('test_y_corrupted.pt'))
test_gauss_loader = DataLoader(test_gauss_set, batch_size=batch_size, shuffle=True)

test_defocus_set = TensorDataset(torch.load('test_x_defocus_blur.pt'), torch.load('test_y_corrupted.pt'))
test_defocus_loader = DataLoader(test_defocus_set, batch_size=batch_size, shuffle=True)

test_elastic_set = TensorDataset(torch.load('test_x_elastic_transform.pt'), torch.load('test_y_corrupted.pt'))
test_elastic_loader = DataLoader(test_elastic_set, batch_size=batch_size, shuffle=True)

test_hmotion_set = TensorDataset(torch.load('test_x_hmotion_blur.pt'), torch.load('test_y_corrupted.pt'))
test_hmotion_loader = DataLoader(test_hmotion_set, batch_size=batch_size, shuffle=True)

test_vmotion_set = TensorDataset(torch.load('test_x_vmotion_blur.pt'), torch.load('test_y_corrupted.pt'))
test_vmotion_loader = DataLoader(test_vmotion_set, batch_size=batch_size, shuffle=True)

test_zoom_set = TensorDataset(torch.load('test_x_zoom_blur.pt'), torch.load('test_y_corrupted.pt'))
test_zoom_loader = DataLoader(test_zoom_set, batch_size=batch_size, shuffle=True)

test_shot_set = TensorDataset(torch.load('test_x_shot.pt'), torch.load('test_y_corrupted.pt'))
test_shot_loader = DataLoader(test_shot_set, batch_size=batch_size, shuffle=True)

test_impulse_set = TensorDataset(torch.load('test_x_impulse.pt'), torch.load('test_y_corrupted.pt'))
test_impulse_loader = DataLoader(test_impulse_set, batch_size=batch_size, shuffle=True)

test_contrast_set = TensorDataset(torch.load('test_x_contrast.pt'), torch.load('test_y_corrupted.pt'))
test_contrast_loader = DataLoader(test_contrast_set, batch_size=batch_size, shuffle=True)

from models import ResNetMainBranch
model_path = 'model_base_1.pth'
model_base_1 = ResNetMainBranch()
params = torch.load(model_path)
model_base_1.load_state_dict(params)
model_base_1 = model_base_1.to(device=device)

print("Gaussian noise: ", evaluate_main(model_base_1, test_gauss_loader))
print("Defocus blur: ", evaluate_main(model_base_1, test_defocus_loader))
print("Elastic transform: ", evaluate_main(model_base_1, test_elastic_loader))
print("Horizontal Motion blur: ", evaluate_main(model_base_1, test_hmotion_loader))
print("Virtical Motion blur: ", evaluate_main(model_base_1, test_vmotion_loader))
print("Zoom blur: ", evaluate_main(model_base_1, test_zoom_loader))
print("Shot: ", evaluate_main(model_base_1, test_shot_loader))
print("Impulse: ", evaluate_main(model_base_1, test_impulse_loader))
print("Contrast: ", evaluate_main(model_base_1, test_contrast_loader))


Gaussian noise:  (2.487208182995136, 0.6048297703879651)
Defocus blur:  (0.704764301960285, 0.8528107680126682)
Elastic transform:  (0.8109714067899264, 0.8272367379255741)
Horizontal Motion blur:  (1.66626422221844, 0.6737133808392716)
Virtical Motion blur:  (1.7649270937992976, 0.6558986539984165)
Zoom blur:  (0.6961561670670142, 0.8586698337292161)
Shot:  (1.4280506464151235, 0.7303246239113222)
Impulse:  (4.0637488548572245, 0.4505146476642914)
Contrast:  (2.1393172099040103, 0.6636579572446556)


# Experiment 1.5: Baseline CNN
## Training

In [9]:
# Experiment 1: Train a baseline ResNet18: no branch
lr = 1e-3
wd = 1e-4
batch_size = 64
train_set = TensorDataset(torch.load('train_x.pt'), torch.load('train_y.pt'))
val_set = TensorDataset(torch.load('val_x.pt'), torch.load('val_y.pt'))
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)

from cnn_models import CNNMainBranch
model_cnn_base_1 = CNNMainBranch()
optimizer = optim.Adam(model_cnn_base_1.parameters(), lr=lr, weight_decay=wd)
train_main(model_cnn_base_1, optimizer, train_loader, val_loader, epochs=30, model_path='model_cnn_base_1.pth', early_stop_patience=5)

[Epoch 0] train_acc: 0.44478449375159396, val_acc:0.7945415125621732, val_loss improved from inf to 0.6778. Saving model to model_cnn_base_1.pth.
[Epoch 1] train_acc: 0.872991583779648, val_acc:0.8996301492156613, val_loss improved from 0.6778 to 0.3613. Saving model to model_cnn_base_1.pth.
[Epoch 2] train_acc: 0.926900025503698, val_acc:0.9284530034434383, val_loss improved from 0.3613 to 0.2514. Saving model to model_cnn_base_1.pth.
[Epoch 3] train_acc: 0.9473029839326702, val_acc:0.9120010202780258, val_loss did not improve from 0.2514
[Epoch 4] train_acc: 0.9590665646518746, val_acc:0.9174850146664966, val_loss did not improve from 0.3109
[Epoch 5] train_acc: 0.9663988778372864, val_acc:0.9377630404285168, val_loss improved from 0.3446 to 0.2477. Saving model to model_cnn_base_1.pth.
[Epoch 6] train_acc: 0.9742731446059678, val_acc:0.93431960209157, val_loss did not improve from 0.2477
[Epoch 7] train_acc: 0.9751976536597806, val_acc:0.9396760617268206, val_loss improved from 0.27

(<logger.Logger at 0x14b18a35a30>,
 CNNMainBranch(
   (cnn): BaselineCNN(
     (conv1): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
     (conv3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (conv4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
     (conv5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (conv6): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (maxpool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
     (fc1): Linear(in_features=18432, out_features=9216, bias=True)
     (fc2): Linear(in_features=9216, out_features=1000, bias=True)
   )
   (relu): ReLU()
  

## Evaluate on Uncorrupted Test Set

In [10]:
# Evaluate
batch_size = 64
test_set = TensorDataset(torch.load('test_x.pt'), torch.load('test_y.pt'))
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

from cnn_models import CNNMainBranch
model_path = 'model_cnn_base_1.pth'
model_cnn_base_1 = CNNMainBranch()
params = torch.load(model_path)
model_cnn_base_1.load_state_dict(params)
model_cnn_base_1 = model_cnn_base_1.to(device=device)

print('Uncorrupted test set: ', evaluate_main(model_cnn_base_1, test_loader))

Uncorrupted test set:  (0.4815320502119986, 0.9054631828978622)


## Evaluate on Corrupted Test Set

In [11]:
batch_size = 64
test_gauss_set = TensorDataset(torch.load('test_x_gauss_noise.pt'), torch.load('test_y_corrupted.pt'))
test_gauss_loader = DataLoader(test_gauss_set, batch_size=batch_size, shuffle=True)

test_defocus_set = TensorDataset(torch.load('test_x_defocus_blur.pt'), torch.load('test_y_corrupted.pt'))
test_defocus_loader = DataLoader(test_defocus_set, batch_size=batch_size, shuffle=True)

test_elastic_set = TensorDataset(torch.load('test_x_elastic_transform.pt'), torch.load('test_y_corrupted.pt'))
test_elastic_loader = DataLoader(test_elastic_set, batch_size=batch_size, shuffle=True)

test_hmotion_set = TensorDataset(torch.load('test_x_hmotion_blur.pt'), torch.load('test_y_corrupted.pt'))
test_hmotion_loader = DataLoader(test_hmotion_set, batch_size=batch_size, shuffle=True)

test_vmotion_set = TensorDataset(torch.load('test_x_vmotion_blur.pt'), torch.load('test_y_corrupted.pt'))
test_vmotion_loader = DataLoader(test_vmotion_set, batch_size=batch_size, shuffle=True)

test_zoom_set = TensorDataset(torch.load('test_x_zoom_blur.pt'), torch.load('test_y_corrupted.pt'))
test_zoom_loader = DataLoader(test_zoom_set, batch_size=batch_size, shuffle=True)

test_shot_set = TensorDataset(torch.load('test_x_shot.pt'), torch.load('test_y_corrupted.pt'))
test_shot_loader = DataLoader(test_shot_set, batch_size=batch_size, shuffle=True)

test_impulse_set = TensorDataset(torch.load('test_x_impulse.pt'), torch.load('test_y_corrupted.pt'))
test_impulse_loader = DataLoader(test_impulse_set, batch_size=batch_size, shuffle=True)

test_contrast_set = TensorDataset(torch.load('test_x_contrast.pt'), torch.load('test_y_corrupted.pt'))
test_contrast_loader = DataLoader(test_contrast_set, batch_size=batch_size, shuffle=True)

from cnn_models import CNNMainBranch
model_path = 'model_cnn_base_1.pth'
model_cnn_base_1 = CNNMainBranch()
params = torch.load(model_path)
model_cnn_base_1.load_state_dict(params)
model_cnn_base_1 = model_cnn_base_1.to(device=device)

print("Gaussian noise: ", evaluate_main(model_cnn_base_1, test_gauss_loader))
print("Defocus blur: ", evaluate_main(model_cnn_base_1, test_defocus_loader))
print("Elastic transform: ", evaluate_main(model_cnn_base_1, test_elastic_loader))
print("Horizontal Motion blur: ", evaluate_main(model_cnn_base_1, test_hmotion_loader))
print("Virtical Motion blur: ", evaluate_main(model_cnn_base_1, test_vmotion_loader))
print("Zoom blur: ", evaluate_main(model_cnn_base_1, test_zoom_loader))
print("Shot: ", evaluate_main(model_cnn_base_1, test_shot_loader))
print("Impulse: ", evaluate_main(model_cnn_base_1, test_impulse_loader))
print("Contrast: ", evaluate_main(model_cnn_base_1, test_contrast_loader))

Gaussian noise:  (1.044238771001498, 0.7676959619952494)
Defocus blur:  (0.49437062746158456, 0.8997624703087886)
Elastic transform:  (0.6439499969434257, 0.8684877276326207)
Horizontal Motion blur:  (1.075326208213363, 0.7423594615993666)
Virtical Motion blur:  (1.9190652045336636, 0.5697545526524149)
Zoom blur:  (0.5435288520247648, 0.8913697545526524)
Shot:  (0.6516357510529384, 0.8585114806017419)
Impulse:  (1.182600476224013, 0.729612034837688)
Contrast:  (1.5937491890155908, 0.6174980205859065)


# Evaluate on the traditional enhanced set

In [None]:
batch_size = 64
test_enhanced_set = TensorDataset(torch.load('test_x_enhanced_impulse.pt'), torch.load('test_y_enhanced_impulse.pt'))
test_enhanced_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

from cnn_models import CNNMainBranch
model_path = 'model_cnn_base_1.pth'
model_cnn_base_1 = CNNMainBranch()
params = torch.load(model_path)
model_cnn_base_1.load_state_dict(params)
model_cnn_base_1 = model_cnn_base_1.to(device=device)

print('Uncorrupted test set: ', evaluate_main(model_cnn_base_1, test_enhanced_loader))

# Experiment 2: ResNet18 with Auxillary Branch (No Online Training)
## Training

In [9]:
lr = 1e-3
wd = 1e-5
batch_size = 1024
train_mean = [86.69585, 86.342995, 85.84817]
train_std = [74.59906, 74.196365, 73.890495]

train_rotate_set = TensorDataset(torch.load('train_x_rotate.pt'), torch.load('train_y_rotate.pt'))
val_rotate_set = TensorDataset(torch.load('val_x_rotate.pt'), torch.load('val_y_rotate.pt'))

train_rotate_loader = DataLoader(train_rotate_set, batch_size=batch_size, shuffle=True)
val_rotate_loader = DataLoader(val_rotate_set, batch_size=batch_size, shuffle=True)

from models import ResNetTwoBranch
exp_2_model_1 = ResNetTwoBranch()
optimizer = optim.Adam(exp_2_model_1.parameters(), lr=lr, weight_decay=wd)
train_both(exp_2_model_1, optimizer, train_rotate_loader, val_rotate_loader, epochs=50, model_path='exp_2_model_1.pth', early_stop_patience=5)

[Epoch 0] train_acc: 0.22435284366233105, val_acc:0.5029014156357607, val_loss improved from inf to 1.6574. Saving model to exp_2_model_1.pth.
[Epoch 1] train_acc: 0.6622353991328742, val_acc:0.7657824257110063, val_loss improved from 1.6574 to 0.7710. Saving model to exp_2_model_1.pth.
[Epoch 2] train_acc: 0.8420205304769192, val_acc:0.8594567019512818, val_loss improved from 0.7710 to 0.4682. Saving model to exp_2_model_1.pth.
[Epoch 3] train_acc: 0.8975707727620506, val_acc:0.8927751562300726, val_loss improved from 0.4682 to 0.3605. Saving model to exp_2_model_1.pth.
[Epoch 4] train_acc: 0.927402129558786, val_acc:0.8969200357097309, val_loss improved from 0.3605 to 0.3528. Saving model to exp_2_model_1.pth.
[Epoch 5] train_acc: 0.9433339709257842, val_acc:0.906038770564979, val_loss improved from 0.3528 to 0.3133. Saving model to exp_2_model_1.pth.
[Epoch 6] train_acc: 0.9564683754144351, val_acc:0.9068358627726055, val_loss did not improve from 0.3133
[Epoch 7] train_acc: 0.96286

KeyboardInterrupt: 

## Evaluate on Uncorrupted Test Set

In [9]:
batch_size = 1024
test_set = TensorDataset(torch.load('test_x.pt'), torch.load('test_y.pt'))
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

from models import ResNetTwoBranch
model_path = 'exp_2_model_1.pth'
exp_2_model_1 = ResNetTwoBranch()
params = torch.load(model_path)
exp_2_model_1.load_state_dict(params)
exp_2_model_1 = exp_2_model_1.to(device=device)
print('Uncorrupted test set: ', evaluate_non_rotate(exp_2_model_1, test_loader))

Uncorrupted test set:  (0.7577824179942791, 0.8707046714172605)


## Evaluate on Corrupted Test Set

In [13]:
batch_size = 1024
test_gauss_set = TensorDataset(torch.load('test_x_gauss_noise.pt'), torch.load('test_y_corrupted.pt'))
test_gauss_loader = DataLoader(test_gauss_set, batch_size=batch_size, shuffle=True)

test_defocus_set = TensorDataset(torch.load('test_x_defocus_blur.pt'), torch.load('test_y_corrupted.pt'))
test_defocus_loader = DataLoader(test_defocus_set, batch_size=batch_size, shuffle=True)

test_elastic_set = TensorDataset(torch.load('test_x_elastic_transform.pt'), torch.load('test_y_corrupted.pt'))
test_elastic_loader = DataLoader(test_elastic_set, batch_size=batch_size, shuffle=True)

test_hmotion_set = TensorDataset(torch.load('test_x_hmotion_blur.pt'), torch.load('test_y_corrupted.pt'))
test_hmotion_loader = DataLoader(test_hmotion_set, batch_size=batch_size, shuffle=True)

test_vmotion_set = TensorDataset(torch.load('test_x_vmotion_blur.pt'), torch.load('test_y_corrupted.pt'))
test_vmotion_loader = DataLoader(test_vmotion_set, batch_size=batch_size, shuffle=True)

test_zoom_set = TensorDataset(torch.load('test_x_zoom_blur.pt'), torch.load('test_y_corrupted.pt'))
test_zoom_loader = DataLoader(test_zoom_set, batch_size=batch_size, shuffle=True)

test_shot_set = TensorDataset(torch.load('test_x_shot.pt'), torch.load('test_y_corrupted.pt'))
test_shot_loader = DataLoader(test_shot_set, batch_size=batch_size, shuffle=True)

test_impulse_set = TensorDataset(torch.load('test_x_impulse.pt'), torch.load('test_y_corrupted.pt'))
test_impulse_loader = DataLoader(test_impulse_set, batch_size=batch_size, shuffle=True)

test_contrast_set = TensorDataset(torch.load('test_x_contrast.pt'), torch.load('test_y_corrupted.pt'))
test_contrast_loader = DataLoader(test_contrast_set, batch_size=batch_size, shuffle=True)

from models import ResNetTwoBranch
model_path = 'exp_2_model_1.pth'
exp_2_model_1 = ResNetTwoBranch()
params = torch.load(model_path)
exp_2_model_1.load_state_dict(params)
exp_2_model_1 = exp_2_model_1.to(device=device)

print("Gaussian noise: ", evaluate_non_rotate(exp_2_model_1, test_gauss_loader))
print("Defocus blur: ", evaluate_non_rotate(exp_2_model_1, test_defocus_loader))
print("Elastic transform: ", evaluate_non_rotate(exp_2_model_1, test_elastic_loader))
print("Horizontal Motion blur: ", evaluate_non_rotate(exp_2_model_1, test_hmotion_loader))
print("Vertical Motion blur: ", evaluate_non_rotate(exp_2_model_1, test_vmotion_loader))
print("Zoom blur: ", evaluate_non_rotate(exp_2_model_1, test_zoom_loader))
print("Shot: ", evaluate_non_rotate(exp_2_model_1, test_shot_loader))
print("Impulse: ", evaluate_non_rotate(exp_2_model_1, test_impulse_loader))
print("Contrast: ", evaluate_non_rotate(exp_2_model_1, test_contrast_loader))

Gaussian noise:  (3.3912662175985484, 0.6321456848772763)
Defocus blur:  (0.8509451242593619, 0.8607284243863816)
Elastic transform:  (1.0640129309434156, 0.8300870942201108)
Horizontal Motion blur:  (2.1406559577355018, 0.6624703087885986)
Virtical Motion blur:  (2.1507714803402243, 0.6610451306413302)
Zoom blur:  (0.8463354981862582, 0.8596991290577989)
Shot:  (1.79858844096844, 0.7488519398258116)
Impulse:  (5.357336667867807, 0.49382422802850356)
Contrast:  (2.2237631724430966, 0.7022961203483769)


# Experiment 2.5: CNN with Auxillary Branch (No Online Training)
## Training

In [None]:
lr = 1e-3
wd = 1e-5
batch_size = 64
train_mean = [86.69585, 86.342995, 85.84817]
train_std = [74.59906, 74.196365, 73.890495]

train_rotate_set = TensorDataset(torch.load('train_x_rotate.pt'), torch.load('train_y_rotate.pt'))
val_rotate_set = TensorDataset(torch.load('val_x_rotate.pt'), torch.load('val_y_rotate.pt'))

train_rotate_loader = DataLoader(train_rotate_set, batch_size=batch_size, shuffle=True)
val_rotate_loader = DataLoader(val_rotate_set, batch_size=batch_size, shuffle=True)

from cnn_models import CNNTwoBranch
exp_2_cnn_model_1 = CNNTwoBranch()
optimizer = optim.Adam(exp_2_cnn_model_1.parameters(), lr=lr, weight_decay=wd)
train_both(exp_2_cnn_model_1, optimizer, train_rotate_loader, val_rotate_loader, epochs=30, model_path='exp_2_cnn_model_1.pth', early_stop_patience=5)

## Evaluate on Uncorrupted Test Set

In [18]:
batch_size = 64
test_set = TensorDataset(torch.load('test_x.pt'), torch.load('test_y.pt'))
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

from cnn_models import CNNTwoBranch
model_path = 'exp_2_cnn_model_1.pth'
exp_2_cnn_model_1 = CNNTwoBranch()
params = torch.load(model_path)
exp_2_cnn_model_1.load_state_dict(params)
exp_2_cnn_model_1 = exp_2_cnn_model_1.to(device=device)
print('Uncorrupted test set: ', evaluate_non_rotate(exp_2_cnn_model_1, test_loader))

Uncorrupted test set:  (0.5389380103936701, 0.9046714172604909)


## Evaluate on Corrupted Test Set

In [19]:
batch_size = 1024
test_gauss_set = TensorDataset(torch.load('test_x_gauss_noise.pt'), torch.load('test_y_corrupted.pt'))
test_gauss_loader = DataLoader(test_gauss_set, batch_size=batch_size, shuffle=True)

test_defocus_set = TensorDataset(torch.load('test_x_defocus_blur.pt'), torch.load('test_y_corrupted.pt'))
test_defocus_loader = DataLoader(test_defocus_set, batch_size=batch_size, shuffle=True)

test_elastic_set = TensorDataset(torch.load('test_x_elastic_transform.pt'), torch.load('test_y_corrupted.pt'))
test_elastic_loader = DataLoader(test_elastic_set, batch_size=batch_size, shuffle=True)

test_hmotion_set = TensorDataset(torch.load('test_x_hmotion_blur.pt'), torch.load('test_y_corrupted.pt'))
test_hmotion_loader = DataLoader(test_hmotion_set, batch_size=batch_size, shuffle=True)

test_vmotion_set = TensorDataset(torch.load('test_x_vmotion_blur.pt'), torch.load('test_y_corrupted.pt'))
test_vmotion_loader = DataLoader(test_vmotion_set, batch_size=batch_size, shuffle=True)

test_zoom_set = TensorDataset(torch.load('test_x_zoom_blur.pt'), torch.load('test_y_corrupted.pt'))
test_zoom_loader = DataLoader(test_zoom_set, batch_size=batch_size, shuffle=True)

test_shot_set = TensorDataset(torch.load('test_x_shot.pt'), torch.load('test_y_corrupted.pt'))
test_shot_loader = DataLoader(test_shot_set, batch_size=batch_size, shuffle=True)

test_impulse_set = TensorDataset(torch.load('test_x_impulse.pt'), torch.load('test_y_corrupted.pt'))
test_impulse_loader = DataLoader(test_impulse_set, batch_size=batch_size, shuffle=True)

test_contrast_set = TensorDataset(torch.load('test_x_contrast.pt'), torch.load('test_y_corrupted.pt'))
test_contrast_loader = DataLoader(test_contrast_set, batch_size=batch_size, shuffle=True)

from cnn_models import CNNTwoBranch
model_path = 'exp_2_cnn_model_1.pth'
exp_2_cnn_model_1 = CNNTwoBranch()
params = torch.load(model_path)
exp_2_cnn_model_1.load_state_dict(params)
exp_2_cnn_model_1 = exp_2_cnn_model_1.to(device=device)

print("Gaussian noise: ", evaluate_non_rotate(exp_2_cnn_model_1, test_gauss_loader))
print("Defocus blur: ", evaluate_non_rotate(exp_2_cnn_model_1, test_defocus_loader))
print("Elastic transform: ", evaluate_non_rotate(exp_2_cnn_model_1, test_elastic_loader))
print("Horizontal Motion blur: ", evaluate_non_rotate(exp_2_cnn_model_1, test_hmotion_loader))
print("Vertical Motion blur: ", evaluate_non_rotate(exp_2_cnn_model_1, test_vmotion_loader))
print("Zoom blur: ", evaluate_non_rotate(exp_2_cnn_model_1, test_zoom_loader))
print("Shot: ", evaluate_non_rotate(exp_2_cnn_model_1, test_shot_loader))
print("Impulse: ", evaluate_non_rotate(exp_2_cnn_model_1, test_impulse_loader))
print("Contrast: ", evaluate_non_rotate(exp_2_cnn_model_1, test_contrast_loader))

Gaussian noise:  (1.5474033264013438, 0.7116389548693587)
Defocus blur:  (0.5702346425790054, 0.8922406967537608)
Elastic transform:  (0.7285343683682955, 0.8669041963578781)
Horizontal Motion blur:  (1.6213888205014741, 0.6400633412509897)
Vertical Motion blur:  (2.0637162832113414, 0.5583531274742676)
Zoom blur:  (0.591289235995366, 0.8904196357878068)
Shot:  (0.8383319698847257, 0.8400633412509897)
Impulse:  (2.1770995396834154, 0.5878859857482185)
Contrast:  (3.185165441953219, 0.4818685669041964)


# Experiment 3: ResNet18 with Auxillary Branch (TTT)
## Prepare Model

In [2]:
def initialize_model_part3(model_path):
    lr = 1e-3
    from models import ResNetTwoBranch
    exp_3_model_1 = ResNetTwoBranch()
    params = torch.load(model_path)
    exp_3_model_1.load_state_dict(params)
    optimizer = optim.SGD(exp_3_model_1.parameters(), lr=lr)
    return exp_3_model_1, optimizer

## TTT on Uncorrupted Test Set

In [10]:
batch_size = 1024
test_set = TensorDataset(torch.load('test_x.pt'), torch.load('test_y.pt'))
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

test_rotate_set = TensorDataset(torch.load('test_x_rotate.pt'), torch.load('test_y_rotate.pt'))
test_rotate_loader = DataLoader(test_rotate_set, batch_size=batch_size, shuffle=True)

exp_3_model_1, optimizer = initialize_model_part3('exp_2_model_1.pth')
print('Uncorrupted test set: ', ttt(exp_3_model_1, test_loader, test_rotate_loader, optimizer))

Uncorrupted test set:  tensor(0.8705, device='cuda:0')


## TTT on Corrupted Test Set

In [29]:
batch_size = 1024

def gauss():
    test_gauss_set = TensorDataset(torch.load('test_x_gauss_noise.pt'), torch.load('test_y_corrupted.pt'))
    test_gauss_loader = DataLoader(test_gauss_set, batch_size=batch_size, shuffle=True)
    test_gauss_rotate_set = TensorDataset(torch.load('test_x_rotate_gauss_noise.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_gauss_rotate_loader = DataLoader(test_gauss_rotate_set, batch_size=batch_size, shuffle=True)
    return test_gauss_loader, test_gauss_rotate_loader

def defocus():
    test_defocus_set = TensorDataset(torch.load('test_x_defocus_blur.pt'), torch.load('test_y_corrupted.pt'))
    test_defocus_loader = DataLoader(test_defocus_set, batch_size=batch_size, shuffle=True)
    test_defocus_rotate_set = TensorDataset(torch.load('test_x_rotate_defocus_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_defocus_rotate_loader = DataLoader(test_defocus_rotate_set, batch_size=batch_size, shuffle=True)
    return test_defocus_loader, test_defocus_rotate_loader

def elastic():
    test_elastic_set = TensorDataset(torch.load('test_x_elastic_transform.pt'), torch.load('test_y_corrupted.pt'))
    test_elastic_loader = DataLoader(test_elastic_set, batch_size=batch_size, shuffle=True)
    test_elastic_rotate_set = TensorDataset(torch.load('test_x_rotate_elastic_transform.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_elastic_rotate_loader = DataLoader(test_elastic_rotate_set, batch_size=batch_size, shuffle=True)
    return test_elastic_loader, test_elastic_rotate_loader

def hmotion():
    test_hmotion_set = TensorDataset(torch.load('test_x_hmotion_blur.pt'), torch.load('test_y_corrupted.pt'))
    test_hmotion_loader = DataLoader(test_hmotion_set, batch_size=batch_size, shuffle=True)
    test_hmotion_rotate_set = TensorDataset(torch.load('test_x_rotate_hmotion_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_hmotion_rotate_loader = DataLoader(test_hmotion_rotate_set, batch_size=batch_size, shuffle=True)
    return test_hmotion_loader, test_hmotion_rotate_loader

def vmotion():
    test_vmotion_set = TensorDataset(torch.load('test_x_vmotion_blur.pt'), torch.load('test_y_corrupted.pt'))
    test_vmotion_loader = DataLoader(test_vmotion_set, batch_size=batch_size, shuffle=True)
    test_vmotion_rotate_set = TensorDataset(torch.load('test_x_rotate_vmotion_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_vmotion_rotate_loader = DataLoader(test_vmotion_rotate_set, batch_size=batch_size, shuffle=True)
    return test_vmotion_loader, test_vmotion_rotate_loader

def zoom():
    test_zoom_set = TensorDataset(torch.load('test_x_zoom_blur.pt'), torch.load('test_y_corrupted.pt'))
    test_zoom_loader = DataLoader(test_zoom_set, batch_size=batch_size, shuffle=True)
    test_zoom_rotate_set = TensorDataset(torch.load('test_x_rotate_zoom_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_zoom_rotate_loader = DataLoader(test_zoom_rotate_set, batch_size=batch_size, shuffle=True)
    return test_zoom_loader, test_zoom_rotate_loader

def shot():
    test_shot_set = TensorDataset(torch.load('test_x_shot.pt'), torch.load('test_y_corrupted.pt'))
    test_shot_loader = DataLoader(test_shot_set, batch_size=batch_size, shuffle=True)
    test_shot_rotate_set = TensorDataset(torch.load('test_x_rotate_shot.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_shot_rotate_loader = DataLoader(test_shot_rotate_set, batch_size=batch_size, shuffle=True)
    return test_shot_loader, test_shot_rotate_loader

def impulse():
    test_impulse_set = TensorDataset(torch.load('test_x_impulse.pt'), torch.load('test_y_corrupted.pt'))
    test_impulse_loader = DataLoader(test_impulse_set, batch_size=batch_size, shuffle=True)
    test_impulse_rotate_set = TensorDataset(torch.load('test_x_rotate_impulse.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_impulse_rotate_loader = DataLoader(test_impulse_rotate_set, batch_size=batch_size, shuffle=True)
    return test_impulse_loader,  test_impulse_rotate_loader

def contrast():     
    test_contrast_set = TensorDataset(torch.load('test_x_contrast.pt'), torch.load('test_y_corrupted.pt'))
    test_contrast_loader = DataLoader(test_contrast_set, batch_size=batch_size, shuffle=True)
    test_contrast_rotate_set = TensorDataset(torch.load('test_x_rotate_contrast.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_contrast_rotate_loader = DataLoader(test_contrast_rotate_set, batch_size=batch_size, shuffle=True)
    return test_contrast_loader, test_contrast_rotate_loader

test_loaders = [gauss, defocus, elastic, hmotion, vmotion, zoom, shot, impulse, contrast]
test_names = ["Gaussian noise", "Defocus blur", "Elastic transform", 'Horizontal motion blur', 'Vertical motion blur', 'Zoom blur', 'Shot', 'Impulse', 'Contrast']

for i in range(len(test_loaders)):
    exp_3_model_1, optimizer = initialize_model_part3('exp_2_model_1.pth')
    loaders = test_loaders[i]()
    print(test_names[i], ': ', ttt(exp_3_model_1, loaders[0], loaders[1], optimizer))

Gaussian noise :  tensor(0.6324, device='cuda:0')
Defocus blur :  tensor(0.8609, device='cuda:0')
Elastic transform :  tensor(0.8310, device='cuda:0')
Horizontal motion blur :  tensor(0.6643, device='cuda:0')
Vertical motion blur :  tensor(0.6676, device='cuda:0')
Zoom blur :  tensor(0.8603, device='cuda:0')
Shot :  tensor(0.7460, device='cuda:0')
Impulse :  tensor(0.4981, device='cuda:0')
Contrast :  tensor(0.7095, device='cuda:0')


# Experiment 3.5: CNN with Auxillary Branch (TTT)
## Prepare Model

In [1]:
def initialize_cnn_model_part3(model_path):
    lr = 1e-3
    from cnn_models import CNNTwoBranch
    exp_3_cnn_model_1 = CNNTwoBranch()
    params = torch.load(model_path)
    exp_3_cnn_model_1.load_state_dict(params)
    optimizer = optim.SGD(exp_3_cnn_model_1.parameters(), lr=lr)
    return exp_3_cnn_model_1, optimizer

## TTT on Uncorrupted Test Set

In [10]:
batch_size = 64
test_set = TensorDataset(torch.load('test_x.pt'), torch.load('test_y.pt'))
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

test_rotate_set = TensorDataset(torch.load('test_x_rotate.pt'), torch.load('test_y_rotate.pt'))
test_rotate_loader = DataLoader(test_rotate_set, batch_size=batch_size, shuffle=True)

exp_3_cnn_model_1, optimizer = initialize_cnn_model_part3('exp_2_cnn_model_1.pth')
print('Uncorrupted test set: ', ttt(exp_3_cnn_model_1, test_loader, test_rotate_loader, optimizer))

Uncorrupted test set:  tensor(0.9046, device='cuda:0')


## TTT on Corrupted Test Set

In [11]:
batch_size = 64

def gauss():
    test_gauss_set = TensorDataset(torch.load('test_x_gauss_noise.pt'), torch.load('test_y_corrupted.pt'))
    test_gauss_loader = DataLoader(test_gauss_set, batch_size=batch_size, shuffle=True)
    test_gauss_rotate_set = TensorDataset(torch.load('test_x_rotate_gauss_noise.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_gauss_rotate_loader = DataLoader(test_gauss_rotate_set, batch_size=batch_size, shuffle=True)
    return test_gauss_loader, test_gauss_rotate_loader

def defocus():
    test_defocus_set = TensorDataset(torch.load('test_x_defocus_blur.pt'), torch.load('test_y_corrupted.pt'))
    test_defocus_loader = DataLoader(test_defocus_set, batch_size=batch_size, shuffle=True)
    test_defocus_rotate_set = TensorDataset(torch.load('test_x_rotate_defocus_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_defocus_rotate_loader = DataLoader(test_defocus_rotate_set, batch_size=batch_size, shuffle=True)
    return test_defocus_loader, test_defocus_rotate_loader

def elastic():
    test_elastic_set = TensorDataset(torch.load('test_x_elastic_transform.pt'), torch.load('test_y_corrupted.pt'))
    test_elastic_loader = DataLoader(test_elastic_set, batch_size=batch_size, shuffle=True)
    test_elastic_rotate_set = TensorDataset(torch.load('test_x_rotate_elastic_transform.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_elastic_rotate_loader = DataLoader(test_elastic_rotate_set, batch_size=batch_size, shuffle=True)
    return test_elastic_loader, test_elastic_rotate_loader

def hmotion():
    test_hmotion_set = TensorDataset(torch.load('test_x_hmotion_blur.pt'), torch.load('test_y_corrupted.pt'))
    test_hmotion_loader = DataLoader(test_hmotion_set, batch_size=batch_size, shuffle=True)
    test_hmotion_rotate_set = TensorDataset(torch.load('test_x_rotate_hmotion_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_hmotion_rotate_loader = DataLoader(test_hmotion_rotate_set, batch_size=batch_size, shuffle=True)
    return test_hmotion_loader, test_hmotion_rotate_loader

def vmotion():
    test_vmotion_set = TensorDataset(torch.load('test_x_vmotion_blur.pt'), torch.load('test_y_corrupted.pt'))
    test_vmotion_loader = DataLoader(test_vmotion_set, batch_size=batch_size, shuffle=True)
    test_vmotion_rotate_set = TensorDataset(torch.load('test_x_rotate_vmotion_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_vmotion_rotate_loader = DataLoader(test_vmotion_rotate_set, batch_size=batch_size, shuffle=True)
    return test_vmotion_loader, test_vmotion_rotate_loader

def zoom():
    test_zoom_set = TensorDataset(torch.load('test_x_zoom_blur.pt'), torch.load('test_y_corrupted.pt'))
    test_zoom_loader = DataLoader(test_zoom_set, batch_size=batch_size, shuffle=True)
    test_zoom_rotate_set = TensorDataset(torch.load('test_x_rotate_zoom_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_zoom_rotate_loader = DataLoader(test_zoom_rotate_set, batch_size=batch_size, shuffle=True)
    return test_zoom_loader, test_zoom_rotate_loader

def shot():
    test_shot_set = TensorDataset(torch.load('test_x_shot.pt'), torch.load('test_y_corrupted.pt'))
    test_shot_loader = DataLoader(test_shot_set, batch_size=batch_size, shuffle=True)
    test_shot_rotate_set = TensorDataset(torch.load('test_x_rotate_shot.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_shot_rotate_loader = DataLoader(test_shot_rotate_set, batch_size=batch_size, shuffle=True)
    return test_shot_loader, test_shot_rotate_loader

def impulse():
    test_impulse_set = TensorDataset(torch.load('test_x_impulse.pt'), torch.load('test_y_corrupted.pt'))
    test_impulse_loader = DataLoader(test_impulse_set, batch_size=batch_size, shuffle=True)
    test_impulse_rotate_set = TensorDataset(torch.load('test_x_rotate_impulse.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_impulse_rotate_loader = DataLoader(test_impulse_rotate_set, batch_size=batch_size, shuffle=True)
    return test_impulse_loader,  test_impulse_rotate_loader

def contrast():     
    test_contrast_set = TensorDataset(torch.load('test_x_contrast.pt'), torch.load('test_y_corrupted.pt'))
    test_contrast_loader = DataLoader(test_contrast_set, batch_size=batch_size, shuffle=True)
    test_contrast_rotate_set = TensorDataset(torch.load('test_x_rotate_contrast.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_contrast_rotate_loader = DataLoader(test_contrast_rotate_set, batch_size=batch_size, shuffle=True)
    return test_contrast_loader, test_contrast_rotate_loader

test_loaders = [gauss, defocus, elastic, hmotion, vmotion, zoom, shot, impulse, contrast]
test_names = ["Gaussian noise", "Defocus blur", "Elastic transform", 'Horizontal motion blur', 'Vertical motion blur', 'Zoom blur', 'Shot', 'Impulse', 'Contrast']

for i in range(len(test_loaders)):
    exp_3_cnn_model_1, optimizer = initialize_cnn_model_part3('exp_2_cnn_model_1.pth')
    loaders = test_loaders[i]()
    print(test_names[i], ': ', ttt(exp_3_cnn_model_1, loaders[0], loaders[1], optimizer))

Gaussian noise :  tensor(0.7134, device='cuda:0')
Defocus blur :  tensor(0.8931, device='cuda:0')
Elastic transform :  tensor(0.8660, device='cuda:0')
Horizontal motion blur :  tensor(0.6478, device='cuda:0')
Vertical motion blur :  tensor(0.5719, device='cuda:0')
Zoom blur :  tensor(0.8918, device='cuda:0')
Shot :  tensor(0.8429, device='cuda:0')
Impulse :  tensor(0.5901, device='cuda:0')
Contrast :  tensor(0.5160, device='cuda:0')


# Experiment 4: ResNet18 with Auxillary Branch (TTT Online)
## Prepare Model

In [9]:
def initialize_model_part4(model_path):
    lr = 1e-3
    wd = 0
    from models import ResNetTwoBranch
    exp_4_model_1 = ResNetTwoBranch()
    params = torch.load(model_path)
    exp_4_model_1.load_state_dict(params)
    optimizer = optim.SGD(exp_4_model_1.parameters(), lr=lr)
    return exp_4_model_1, optimizer

## Online TTT on Uncorrupted Test Set

In [31]:
batch_size = 4
test_set = TensorDataset(torch.load('test_x.pt'), torch.load('test_y.pt'))
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
test_rotate_set = TensorDataset(torch.load('test_x_rotate.pt'), torch.load('test_y_rotate.pt'))
test_rotate_loader = DataLoader(test_rotate_set, batch_size=batch_size, shuffle=True)
exp_3_model_1, optimizer = initialize_model_part4('exp_2_model_1.pth')
print('Uncorrupted test set: ', ttt_online(exp_3_model_1, test_loader, test_rotate_loader, optimizer))

Uncorrupted test set:  tensor(0.8674, device='cuda:0')


## Online TTT on Corrupted Test Set

In [10]:
batch_size = 4

def gauss():
    test_gauss_set = TensorDataset(torch.load('test_x_gauss_noise.pt'), torch.load('test_y_corrupted.pt'))
    test_gauss_loader = DataLoader(test_gauss_set, batch_size=batch_size, shuffle=False)
    test_gauss_rotate_set = TensorDataset(torch.load('test_x_rotate_gauss_noise.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_gauss_rotate_loader = DataLoader(test_gauss_rotate_set, batch_size=batch_size, shuffle=False)
    return test_gauss_loader, test_gauss_rotate_loader

def defocus():
    test_defocus_set = TensorDataset(torch.load('test_x_defocus_blur.pt'), torch.load('test_y_corrupted.pt'))
    test_defocus_loader = DataLoader(test_defocus_set, batch_size=batch_size, shuffle=False)
    test_defocus_rotate_set = TensorDataset(torch.load('test_x_rotate_defocus_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_defocus_rotate_loader = DataLoader(test_defocus_rotate_set, batch_size=batch_size, shuffle=False)
    return test_defocus_loader, test_defocus_rotate_loader

def elastic():
    test_elastic_set = TensorDataset(torch.load('test_x_elastic_transform.pt'), torch.load('test_y_corrupted.pt'))
    test_elastic_loader = DataLoader(test_elastic_set, batch_size=batch_size, shuffle=False)
    test_elastic_rotate_set = TensorDataset(torch.load('test_x_rotate_elastic_transform.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_elastic_rotate_loader = DataLoader(test_elastic_rotate_set, batch_size=batch_size, shuffle=False)
    return test_elastic_loader, test_elastic_rotate_loader

def hmotion():
    test_hmotion_set = TensorDataset(torch.load('test_x_hmotion_blur.pt'), torch.load('test_y_corrupted.pt'))
    test_hmotion_loader = DataLoader(test_hmotion_set, batch_size=batch_size, shuffle=False)
    test_hmotion_rotate_set = TensorDataset(torch.load('test_x_rotate_hmotion_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_hmotion_rotate_loader = DataLoader(test_hmotion_rotate_set, batch_size=batch_size, shuffle=False)
    return test_hmotion_loader, test_hmotion_rotate_loader

def vmotion():
    test_vmotion_set = TensorDataset(torch.load('test_x_vmotion_blur.pt'), torch.load('test_y_corrupted.pt'))
    test_vmotion_loader = DataLoader(test_vmotion_set, batch_size=batch_size, shuffle=False)
    test_vmotion_rotate_set = TensorDataset(torch.load('test_x_rotate_vmotion_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_vmotion_rotate_loader = DataLoader(test_vmotion_rotate_set, batch_size=batch_size, shuffle=False)
    return test_vmotion_loader, test_vmotion_rotate_loader

def zoom():
    test_zoom_set = TensorDataset(torch.load('test_x_zoom_blur.pt'), torch.load('test_y_corrupted.pt'))
    test_zoom_loader = DataLoader(test_zoom_set, batch_size=batch_size, shuffle=False)
    test_zoom_rotate_set = TensorDataset(torch.load('test_x_rotate_zoom_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_zoom_rotate_loader = DataLoader(test_zoom_rotate_set, batch_size=batch_size, shuffle=False)
    return test_zoom_loader, test_zoom_rotate_loader

def shot():
    test_shot_set = TensorDataset(torch.load('test_x_shot.pt'), torch.load('test_y_corrupted.pt'))
    test_shot_loader = DataLoader(test_shot_set, batch_size=batch_size, shuffle=False)
    test_shot_rotate_set = TensorDataset(torch.load('test_x_rotate_shot.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_shot_rotate_loader = DataLoader(test_shot_rotate_set, batch_size=batch_size, shuffle=False)
    return test_shot_loader, test_shot_rotate_loader

def impulse():
    test_impulse_set = TensorDataset(torch.load('test_x_impulse.pt'), torch.load('test_y_corrupted.pt'))
    test_impulse_loader = DataLoader(test_impulse_set, batch_size=batch_size, shuffle=False)
    test_impulse_rotate_set = TensorDataset(torch.load('test_x_rotate_impulse.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_impulse_rotate_loader = DataLoader(test_impulse_rotate_set, batch_size=batch_size, shuffle=False)
    return test_impulse_loader,  test_impulse_rotate_loader

def contrast():     
    test_contrast_set = TensorDataset(torch.load('test_x_contrast.pt'), torch.load('test_y_corrupted.pt'))
    test_contrast_loader = DataLoader(test_contrast_set, batch_size=batch_size, shuffle=False)
    test_contrast_rotate_set = TensorDataset(torch.load('test_x_rotate_contrast.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_contrast_rotate_loader = DataLoader(test_contrast_rotate_set, batch_size=batch_size, shuffle=False)
    return test_contrast_loader, test_contrast_rotate_loader

test_loaders = [gauss, defocus, elastic, hmotion, vmotion, zoom, shot, impulse, contrast]
test_names = ["Gaussian noise", "Defocus blur", "Elastic transform", 'Horizontal motion blur', 'Vertical motion blur', 'Zoom blur', 'Shot', 'Impulse', 'Contrast']

for i in range(len(test_loaders)):
    exp_3_model_1, optimizer = initialize_model_part4('exp_2_model_1.pth')
    loaders = test_loaders[i]()
    print(test_names[i], ': ', ttt(exp_3_model_1, loaders[0], loaders[1], optimizer))

NameError: name 'initialize_model_part3' is not defined

# Experiment 4.5: CNN with Auxillary Branch (TTT Online)
## Prepare Model

In [12]:
def initialize_cnn_model_part4(model_path):
    lr = 1e-3
    wd = 0
    from cnn_models import CNNTwoBranch
    exp_4_cnn_model_1 = CNNTwoBranch()
    params = torch.load(model_path)
    exp_4_cnn_model_1.load_state_dict(params)
    optimizer = optim.SGD(exp_4_cnn_model_1.parameters(), lr=lr)
    return exp_4_cnn_model_1, optimizer

## Online TTT on Uncorrupted Test Set

In [13]:
batch_size = 4
test_set = TensorDataset(torch.load('test_x.pt'), torch.load('test_y.pt'))
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
test_rotate_set = TensorDataset(torch.load('test_x_rotate.pt'), torch.load('test_y_rotate.pt'))
test_rotate_loader = DataLoader(test_rotate_set, batch_size=batch_size, shuffle=True)
exp_4_cnn_model_1, optimizer = initialize_cnn_model_part4('exp_2_cnn_model_1.pth')
print('Uncorrupted test set: ', ttt_online(exp_4_cnn_model_1, test_loader, test_rotate_loader, optimizer))

Uncorrupted test set:  tensor(0.8971, device='cuda:0')


## Online TTT on Corrupted Test Set

In [15]:
batch_size = 4

def gauss():
    test_gauss_set = TensorDataset(torch.load('test_x_gauss_noise.pt'), torch.load('test_y_corrupted.pt'))
    test_gauss_loader = DataLoader(test_gauss_set, batch_size=batch_size, shuffle=False)
    test_gauss_rotate_set = TensorDataset(torch.load('test_x_rotate_gauss_noise.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_gauss_rotate_loader = DataLoader(test_gauss_rotate_set, batch_size=batch_size, shuffle=False)
    return test_gauss_loader, test_gauss_rotate_loader

def defocus():
    test_defocus_set = TensorDataset(torch.load('test_x_defocus_blur.pt'), torch.load('test_y_corrupted.pt'))
    test_defocus_loader = DataLoader(test_defocus_set, batch_size=batch_size, shuffle=False)
    test_defocus_rotate_set = TensorDataset(torch.load('test_x_rotate_defocus_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_defocus_rotate_loader = DataLoader(test_defocus_rotate_set, batch_size=batch_size, shuffle=False)
    return test_defocus_loader, test_defocus_rotate_loader

def elastic():
    test_elastic_set = TensorDataset(torch.load('test_x_elastic_transform.pt'), torch.load('test_y_corrupted.pt'))
    test_elastic_loader = DataLoader(test_elastic_set, batch_size=batch_size, shuffle=False)
    test_elastic_rotate_set = TensorDataset(torch.load('test_x_rotate_elastic_transform.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_elastic_rotate_loader = DataLoader(test_elastic_rotate_set, batch_size=batch_size, shuffle=False)
    return test_elastic_loader, test_elastic_rotate_loader

def hmotion():
    test_hmotion_set = TensorDataset(torch.load('test_x_hmotion_blur.pt'), torch.load('test_y_corrupted.pt'))
    test_hmotion_loader = DataLoader(test_hmotion_set, batch_size=batch_size, shuffle=False)
    test_hmotion_rotate_set = TensorDataset(torch.load('test_x_rotate_hmotion_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_hmotion_rotate_loader = DataLoader(test_hmotion_rotate_set, batch_size=batch_size, shuffle=False)
    return test_hmotion_loader, test_hmotion_rotate_loader

def vmotion():
    test_vmotion_set = TensorDataset(torch.load('test_x_vmotion_blur.pt'), torch.load('test_y_corrupted.pt'))
    test_vmotion_loader = DataLoader(test_vmotion_set, batch_size=batch_size, shuffle=False)
    test_vmotion_rotate_set = TensorDataset(torch.load('test_x_rotate_vmotion_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_vmotion_rotate_loader = DataLoader(test_vmotion_rotate_set, batch_size=batch_size, shuffle=False)
    return test_vmotion_loader, test_vmotion_rotate_loader

def zoom():
    test_zoom_set = TensorDataset(torch.load('test_x_zoom_blur.pt'), torch.load('test_y_corrupted.pt'))
    test_zoom_loader = DataLoader(test_zoom_set, batch_size=batch_size, shuffle=False)
    test_zoom_rotate_set = TensorDataset(torch.load('test_x_rotate_zoom_blur.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_zoom_rotate_loader = DataLoader(test_zoom_rotate_set, batch_size=batch_size, shuffle=False)
    return test_zoom_loader, test_zoom_rotate_loader

def shot():
    test_shot_set = TensorDataset(torch.load('test_x_shot.pt'), torch.load('test_y_corrupted.pt'))
    test_shot_loader = DataLoader(test_shot_set, batch_size=batch_size, shuffle=False)
    test_shot_rotate_set = TensorDataset(torch.load('test_x_rotate_shot.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_shot_rotate_loader = DataLoader(test_shot_rotate_set, batch_size=batch_size, shuffle=False)
    return test_shot_loader, test_shot_rotate_loader

def impulse():
    test_impulse_set = TensorDataset(torch.load('test_x_impulse.pt'), torch.load('test_y_corrupted.pt'))
    test_impulse_loader = DataLoader(test_impulse_set, batch_size=batch_size, shuffle=False)
    test_impulse_rotate_set = TensorDataset(torch.load('test_x_rotate_impulse.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_impulse_rotate_loader = DataLoader(test_impulse_rotate_set, batch_size=batch_size, shuffle=False)
    return test_impulse_loader,  test_impulse_rotate_loader

def contrast():     
    test_contrast_set = TensorDataset(torch.load('test_x_contrast.pt'), torch.load('test_y_corrupted.pt'))
    test_contrast_loader = DataLoader(test_contrast_set, batch_size=batch_size, shuffle=False)
    test_contrast_rotate_set = TensorDataset(torch.load('test_x_rotate_contrast.pt'), torch.load('test_y_rotate_corrupted.pt'))
    test_contrast_rotate_loader = DataLoader(test_contrast_rotate_set, batch_size=batch_size, shuffle=False)
    return test_contrast_loader, test_contrast_rotate_loader

test_loaders = [gauss, defocus, elastic, hmotion, vmotion, zoom, shot, impulse, contrast]
test_names = ["Gaussian noise", "Defocus blur", "Elastic transform", 'Horizontal motion blur', 'Vertical motion blur', 'Zoom blur', 'Shot', 'Impulse', 'Contrast']

for i in range(len(test_loaders)):
    exp_4_cnn_model_1, optimizer = initialize_cnn_model_part4('exp_2_cnn_model_1.pth')
    loaders = test_loaders[i]()
    print(test_names[i], ': ', ttt(exp_4_cnn_model_1, loaders[0], loaders[1], optimizer))

Gaussian noise :  tensor(0.7277, device='cuda:0')
Defocus blur :  tensor(0.9004, device='cuda:0')
Elastic transform :  tensor(0.8735, device='cuda:0')
Horizontal motion blur :  tensor(0.6664, device='cuda:0')
Vertical motion blur :  tensor(0.6032, device='cuda:0')
Zoom blur :  tensor(0.8956, device='cuda:0')
Shot :  tensor(0.8541, device='cuda:0')
Impulse :  tensor(0.6165, device='cuda:0')
Contrast :  tensor(0.5667, device='cuda:0')
