In [1]:
import os, sys
parent_dir = os.path.abspath('../..')

if parent_dir not in sys.path:
    sys.path.append(parent_dir)

# Import

In [2]:
import time
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F
from torchvision.utils import make_grid
import torchvision.datasets as dset
import matplotlib.pyplot as plt

import utils_os
import mi
from datasets.mnist_m import MNISTM


torch.manual_seed(time.time())

%load_ext autoreload
%autoreload 2

# Global Variables

In [3]:
DEVICE = 'cuda:0' 

class Hyperparams(utils_os.ConfigDict):

    def __init__(self): 
        self.device = DEVICE
        self.learning_rate = 1e-3
        self.n_epochs = 100
        self.grad_clip = None
        self.batch_size = 128
        self.image_shape = (32, 32)
        self.update_inner_every_num_epoch = 2

BASE_HYPERPARAMS = Hyperparams()

ROOT = '../..'
MODEL_SAVE_DIR = f'{ROOT}/results/m_mm'
DATASET_DIR = f'{ROOT}/data'

NUM_COLOR_CHANNELS = 3
NUM_CLASSES = 10
NUM_DOMAIN = 2

DIM_Z_CONTENT, DIM_Z_DOMAIN = 128, 128

# Prepare Dataloaders

### Load MNIST

In [4]:
if not os.path.exists(DATASET_DIR):
    os.mkdir(DATASET_DIR)
    
trans = transforms.Compose([transforms.Resize(32), transforms.ToTensor(), lambda x: x.repeat(3, 1, 1)])  # expand color channel of grayscale image to align with mnistm
m_train_set = dset.MNIST(root=DATASET_DIR, train=True, transform=trans, download=True)
m_test_set = dset.MNIST(root=DATASET_DIR, train=False, transform=trans, download=True)

m_train_loader = torch.utils.data.DataLoader(
                 dataset=m_train_set,
                 batch_size=BASE_HYPERPARAMS.batch_size)
m_test_loader = torch.utils.data.DataLoader(
                dataset=m_test_set,
                batch_size=BASE_HYPERPARAMS.batch_size,
                shuffle=False)              

DOMAIN1 = 'MNIST'

### Load MNISTM

In [5]:
trans = transforms.Compose([transforms.Resize(32), transforms.ToTensor()])
mm_train_set = MNISTM(root=DATASET_DIR, train=True, transform=trans, download=True)
mm_test_set = MNISTM(root=DATASET_DIR, train=False, transform=trans, download=True)

mm_train_loader = torch.utils.data.DataLoader(
                 dataset=mm_train_set,
                 batch_size=BASE_HYPERPARAMS.batch_size)
mm_test_loader = torch.utils.data.DataLoader(
                dataset=mm_test_set,
                batch_size=BASE_HYPERPARAMS.batch_size,
                shuffle=False)  

DOMAIN2 = 'MNISTM'

../../data/MNISTM/processed/mnist_m_train.pt
../../data/MNISTM/processed/mnist_m_test.pt


### Load all data into memeory for faster access when training infomin layers

In [6]:
def loader2tensor(dataloader):
    all_data, all_target = [], []
    for data, target in dataloader:
        all_data.append(data)
        all_target.append(target)
    print('loader2tensor finished.', 'data max=', data.max(), 'data min=', data.min())
    return torch.cat(all_data, dim=0), torch.cat(all_target, dim=0)

all_data_m, all_label_m = loader2tensor(m_train_loader)
all_data_test_m, all_label_test_m = loader2tensor(m_test_loader)

all_data_mm, all_label_mm = loader2tensor(mm_train_loader)
all_data_test_mm, all_label_test_mm = loader2tensor(mm_test_loader)

INFOMIN_TRAIN_BATCH_SIZE = 2500

loader2tensor finished. data max= tensor(1.) data min= tensor(0.)
loader2tensor finished. data max= tensor(1.) data min= tensor(0.)
loader2tensor finished. data max= tensor(1.) data min= tensor(0.)
loader2tensor finished. data max= tensor(1.) data min= tensor(0.)


# Experiments

In [7]:
from tasks.domain_adaptation.experiment import DomainAdaptation, train, test
from train_test_routine import exp_run


def inner_batch_provider():
    idx1, idx2 = torch.randperm(len(all_data_m)), torch.randperm(len(all_data_mm))
    infomin_x1, infomin_x2 = all_data_m[idx1[:INFOMIN_TRAIN_BATCH_SIZE]].to(DEVICE), all_data_mm[idx2[:INFOMIN_TRAIN_BATCH_SIZE]].to(DEVICE)

    return infomin_x1, infomin_x2


