In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import PIL

import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torchsummary import summary

from pushover import notify
from utils import makegif
from random import randint

from IPython.display import Image
from IPython.core.display import Image, display

%load_ext autoreload
%autoreload 2

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [3]:
bs = 32 # batchsize

In [4]:
# Load Data
dataset = datasets.ImageFolder(root='trainings/roll_imgs_partial', transform=transforms.Compose([
#     transforms.Resize(64),
    transforms.ToTensor(), 
]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=True)
len(dataset.imgs), len(dataloader)
# size of input = 3 x 128 x 128

(4312, 135)

In [5]:
# Fixed input for debugging
fixed_x, _ = next(iter(dataloader))
save_image(fixed_x, 'outputs/real_image.png')

# Image('outputs/real_image.png')

In [6]:
print(dataset[1][0].shape)
HSIZE = 2048 #9216 # 1024
ZDIM =  32

torch.Size([3, 128, 128])


In [7]:
class Flatten(nn.Module):
    def forward(self, input):
#         print("flatten: ", input.shape)
        return input.view(input.size(0), -1)

In [8]:
class UnFlatten(nn.Module):
    def forward(self, input, size=HSIZE):
        return input.view(input.size(0), size, 1, 1)

In [9]:
class VAE(nn.Module):
    def __init__(self, image_channels=3, h_dim=HSIZE, z_dim=ZDIM):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=4, stride=2), # -> [32, 32, 31, 31] 63
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), # -> [32, 64, 14, 14] 31
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2), # -> [32, 128, 6, 6] 14
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2), # -> [32, 256, 2, 2] 6
            nn.ReLU(), 
            nn.Conv2d(256, 512, kernel_size=4, stride=2), # -> Null -> [32, 512, 2, 2] 
            nn.ReLU(), 
            Flatten() # -> [32, 1024]  -> [32, 2048]
            # [32, a, b, c] -> [32, abc]
        )
        
        self.fc1 = nn.Linear(h_dim, z_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(z_dim, h_dim)
        
        self.decoder = nn.Sequential(
            UnFlatten(), 
            nn.ConvTranspose2d(h_dim, 256 , kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, image_channels, kernel_size=6, stride=2),
            nn.Sigmoid(),
        )
#         self.decoder = nn.Sequential(
#             UnFlatten(),
#             nn.ConvTranspose2d(h_dim, 128, kernel_size=5, stride=2),
#             nn.ReLU(),
#             nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
#             nn.ReLU(),
#             nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),
#             nn.ReLU(),
#             nn.ConvTranspose2d(32, image_channels, kernel_size=6, stride=2),
#             nn.Sigmoid(),
#         )
        
    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        # return torch.normal(mu, std)
        esp = torch.randn(*mu.size())
        z = mu + std * esp
        return z
    
    def bottleneck(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
#         print("bottle: ",mu.shape, logvar.shape)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar

    def encode(self, x):
#         print("======== Encode ========", x.shape)
        h = self.encoder(x)
#         print("enc(x): ", h.shape)
        z, mu, logvar = self.bottleneck(h)
#         print("z.shape: ", z.shape)
        return z, mu, logvar

    def decode(self, z):
#         print("======== Decode ========", z.shape)
        z = self.fc3(z)
#         print("fc3(z).shape: ", z.shape)
        z = self.decoder(z)
#         print("decode(fc3(z)).shape: ", z.shape)
        return z

    def forward(self, x):
        z, mu, logvar = self.encode(x)
#         print(z.shape)
        z = self.decode(z)
#         print(z.shape, mu.shape, logvar.shape)
        return z, mu, logvar

In [10]:
image_channels = fixed_x.size(1)

In [11]:
vae = VAE(image_channels=image_channels).to(device)
# model.load_state_dict(torch.load('vae.torch', map_location='cpu'))

In [12]:
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

In [13]:
def loss_fn(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, size_average=False)
    # BCE = F.mse_loss(recon_x, x, size_average=False)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD, BCE, KLD

In [14]:
epochs = 50

In [15]:
for epoch in range(epochs):
    for idx, (images, _) in enumerate(dataloader):
        recon_images, mu, logvar = vae(images)
        loss, bce, kld = loss_fn(recon_images, images, mu, logvar)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

#         to_print = "Epoch[{}/{}] Loss: {:.3f} {:.3f} {:.3f}".format(epoch+1, 
#                                 epochs, loss.data[0]/bs, bce.data[0]/bs, kld.data[0]/bs)
        to_print = "Epoch[{}/{}] Loss: {:.3f} {:.3f} {:.3f}".format(epoch+1, 
                                epochs, loss.data/bs, bce.data/bs, kld.data/bs)
    print(to_print)

# notify to android when finished training
notify(to_print, priority=1)



Epoch[1/50] Loss: 2817.555 2817.400 0.155
Epoch[2/50] Loss: 2247.128 2246.914 0.214
Epoch[3/50] Loss: 2860.630 2860.391 0.239
Epoch[4/50] Loss: 2614.026 2613.764 0.262
Epoch[5/50] Loss: 1890.036 1889.761 0.275
Epoch[6/50] Loss: 2108.351 2108.114 0.237
Epoch[7/50] Loss: 1815.755 1815.517 0.238
Epoch[8/50] Loss: 1522.818 1522.577 0.242
Epoch[9/50] Loss: 1538.842 1538.586 0.256
Epoch[10/50] Loss: 1493.819 1493.569 0.249
Epoch[11/50] Loss: 1467.771 1467.495 0.276
Epoch[12/50] Loss: 1221.451 1221.198 0.254
Epoch[13/50] Loss: 1108.783 1108.529 0.254
Epoch[14/50] Loss: 986.299 985.997 0.302
Epoch[15/50] Loss: 990.037 989.733 0.304
Epoch[16/50] Loss: 1063.933 1063.649 0.284
Epoch[17/50] Loss: 728.666 728.360 0.306
Epoch[18/50] Loss: 884.410 884.070 0.340
Epoch[19/50] Loss: 525.903 525.550 0.352
Epoch[20/50] Loss: 570.109 569.754 0.355
Epoch[21/50] Loss: 639.068 638.721 0.346
Epoch[22/50] Loss: 648.377 648.040 0.338
Epoch[23/50] Loss: 595.755 595.399 0.356
Epoch[24/50] Loss: 548.748 548.372 0.3

In [16]:
torch.save(vae.state_dict(), 'models/vae.torch-alb-nimgs_{}-epochs_{}'.format(len(dataset.imgs), epochs))