In [5]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
import os
from skimage import io, transform
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from skimage.io import imread
from torchsummary import summary
import pandas as pd
import time
from models import VAE_CNN
NAME = 'KL4z_on'
KL_PAR = 0.00001

TRAIN_ROOT = '../data/data-celeba/Train'
VAL_ROOT = '../data/data-celeba/Validation'
VAL_ROOT_FIX = '../data/data-celeba/Fix_sample'
BATCH_SIZE = 32
EPOCHS = 4
LOG_INTERVAL = 50
BOTTLENECK_SIZE = 512
SAVE_MODEL = '../models/' + NAME
SAVE_RESULTS = '../results/' + NAME
%mkdir {SAVE_RESULTS}

In [7]:
print('hi hi hi hi hi')

hi hi


In [2]:
no_cuda = False
seed = 1
cuda = not no_cuda and torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
torch.manual_seed(seed)


transform_seq = [ transforms.Resize((200,163)), transforms.Pad(( 19, 0, 18, 0)),
                  transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]

train_loader_celeba = torch.utils.data.DataLoader(
    datasets.ImageFolder(TRAIN_ROOT, transform=transforms.Compose(transform_seq)),
    batch_size = BATCH_SIZE, shuffle=True, **kwargs)

val_loader_celeba = torch.utils.data.DataLoader(
    datasets.ImageFolder(VAL_ROOT, transform=transforms.Compose(transform_seq)),
    batch_size = BATCH_SIZE, shuffle=True, **kwargs)

val_loader_celeba_fix = torch.utils.data.DataLoader(
    datasets.ImageFolder(VAL_ROOT_FIX, transform=transforms.Compose(transform_seq)),
    batch_size = BATCH_SIZE, shuffle=False, **kwargs)

FileNotFoundError: [Errno 2] No such file or directory: 'data-celeba/Train'

In [18]:
val_loader_celeba.dataset

Dataset ImageFolder
    Number of datapoints: 40520
    Root location: data-celeba/Validation
    StandardTransform
Transform: Compose(
               Resize(size=(200, 163), interpolation=PIL.Image.BILINEAR)
               Pad(padding=(19, 0, 18, 0), fill=0, padding_mode=constant)
               ToTensor()
               Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
           )

In [4]:
data

Dataset ImageFolder
    Number of datapoints: 162078
    Root location: data-celeba/Train
    StandardTransform
Transform: Compose(
               Resize(size=(200, 163), interpolation=PIL.Image.BILINEAR)
               Pad(padding=(19, 0, 18, 0), fill=0, padding_mode=constant)
               ToTensor()
               Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
           )

In [5]:
class customLoss(nn.Module):
    def __init__(self):
        super(customLoss, self).__init__()
        self.mse_loss = nn.MSELoss(reduction="mean")

    def forward(self, x_recon, x, mu, logvar):
        loss_MSE = self.mse_loss(x_recon, x)
        loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return loss_MSE, KL_PAR * loss_KLD

In [6]:
model = VAE_CNN(BOTTLENECK_SIZE).to(device)
optimizer = optim.RMSprop(model.parameters(), lr=0.001, alpha= 0.9)
loss_custom = customLoss()

In [7]:
#summary(model, (3,200,200))

In [8]:
'''model.train()
train_loss = 0

data = batch.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_mse(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
'''

'model.train()\ntrain_loss = 0\n\ndata = batch.to(device)\noptimizer.zero_grad()\nrecon_batch, mu, logvar = model(data)\nloss = loss_mse(recon_batch, data, mu, logvar)\nloss.backward()\ntrain_loss += loss.item()\noptimizer.step()\n'

In [9]:
val_losses = []
train_losses = []

def train(epoch):
    epoch_start = 0
    model.train(True)
    train_loss = 0
    start = time.time()
    for batch_idx, (data, _) in enumerate(val_loader_celeba_fix): 
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss_mse, loss_kl = loss_custom(recon_batch, data, mu, logvar)
        loss = loss_mse + loss_kl
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % LOG_INTERVAL == 0:
            interval = time.time() - start
            start = time.time()
            epoch_start = epoch_start + interval
            
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss MSE: {:.6f} \tLoss KL: {:.6f} \tTime Interv: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader_celeba.dataset),
                       100. * batch_idx / len(train_loader_celeba),
                       loss_mse.item() / len(data), loss_kl.item() / len(data), interval))

    print('====> Epoch: {} Average loss: {:.4f} Elapsed Time: {:.6f}'.format(
        epoch, train_loss / len(train_loader_celeba.dataset), epoch_start))
    train_losses.append(train_loss / len(train_loader_celeba.dataset))


