In [None]:
import os
import sys

sys.path.append("../")

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

from src.ml.sinkhorn import SinkhornValue, sinkhorn

In [5]:
class CIFAR10Instance(torchvision.datasets.CIFAR10):
    """
    https://github.com/yukimasano/self-label/blob/581957c2fcb3f14a0382cf71a3d36b21b9943798/cifar_utils.py#L5
    """
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(CIFAR10Instance, self).__init__(root=root,
                                                           train=train,
                                                           transform=transform,
                                                           target_transform=target_transform)


    def __getitem__(self, index):
        image, target = self.data[index], self.targets[index]
        image = Image.fromarray(image)

        if self.transform is not None:
            img = self.transform(image)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index

In [22]:
# Load CIFAR-10
transform_train = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.RandomResizedCrop(size=224, scale=(0.2, 1.)),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ]
)

transform_test = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.RandomResizedCrop(size=224, scale=(0.2, 1.)),
        transforms.ToTensor(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ]
)

trainset = CIFAR10Instance(
    root="../data/cifar-10",
    train=True,
    download=True,
    transform=transform_train
)
testset = CIFAR10Instance(
    root="../data/cifar-10",
    train=False,
    download=True,
    transform=transform_test
)
trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=16,
    shuffle=True,
    num_workers=0
)
testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=100,
    shuffle=True,
    num_workers=0
)

classes = (
    "plane", "car", "bird", "cat",
    "deer", "dog", "frog", "horse", 
    "ship", "truck"
)

In [71]:
N = len(trainloader.dataset)
K=128  # number of clusters

selflabels = np.zeros(N, dtype=np.int32)

for qq in range(N):
    selflabels[qq] = qq % K
    
selflabels = np.random.permutation(selflabels)
selflabels = torch.LongTensor(selflabels)

selflabels_onehot = torch.nn.functional.one_hot(
    selflabels,
    num_classes=K
)

In [86]:
# Load model
model = torchvision.models.alexnet(pretrained=False, num_classes=K)
model.eval()

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [84]:
a.shape, b.shape, M.shape

(torch.Size([16]), torch.Size([128]), torch.Size([16, 128]))

In [89]:
x = model(inputs)
M = torch.nn.LogSoftmax(dim=1)(x)

In [91]:
torch.exp(M)

tensor([[0.0079, 0.0079, 0.0077,  ..., 0.0077, 0.0078, 0.0078],
        [0.0079, 0.0079, 0.0077,  ..., 0.0077, 0.0078, 0.0078],
        [0.0079, 0.0079, 0.0077,  ..., 0.0077, 0.0078, 0.0078],
        ...,
        [0.0079, 0.0079, 0.0077,  ..., 0.0077, 0.0078, 0.0078],
        [0.0079, 0.0079, 0.0077,  ..., 0.0078, 0.0078, 0.0078],
        [0.0079, 0.0079, 0.0077,  ..., 0.0077, 0.0078, 0.0078]],
       grad_fn=<ExpBackward0>)

In [92]:
a = torch.ones(16) / 16
b = torch.ones(K) / K

optimizer = torch.optim.Adam(lr=0.01, params=model.parameters())

epoch_loss = 0

for batch_idx, (inputs, targets, indexes) in enumerate(trainloader):
    optimizer.zero_grad()
    
    x = model(inputs)
    M = torch.nn.LogSoftmax(dim=1)(x)

    SV = SinkhornValue(
        a,
        b,
        epsilon=0.1,
        solver=sinkhorn,
        n_iter=10
    )
    
    loss = -SV(M)
    epoch_loss += loss.item()
    
loss.backward()
optimizer.step()

print(loss.item())

3.989607572555542
nan
nan
nan


KeyboardInterrupt: 

In [78]:
M

tensor([[-4.8345, -4.8602, -4.8483,  ..., -4.8524, -4.8521, -4.8587],
        [-4.8326, -4.8572, -4.8485,  ..., -4.8526, -4.8511, -4.8594],
        [-4.8338, -4.8586, -4.8498,  ..., -4.8517, -4.8499, -4.8615],
        ...,
        [-4.8322, -4.8589, -4.8485,  ..., -4.8483, -4.8515, -4.8597],
        [-4.8314, -4.8591, -4.8488,  ..., -4.8511, -4.8529, -4.8586],
        [-4.8324, -4.8577, -4.8478,  ..., -4.8527, -4.8545, -4.8592]],
       grad_fn=<LogSoftmaxBackward0>)