In [1]:
from model import TMapper, PhiMapper, OmegaMapper, UnbalancedLoss, UnbalancedSampler
import torch
from collections import Counter
from tqdm import tqdm_notebook
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
from model import Hellinger_dual, Jensen_Shannon_dual, KL_dual, Pearson_xi_dual

In [3]:
import torch.nn.functional as F
import torchvision
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader

In [4]:
dataset1 = torch.utils.data.TensorDataset(*torch.load('./unbalanced_mnist_1'))
dataset2 = torch.utils.data.TensorDataset(*torch.load('./unbalanced_mnist_2'))

In [5]:
dl_1 = torch.utils.data.DataLoader(dataset1, batch_size=800)
dl_2 = torch.utils.data.DataLoader(dataset2, batch_size=800)

In [6]:
Counter(dataset1.tensors[1].tolist())

Counter({3: 3994,
         0: 4007,
         5: 1983,
         1: 3918,
         8: 1965,
         6: 2033,
         4: 4060,
         2: 4051,
         9: 1961,
         7: 2028})

In [7]:
Counter(dataset2.tensors[1].tolist())

Counter({5: 4077,
         7: 3969,
         6: 3895,
         9: 4080,
         8: 3941,
         2: 2045,
         1: 1966,
         3: 2007,
         4: 2010,
         0: 2010})

In [8]:
def cost_matrix(x:torch.Tensor, y:torch.Tensor) -> torch.Tensor:
    return torch.norm(x[:, None] - y, dim=2, p=2)**2

def mass_variation(s):
    return (s - 1)**2

In [9]:
cuda=True

In [10]:
in_dim = 64
out_dim = 64

if cuda:
    T = TMapper(in_dim, out_dim, hidden_dims=128).float().cuda()
    Xi = PhiMapper(in_dim, 1, hidden_dims=128).float().cuda()
    f = OmegaMapper(out_dim, 1, hidden_dims=128).float().cuda()
else:
    T = TMapper(in_dim, out_dim, hidden_dims=128).float()
    Xi = PhiMapper(in_dim, 1, hidden_dims=128).float()
    f = OmegaMapper(out_dim, 1, hidden_dims=128).float()

loss = UnbalancedLoss(1./30000, 1./30000, cost_matrix, mass_variation, Pearson_xi_dual)
w_optim = torch.optim.Adam(f.parameters(), lr=1e-4)
t_optim = torch.optim.Adam(T.parameters(), lr=1e-4)
xi_optim = torch.optim.Adam(Xi.parameters(), lr=1e-4)

In [11]:
def train_loop(data_loader_1: torch.utils.data.DataLoader, data_loader_2: torch.utils.data.DataLoader, cuda=cuda):
    loss_value_1 = []
    loss_value_2 = []
    for i in tqdm_notebook(range(1000)):
        for (X, Z_1), (Y, Z) in zip(data_loader_1, data_loader_2):

            if cuda:
                X = X.float().cuda()
              #  print(X.size())
                Y = Y.float().cuda()
              #  print(Y.size())
                Z = Z.float().cuda()
               # print(Z.size())
                
            else:
                X = X.float()
                Y = Y.float()
                Z = Z.float()
                
            w_optim.zero_grad()
            t_optim.zero_grad()
            xi_optim.zero_grad()
            
            T_output = T(X, Z.view(-1, 1))
            Xi_output = Xi(X)
            f_y = f(Y)
            f_T = f(T_output)
            
            if i % 10 == 0:
                loss_value = -loss.compute(X, Z, Y, T_output, Xi_output, f_y, f_T, verbose=True)
            else:
                loss_value = -loss.compute(X, Z, Y, T_output, Xi_output, f_y, f_T)
            loss_value_1.append(loss_value.item())
        #    print('-loss value {}'.format(loss_value))
            loss_value.backward(retain_graph=True)
            w_optim.step()

            t_optim.zero_grad()
            xi_optim.zero_grad()
            
            f_y = f(Y)
            f_T = f(T_output)
            
            loss_value = loss.compute(X, Z, Y, T_output, Xi_output, f_y, f_T)
            loss_value_2.append(loss_value.item())
         #   print('loss value {}'.format(loss_value))
            loss_value.backward()
            t_optim.step()
            xi_optim.step()            
        if i % 10 == 0:
            print('epoch {}'.format(i))
            print(loss_value_1[-1])
            print(loss_value_2[-1])
    return loss_value_1, loss_value_2

In [None]:
lv1, lv2 = train_loop(dl_1, dl_2)

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

Mass variation output 0.09830570220947266
Distance function output 3.745424509048462
Dual function output -0.04106717184185982
Terms:
First 8.576004620408639e-05
Second 3.2768566597951576e-06
Third -1.204611635330366e-06
Fourth 1.3689056004295708e-06
Mass variation output 0.09869853407144547
Distance function output 3.6780593395233154
Dual function output -0.04248746857047081
Terms:
First 8.414517651544884e-05
Second 3.2899511097639333e-06
Third -1.2174718904134352e-06
Fourth 1.4162490060698474e-06
Mass variation output 0.09922311455011368
Distance function output 3.7383015155792236
Dual function output -0.043684642761945724
Terms:
First 8.54107056511566e-05
Second 3.307437282273895e-06
Third -1.2315969115661574e-06
Fourth 1.4561547914127004e-06
Mass variation output 0.0998481810092926
Distance function output 3.6870553493499756
Dual function output -0.04527638480067253
Terms:
First 8.410306327277794e-05
Second 3.3282726690231357e-06
Third -1.246581405212055e-06
Fourth 1.50921289332472

Mass variation output 0.12411953508853912
Distance function output 3.803638219833374
Dual function output -0.11749126017093658
Terms:
First 8.185083424905315e-05
Second 4.137317773711402e-06
Third -1.886387622107577e-06
Fourth 3.916375135304406e-06
Mass variation output 0.12488511949777603
Distance function output 3.681556224822998
Dual function output -0.1195843443274498
Terms:
First 7.906479731900617e-05
Second 4.1628372855484486e-06
Third -1.9100302779406775e-06
Fourth 3.986144747614162e-06
Mass variation output 0.12635073065757751
Distance function output 3.9178667068481445
Dual function output -0.12284586578607559
Terms:
First 8.386422268813476e-05
Second 4.211690793454181e-06
Third -1.930765165525372e-06
Fourth 4.094862106285291e-06
epoch 0
-9.024000610224903e-05
9.030706860357895e-05


In [None]:
fig, axes = plt.subplots(nrows=10, ncols=2)
transformed = T(torch.cat((dataset1[:10][0],
                       dataset1[:10][1].float().view(-1, 1)), dim=1).cuda()).detach().cpu()
for i, ax in enumerate(axes):
    ax[0].imshow(dataset1[i][0].view(8, 8))
    ax[1].imshow(transformed[i].view(8, 8))

In [None]:
lv1 = np.array(lv1)
lv2 = np.array(lv2)

In [None]:
plt.plot(lv1)
plt.plot(lv2)