In [1]:
from imutils import paths
import cv2
import os
from tqdm import tqdm
import numpy as np
from sklearn.preprocessing import LabelEncoder
from collections import Counter
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

import torch.autograd as autograd

In [2]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [3]:
import torch
device = torch.device("cuda")

In [22]:
image_paths = list(paths.list_images('data/Caltech101/001'))
#image_paths = list(paths.list_images('data/cars_side-view'))

In [23]:
data = []
labels = []
for img_path in tqdm(image_paths):
    label = img_path.split(os.path.sep)[-2]
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    data.append(img)
    labels.append(label)
    if len(labels) > 5000:
        break
    
data = np.array(data)
labels = np.array(labels)

100%|██████████| 800/800 [00:00<00:00, 1133.62it/s]
  data = np.array(data)


In [24]:

lb = LabelEncoder()
labels = lb.fit_transform(labels)
print(f"Total Number of Classes: {len(lb.classes_)}")

Total Number of Classes: 1


In [7]:
Counter(labels)

Counter({0: 123})

In [25]:
from sklearn.model_selection import train_test_split
# divide the data into train and test set
(x_train, x_test, y_train, y_test) = train_test_split(data, labels, test_size=0.1, stratify=labels, random_state=42)
print(f"x_train examples: {x_train.shape}\nx_test examples: {x_test.shape}")

x_train examples: (720,)
x_test examples: (80,)


In [26]:
dataset_config = {'size': 64, 'channels': 3, 'classes': 1}

In [27]:
train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((dataset_config['size'], dataset_config['size'])),
    transforms.ToTensor(),
    #transforms.Normalize(mean = [0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

val_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((dataset_config['size'],dataset_config['size'])),
    transforms.ToTensor(),
    #transforms.Normalize(mean = [0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])    

In [11]:
BS = 32
train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=BS, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../data', train=False, transform=transforms.ToTensor()),
    batch_size=BS, shuffle=True)


Files already downloaded and verified


In [28]:
BS = 32
# custom dataset class
class CustomDataset(Dataset):
    def __init__(self, images, labels= None, transforms = None):
        self.labels = labels
        self.images = images
        self.transforms = transforms
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        data = self.images[index][:]
        
        if self.transforms:
            data = self.transforms(data)
            
        
        return (data, self.labels[index])
        
train_data = CustomDataset(x_train, y_train, train_transforms)
test_data = CustomDataset(x_test, y_test, val_transform)       

train_loader = DataLoader(train_data, batch_size=BS, shuffle=True, num_workers=4)
test_loader = DataLoader(test_data, batch_size=BS, shuffle=True, num_workers=4, drop_last=False) 

### --- Main --

In [29]:
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torch.autograd import Variable

from torchvision.utils import save_image

def gradients(y, x):
    return autograd.grad(
                outputs=y, inputs=x, retain_graph=True,
                create_graph=True, grad_outputs=torch.ones_like(y), only_inputs=True)[0]

