In [1]:
import os
import sys

sys.path.append("../")

In [2]:
import torch
import torchvision
from torch import nn
from torch.autograd import Function
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import ot

from src.ml.sinkhorn import SinkhornValue, sinkhorn, pot_sinkhorn

In [3]:
batch_size=64

In [4]:
class CIFAR10Instance(torchvision.datasets.CIFAR10):
    """
    https://github.com/yukimasano/self-label/blob/581957c2fcb3f14a0382cf71a3d36b21b9943798/cifar_utils.py#L5
    """
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(CIFAR10Instance, self).__init__(root=root,
                                              train=train,
                                              transform=transform,
                                              target_transform=target_transform)


    def __getitem__(self, index):
        image, target = self.data[index], self.targets[index]
        image = Image.fromarray(image)

        if self.transform is not None:
            img = self.transform(image)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index

In [5]:
# Load CIFAR-10
transform_train = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.RandomResizedCrop(size=224, scale=(0.2, 1.)),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ]
)

transform_test = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.RandomResizedCrop(size=224, scale=(0.2, 1.)),
        transforms.ToTensor(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ]
)

trainset = CIFAR10Instance(
    root="../data/cifar-10",
    train=True,
    download=True,
    transform=transform_train
)
testset = CIFAR10Instance(
    root="../data/cifar-10",
    train=False,
    download=True,
    transform=transform_test
)
trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0
)
testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=100,
    shuffle=True,
    num_workers=0
)

classes = (
    "plane", "car", "bird", "cat",
    "deer", "dog", "frog", "horse", 
    "ship", "truck"
)

In [8]:
N = len(trainloader.dataset)
K = 128  # number of clusters

selflabels = np.zeros(N, dtype=np.int32)

for qq in range(N):
    selflabels[qq] = qq % K
    
selflabels = np.random.permutation(selflabels)
selflabels = torch.LongTensor(selflabels)

In [None]:
# Load Alexnet model, with output size = K (128)
model = torchvision.models.alexnet(pretrained=False, num_classes=K)

# ADAM optimizer
optimizer = torch.optim.Adam(lr=0.001, params=model.parameters())

for epoch in range(30):
    epoch_loss = 0
    
    # loop over minibatches
    for batch_idx, (inputs, targets, indexes) in enumerate(trainloader):
        # marginals
        a = torch.ones(inputs.shape[0]) / inputs.shape[0]  # minibatch size
        b = torch.ones(K) / K                              # K clusters (128)

        # train mode
        model.train()

        # set gradients to zero
        optimizer.zero_grad()

        # compute inputs (images) representations
        x = model(inputs)
        P = torch.nn.LogSoftmax(dim=1)(x)

        # compute softmax probabilities over each cluster ()
        M = P # - np.log(inputs.shape[0])

        # init Sinkhorn loss
        SV = SinkhornValue(
           a,
           b,
           epsilon=0.1,
           solver=pot_sinkhorn,
           n_iter=10
        )

        # compute Sinkhorn loss
        loss = -SV(M)
        
        if torch.isnan(loss):
            raise Exception()

        # compute gradients
        loss.backward()

        # backpropagation
        optimizer.step()

        epoch_loss += loss.item()
        
        break
        
    print(epoch_loss / (batch_idx+1))



3.850905418395996




-0.004397315438836813
-0.00782102532684803
-0.00782102532684803


In [166]:
M = P - np.log(inputs.shape[0])  # add - M minimum, rescaling
v = torch.ones_like(b)
K = torch.exp(M / 0.1)

# for _ in range(10):
    # u = a / torch.matmul(K, v)
    # v = b / torch.matmul(torch.transpose(K, 1, 0), u)

# uv = torch.outer(u, v)
# K * u

In [167]:
M.min().item(), M.max().item()

(-233.62417602539062, -4.158905506134033)

In [168]:
K.min().item(), K.max().item()

(0.0, 8.671681260262779e-19)

In [169]:
u = a / torch.matmul(K, v)

In [170]:
u

tensor([6.4392e+20, 1.4059e+17, 3.8684e+16, 5.8867e+20, 2.0348e+16, 1.0551e+18,
        1.9243e+16, 2.2486e+16, 2.7568e+16, 2.0296e+16, 1.4485e+18, 9.1779e+16,
        7.7896e+18, 1.4599e+20, 2.5163e+17, 2.1787e+18, 4.6286e+16, 6.9416e+16,
        6.1957e+16, 1.8018e+16, 1.8074e+16, 1.7511e+18, 8.7680e+17, 4.0653e+18,
        8.5586e+18, 1.8030e+16, 1.8301e+16, 1.7520e+17, 2.0260e+16, 2.1542e+16,
        1.5410e+17, 6.6467e+16, 1.1717e+17, 1.7143e+18, 6.7043e+18, 6.5340e+18,
        2.2353e+16, 1.7291e+18, 4.3321e+18, 2.0772e+16, 1.8394e+16, 1.8640e+20,
        2.9320e+16, 1.8395e+16, 3.7350e+17, 3.4864e+16, 2.0123e+16, 5.7577e+16,
        3.8316e+16, 2.0196e+16, 1.9852e+16, 5.2017e+17, 3.7741e+16, 1.1059e+18,
        3.3427e+16, 5.2656e+16, 2.3015e+18, 6.6982e+17, 6.1382e+16, 1.1024e+19,
        3.1612e+18, 5.3658e+17, 2.6773e+16, 6.9894e+18],
       grad_fn=<DivBackward0>)

In [171]:
torch.matmul(torch.transpose(K, 1, 0), u)

tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 3.8750e-21, 6.4484e-02, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.1088e-07,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 6.2893e-02, 0.0000e+00,
        1.5900e-06, 4.4297e-27, 0.0000e+00, 6.6329e-17, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 7.4159e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+

In [172]:
b / torch.matmul(torch.transpose(K, 1, 0), u)

tensor([       inf,        inf,        inf,        inf,        inf,        inf,
               inf,        inf, 2.0161e+18, 1.2115e-01,        inf,        inf,
               inf,        inf,        inf,        inf,        inf, 7.0462e+04,
               inf,        inf,        inf,        inf,        inf,        inf,
               inf,        inf,        inf,        inf, 1.2422e-01,        inf,
        4.9134e+03, 1.7636e+24,        inf, 1.1778e+14,        inf,        inf,
               inf,        inf,        inf,        inf,        inf,        inf,
               inf,        inf,        inf,        inf,        inf,        inf,
               inf,        inf,        inf,        inf,        inf,        inf,
               inf, 1.0535e-02,        inf,        inf,        inf,        inf,
               inf,        inf,        inf,        inf,        inf,        inf,
               inf,        inf,        inf,        inf,        inf,        inf,
               inf,        inf,        i