In [1]:
import os
import sys

sys.path.append("../")

In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

from src.ml.sinkhorn import SinkhornValue, sinkhorn

In [3]:
class CIFAR10Instance(torchvision.datasets.CIFAR10):
    """
    https://github.com/yukimasano/self-label/blob/581957c2fcb3f14a0382cf71a3d36b21b9943798/cifar_utils.py#L5
    """
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        super(CIFAR10Instance, self).__init__(root=root,
                                                           train=train,
                                                           transform=transform,
                                                           target_transform=target_transform)


    def __getitem__(self, index):
        image, target = self.data[index], self.targets[index]
        image = Image.fromarray(image)

        if self.transform is not None:
            img = self.transform(image)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index

In [40]:
# Load CIFAR-10
transform_train = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.RandomResizedCrop(size=224, scale=(0.2, 1.)),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ]
)

transform_test = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.RandomResizedCrop(size=224, scale=(0.2, 1.)),
        transforms.ToTensor(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ]
)

trainset = CIFAR10Instance(
    root="../data/cifar-10",
    train=True,
    download=True,
    transform=transform_train
)
testset = CIFAR10Instance(
    root="../data/cifar-10",
    train=False,
    download=True,
    transform=transform_test
)
trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=32,
    shuffle=True,
    num_workers=0
)
testloader = torch.utils.data.DataLoader(
    testset,
    batch_size=100,
    shuffle=True,
    num_workers=0
)

classes = (
    "plane", "car", "bird", "cat",
    "deer", "dog", "frog", "horse", 
    "ship", "truck"
)

## Debug training

In [396]:
N = len(trainloader.dataset)
K=128  # number of clusters

selflabels = np.zeros(N, dtype=np.int32)

for qq in range(N):
    selflabels[qq] = qq % K
    
selflabels = np.random.permutation(selflabels)
selflabels = torch.LongTensor(selflabels)

selflabels_onehot = torch.nn.functional.one_hot(
    selflabels,
    num_classes=K
)

In [411]:
# Load model
model = torchvision.models.alexnet(pretrained=False, num_classes=K)
model.eval()

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [412]:
a = torch.ones(32) / 32
b = torch.ones(K) / K

optimizer = torch.optim.Adam(lr=0.01, params=model.parameters())

epoch_loss = 0

for batch_idx, (inputs, targets, indexes) in enumerate(trainloader):
    break

In [413]:
optimizer.zero_grad()

In [414]:
x = model(inputs)
M = torch.nn.LogSoftmax(dim=1)(x)

In [415]:
SV = SinkhornValue(
   a,
   b,
   epsilon=0.1,
   solver=sinkhorn,
   n_iter=100
)

In [416]:
loss = -SV(M)

In [417]:
print(loss)

tensor(3.9203, grad_fn=<NegBackward0>)


In [418]:
loss.backward()
optimizer.step()

In [419]:
optimizer.zero_grad()

In [420]:
x = model(inputs)
M = torch.nn.LogSoftmax(dim=1)(x)

In [421]:
x

tensor([[ 13895623.0000,  -5892078.5000,  19694676.0000,  ...,
         -17475746.0000, -12377340.0000,  14919158.0000],
        [ 31494662.0000, -13357301.0000,  44623944.0000,  ...,
         -39599620.0000, -28044182.0000,  33815436.0000],
        [ 14710533.0000,  -6246262.5000,  20836468.0000,  ...,
         -18493568.0000, -13093619.0000,  15792885.0000],
        ...,
        [ 19048390.0000,  -8082766.5000,  26983916.0000,  ...,
         -23947218.0000, -16958176.0000,  20450404.0000],
        [  6485756.5000,  -2752489.0000,   9187670.0000,  ...,
          -8153805.5000,  -5773966.0000,   6963179.5000],
        [ 12735177.0000,  -5401888.0000,  18043662.0000,  ...,
         -16011636.0000, -11340112.0000,  13672323.0000]],
       grad_fn=<AddmmBackward0>)