In [1]:
from imutils import paths
import cv2
import os
from tqdm import tqdm
import numpy as np
from sklearn.preprocessing import LabelEncoder
from collections import Counter
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

import torch.autograd as autograd

In [2]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [3]:
import torch
device = torch.device("cuda")

In [4]:
# image_paths = list(paths.list_images('data/Caltech101/001'))
# image_paths = list(paths.list_images('data/cars_side-view'))
image_paths = list(paths.list_images('data/Caltech101/016'))

In [5]:
data = []
labels = []
for img_path in tqdm(image_paths):
    label = img_path.split(os.path.sep)[-2]
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    data.append(img)
    labels.append(label)
    if len(labels) > 5000:
        break
    
data = np.array(data)
labels = np.array(labels)

100%|██████████| 123/123 [00:00<00:00, 2668.46it/s]


In [6]:

lb = LabelEncoder()
labels = lb.fit_transform(labels)
print(f"Total Number of Classes: {len(lb.classes_)}")

Total Number of Classes: 1


In [7]:
Counter(labels)

Counter({0: 123})

In [8]:
from sklearn.model_selection import train_test_split
# divide the data into train and test set
(x_train, x_test, y_train, y_test) = train_test_split(data, labels, test_size=0.1, stratify=labels, random_state=42)
print(f"x_train examples: {x_train.shape}\nx_test examples: {x_test.shape}")

x_train examples: (110, 197, 300, 3)
x_test examples: (13, 197, 300, 3)


In [9]:
dataset_config = {'size': 64, 'channels': 3, 'classes': 1}

In [10]:
train_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((dataset_config['size'], dataset_config['size'])),
    transforms.ToTensor(),
    #transforms.Normalize(mean = [0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

val_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((dataset_config['size'],dataset_config['size'])),
    transforms.ToTensor(),
    #transforms.Normalize(mean = [0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])    

In [11]:
BS = 32
train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=BS, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../data', train=False, transform=transforms.ToTensor()),
    batch_size=BS, shuffle=True)


Files already downloaded and verified


In [12]:
BS = 32
# custom dataset class
class CustomDataset(Dataset):
    def __init__(self, images, labels= None, transforms = None):
        self.labels = labels
        self.images = images
        self.transforms = transforms
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        data = self.images[index][:]
        
        if self.transforms:
            data = self.transforms(data)
            
        
        return (data, self.labels[index])
        
train_data = CustomDataset(x_train, y_train, train_transforms)
test_data = CustomDataset(x_test, y_test, val_transform)       

train_loader = DataLoader(train_data, batch_size=BS, shuffle=True, num_workers=4)
test_loader = DataLoader(test_data, batch_size=BS, shuffle=True, num_workers=4, drop_last=False) 

### --- Main --

In [13]:
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torch.autograd import Variable

from torchvision.utils import save_image

def gradients(y, x):
    return autograd.grad(
                outputs=y, inputs=x, retain_graph=True,
                create_graph=True, grad_outputs=torch.ones_like(y), only_inputs=True)[0]

In [14]:
class HWReduction(nn.Module):
    def forward(self, x):
        # x -> [B, C, H, W]
        return x.mean(dim=(-1, -2))

class Reshape(nn.Module):
    def __init__(self, shape: list):
        super(Reshape, self).__init__()
        self.shape = shape
    def forward(self, x):
        batch_size = x.shape[0]
        return x.reshape([batch_size] + self.shape)