In [30]:
class VAE_Cifar10(nn.Module):
    def __init__(self, label = 'cifar10', image_size = dataset_config['size'], channel_num = dataset_config['channels'], kernel_num = 128, z_size=128):
        # configurations
        super().__init__()
        self.label = label
        self.image_size = image_size
        self.channel_num = channel_num
        self.kernel_num = kernel_num
        self.z_size = z_size

        # encoder
        self.encoder = nn.Sequential(
            self._conv(channel_num, kernel_num // 4),
            self._conv(kernel_num // 4, kernel_num // 2),
            self._conv(kernel_num // 2, kernel_num),
        )

        # encoded feature's size and volume

        # q
        self.q_mean = self._linear(self.feature_volume, z_size, relu=False)
        self.q_logvar = self._linear(self.feature_volume, z_size, relu=False)

        # projection
        self.project = self._linear(z_size, self.feature_volume, relu=False)

        # decoder
        self.decoder = nn.Sequential(
            self._deconv(kernel_num, kernel_num // 2),
            self._deconv(kernel_num // 2, kernel_num // 4),
            self._deconv(kernel_num // 4, channel_num),
            nn.Sigmoid()
        )


    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std


    def forward(self, x):
        encoded = self.encoder(x)

        # sample latent code z from q given x.
        mean, logvar = self.q(encoded)
        z = self.z(mean, logvar)
        z_projected = self.project(z).view(
            -1, self.kernel_num,
            self.feature_size,
            self.feature_size,
        )

        # reconstruct x from z
        x_reconstructed = self.decoder(z_projected)

        return x_reconstructed, mean, logvar
    
    def q(self, encoded):
        unrolled = encoded.view(-1, self.feature_volume)
        return self.q_mean(unrolled), self.q_logvar(unrolled)

    def z(self, mean, logvar):
        std = logvar.mul(0.5).exp_()
        eps = (Variable(torch.randn(std.size())).to(device))
        
        return mean#eps.mul(std).add_(mean)
    
    
    # ======
    # Layers
    # ======

    def _conv(self, channel_size, kernel_num):
        return nn.Sequential(
            nn.Conv2d(
                channel_size, kernel_num,
                kernel_size=4, stride=2, padding=1,
            ),
            nn.InstanceNorm2d(kernel_num),
            nn.ReLU(),
        )

    def _deconv(self, channel_num, kernel_num):
        return nn.Sequential(
            nn.ConvTranspose2d(
                channel_num, kernel_num,
                kernel_size=4, stride=2, padding=1,
            ),
            nn.InstanceNorm2d(kernel_num),
            nn.ReLU(),
        )

    def _linear(self, in_size, out_size, relu=True):
        return nn.Sequential(
            nn.Linear(in_size, out_size),
            nn.ReLU(),
        ) if relu else nn.Linear(in_size, out_size)

model = VAE_Cifar10().to(device)
optimizer = optim.Adam(model.parameters(), lr=3e-4)

In [32]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    MSE = (recon_x - x.view(-1,  dataset_config['channels'],dataset_config['size'],dataset_config['size'])) ** 2
    MSE = MSE.sum(dim=(-1,-2,-3))

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)

    return MSE + 0.3 * KLD

In [33]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar).mean(dim=0)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 2 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

In [34]:
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).mean(dim=0).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch[:n].view(n,  dataset_config['channels'], dataset_config['size'], dataset_config['size'])[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

In [35]:
for epoch in range(1, 100 + 1):
        train(epoch)
        test(epoch)

====> Epoch: 1 Average loss: 26.3103
====> Test set loss: 23.0203
====> Epoch: 2 Average loss: 19.3380
====> Test set loss: 19.1912
====> Epoch: 3 Average loss: 16.8911
====> Test set loss: 17.7870
====> Epoch: 4 Average loss: 15.8511
====> Test set loss: 17.1570
====> Epoch: 5 Average loss: 15.2922
====> Test set loss: 17.0207
====> Epoch: 6 Average loss: 14.8895
====> Test set loss: 16.5510
====> Epoch: 7 Average loss: 14.6538
====> Test set loss: 16.4849
====> Epoch: 8 Average loss: 14.4522
====> Test set loss: 16.1776
====> Epoch: 9 Average loss: 14.1755
====> Test set loss: 16.1221
====> Epoch: 10 Average loss: 13.9835
====> Test set loss: 15.8670
====> Epoch: 11 Average loss: 13.7936
====> Test set loss: 15.8376
====> Epoch: 12 Average loss: 13.7208
====> Test set loss: 15.7574
====> Epoch: 13 Average loss: 13.6545
====> Test set loss: 15.7556


====> Epoch: 14 Average loss: 13.5296
====> Test set loss: 15.7205
====> Epoch: 15 Average loss: 13.3835
====> Test set loss: 15.3525
====> Epoch: 16 Average loss: 13.2604
====> Test set loss: 15.5147
====> Epoch: 17 Average loss: 13.1805
====> Test set loss: 15.3069
====> Epoch: 18 Average loss: 13.1409
====> Test set loss: 15.2325
====> Epoch: 19 Average loss: 13.0787
====> Test set loss: 15.3141
====> Epoch: 20 Average loss: 13.0296
====> Test set loss: 15.3174
====> Epoch: 21 Average loss: 12.9721
====> Test set loss: 15.1530
====> Epoch: 22 Average loss: 12.9705
====> Test set loss: 15.1267
====> Epoch: 23 Average loss: 12.9023
====> Test set loss: 15.4176
====> Epoch: 24 Average loss: 12.9217
====> Test set loss: 15.1538
====> Epoch: 25 Average loss: 12.8729
====> Test set loss: 15.1319
====> Epoch: 26 Average loss: 12.7699
====> Test set loss: 15.1216


====> Epoch: 27 Average loss: 12.7973
====> Test set loss: 15.2534
====> Epoch: 28 Average loss: 12.7921
====> Test set loss: 15.0995
====> Epoch: 29 Average loss: 12.7493
====> Test set loss: 15.0056
====> Epoch: 30 Average loss: 12.6760
====> Test set loss: 14.9635
====> Epoch: 31 Average loss: 12.7340
====> Test set loss: 15.0953
====> Epoch: 32 Average loss: 12.7749
====> Test set loss: 15.0222
====> Epoch: 33 Average loss: 12.6544
====> Test set loss: 15.0405
====> Epoch: 34 Average loss: 12.6537
====> Test set loss: 15.0470
====> Epoch: 35 Average loss: 12.6340
====> Test set loss: 14.9586
====> Epoch: 36 Average loss: 12.6647
====> Test set loss: 14.9508
====> Epoch: 37 Average loss: 12.5691
====> Test set loss: 15.1627
====> Epoch: 38 Average loss: 12.5482
====> Test set loss: 14.8920
====> Epoch: 39 Average loss: 12.4968
====> Test set loss: 14.9269


====> Epoch: 40 Average loss: 12.4819
====> Test set loss: 14.8220
====> Epoch: 41 Average loss: 12.4523
====> Test set loss: 14.8164
====> Epoch: 42 Average loss: 12.4351
====> Test set loss: 14.8232
====> Epoch: 43 Average loss: 12.4686
====> Test set loss: 14.8174
====> Epoch: 44 Average loss: 12.4870
====> Test set loss: 14.9078
====> Epoch: 45 Average loss: 12.4358
====> Test set loss: 14.7471
====> Epoch: 46 Average loss: 12.3667
====> Test set loss: 14.8787
====> Epoch: 47 Average loss: 12.4135
====> Test set loss: 14.9142
====> Epoch: 48 Average loss: 12.3969
====> Test set loss: 14.7624
====> Epoch: 49 Average loss: 12.3470
====> Test set loss: 14.8666
====> Epoch: 50 Average loss: 12.4278
====> Test set loss: 14.8362
====> Epoch: 51 Average loss: 12.4068
====> Test set loss: 14.8607
====> Epoch: 52 Average loss: 12.3570
====> Test set loss: 14.8121


====> Epoch: 53 Average loss: 12.3853
====> Test set loss: 14.9198
====> Epoch: 54 Average loss: 12.3818
====> Test set loss: 14.9095
====> Epoch: 55 Average loss: 12.3084
====> Test set loss: 14.8212
====> Epoch: 56 Average loss: 12.3246
====> Test set loss: 14.8170
====> Epoch: 57 Average loss: 12.2791
====> Test set loss: 14.7662
====> Epoch: 58 Average loss: 12.2969
====> Test set loss: 14.8737
====> Epoch: 59 Average loss: 12.3176
====> Test set loss: 14.7338
====> Epoch: 60 Average loss: 12.2492
====> Test set loss: 14.7664
====> Epoch: 61 Average loss: 12.2617
====> Test set loss: 14.6775
====> Epoch: 62 Average loss: 12.2162
====> Test set loss: 14.8077
====> Epoch: 63 Average loss: 12.3030
====> Test set loss: 14.8026
====> Epoch: 64 Average loss: 12.2931
====> Test set loss: 14.7991
====> Epoch: 65 Average loss: 12.2327
====> Test set loss: 14.7392


====> Epoch: 66 Average loss: 12.2126
====> Test set loss: 14.6609
====> Epoch: 67 Average loss: 12.1619
====> Test set loss: 14.7165
====> Epoch: 68 Average loss: 12.1632
====> Test set loss: 14.7660
====> Epoch: 69 Average loss: 12.1562
====> Test set loss: 14.8013
====> Epoch: 70 Average loss: 12.2101
====> Test set loss: 14.9263
====> Epoch: 71 Average loss: 12.2718
====> Test set loss: 14.8833
====> Epoch: 72 Average loss: 12.2662
====> Test set loss: 14.7363
====> Epoch: 73 Average loss: 12.1873
====> Test set loss: 14.7119
====> Epoch: 74 Average loss: 12.1422
====> Test set loss: 14.7767
====> Epoch: 75 Average loss: 12.1839
====> Test set loss: 14.6742
====> Epoch: 76 Average loss: 12.1253
====> Test set loss: 14.8557
====> Epoch: 77 Average loss: 12.2713
====> Test set loss: 14.9243
====> Epoch: 78 Average loss: 12.1969
====> Test set loss: 14.7684


====> Epoch: 79 Average loss: 12.1315
====> Test set loss: 14.6998
====> Epoch: 80 Average loss: 12.0989
====> Test set loss: 14.6510
====> Epoch: 81 Average loss: 12.0699
====> Test set loss: 15.0573
====> Epoch: 82 Average loss: 12.2677
====> Test set loss: 14.8240
====> Epoch: 83 Average loss: 12.1483
====> Test set loss: 14.7071
====> Epoch: 84 Average loss: 12.1306
====> Test set loss: 14.9564
====> Epoch: 85 Average loss: 12.2630
====> Test set loss: 14.6736
====> Epoch: 86 Average loss: 12.1250
====> Test set loss: 14.6758
====> Epoch: 87 Average loss: 12.0901
====> Test set loss: 14.8091
====> Epoch: 88 Average loss: 12.1495
====> Test set loss: 14.6966
====> Epoch: 89 Average loss: 12.0847
====> Test set loss: 14.7075
====> Epoch: 90 Average loss: 12.0670
====> Test set loss: 14.6697
====> Epoch: 91 Average loss: 12.1014
====> Test set loss: 14.7447


====> Epoch: 92 Average loss: 12.0866
====> Test set loss: 14.6550
====> Epoch: 93 Average loss: 12.0540
====> Test set loss: 14.6352
====> Epoch: 94 Average loss: 12.0907
====> Test set loss: 14.7323
====> Epoch: 95 Average loss: 12.0512
====> Test set loss: 14.6501
====> Epoch: 96 Average loss: 11.9958
====> Test set loss: 14.8117
====> Epoch: 97 Average loss: 12.0329
====> Test set loss: 14.7298
====> Epoch: 98 Average loss: 12.0450
====> Test set loss: 14.7562
====> Epoch: 99 Average loss: 12.0330
====> Test set loss: 14.7962
====> Epoch: 100 Average loss: 12.0214
====> Test set loss: 14.6819


In [20]:
1048576/128/128

64.0

In [21]:
model.train()
train_loss = 0
for batch_idx,( data, _) in enumerate(train_loader):
    data = data.to(device)
    data = data.reshape([-1,3, 128*128])
    optimizer.zero_grad()
    recon_batch, mu, logvar = model(data)
    loss = loss_function(recon_batch, data, mu, logvar)
    loss.backward()
    train_loss += loss.item()
    optimizer.step()
    if batch_idx % 100 == 0:
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader),
            loss.item() / len(data)))

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [32, 3, 4, 4], but got 3-dimensional input of size [8, 3, 16384] instead