In [1]:
%cd ..

f:\DS Lab\OT\ot-kpgg-fc


In [2]:
# Python ≥3.5 is required
import sys
assert sys.version_info >= (3, 5)

# Disable warnings
import warnings
warnings.filterwarnings('ignore')

# Scikit-Learn ≥0.20 is required
import sklearn
assert sklearn.__version__ >= "0.20"

# Common imports
import json
import pandas as pd
import numpy as np
import scipy
import os
from collections import OrderedDict

# OT
import ot
from optimal_transport.models import KeypointFOT, FOT, LOT, EMD
from typing import Tuple, Optional, List, Union, Dict

# Torch imports
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.utils.model_zoo as model_zoo
import torch.nn as nn

# To make this notebook's output stable across runs
np.random.seed(42)

# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)
import seaborn as sns

In [3]:
#@title Load MNIST
def mnist(
    root_dir="datasets", n_samples=1000,
    transform=transforms.Compose([transforms.ToTensor()]), seed=5
):
    torch.manual_seed(5)

    train_dataset = datasets.MNIST(root=os.path.join(root_dir, "mnist"), train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=n_samples, shuffle=True)
    test_dataset = datasets.MNIST(root=os.path.join(root_dir, "mnist"), train=False, download=True, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=2*n_samples, shuffle=False)

    return train_loader, test_loader

mnist_train_loader, mnist_test_loader = mnist(n_samples=1000)
mnist_X_train, mnist_y_train = next(iter(mnist_train_loader))
mnist_X_test, mnist_y_test = next(iter(mnist_test_loader))

In [4]:
#@title Load USPS
def usps(
    root_dir="datasets", n_samples=1000,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Pad(6)]), seed=5,
):
    torch.manual_seed(5)

    train_dataset = datasets.USPS(root=os.path.join(root_dir, "usps"), train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=n_samples, shuffle=True)
    test_data = datasets.USPS(root=os.path.join(root_dir, "usps"), train=False, download=True, transform=transform)
    test_loader = DataLoader(test_data, batch_size=2*n_samples, shuffle=False)

    return train_loader, test_loader

usps_train_loader, usps_test_loader = usps(n_samples=1000)
usps_X_train, usps_y_train = next(iter(usps_train_loader))
usps_X_test, usps_y_test = next(iter(usps_test_loader))

In [5]:
#@title Load pretrained
model_urls = {
    'mnist': 'http://ml.cs.tsinghua.edu.cn/~chenxi/pytorch-models/mnist-b07bb66b.pth'
}

class MLP(nn.Module):
    def __init__(self, input_dims, n_hiddens, n_class):
        super(MLP, self).__init__()
        assert isinstance(input_dims, int), 'Please provide int for input_dims'
        self.input_dims = input_dims
        current_dims = input_dims
        layers = OrderedDict()

        if isinstance(n_hiddens, int):
            n_hiddens = [n_hiddens]
        else:
            n_hiddens = list(n_hiddens)
        for i, n_hidden in enumerate(n_hiddens):
            layers['fc{}'.format(i+1)] = nn.Linear(current_dims, n_hidden)
            layers['relu{}'.format(i+1)] = nn.ReLU()
            layers['drop{}'.format(i+1)] = nn.Dropout(0.2)
            current_dims = n_hidden
        layers['out'] = nn.Linear(current_dims, n_class)

        self.model= nn.Sequential(layers)

    def forward(self, input):
        input = input.view(input.size(0), -1)
        assert input.size(1) == self.input_dims
        return self.model.forward(input)

def pretrain_mnist(input_dims=784, n_hiddens=[256, 256], n_class=10, pretrained=None):
    model = MLP(input_dims, n_hiddens, n_class)
    if pretrained is not None:
        m = model_zoo.load_url(model_urls['mnist'],map_location=torch.device('cpu'))
        state_dict = m.state_dict() if isinstance(m, nn.Module) else m
        assert isinstance(state_dict, (dict, OrderedDict)), type(state_dict)
        model.load_state_dict(state_dict)
    return model