class VAE_Cifar10(nn.Module):
    def __init__(self, label = 'cifar10', image_size = dataset_config['size'],
                 channel_num = dataset_config['channels'],
                 z_size=128):
        # configurations
        super().__init__()
        self.label = label
        self.image_size = image_size
        self.channel_num = channel_num
        self.z_size = z_size

        # encoder
        self.encoder = nn.Sequential(
            self.capacity_conv(channel_num, 16), # 16 x 64 x 64
            nn.InstanceNorm2d(16),
            self.downsampling_conv(16, 32), # 32 x 32 x 32
            self.capacity_conv(32, 64), # 64 x 32 x 32
            self.downsampling_conv(64, 128), # 128 x 16 x 16
            self.capacity_conv(128, 256), # 256 x 16 x 16
            self.downsampling_conv(256, 512), # 512 x 8 x 8
            HWReduction(),
        )

        # H, W will be reduced


        # q
        self.q_mean = self._linear(512, z_size, relu=False)
        self.q_logvar = self._linear(512, z_size, relu=False)

        # projection
        self.project = nn.Sequential(
            self._linear(z_size, 1024),
            self._linear(1024, 8 * 8 * 128),
            Reshape([128, 8, 8])
        )

        # decoder
        self.decoder = nn.Sequential(
            self.upsampling_conv(128, 64), # 32 x 16 x 16
            self.capacity_conv(64, 64), # 64 x 16 x 16
            self.upsampling_conv(64, 32), # 32 x 32 x 32
            self.capacity_conv(32, 32), # 32 x 32 x 32
            self.upsampling_conv(32, 16), # 16 x 64 x 64
            nn.Conv2d(
                16, channel_num,
                kernel_size=3, stride=1, padding=1,
            ),
            nn.Sigmoid()
        )


    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std


    def forward(self, x):
        encoded = self.encoder(x)

        # sample latent code z from q given x.
        mean, logvar = self.q_mean(encoded), self.q_logvar(encoded) - 10.0
        z = self.reparameterize(mean, logvar)
        z_projected = self.project(z)

        # reconstruct x from z
        x_reconstructed = self.decoder(z_projected)
        return x_reconstructed, mean, logvar
    
    # ======
    # Layers
    # ======

    def downsampling_conv(self, channel_size, kernel_num):
        return nn.Sequential(
            nn.Conv2d(
                channel_size, kernel_num,
                kernel_size=4, stride=2, padding=1,
            ),
            nn.InstanceNorm2d(kernel_num),
            nn.LeakyReLU(negative_slope=0.1),
        )
    def capacity_conv(self, channel_num, kernel_num):
        return nn.Sequential(
            nn.Conv2d(
                channel_num, kernel_num,
                kernel_size=3, stride=1, padding=1,
            ),
            nn.LeakyReLU(negative_slope=0.1),
        )

    def upsampling_conv(self, channel_num, kernel_num):
        return nn.Sequential(
            nn.ConvTranspose2d(
                channel_num, kernel_num,
                kernel_size=4, stride=2, padding=1,
            ),
            nn.InstanceNorm2d(kernel_num),
            nn.LeakyReLU(negative_slope=0.1),
        )

    def _linear(self, in_size, out_size, relu=True):
        return nn.Sequential(
            nn.Linear(in_size, out_size),
            nn.ReLU(),
        ) if relu else nn.Linear(in_size, out_size)

model = VAE_Cifar10().to(device)
optimizer = optim.AdamW(model.parameters(), lr=3e-4)

In [15]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    MSE = (recon_x - x.view(-1,  dataset_config['channels'],dataset_config['size'],dataset_config['size'])) ** 2
    MSE = MSE.sum(dim=(-1,-2,-3))

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)

    return MSE, KLD

In [16]:
def train(epoch, log_eta):
    model.train()
    log_eta, just_updated = log_eta
    train_recons_loss = 0
    train_kld_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        recons_loss, kld_loss = loss_function(recon_batch, data, mu, logvar)
        loss = (recons_loss + float(np.exp(log_eta)) * kld_loss).mean(dim=0)
        loss.backward()
        train_recons_loss += recons_loss.mean(dim=0).item()
        train_kld_loss += kld_loss.mean(dim=0).item()
        optimizer.step()
        if batch_idx % 2 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tRecons Loss: {:.6f}; KLD Loss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                recons_loss.mean(dim=0).item(), recons_loss.mean(dim=0).item()))
    
    print('====> Epoch: {} Average recons loss: {:.4f}, kld loss: {:.4f}'.format(
          epoch, train_recons_loss / len(train_loader.dataset), train_kld_loss / len(train_loader.dataset)))
    if train_recons_loss / len(train_loader.dataset) < 6.0 and (epoch - just_updated) >= 10 and (train_recons_loss / len(train_loader.dataset)) < train_kld_loss / len(train_loader.dataset):
        log_eta += 0.25
        if log_eta >= 0.0:
            log_eta = 0.0
        just_updated = epoch
        print('====> Epoch: {} Eta is now {:.4f} due to sufficient recons results'.format(epoch, float(np.exp(log_eta))))
    return (log_eta, just_updated)