def get_model_path(hyperparams, epoch=None):
    if epoch is not None:
        return f'{MODEL_SAVE_DIR}/{hyperparams.estimator}/model_{hyperparams.get_name("alpha", "beta", "gamma")}_{epoch}'
    return f'{MODEL_SAVE_DIR}/{hyperparams.estimator}/model_{hyperparams.get_name("alpha", "beta", "gamma")}'

In [8]:
import copy

hyperparams = copy.deepcopy(BASE_HYPERPARAMS)
hyperparams.alpha = 1.0
hyperparams.beta = 1.0
hyperparams.gamma = 0.1
hyperparams.n_slice = 100
hyperparams.estimator = 'SLICE'

da = DomainAdaptation(DIM_Z_CONTENT, DIM_Z_DOMAIN, NUM_CLASSES, NUM_DOMAIN, hyperparams).to(DEVICE)

da = exp_run(
    (m_train_loader, mm_train_loader), (m_test_loader, mm_test_loader),
    train, test,
    inner_batch_provider, get_model_path,
    hyperparams,
    device=DEVICE,
    model=da,
    scheduler_func=lambda optimizer: torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50], gamma=0.1)
)

epoch: 0, loss: 0.06311824537012019, loss_content: 0.05941274517828811, loss_domain: 0.0002561857632090208 redundancy: 0.034493144748944646  acc_d1: 98.18000030517578 acc_d2: 78.33999633789062 acc_domain: 100.0

epoch: 1, loss: 0.049141075827558583, loss_content: 0.03919112134895794, loss_domain: 0.0005385672268892426 redundancy: 0.09411387199748165  acc_d1: 98.79000091552734 acc_d2: 79.8499984741211 acc_domain: 99.97000122070312

epoch: 2, loss: 0.05082822209106216, loss_content: 0.03384992097341603, loss_domain: 0.0001535694965598309 redundancy: 0.16824730721455586  acc_d1: 98.97000122070312 acc_d2: 79.11000061035156 acc_domain: 99.9949951171875

epoch: 3, loss: 0.060243162285253594, loss_content: 0.04139788303622988, loss_domain: 4.3696490794462976e-05 redundancy: 0.18801582595215569  acc_d1: 98.69999694824219 acc_d2: 80.18999481201172 acc_domain: 100.0

epoch: 4, loss: 0.041076019170540795, loss_content: 0.025978112773534318, loss_domain: 9.771604709385429e-06 redundancy: 0.1508813

In [10]:
import copy

hyperparams = copy.deepcopy(BASE_HYPERPARAMS)
hyperparams.alpha = 1.0
hyperparams.beta = 1.0
hyperparams.gamma = 0.01
hyperparams.estimator = 'CLUB'
hyperparams.learning_rate = 1e-4
hyperparams.inner_lr = 5e-4
hyperparams.inner_epochs = 5
hyperparams.inner_batch_size = 2500

da = DomainAdaptation(DIM_Z_CONTENT, DIM_Z_DOMAIN, NUM_CLASSES, NUM_DOMAIN, hyperparams).to(DEVICE)

da = exp_run(
    (m_train_loader, mm_train_loader), (m_test_loader, mm_test_loader),
    train, test,
    inner_batch_provider, get_model_path,
    hyperparams,
    device=DEVICE,
    model=da,
    scheduler_func=lambda optimizer: torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50], gamma=0.1)
)

epoch: 0, loss: 0.11022160786993895, loss_content: 0.1033524866628496, loss_domain: 0.0012245605231109488 redundancy: 0.5644560299719437  acc_d1: 98.8499984741211 acc_d2: 53.13999938964844 acc_domain: 99.9949951171875

epoch: 1, loss: 0.07882452143144004, loss_content: 0.07178078815812551, loss_domain: 0.0004483862913396398 redundancy: 0.6595347369018989  acc_d1: 98.86000061035156 acc_d2: 62.1099967956543 acc_domain: 100.0

epoch: 2, loss: 0.04712345303755395, loss_content: 0.045065468318666084, loss_domain: 0.0002311480318526207 redundancy: 0.1826836957207209  acc_d1: 98.95999908447266 acc_d2: 65.20999908447266 acc_domain: 100.0

epoch: 3, loss: 0.05443184275793124, loss_content: 0.04893891337693115, loss_domain: 0.00013804790842762406 redundancy: 0.5354881520512738  acc_d1: 98.88999938964844 acc_d2: 63.82999801635742 acc_domain: 100.0

epoch: 4, loss: 0.029591629063545524, loss_content: 0.02612104630864025, loss_domain: 9.145181905330878e-05 redundancy: 0.3379131183971333  acc_d1: 99