In [1]:
import os, sys
parent_dir = os.path.abspath('../..')

if parent_dir not in sys.path:
    sys.path.append(parent_dir)

# Import

In [2]:
import time
from itertools import cycle
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
from torchvision.utils import make_grid
import torchvision.datasets as dset
import matplotlib.pyplot as plt

import utils_os
import mi


torch.manual_seed(time.time())

%load_ext autoreload
%autoreload

# Global Variables

In [3]:
DEVICE = 'cuda:0' 

class Hyperparams(utils_os.ConfigDict):

    def __init__(self): 
        self.device = DEVICE
        self.learning_rate = 1e-3
        self.n_epochs = 100
        self.grad_clip = None
        self.batch_size = 128
        self.image_shape = (32, 32)
        self.update_inner_every_num_epoch = 2

BASE_HYPERPARAMS = Hyperparams()

ROOT = '../..'
MODEL_SAVE_DIR = f'{ROOT}/results/c_s'
DATASET_DIR = f'{ROOT}/data'

NUM_COLOR_CHANNELS = 3
NUM_CLASSES = 9
NUM_DOMAIN = 2

DIM_Z_CONTENT, DIM_Z_DOMAIN = 128, 128

INFOMIN_TRAIN_BATCH_SIZE = 2500

# Prepare Dataloaders

### Align labels

In [4]:
"""
CIFAR
0: plane
1: car
2: bird
3: cat
4: deer
5: dog
6: frog
7: horse
8: ship
9: truck

remove frog(6)
"""

"""
STL
0: plane
1: bird
2: car
3: cat
4: deer
5: dog
6: horse
7: monkey
8: ship
9: truck

remove monkey(7)
"""
CIFAR_MAP = {
    0: 0,
    1: 1,
    2: 2,
    3: 3,
    4: 4,
    5: 5,
    6: -2,
    7: 6,
    8: 7,
    9: 8,
    -1: -1,
}

STL_MAP = {
    0: 0,
    1: 2,
    2: 1,
    3: 3,
    4: 4,
    5: 5,
    6: 6,
    7: -2,
    8: 7,
    9: 8,
    -1: -1,
}

### Load CIFAR10 and STL

In [5]:
if not os.path.exists(DATASET_DIR):
    os.mkdir(DATASET_DIR)
    
# define data transformations
trans_train = transforms.Compose([
    transforms.Resize(32),
    transforms.RandomCrop(32, padding=4, padding_mode='reflect'), 
    transforms.RandomHorizontalFlip(), 
    transforms.ToTensor(), 
])
trans_test = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(), 
])
target_trans_cifar = transforms.Compose([lambda x: CIFAR_MAP[x]])
target_trans_stl = transforms.Compose([lambda x: STL_MAP[x]])


# load cifar (remove the class 'frog')
m_train_set = dset.CIFAR10(root=DATASET_DIR, train=True, transform=trans_train, target_transform=target_trans_cifar, download=True)
m_test_set = dset.CIFAR10(root=DATASET_DIR, train=False, transform=trans_test, target_transform=target_trans_cifar, download=True)

cifar_train_indices = [idx for idx, (_, target) in enumerate(m_train_set) if target != -2]
m_train_loader = torch.utils.data.DataLoader(
                 dataset=torch.utils.data.Subset(m_train_set, cifar_train_indices),
                 batch_size=BASE_HYPERPARAMS.batch_size)
cifar_test_indices = [idx for idx, (_, target) in enumerate(m_test_set) if target != -2]
m_test_loader = torch.utils.data.DataLoader(
                dataset=torch.utils.data.Subset(m_test_set, cifar_test_indices),
                batch_size=BASE_HYPERPARAMS.batch_size,
                shuffle=False)
domain1 = 'CIFAR'

# load stl-10 (remove the class 'monkey')
mm_train_set = dset.STL10(root=DATASET_DIR, split='train', transform=trans_train, target_transform=target_trans_stl, download=True)
mm_test_set = dset.STL10(root=DATASET_DIR, split='test', transform=trans_test, target_transform=target_trans_stl, download=True)

stl_train_indices = [idx for idx, (_, target) in enumerate(mm_train_set) if target != -2]
mm_train_loader = torch.utils.data.DataLoader(
                 dataset=torch.utils.data.Subset(mm_train_set, stl_train_indices),
                 batch_size=BASE_HYPERPARAMS.batch_size)
stl_test_indices = [idx for idx, (_, target) in enumerate(mm_test_set) if target != -2]
mm_test_loader = torch.utils.data.DataLoader(
                dataset=torch.utils.data.Subset(mm_test_set, stl_test_indices),
                batch_size=BASE_HYPERPARAMS.batch_size,
                shuffle=False)  
domain2 = 'STL'

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


### Load all data into memeory for faster access when training infomin layers

In [6]:
def loader2tensor(dataloader):
    all_data, all_target = [], []
    for data, target in iter(dataloader):
        all_data.append(data)
        all_target.append(target)
    print('loader2tensor finished.', 'data max=', data.max(), 'data min=', data.min())
    return torch.cat(all_data, dim=0), torch.cat(all_target, dim=0)

all_data_m, all_label_m = loader2tensor(m_train_loader)
all_data_test_m, all_label_test_m = loader2tensor(m_test_loader)

all_data_mm, all_label_mm = loader2tensor(mm_train_loader)
all_data_test_mm, all_label_test_mm = loader2tensor(mm_test_loader)