In [17]:
def test(epoch):
    model.eval()
    test_recons_loss = 0
    test_kld_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            recons_loss, kld_loss = loss_function(recon_batch, data, mu, logvar)
            test_recons_loss += recons_loss.mean(dim=0).item()
            test_kld_loss += kld_loss.mean(dim=0).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch[:n].view(n,  dataset_config['channels'], dataset_config['size'], dataset_config['size'])[:n]])
                if not os.path.exists("results/"):
                    os.mkdir("results")
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)
        for i, (data, _) in enumerate(train_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch[:n].view(n,  dataset_config['channels'], dataset_config['size'], dataset_config['size'])[:n]])
                if not os.path.exists("overfit_results/"):
                    os.mkdir("overfit_results")
                save_image(comparison.cpu(),
                         'overfit_results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_recons_loss /= len(test_loader.dataset)
    test_kld_loss /= len(test_loader.dataset)
    print('====> Test set recons loss: {:.4f}, kld_loss: {:.4f}'.format(test_recons_loss, test_kld_loss))

In [18]:
log_eta = (-3., 0)
for epoch in range(1, 1000 + 1):
    log_eta = train(epoch, log_eta)
    if epoch % 20 == 0:
        test(epoch)

====> Epoch: 1 Average recons loss: 21.3870, kld loss: 41.7312
====> Test set recons loss: 41.4171, kld_loss: 87.0031
====> Epoch: 2 Average recons loss: 19.1828, kld loss: 40.7559
====> Test set recons loss: 39.4873, kld_loss: 84.8462
====> Epoch: 3 Average recons loss: 18.7149, kld loss: 39.7263
====> Test set recons loss: 38.0889, kld_loss: 82.6257
====> Epoch: 4 Average recons loss: 18.0295, kld loss: 38.6757
====> Test set recons loss: 37.3324, kld_loss: 80.4146
====> Epoch: 5 Average recons loss: 17.0175, kld loss: 37.6271
====> Test set recons loss: 36.3341, kld_loss: 78.1668
====> Epoch: 6 Average recons loss: 16.1395, kld loss: 36.6148
====> Test set recons loss: 34.2342, kld_loss: 76.1597
====> Epoch: 7 Average recons loss: 15.9862, kld loss: 35.6100
====> Test set recons loss: 33.9964, kld_loss: 73.8916
====> Epoch: 8 Average recons loss: 15.4244, kld loss: 34.6070
====> Test set recons loss: 33.5881, kld_loss: 71.8914
====> Epoch: 9 Average recons loss: 14.3608, kld loss: 3

====> Test set recons loss: 27.3322, kld_loss: 62.9920
====> Epoch: 32 Average recons loss: 6.2441, kld loss: 29.9525
====> Test set recons loss: 29.6751, kld_loss: 63.2534
====> Epoch: 33 Average recons loss: 6.0061, kld loss: 30.0889
====> Test set recons loss: 28.0845, kld_loss: 63.5753
====> Epoch: 34 Average recons loss: 5.9269, kld loss: 30.2944
====> Epoch: 34 Eta is now 0.0639 due to sufficient recons results
====> Test set recons loss: 28.1702, kld_loss: 64.0722
====> Epoch: 35 Average recons loss: 5.6254, kld loss: 30.4129
====> Test set recons loss: 29.8278, kld_loss: 64.2140
====> Epoch: 36 Average recons loss: 5.7992, kld loss: 30.4543
====> Test set recons loss: 29.2006, kld_loss: 64.1615
====> Epoch: 37 Average recons loss: 5.5639, kld loss: 30.4488
====> Test set recons loss: 29.8351, kld_loss: 64.1215
====> Epoch: 38 Average recons loss: 5.3326, kld loss: 30.3966
====> Test set recons loss: 29.8966, kld_loss: 64.0161
====> Epoch: 39 Average recons loss: 5.0843, kld los

====> Test set recons loss: 30.8820, kld_loss: 55.2483
====> Epoch: 62 Average recons loss: 3.3892, kld loss: 26.2661
====> Test set recons loss: 30.7893, kld_loss: 54.7038
====> Epoch: 63 Average recons loss: 3.3724, kld loss: 26.0615
====> Test set recons loss: 31.8315, kld_loss: 54.4563
====> Epoch: 64 Average recons loss: 3.3228, kld loss: 25.9224
====> Epoch: 64 Eta is now 0.1353 due to sufficient recons results
====> Test set recons loss: 30.7801, kld_loss: 54.2560
====> Epoch: 65 Average recons loss: 3.3032, kld loss: 25.8023
====> Test set recons loss: 29.6334, kld_loss: 53.6826
====> Epoch: 66 Average recons loss: 3.2674, kld loss: 25.5550
====> Test set recons loss: 31.2618, kld_loss: 53.0437
====> Epoch: 67 Average recons loss: 3.2454, kld loss: 25.2482
====> Test set recons loss: 30.6545, kld_loss: 52.5295
====> Epoch: 68 Average recons loss: 3.2332, kld loss: 24.9345
====> Test set recons loss: 33.6660, kld_loss: 51.7080
====> Epoch: 69 Average recons loss: 3.1588, kld los

====> Epoch: 92 Average recons loss: 3.0202, kld loss: 19.2626
====> Test set recons loss: 33.4273, kld_loss: 39.8185
====> Epoch: 93 Average recons loss: 2.9537, kld loss: 19.2875
====> Test set recons loss: 33.0197, kld_loss: 39.4449
====> Epoch: 94 Average recons loss: 2.9361, kld loss: 19.1298
====> Epoch: 94 Eta is now 0.2865 due to sufficient recons results
====> Test set recons loss: 32.0820, kld_loss: 39.0583
====> Epoch: 95 Average recons loss: 2.9726, kld loss: 18.7877
====> Test set recons loss: 34.1469, kld_loss: 38.1966
====> Epoch: 96 Average recons loss: 2.9491, kld loss: 18.3757
====> Test set recons loss: 34.2506, kld_loss: 36.9494
====> Epoch: 97 Average recons loss: 3.1557, kld loss: 17.8594
====> Test set recons loss: 36.4254, kld_loss: 36.4540
====> Epoch: 98 Average recons loss: 3.1143, kld loss: 17.6415
====> Test set recons loss: 35.9125, kld_loss: 35.6663
====> Epoch: 99 Average recons loss: 3.0651, kld loss: 17.4395
====> Test set recons loss: 34.6496, kld_los

====> Test set recons loss: 35.9524, kld_loss: 25.5172
====> Epoch: 122 Average recons loss: 3.7119, kld loss: 12.7791
====> Test set recons loss: 38.1531, kld_loss: 25.0911
====> Epoch: 123 Average recons loss: 3.7070, kld loss: 12.7635
====> Test set recons loss: 39.3326, kld_loss: 24.8972
====> Epoch: 124 Average recons loss: 3.5488, kld loss: 12.6803
====> Epoch: 124 Eta is now 0.6065 due to sufficient recons results
====> Test set recons loss: 35.7266, kld_loss: 24.7925
====> Epoch: 125 Average recons loss: 3.8335, kld loss: 12.4146
====> Test set recons loss: 36.8875, kld_loss: 23.8481
====> Epoch: 126 Average recons loss: 3.7490, kld loss: 12.0507
====> Test set recons loss: 38.3602, kld_loss: 22.8909
====> Epoch: 127 Average recons loss: 3.8091, kld loss: 11.5729
====> Test set recons loss: 35.3914, kld_loss: 22.2833
====> Epoch: 128 Average recons loss: 4.1233, kld loss: 11.2881
====> Test set recons loss: 38.2472, kld_loss: 21.8438
====> Epoch: 129 Average recons loss: 4.1309

====> Epoch: 151 Average recons loss: 6.0472, kld loss: 6.4058
====> Test set recons loss: 40.0798, kld_loss: 11.2308
====> Epoch: 152 Average recons loss: 6.2853, kld loss: 6.2894
====> Test set recons loss: 39.6479, kld_loss: 11.1621
====> Epoch: 153 Average recons loss: 5.9316, kld loss: 6.3025
====> Test set recons loss: 40.0136, kld_loss: 10.5009
====> Epoch: 154 Average recons loss: 5.9754, kld loss: 5.9924
====> Epoch: 154 Eta is now 1.0000 due to sufficient recons results
====> Test set recons loss: 37.7767, kld_loss: 10.9044
====> Epoch: 155 Average recons loss: 5.9735, kld loss: 5.9522
====> Test set recons loss: 39.8249, kld_loss: 9.1550
====> Epoch: 156 Average recons loss: 6.3868, kld loss: 5.8810
====> Test set recons loss: 39.4549, kld_loss: 10.1456
====> Epoch: 157 Average recons loss: 5.9355, kld loss: 5.6749
====> Test set recons loss: 42.4503, kld_loss: 10.0729
====> Epoch: 158 Average recons loss: 6.4425, kld loss: 5.6266
====> Test set recons loss: 39.0823, kld_los

====> Test set recons loss: 42.7647, kld_loss: 5.4456
====> Epoch: 181 Average recons loss: 5.3623, kld loss: 4.1080
====> Test set recons loss: 38.8867, kld_loss: 7.1072
====> Epoch: 182 Average recons loss: 5.2942, kld loss: 4.3428
====> Test set recons loss: 42.9883, kld_loss: 5.8508
====> Epoch: 183 Average recons loss: 5.5837, kld loss: 4.3349
====> Test set recons loss: 38.5147, kld_loss: 7.1850
====> Epoch: 184 Average recons loss: 4.9617, kld loss: 4.3319
====> Epoch: 184 Eta is now 1.0000 due to sufficient recons results
====> Test set recons loss: 39.6972, kld_loss: 5.3662
====> Epoch: 185 Average recons loss: 5.7891, kld loss: 4.2562
====> Test set recons loss: 45.6179, kld_loss: 6.5473
====> Epoch: 186 Average recons loss: 5.1095, kld loss: 4.3500
====> Test set recons loss: 41.2171, kld_loss: 5.6124
====> Epoch: 187 Average recons loss: 5.0265, kld loss: 4.2238
====> Test set recons loss: 41.8695, kld_loss: 6.7813
====> Epoch: 188 Average recons loss: 5.2059, kld loss: 4.2

====> Epoch: 210 Average recons loss: 4.7185, kld loss: 3.9277
====> Test set recons loss: 42.6193, kld_loss: 4.4340
====> Epoch: 211 Average recons loss: 4.8607, kld loss: 3.5070
====> Test set recons loss: 40.7829, kld_loss: 4.3839
====> Epoch: 212 Average recons loss: 4.7402, kld loss: 3.6114
====> Test set recons loss: 40.5084, kld_loss: 4.6149
====> Epoch: 213 Average recons loss: 4.5253, kld loss: 3.4827
====> Test set recons loss: 41.5631, kld_loss: 3.9146
====> Epoch: 214 Average recons loss: 4.9262, kld loss: 3.5408
====> Epoch: 214 Eta is now 1.0000 due to sufficient recons results
====> Test set recons loss: 42.9383, kld_loss: 4.8977
====> Epoch: 215 Average recons loss: 4.4152, kld loss: 3.5452
====> Test set recons loss: 41.9200, kld_loss: 3.7830
====> Epoch: 216 Average recons loss: 4.6131, kld loss: 3.3911
====> Test set recons loss: 43.8521, kld_loss: 3.8440
====> Epoch: 217 Average recons loss: 4.5214, kld loss: 3.2457
====> Test set recons loss: 40.9932, kld_loss: 3.7

====> Test set recons loss: 36.9243, kld_loss: 2.8301
====> Epoch: 240 Average recons loss: 4.3437, kld loss: 3.0376
====> Test set recons loss: 42.7992, kld_loss: 4.5228
====> Epoch: 241 Average recons loss: 4.2123, kld loss: 3.1701
====> Test set recons loss: 41.0085, kld_loss: 2.4592
====> Epoch: 242 Average recons loss: 4.4299, kld loss: 3.0031
====> Test set recons loss: 47.0250, kld_loss: 3.8903
====> Epoch: 243 Average recons loss: 4.3716, kld loss: 3.1144
====> Test set recons loss: 38.3472, kld_loss: 3.0580
====> Epoch: 244 Average recons loss: 3.9408, kld loss: 3.0097
====> Epoch: 244 Eta is now 1.0000 due to sufficient recons results
====> Test set recons loss: 40.0597, kld_loss: 3.4461
====> Epoch: 245 Average recons loss: 4.1309, kld loss: 2.9837
====> Test set recons loss: 41.8713, kld_loss: 2.8772
====> Epoch: 246 Average recons loss: 4.3045, kld loss: 2.9151
====> Test set recons loss: 40.9215, kld_loss: 3.7472
====> Epoch: 247 Average recons loss: 3.9620, kld loss: 3.1

====> Test set recons loss: 40.9883, kld_loss: 2.6161
====> Epoch: 270 Average recons loss: 3.6538, kld loss: 2.5985
====> Test set recons loss: 40.0450, kld_loss: 2.3527
====> Epoch: 271 Average recons loss: 4.1263, kld loss: 2.5618
====> Test set recons loss: 42.8865, kld_loss: 3.4498
====> Epoch: 272 Average recons loss: 3.8781, kld loss: 2.8613
====> Test set recons loss: 42.7443, kld_loss: 2.8796
====> Epoch: 273 Average recons loss: 3.7021, kld loss: 2.7354
====> Test set recons loss: 40.1151, kld_loss: 2.5022
====> Epoch: 274 Average recons loss: 3.9769, kld loss: 2.6471
====> Epoch: 274 Eta is now 1.0000 due to sufficient recons results
====> Test set recons loss: 39.2722, kld_loss: 2.8293
====> Epoch: 275 Average recons loss: 3.5340, kld loss: 2.7206
====> Test set recons loss: 38.8918, kld_loss: 2.4490
====> Epoch: 276 Average recons loss: 3.7114, kld loss: 2.4859
====> Test set recons loss: 40.4723, kld_loss: 2.5187
====> Epoch: 277 Average recons loss: 3.8146, kld loss: 2.5

====> Test set recons loss: 39.9497, kld_loss: 2.2162
====> Epoch: 300 Average recons loss: 3.2375, kld loss: 2.4397
====> Test set recons loss: 42.6685, kld_loss: 2.0365
====> Epoch: 301 Average recons loss: 3.5863, kld loss: 2.2389
====> Test set recons loss: 40.7913, kld_loss: 1.9934
====> Epoch: 302 Average recons loss: 3.5060, kld loss: 2.3952
====> Test set recons loss: 39.7630, kld_loss: 2.3102
====> Epoch: 303 Average recons loss: 3.4801, kld loss: 2.5415
====> Test set recons loss: 39.3526, kld_loss: 2.1026
====> Epoch: 304 Average recons loss: 3.3122, kld loss: 2.2295
====> Epoch: 304 Eta is now 1.0000 due to sufficient recons results
====> Test set recons loss: 42.2282, kld_loss: 1.8335
====> Epoch: 305 Average recons loss: 3.4058, kld loss: 2.2725
====> Test set recons loss: 44.8963, kld_loss: 2.1950
====> Epoch: 306 Average recons loss: 3.2090, kld loss: 2.3507
====> Test set recons loss: 41.2656, kld_loss: 1.7169
====> Epoch: 307 Average recons loss: 3.4447, kld loss: 2.1

====> Test set recons loss: 43.5685, kld_loss: 2.1139
====> Epoch: 330 Average recons loss: 3.3726, kld loss: 2.2815
====> Test set recons loss: 38.3927, kld_loss: 1.8513
====> Epoch: 331 Average recons loss: 3.6200, kld loss: 2.4124
====> Test set recons loss: 42.1340, kld_loss: 2.4600
====> Epoch: 332 Average recons loss: 3.3074, kld loss: 2.6311
====> Test set recons loss: 37.9223, kld_loss: 2.3361
====> Epoch: 333 Average recons loss: 3.3011, kld loss: 2.4177
====> Test set recons loss: 42.8225, kld_loss: 2.0843
====> Epoch: 334 Average recons loss: 3.0806, kld loss: 2.4179
====> Epoch: 334 Eta is now 1.0000 due to sufficient recons results
====> Test set recons loss: 44.2123, kld_loss: 2.1889
====> Epoch: 335 Average recons loss: 3.7782, kld loss: 2.3726
====> Test set recons loss: 40.8942, kld_loss: 2.4226
====> Epoch: 336 Average recons loss: 2.9744, kld loss: 2.5747
====> Test set recons loss: 40.7271, kld_loss: 2.1201
====> Epoch: 337 Average recons loss: 3.6297, kld loss: 2.3

====> Test set recons loss: 42.9228, kld_loss: 2.0567
====> Epoch: 360 Average recons loss: 2.9516, kld loss: 2.1803
====> Test set recons loss: 39.8106, kld_loss: 1.9193
====> Epoch: 361 Average recons loss: 3.1496, kld loss: 2.1056
====> Test set recons loss: 42.3718, kld_loss: 2.0640
====> Epoch: 362 Average recons loss: 3.1231, kld loss: 2.1668
====> Test set recons loss: 37.6985, kld_loss: 2.0276
====> Epoch: 363 Average recons loss: 2.7176, kld loss: 2.2884
====> Test set recons loss: 40.0338, kld_loss: 1.8702
====> Epoch: 364 Average recons loss: 3.2007, kld loss: 2.0115
====> Epoch: 364 Eta is now 1.0000 due to sufficient recons results
====> Test set recons loss: 40.9200, kld_loss: 1.9532
====> Epoch: 365 Average recons loss: 2.8671, kld loss: 2.2950
====> Test set recons loss: 40.9928, kld_loss: 2.0640
====> Epoch: 366 Average recons loss: 3.0112, kld loss: 2.0811
====> Test set recons loss: 45.7503, kld_loss: 1.8475
====> Epoch: 367 Average recons loss: 2.8772, kld loss: 2.0

====> Test set recons loss: 38.8850, kld_loss: 1.7529
====> Epoch: 390 Average recons loss: 2.6803, kld loss: 1.9559
====> Test set recons loss: 42.7978, kld_loss: 1.7432
====> Epoch: 391 Average recons loss: 2.7603, kld loss: 2.0615
====> Test set recons loss: 44.3822, kld_loss: 1.8872
====> Epoch: 392 Average recons loss: 2.7954, kld loss: 2.0197
====> Test set recons loss: 44.9315, kld_loss: 1.8992
====> Epoch: 393 Average recons loss: 2.7870, kld loss: 1.8745
====> Test set recons loss: 39.7616, kld_loss: 1.8728
====> Epoch: 394 Average recons loss: 2.7040, kld loss: 2.0796
====> Epoch: 394 Eta is now 1.0000 due to sufficient recons results
====> Test set recons loss: 45.1947, kld_loss: 1.7803
====> Epoch: 395 Average recons loss: 2.9464, kld loss: 1.9354
====> Test set recons loss: 38.7356, kld_loss: 1.8228
====> Epoch: 396 Average recons loss: 2.8256, kld loss: 2.0028
====> Test set recons loss: 41.1368, kld_loss: 1.8498
====> Epoch: 397 Average recons loss: 2.7446, kld loss: 2.0

====> Test set recons loss: 41.7056, kld_loss: 1.8632
====> Epoch: 420 Average recons loss: 2.6680, kld loss: 1.9688
====> Test set recons loss: 41.8257, kld_loss: 1.9242
====> Epoch: 421 Average recons loss: 2.7547, kld loss: 1.9821
====> Test set recons loss: 40.8772, kld_loss: 1.9060
====> Epoch: 422 Average recons loss: 2.5304, kld loss: 2.1186
====> Test set recons loss: 44.7621, kld_loss: 1.9294
====> Epoch: 423 Average recons loss: 2.5751, kld loss: 1.9816
====> Test set recons loss: 38.4895, kld_loss: 1.9634
====> Epoch: 424 Average recons loss: 2.5626, kld loss: 1.9365
====> Epoch: 424 Eta is now 1.0000 due to sufficient recons results
====> Test set recons loss: 41.8424, kld_loss: 1.8442
====> Epoch: 425 Average recons loss: 2.6138, kld loss: 1.9833
====> Test set recons loss: 38.9163, kld_loss: 1.8799
====> Epoch: 426 Average recons loss: 2.5803, kld loss: 1.9155
====> Test set recons loss: 45.1500, kld_loss: 1.9077
====> Epoch: 427 Average recons loss: 2.5253, kld loss: 1.9

====> Test set recons loss: 44.8156, kld_loss: 1.8758
====> Epoch: 450 Average recons loss: 2.4770, kld loss: 1.9676
====> Test set recons loss: 43.2628, kld_loss: 1.8581
====> Epoch: 451 Average recons loss: 2.2834, kld loss: 1.8800
====> Test set recons loss: 41.7503, kld_loss: 2.1835
====> Epoch: 452 Average recons loss: 2.4885, kld loss: 1.8016
====> Test set recons loss: 49.4168, kld_loss: 1.8215
====> Epoch: 453 Average recons loss: 2.4229, kld loss: 1.8841
====> Test set recons loss: 46.0556, kld_loss: 1.8201
====> Epoch: 454 Average recons loss: 2.3422, kld loss: 1.8123
====> Epoch: 454 Eta is now 1.0000 due to sufficient recons results
====> Test set recons loss: 42.9138, kld_loss: 1.8889
====> Epoch: 455 Average recons loss: 2.4225, kld loss: 1.8135
====> Test set recons loss: 43.4689, kld_loss: 1.8190
====> Epoch: 456 Average recons loss: 2.2734, kld loss: 1.9183
====> Test set recons loss: 41.2559, kld_loss: 1.8343
====> Epoch: 457 Average recons loss: 2.3903, kld loss: 1.8

====> Test set recons loss: 42.9613, kld_loss: 1.9616
====> Epoch: 480 Average recons loss: 2.5285, kld loss: 1.9981
====> Test set recons loss: 43.5024, kld_loss: 1.9444
====> Epoch: 481 Average recons loss: 2.2253, kld loss: 1.9091
====> Test set recons loss: 43.4638, kld_loss: 2.2716
====> Epoch: 482 Average recons loss: 2.5876, kld loss: 1.7849
====> Test set recons loss: 45.5515, kld_loss: 1.8997
====> Epoch: 483 Average recons loss: 2.1398, kld loss: 2.0073
====> Test set recons loss: 40.8408, kld_loss: 1.9142
====> Epoch: 484 Average recons loss: 2.3300, kld loss: 1.7807
====> Epoch: 484 Eta is now 1.0000 due to sufficient recons results
====> Test set recons loss: 42.0266, kld_loss: 2.5592
====> Epoch: 485 Average recons loss: 2.5547, kld loss: 1.7398
====> Test set recons loss: 43.6195, kld_loss: 1.8056
====> Epoch: 486 Average recons loss: 2.2008, kld loss: 2.0736
====> Test set recons loss: 43.5839, kld_loss: 1.8793
====> Epoch: 487 Average recons loss: 2.2793, kld loss: 1.8

KeyboardInterrupt: 

In [None]:
1048576/128/128

In [None]:
model.train()
train_loss = 0
for batch_idx,( data, _) in enumerate(train_loader):
    data = data.to(device)
    data = data.reshape([-1,3, 128*128])
    optimizer.zero_grad()
    recon_batch, mu, logvar = model(data)
    loss = loss_function(recon_batch, data, mu, logvar)
    loss.backward()
    train_loss += loss.item()
    optimizer.step()
    if batch_idx % 100 == 0:
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader),
            loss.item() / len(data)))