In [1]:
import collections
import datetime
import json
import math
import os
import random
import copy
import time
import tempfile
import subprocess
import torch
import torch.utils.data
import torchvision.transforms
import numpy as np
import visdom
import scipy
import einops
import json
from geomloss import SamplesLoss
from torch import nn, optim
from collections import defaultdict
from torch.optim.optimizer import Optimizer, required
from math import sqrt
from functools import partial, lru_cache
from torch.nn import functional as F
from torch.nn import Parameter
from PIL import Image
import matplotlib.pyplot as plt

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [2]:
#dataset defined by my self
class WarriorDataset(torch.utils.data.Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.image_names = os.listdir(folder_path)
        if transform:
            self.transform = torchvision.transforms.Compose([
                torchvision.transforms.Resize((70, 70)),
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0, 0, 0), (1, 1, 1))
            ])
    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, index):
        image_name = self.image_names[index]
        image_path = os.path.join(self.folder_path, image_name)
        image = Image.open(image_path)
        image = image.convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image
    
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

In [3]:
folder_path = os.getcwd() + "/Pictures/Warrior"
dataset = WarriorDataset(folder_path, transform=True)
train_loader = torch.utils.data.DataLoader(
    dataset,
shuffle=True, batch_size=16, drop_last=True)
train_iterator = iter(cycle(train_loader))

print(dataset[50])

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])


In [4]:
args = {
    'width': 32,
    'dataset': 'easy_worrior',
    'n_channels': 3,
    'n_classes': 10,
    'batch_size': 16,
    'vid_batch': 16,
    'latent_dim': 8,  # lower is better modelling but worst interpolation freedom
    'lr': 0.005,
    'log_every': 2
}

In [5]:
if args['dataset'] == 'easy_worrior': # Test case
    xb = next(train_iterator)
    xb = xb.to(device)
else:
    xb,cb = next(train_iterator)
    xb,cb = xb.to(device), cb.to(device)



In [6]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, n_channels):
        super(Decoder, self).__init__()

        self.decoder = nn.Sequential(
            nn.LazyConvTranspose2d(512, 4, stride=1, padding=0),
            nn.LazyBatchNorm2d(),
            nn.ReLU(),  # Output: [512, 4, 4]

            nn.LazyConvTranspose2d(256, 4, stride=2, padding=1),
            nn.LazyBatchNorm2d(),
            nn.ReLU(),  # Output: [256, 8, 8]

            nn.LazyConvTranspose2d(128, 4, stride=2, padding=1),
            nn.LazyBatchNorm2d(),
            nn.ReLU(),  # Output: [128, 16, 16]

            nn.LazyConvTranspose2d(64, 4, stride=2, padding=1),
            nn.LazyBatchNorm2d(),
            nn.ReLU(),  # Output: [64, 32, 32]

            nn.LazyConvTranspose2d(32, 4, stride=2, padding=1),
            nn.LazyBatchNorm2d(),
            nn.ReLU(),  # Output: [32, 64, 64]

            nn.LazyConvTranspose2d(n_channels, 4, stride=2, padding=1),
            nn.Sigmoid()  # Output: [3, 128, 128]
        )

    def forward(self, z):
        x = self.decoder(z)
        # Crop from [3, 128, 128] to [3, 69, 44]
        x = x[:, :, :69, :44]

        # change n_channels above to match number of colours

        # 0,0,0,0,1,0,0,0,0,
        # softmax(0.3, -0.2, 1.5, ...) -> 0.001, 0.0000, 0.9, 0.00

        # # when implementing the softmax, you'll need to remove the sigmoid then do:
        # may need to view x to make this work
        # x = torch.softmax(x, dim=1)

        # to test your softmax code is working, do
        # x.sum(dim=1) , check this is all 1's
        # inspect say x[0, :, 30, 22] # make sure it looks like a PMF
        x = x[:, :, :69, :44]

        # Change n_channels above to match the number of colors

        # Apply softmax to convert the output to a probability distribution
        x = x.view(x.size(0), -1)  # Reshape x to [batch_size, num_features]
        x = torch.softmax(x, dim=1)
        x = x.view(x.size(0), n_channels, 69, 44)  # Reshape back to [batch_size, n_channels, 69, 44]

        # Test the softmax code
        print(x.sum(dim=1))  # Check that the sum is 1 for each channel
        print(x[0, :, 30, 22])  # Inspect the probability mass function (PMF) for a specific pixel
        return x

net = Decoder(args['latent_dim'], args['n_channels']).to(device)
opt = torch.optim.Adam(net.parameters(), lr=args['lr'])
ot_loss_fn = SamplesLoss("sinkhorn", p=2, blur=0.001)



In [7]:
def ot_loss(x, y):
    return ot_loss_fn(x.view(x.size(0), -1), y.view(y.size(0), -1))

In [8]:
logs = {}
logs['loss1'] = logs['loss2'] = logs['loss3'] = 0
logs['num_stats'] = 0

opt.zero_grad()

# p(x | z)
# p(x | z, p)

p_z = torch.randn(args['batch_size'], args['latent_dim'], 1, 1).to(device)
g = net(p_z)

loss = ot_loss(g, xb)  # ((g-xb)**2).mean()
loss.backward()
opt.step()

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [None]:
xb.size(3)