In [1]:
# import these libraries for the re-implementation
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torch.optim import Adam
import torch.optim.lr_scheduler as lr_scheduler

import numpy as np

from data.data import RotatedMNIST
from modules.modules import Encoder, Classifier, Discriminator


In [2]:
# check if CUDA is available
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available')
else:
    print('CUDA is available')
    device = torch.device('cuda')

CUDA is available


### Parameters

In [3]:
"""Parameters for the experiment"""
# build the network
input_dim = 28
hidden_dim = 256
d_hidden_dim = 512
dropout = 0.2

num_features = 100
num_classes = 10
domain_dim = 1

# train the network
epochs = 50
batch_size = 100
lr = 2e-4
weight_decay = 5e-4
beta1 = 0.9
beta2 = 0.999
gamma = 0.5 ** (1 / 50)
lambda_gan = 2.0
lr_scheduler = lr_scheduler

### Data Loading

In [4]:
dataset = RotatedMNIST(rotation_range=(0, 360))
train_dataloader = DataLoader(
    dataset=dataset,
    shuffle=True,
    batch_size=batch_size,
    num_workers=2
)
test_dataloader = DataLoader(
    dataset=dataset,
    shuffle=True,
    batch_size=batch_size,
    num_workers=2
)

### Model

In [5]:
encoder = Encoder(
    input_dim=input_dim,
    hidden_dim=hidden_dim,
    output_dim=num_features,
    dropout=dropout,
    domain_dim=domain_dim
)

classifier = Classifier(
    input_dim=num_features,
    hidden_dim=hidden_dim,
    output_dim=num_classes
)

discriminator = Discriminator(
    input_dim=num_features,
    hidden_dim=d_hidden_dim,
    output_dim=domain_dim
)

if train_on_gpu:
    encoder = encoder.to(device)
    classifier = classifier.to(device)
    discriminator = discriminator.to(device)

#### Functions for CIDA

In [6]:
def to_np(x):
    return x.detach().cpu().numpy()

def to_tensor(x):
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x).to(device)
    else:
        x = x.to(device)
    return x

def init_weight(self, net=None):
    if net is None:
        net = self
    for m in net.modules():
        # if the layer is a Linear layer, then
        if isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, mean=0, std=0.01) # fills the input weight with values drawn from the normal distribution
            nn.init.constant_(m.bias, val=0) # fills the input bias with the value 0

def set_requires_grad(nets, requires_grad=False):
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad

def acc_reset_mnist():
    hit_domain, cnt_domain = np.zeros(8), np.zeros(8)
    acc_source, acc_target = 0, 0
    cnt_source, cnt_target = 0, 0
    hit_source, hit_target = 0, 0
    
    return hit_domain, cnt_domain, acc_source, acc_target, cnt_source, cnt_target, hit_source, hit_target

def acc_update_mnist(y, g, domain, is_source, metrics_domain, metrics_source, metrics_target):
    Y = to_np(y)
    G = to_np(g)
    T = to_np(domain)
    T = (T * 8).astype(np.int32)
    T[T >= 8] = 7
    hit = (Y == G).astype(np.float32)

    is_s = to_np(is_source)

    hit_domain, cnt_domain = metrics_domain
    acc_source, cnt_source, hit_source = metrics_source
    acc_target, cnt_target, hit_target = metrics_target

    for i in range(8):
        hit_domain[i] += hit[T == i].sum()
        cnt_domain[i] += (T == i).sum()
    acc_domain = hit_domain / (cnt_domain + 1e-10)
    acc_source, acc_target = acc_domain[0], acc_domain[1:].mean()
    acc_domain = np.round(acc_domain, decimals=3)
    acc_source = np.round(acc_source, decimals=3)
    acc_target = np.round(acc_target, decimals=3)

    cnt_source += (is_s == 1).sum()
    cnt_target += (is_s == 0).sum()

    hit_source += (hit[is_s == 1]).sum()
    hit_target += (hit[is_s == 0]).sum()

    return hit_domain, cnt_domain, acc_source, acc_target, cnt_source, cnt_target, hit_source, hit_target

### Continuously Indexed Domain Adaptation (CIDA)