loader2tensor finished. data max= tensor(1.) data min= tensor(0.)
loader2tensor finished. data max= tensor(1.) data min= tensor(0.)
loader2tensor finished. data max= tensor(1.) data min= tensor(0.)
loader2tensor finished. data max= tensor(1.) data min= tensor(0.)


# Experiments

In [7]:
from tasks.domain_adaptation import DomainAdaptation, train, test
from routine import exp_run


def inner_batch_provider():
    idx1, idx2 = torch.randperm(len(all_data_m)), torch.randperm(len(all_data_mm))
    infomin_x1, infomin_x2 = all_data_m[idx1[:INFOMIN_TRAIN_BATCH_SIZE]].to(DEVICE), all_data_mm[idx2[:INFOMIN_TRAIN_BATCH_SIZE]].to(DEVICE)

    return infomin_x1, infomin_x2


def get_model_path(hyperparams, epoch=None):
    if epoch is not None:
        return f'{MODEL_SAVE_DIR}/{hyperparams.estimator}/model_{hyperparams.get_name("alpha", "beta", "gamma")}_{epoch}'
    return f'{MODEL_SAVE_DIR}/{hyperparams.estimator}/model_{hyperparams.get_name("alpha", "beta", "gamma")}'

In [9]:
import copy

hyperparams = copy.deepcopy(BASE_HYPERPARAMS)
hyperparams.alpha = 1.0
hyperparams.beta = 1.0
hyperparams.gamma = 0.3
hyperparams.n_slice = 100
hyperparams.learning_rate = 1e-3
hyperparams.batch_size = 128
hyperparams.estimator = 'SLICE'
hyperparams.n_epochs = 70

da = DomainAdaptation(DIM_Z_CONTENT, DIM_Z_DOMAIN, NUM_CLASSES, NUM_DOMAIN, hyperparams).to(DEVICE)

da = exp_run(
    (m_train_loader, mm_train_loader), (m_test_loader, mm_test_loader),
    train, test,
    inner_batch_provider, get_model_path,
    hyperparams,
    device='cuda:0',
    model=da,
    # scheduler_func=lambda optimizer: torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=20)
    scheduler_func=lambda optimizer: torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50], gamma=0.1)
)

epoch: 0, loss: 2.1747948821161835, loss_content: 1.4733088839222008, loss_domain: 0.6865523583452466 redundancy: 0.049778826521630855  acc_d1: 51.56666564941406 acc_d2: 45.00667190551758 acc_domain: 55.43019104003906

epoch: 1, loss: 1.450339114162284, loss_content: 0.766580473369276, loss_domain: 0.669143989052571 redundancy: 0.048715557142252654  acc_d1: 72.96666717529297 acc_d2: 61.97731399536133 acc_domain: 58.78723907470703

epoch: 2, loss: 0.9842295470372052, loss_content: 0.5789589516713586, loss_domain: 0.38811774321005377 redundancy: 0.05717618944792365  acc_d1: 79.47777557373047 acc_d2: 68.1272201538086 acc_domain: 82.23099517822266

epoch: 3, loss: 0.8458467277003007, loss_content: 0.6221673480221923, loss_domain: 0.20840076871321234 redundancy: 0.050928688732723534  acc_d1: 78.97777557373047 acc_d2: 63.80115509033203 acc_domain: 91.92974853515625

epoch: 4, loss: 0.7643178852511124, loss_content: 0.576227991094052, loss_domain: 0.17206421945716294 redundancy: 0.05341890115

In [10]:
import copy

hyperparams = copy.deepcopy(BASE_HYPERPARAMS)
hyperparams.alpha = 1.0
hyperparams.beta = 1.0
hyperparams.gamma = 0.005
hyperparams.estimator = 'CLUB'
hyperparams.learning_rate = 1e-3
hyperparams.inner_lr = 1e-3
hyperparams.inner_epochs = 5
hyperparams.inner_batch_size = 2500

da = DomainAdaptation(DIM_Z_CONTENT, DIM_Z_DOMAIN, NUM_CLASSES, NUM_DOMAIN, hyperparams).to(DEVICE)

da = exp_run(
    (m_train_loader, mm_train_loader), (m_test_loader, mm_test_loader),
    train, test,
    inner_batch_provider, get_model_path,
    hyperparams,
    device=DEVICE,
    model=da,
    scheduler_func=lambda optimizer: torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50], gamma=0.1)
)

epoch: 0, loss: 1.0979798175919224, loss_content: 1.4901455523262561, loss_domain: 0.4815712566946594 redundancy: -174.74740364182162  acc_d1: 42.52222442626953 acc_d2: 33.3073844909668 acc_domain: 86.82747650146484

epoch: 1, loss: 1.8145598727212826, loss_content: 0.9155045042575245, loss_domain: 1.0532486249863262 redundancy: -30.838649588571467  acc_d1: 67.45555877685547 acc_d2: 56.516902923583984 acc_domain: 76.0060043334961

epoch: 2, loss: 2.579122953011956, loss_content: 1.0795885198552844, loss_domain: 1.4820463204048049 redundancy: 3.497622073536188  acc_d1: 63.57777786254883 acc_d2: 52.09074783325195 acc_domain: 52.97354507446289

epoch: 3, loss: 1.1834586699244003, loss_content: 0.6738110386149984, loss_domain: 0.5278251477530305 redundancy: -3.6355011899706344  acc_d1: 76.76667022705078 acc_d2: 64.64635467529297 acc_domain: 81.81969451904297

epoch: 4, loss: 1.6870660143838803, loss_content: 0.606261090073787, loss_domain: 1.0945426485907863 redundancy: -2.7475456116904673

KeyboardInterrupt: 