In [31]:
def test(epoch, sufix, train):
    model.train(train)
    test_loss = 0
    with torch.no_grad():
        
        for i, (data, _) in enumerate(val_loader_celeba):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            loss_mse, loss_kl = loss_custom(recon_batch, data, mu, logvar)
            loss = loss_mse + loss_kl
            test_loss += loss.item()
            if i < 10:
                n = min(data.size(0), 32)
                comparison = torch.cat([data[:32],
                                        recon_batch.view(32, 3, 200, 200)[:n]])
            if i > 10:
                save_image(comparison.cpu(), 
                           SAVE_RESULTS + '/reconstruction_' + str(epoch) + sufix + '.png', nrow=n, normalize=True)

        '''    
        for i, (data, _) in enumerate(val_loader_celeba_fix):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                        recon_batch.view(7, 3, 200, 200)[:n]])
                save_image(comparison.cpu(),
                           SAVE_RESULTS + '/reconstruction_' + str(epoch) + sufix + '.png', nrow=n, normalize=True)

'''
    test_loss /= len(val_loader_celeba.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))
    val_losses.append(test_loss)

In [32]:
3840000/(3*200*200)

32.0

In [33]:
for epoch in range(1, EPOCHS + 1):
    train(epoch)
    test(epoch, '_on', True)
    test(epoch, '_off', False)
    model.train(False)
    with torch.no_grad():
        sample = torch.randn(BATCH_SIZE, BOTTLENECK_SIZE).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(BATCH_SIZE, 3, 200, 200),
                   SAVE_RESULTS + '/sample_off_' + str(epoch) + '.png', normalize=True)
    model.train(True)
    with torch.no_grad():
        sample = torch.randn(BATCH_SIZE, BOTTLENECK_SIZE).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(BATCH_SIZE, 3, 200, 200),
                   SAVE_RESULTS + '/sample_on' + str(epoch) + '.png', normalize=True)

====> Epoch: 1 Average loss: 0.0000 Elapsed Time: 0.114999


KeyboardInterrupt: 

In [None]:
torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, SAVE_FOLDER)

In [None]:
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

In [None]:
%matplotlib inline
def show(img):
    npimg = img.detach().numpy()
    plt.figure(figsize=(30, 10))
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')

In [None]:
model.train(True)
test_loss = 0
for i, (data, _) in enumerate(val_loader_celeba):
    data = data.to(device)
    recon_batch, mu, logvar = model(data)
    loss_mse, loss_kl = loss_custom(recon_batch, data, mu, logvar)
    loss = loss_mse + loss_kl
    test_loss += loss.item()
    if i == 0:
        n = min(data.size(0), 8)
        comparison = torch.cat([data[:n],
                                recon_batch.view(7, 3, 200, 200)[:n]])
        show(make_grid(comparison.cpu(), nrow=n, normalize=True))

In [None]:
loss

In [None]:
'''checkpoint = torch.load('model_save')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
'''


In [None]:
loss

In [None]:
'''model.train(True)
test_loss = 0
for i, (data, _) in enumerate(val_loader_celeba):
    data = data.to(device)
    recon_batch, mu, logvar = model(data)
    loss_mse, loss_kl = loss_custom(recon_batch, data, mu, logvar)
    loss = loss_mse + loss_kl
    test_loss += loss.item()
    if i == 0:
        n = min(data.size(0), 8)
        comparison = torch.cat([data[:n],
                                recon_batch.view(7, 3, 200, 200)[:n]])
        show(make_grid(comparison.cpu(), nrow=n, normalize=True))'''

In [None]:
'''import random
x_test = []
plt.figure(figsize=(30,10))
num_figs = 1
plt.figure(figsize=(15,100))
plt.axis('off')
plt.imshow(make_grid(comparison.cpu(), nrow=n, normalize=True).detach().numpy())
'''
for i in range(num_figs):
    figure_Decoded = vae_2.predict(np.array([x_test[i].astype('float32')/127.5 -1]), batch_size = b_size)
    plt.axis('off')
    plt.subplot(num_figs,2,1+i*2)
    plt.imshow(x_test[i])
    plt.axis('off')
    plt.subplot(num_figs,2,2 + i*2)
    plt.imshow((figure_Decoded[0]+1)/2)
    plt.axis('off')
'''
plt.show()'''

In [None]:
'''torch.save({
            'epoch': 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, "model_save")'''