In [7]:
def train(encoder, classifier, discriminator, epochs, lr_scheduler):
    init_weight(encoder)

    encoder.train()
    classifier.train()
    discriminator.train()

    criterion = nn.NLLLoss()
    d_criterion = nn.MSELoss()
    optimizer = Adam(
        list(encoder.parameters()) + list(classifier.parameters()),
        lr=lr,
        betas=(beta1, beta2),
        weight_decay=weight_decay
    )
    d_optimizer = Adam(
        discriminator.parameters(),
        lr=lr,
        betas=(beta1, beta2),
        weight_decay=weight_decay
    )
    e_lr_scheduler = lr_scheduler.ExponentialLR(
        optimizer=optimizer,
        gamma=gamma
    )
    d_lr_scheduler = lr_scheduler.ExponentialLR(
        optimizer=d_optimizer,
        gamma=gamma
    )
    lr_schedulers = [e_lr_scheduler, d_lr_scheduler]

    for epoch in range(epochs):
        accuracies = []
        d_losses = []
        e_gan_losses = []
        e_pred_losses = []
        hit_domain, cnt_domain, acc_source, acc_target, cnt_source, cnt_target, hit_source, hit_target = acc_reset_mnist()

        for batch_idx, (x, y, u, domain) in enumerate(train_dataloader):
            x, y, u, domain = x.to(device), y.to(device), u.to(device), domain.to(device)
            domain = domain[:, 0]
            is_source = (domain < 1.0 / 8).to(torch.float)

            optimizer.zero_grad()
            d_optimizer.zero_grad()

            x_align, features = encoder(x, u)
            predictions = classifier(features)

            set_requires_grad(discriminator, True)
            d = discriminator(features.detach())
            d_src = d_criterion(d[is_source == 1], u[is_source == 1])
            d_tgt = d_criterion(d[is_source == 0], u[is_source == 0])
            d_loss = (d_src + d_tgt) / 2
            d_loss.backward()
            d_optimizer.step()

            set_requires_grad(discriminator, False)
            d = discriminator(features)
            e_gan_src = d_criterion(d[is_source == 1], u[is_source == 1])
            e_gan_tgt = d_criterion(d[is_source == 0], u[is_source == 0])
            e_gan_loss = (e_gan_src + e_gan_tgt) / 2

            y_src = y[is_source == 1]
            predictions_src = predictions[is_source == 1]
            e_pred_loss = criterion(predictions_src, y_src)
            e_loss = e_gan_loss * lambda_gan + e_pred_loss
            e_loss.backward()
            optimizer.step()

            d_losses.append(d_loss.detach().item())
            e_gan_losses.append(e_gan_loss.detach().item())
            e_pred_losses.append(e_pred_loss.detach().item())

            g = torch.argmax(predictions.detach(), dim=1)
            metrics_domain = hit_domain, cnt_domain
            metrics_source = acc_source, cnt_source, hit_source
            metrics_target = acc_target, cnt_target, hit_target
            hit_domain, cnt_domain, acc_source, acc_target, cnt_source, cnt_target, hit_source, hit_target = acc_update_mnist(y, g, domain, is_source, metrics_domain, metrics_source, metrics_target)    

        # print for each epoch
        print(f'Epoch: {epoch + 1} \n \t D Loss:{torch.tensor(d_losses).mean():0.3f} \t E_gan Loss:{torch.tensor(e_gan_losses).mean():0.3f} \t E_pred Loss:{torch.tensor(e_pred_losses).mean():0.3f}')
        print(f' \t Source Acc:{acc_source:0.3f} ({hit_source}/{cnt_source}) \t Target Acc:{acc_target:0.3f} ({hit_target}/{cnt_target})')
        test(encoder, classifier, discriminator)

        for lr_scheduler in lr_schedulers:
            lr_scheduler.step()

    return encoder, classifier, discriminator

def test(encoder, classifier, discriminator):
    encoder.eval()
    classifier.eval()
    discriminator.eval()

    hit_domain, cnt_domain, acc_source, acc_target, cnt_source, cnt_target, hit_source, hit_target = acc_reset_mnist()

    for batch_idx, (x, y, u, domain) in enumerate(test_dataloader):
        x, y, u, domain = x.to(device), y.to(device), u.to(device), domain.to(device)
        domain = domain[:, 0]
        is_source = (domain < 1.0 / 8).to(torch.float)

        x_align, features = encoder(x, u)
        predictions = classifier(features)

        g = torch.argmax(predictions.detach(), dim=1)
        metrics_domain = hit_domain, cnt_domain
        metrics_source = acc_source, cnt_source, hit_source
        metrics_target = acc_target, cnt_target, hit_target
        hit_domain, cnt_domain, acc_source, acc_target, cnt_source, cnt_target, hit_source, hit_target = acc_update_mnist(y, g, domain, is_source, metrics_domain, metrics_source, metrics_target)

    print(f' \t Val Source Acc:{acc_source:0.3f} ({hit_source}/{cnt_source}) \t Val Target Acc:{acc_target:0.3f} ({hit_target}/{cnt_target})')