m = pretrain_mnist(pretrained=True).eval()

In [6]:
#@title Metrics
def accuracy(y_hat, y):
    y_pred = np.argmax(y_hat, axis=1)
    acc = (y_pred == y).sum() / y.shape[0]
    return acc

In [7]:
#@title Keypoints
def query_keypoints(X, y, keypoints_per_cls=1):
    def euclidean(source, target, p=2):
        return np.sum(
            np.power(
                source.reshape([source.shape[0], 1, source.shape[1]]) -
                target.reshape([1, target.shape[0], target.shape[1]]),
                p
            ),
            axis=-1
        ) ** 1/2
    labels = np.unique(y)
    selected_inds = []
    for label in labels:
        cls_indices = np.where(y == label)[0]
        distance = euclidean(X[cls_indices], np.mean(X[cls_indices], axis=0)[None, :]).squeeze()
        selected_inds.extend(cls_indices[np.argsort(distance)[:keypoints_per_cls]])
    return selected_inds

In [8]:
#@title Before mapping
print("--- MNIST ---")
print('Train accuracy:', accuracy(m(mnist_X_train).detach().numpy(), mnist_y_train.numpy()))
print('Test accuracy :', accuracy(m(mnist_X_test).detach().numpy(), mnist_y_test.numpy()))
print("--- USPS ---")
print('Train accuracy:', accuracy(m(usps_X_train).detach().numpy(), usps_y_train.numpy()))
print('Test accuracy :', accuracy(m(usps_X_test).detach().numpy(), usps_y_test.numpy()))

--- MNIST ---
Train accuracy: 0.993
Test accuracy : 0.9705
--- USPS ---
Train accuracy: 0.869
Test accuracy : 0.802


In [9]:
#@title Project samples into logit space
mnist_train_logits = np.array(m(mnist_X_train).detach())
usps_test_logits = np.array(m(usps_X_test).detach())
mnist_train_logits.shape

(1000, 10)

In [10]:
#@title Extract candidate keypoints
mnist_keypoints = query_keypoints(mnist_train_logits, mnist_y_train.numpy())
usps_keypoints = query_keypoints(usps_test_logits, usps_y_test.numpy())
K = [(usps_keypoints[i], mnist_keypoints[i]) for i in range(len(mnist_keypoints))]
K

[(1843, 201),
 (1747, 953),
 (240, 439),
 (1278, 607),
 (1464, 137),
 (351, 904),
 (1822, 314),
 (196, 334),
 (717, 581),
 (1438, 479)]

In [11]:
#@title Domain adaptation
n_anchors = 10
model = {
    "KeypointFOT": KeypointFOT(mnist_y_train, n_free_anchors=n_anchors, alpha=0.5, stop_thr=1e-5,
                               sinkhorn_reg=0.001, temperature=0.1, div_term=1e-20, max_iters=200, n_clusters = 10),
    "FOT": FOT(n_anchors=n_anchors, sinkhorn_reg=0.1),
    "LOT": LOT(None, n_source_anchors=n_anchors, n_target_anchors=n_anchors, epsilon=10, epsilon_z=10),
    "OT": EMD(),
}

exp_name = "domain_adaptation"
record_ = {}
record_[exp_name] = {model_id: {"accuracy": []} for model_id in model}

n = usps_test_logits.shape[0]
n_ = mnist_train_logits.shape[0]
for model_id in model:
    model[model_id].fit(usps_test_logits, mnist_train_logits,
                        a=1/n*np.ones(n), b=1/n_*np.ones(n_), K=K)
    transported_logits = model[model_id].transport(usps_test_logits, mnist_train_logits)

    record_[exp_name][model_id]["accuracy"].append(accuracy(transported_logits, usps_y_test.numpy()))
    score = record_[exp_name][model_id]["accuracy"][0]
    print(f">> [{model_id}] Acc: {score}")

Threshold reached at iteration 1
>> [KeypointFOT] Acc: 0.8255
>> [FOT] Acc: 0.762
>> [LOT] Acc: 0.811
>> [OT] Acc: 0.7655
