In [1]:
from model import TMapper, PhiMapper, OmegaMapper, UnbalancedLoss
import torch
from collections import Counter

In [2]:
from model import Hellinger, Jensen_Shannon, KL_dual, Pearson_xi

In [3]:
import torch.nn.functional as F
import torchvision
from torchvision.datasets import MNIST
from torch.utils.data import Dataset, DataLoader

In [50]:
def cost_matrix(x:torch.Tensor, y:torch.Tensor) -> torch.Tensor:
    return torch.norm(x[:, None] - y, dim=2, p=2)

def mass_variation(s):
    return (s - 1)**2

In [51]:
T = TMapper(784, 784)
Xi = PhiMapper(784, 1, hidden_dims=1024)
f = OmegaMapper(784, 1)
loss = UnbalancedLoss(10, 5, cost_matrix, mass_variation, KL_dual)
w_optim = torch.optim.SGD(f.parameters(), lr=1e-2)
t_optim = torch.optim.SGD(T.parameters(), lr=1e-2)
xi_optim = torch.optim.SGD(Xi.parameters(), lr=1e-2)

In [15]:
sampler1 = torch.utils.data.SubsetRandomSampler(torch.arange(30000))
sampler2 = torch.utils.data.SubsetRandomSampler(torch.arange(30000, 60000))

In [16]:
def target_transform(x):
    return x.type(torch.float32).view(1)

In [17]:
class ReshapeTransform:
    def __init__(self, new_size):
        self.new_size = new_size

    def __call__(self, img):
        return torch.reshape(img, self.new_size)

In [52]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    ReshapeTransform((-1, ))
])

dataset = MNIST(root='.', download=True, train=True, transform=transform, target_transform=target_transform)

dl_1 = DataLoader(dataset, batch_size=500, sampler=sampler1)
dl_2 = DataLoader(dataset, batch_size=500, sampler=sampler2)

In [53]:
def train_loop(data_loader_1: torch.utils.data.DataLoader, data_loader_2: torch.utils.data.DataLoader):
    for (X, Z_1), (Y, Z) in zip(data_loader_1, data_loader_2):
        
        w_optim.zero_grad()
        t_optim.zero_grad()
        xi_optim.zero_grad()
        print(X.shape)
        print(T(X, Z).shape)
        loss_value = -loss.compute(X, Z, Y, T(X, Z), Xi(X), f(Y), f(T(X, Z)))
        print('-loss value {}'.format(loss_value))
        loss_value.backward()
        w_optim.step()
        
        t_optim.zero_grad()
        xi_optim.zero_grad()
        loss_value = loss.compute(X, Z, Y, T(X, Z), Xi(X), f(Y), f(T(X, Z)))
        print('loss value {}'.format(loss_value))
        loss_value.backward()
        t_optim.step()
        xi_optim.step()

In [58]:
train_loop(dl_1, dl_2)

torch.Size([500, 784])
torch.Size([500, 784])
-loss value -4.725185394287109
loss value 5.578133583068848
torch.Size([500, 784])
torch.Size([500, 784])
-loss value 124.43385314941406
loss value -87.38523864746094
torch.Size([500, 784])
torch.Size([500, 784])
-loss value 29.259660720825195
loss value -28.185707092285156
torch.Size([500, 784])
torch.Size([500, 784])
-loss value 88.3505630493164
loss value -79.99698638916016
torch.Size([500, 784])
torch.Size([500, 784])
-loss value 300.013427734375
loss value 344.48284912109375
torch.Size([500, 784])
torch.Size([500, 784])
-loss value 0.6387434601783752
loss value -0.3473738729953766
torch.Size([500, 784])
torch.Size([500, 784])
-loss value 23.71762466430664
loss value -20.48738670349121
torch.Size([500, 784])
torch.Size([500, 784])
-loss value 36.299949645996094
loss value -26.184133529663086
torch.Size([500, 784])
torch.Size([500, 784])
-loss value 32.75897216796875
loss value -2.0265541076660156
torch.Size([500, 784])
torch.Size([500, 