def generate_result_table(encoder, classifier, discriminator):
    field_names = ["Accuracy"] + ["Source"] + [f"Target #{i}" for i in range(1, 8)]

    hit = np.zeros((10, 8))
    cnt = np.zeros((10, 8))

    for batch_idx, (x, y, u, domain) in enumerate(test_dataloader):
        x, y, u, domain = x.to(device), y.to(device), u.to(device), domain.to(device)
        domain = domain[:, 0]
        is_source = (domain < 1.0 / 8).to(torch.float)

        x_align, features = encoder(x, u)
        predictions = classifier(features)

        g = torch.argmax(predictions.detach(), dim=1)

        Y = to_np(y)
        G = to_np(g)
        T = to_np(u)[:, 0]
        T = (T * 8).astype(np.int32)
        T[T >= 8] = 7

        for label, pred, domain in zip(Y, G, T):
            hit[label, domain] += int(label == pred)
            cnt[label, domain] += 1

    print('Accuracy Source Target #1 #2 #3 #4 #5 #6 #7')
    for c in range(10):
        print(f'Class {c}' + str(list(np.round(100 * hit[c] / cnt[c], decimals=1))))
    print(f'Total' + str(list(np.round(100 * hit.sum(0) / cnt.sum(0), decimals=1))))


In [8]:
encoder, classifier, discriminator = train(encoder, classifier, discriminator, epochs, lr_scheduler)

The default behavior has changed from using the upper triangular portion of the matrix by default to using the lower triangular portion.
L, _ = torch.symeig(A, upper=upper)
should be replaced with
L = torch.linalg.eigvalsh(A, UPLO='U' if upper else 'L')
and
L, V = torch.symeig(A, eigenvectors=True)
should be replaced with
L, V = torch.linalg.eigh(A, UPLO='U' if upper else 'L') (Triggered internally at  ..\aten\src\ATen\native\BatchLinearAlgebra.cpp:2499.)
  _, evs = torch.symeig(A, eigenvectors=True)
  "Default grid_sample and affine_grid behavior has changed "
  "Default grid_sample and affine_grid behavior has changed "


Epoch: 1 
 	 D Loss:0.097 	 E_gan Loss:0.083 	 E_pred Loss:0.758
 	 Source Acc:0.763 (5683.0/7450) 	 Target Acc:0.409 (21508.0/52550)
 	 Val Source Acc:0.920 (6770.0/7356) 	 Val Target Acc:0.483 (25446.0/52644)
Epoch: 2 
 	 D Loss:0.062 	 E_gan Loss:0.058 	 E_pred Loss:0.316
 	 Source Acc:0.898 (6604.0/7357) 	 Target Acc:0.446 (23483.0/52643)
 	 Val Source Acc:0.936 (6888.0/7360) 	 Val Target Acc:0.437 (22945.0/52640)
Epoch: 3 
 	 D Loss:0.045 	 E_gan Loss:0.043 	 E_pred Loss:0.198
 	 Source Acc:0.942 (6992.0/7424) 	 Target Acc:0.466 (24501.0/52576)
 	 Val Source Acc:0.953 (7052.0/7400) 	 Val Target Acc:0.512 (26885.0/52600)
Epoch: 4 
 	 D Loss:0.042 	 E_gan Loss:0.041 	 E_pred Loss:0.175
 	 Source Acc:0.948 (7102.0/7490) 	 Target Acc:0.496 (26061.0/52510)
 	 Val Source Acc:0.946 (7048.0/7453) 	 Val Target Acc:0.486 (25490.0/52547)
Epoch: 5 
 	 D Loss:0.039 	 E_gan Loss:0.038 	 E_pred Loss:0.148
 	 Source Acc:0.952 (7122.0/7483) 	 Target Acc:0.501 (26302.0/52517)
 	 Val Source Acc:0.96

In [9]:
generate_result_table(encoder, classifier, discriminator)

Accuracy Source Target #1 #2 #3 #4 #5 #6 #7
Class 0[98.8, 96.8, 92.6, 85.2, 90.9, 93.9, 92.5, 81.5]
Class 1[99.4, 99.2, 95.8, 92.2, 87.0, 71.1, 9.6, 8.0]
Class 2[96.8, 77.2, 40.3, 70.6, 95.2, 85.1, 49.9, 68.9]
Class 3[97.0, 83.1, 59.4, 55.8, 77.1, 34.2, 6.6, 8.8]
Class 4[98.9, 91.2, 46.3, 31.3, 60.3, 40.2, 19.3, 33.6]
Class 5[98.7, 87.8, 63.5, 71.2, 89.1, 52.8, 23.9, 58.2]
Class 6[98.4, 80.5, 20.6, 41.8, 73.9, 61.5, 47.5, 71.4]
Class 7[96.9, 82.8, 41.8, 12.9, 26.2, 48.7, 29.4, 30.6]
Class 8[98.1, 87.8, 64.4, 53.7, 65.3, 39.7, 6.2, 18.5]
Class 9[98.3, 91.3, 44.6, 3.3, 7.0, 8.0, 6.7, 17.8]
Total[98.1, 87.7, 57.4, 52.4, 66.7, 54.3, 29.0, 38.8]
