In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import PIL

import torchvision as tv
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
from torchsummary import summary

from pushover import notify
from utils import makegif
from random import randint

from IPython.display import Image
from IPython.core.display import Image, display
import numpy as np

%load_ext autoreload
%autoreload 2

In [3]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [4]:
bs = 32 # batchsize 32

In [5]:
# Load Data

def print_gt_zero_elem(matrix):
    print(matrix[matrix > 0])

def print_elems_gt(matrix, val):
    print(matrix[matrix > val])
    
def load_img(img_path):
    img = PIL.Image.open(img_path)
    img = img.convert(mode="L")
#     npimg = np.array(img)/256.0
    return img

IMRANGE = 256 # uint8

# dataset = datasets.ImageFolder(root='trainings/rolls_gray', transform=transforms.Compose([
#     transforms.ToTensor(), 
#     lambda x: (x > 0).type(torch.FloatTensor) + (x == 0).type(torch.FloatTensor) * -1,
# ]), loader=load_img)

dataset = datasets.ImageFolder(root='trainings/rolls_gray', transform=transforms.Compose([
    transforms.ToTensor(), 
    lambda x: x > 0,
    lambda x: x.type(torch.FloatTensor),
]), loader=load_img)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=bs, shuffle=True)
len(dataset.imgs), len(dataloader)

# size of input = 3 x 128 x 128

(2297, 72)

In [6]:
# Fixed input for debugging
fixed_x, _ = next(iter(dataloader))
save_image(fixed_x, 'outputs/real_image.png')

# Image('outputs/real_image.png')

In [7]:
print(dataset[0][0].shape)
sample_dat = dataset[0][0]
print(sample_dat)
HSIZE = 2048 #9216 # 1024
ZDIM = 32

torch.Size([1, 128, 128])
tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])


In [8]:
class Flatten(nn.Module):
    def forward(self, input):
#         print("flatten: ", input.shape)
        return input.view(input.size(0), -1)

In [9]:
class UnFlatten(nn.Module):
    def forward(self, input, size=HSIZE):
        return input.view(input.size(0), size, 1, 1)

In [46]:
class VAE(nn.Module):
    def __init__(self, image_channels=1, h_dim=HSIZE, z_dim=ZDIM):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=4, stride=2), # -> [32, 32, 31, 31] 63
            nn.BatchNorm2d(32),
            nn.Tanh(),
#             nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), # -> [32, 64, 14, 14] 31
            nn.BatchNorm2d(64),
            nn.Tanh(),
#             nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2), # -> [32, 128, 6, 6] 14
            nn.BatchNorm2d(128),
            nn.Tanh(),
#             nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2), # -> [32, 256, 2, 2] 6
            nn.BatchNorm2d(256),
            nn.Tanh(),
#             nn.ReLU(), 
            nn.Conv2d(256, 512, kernel_size=4, stride=2), # -> Null -> [32, 512, 2, 2] 
            nn.BatchNorm2d(512),
            nn.Tanh(),
#             nn.ReLU(), 
            Flatten() # -> [32, 1024]  -> [32, 2048]
#             # [32, a, b, c] -> [32, abc]
        )
#         self.relu = nn.Sequential(nn.Tanh())
        
        self.fc1 = nn.Linear(h_dim, z_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(z_dim, h_dim)
        
        self.decoder = nn.Sequential(
            UnFlatten(), 
            nn.ConvTranspose2d(h_dim, 256 , kernel_size=5, stride=2),
            nn.BatchNorm2d(256),
            nn.Tanh(),
#             nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=2),
            nn.BatchNorm2d(128),
            nn.Tanh(),
#             nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
            nn.BatchNorm2d(64),
            nn.Tanh(),
#             nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),
            nn.BatchNorm2d(32),
            nn.Tanh(),
#             nn.ReLU(),
            nn.ConvTranspose2d(32, image_channels, kernel_size=6, stride=2),
            nn.BatchNorm2d(image_channels),
            nn.Sigmoid(),
        )
#         self.decoder = nn.Sequential(
#             UnFlatten(),
#             nn.ConvTranspose2d(h_dim, 128, kernel_size=5, stride=2),
#             nn.ReLU(),
#             nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2),
#             nn.ReLU(),
#             nn.ConvTranspose2d(64, 32, kernel_size=6, stride=2),
#             nn.ReLU(),
#             nn.ConvTranspose2d(32, image_channels, kernel_size=6, stride=2),
#             nn.Sigmoid(),
#         )
        
    def reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        # return torch.normal(mu, std)
        esp = torch.randn(*mu.size())
        print("std: ", std, "\tesp: ", esp)
        z = mu + std * esp
        print("mu: ", mu, "\n=> z = mu + std * esp:", z)
        print("shape mu,z = ", mu.shape, z.shape)
        return z
    
    def bottleneck(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
#         print("bottle: ",mu.shape, logvar.shape)
        z = self.reparameterize(mu, logvar)
        return z, mu, logvar

    def encode(self, x):
        print("======== Encode ========", x.shape)
        h = self.encoder(x)
#         print("enc(x): ", h.shape, print_elems_gt(h, 1))
#         print("relu(enc): ", self.relu(h))
        z, mu, logvar = self.bottleneck(h)
        print("mu.size() ", mu.size())
        return z, mu, logvar

    def decode(self, z):
#         print("======== Decode ========", z.shape)
        z = self.fc3(z)
#         print("fc3(z).shape: ", z.shape)
        z = self.decoder(z)
#         print("decode(fc3(z)).shape: ", z.shape)
        return z

    def forward(self, x):
        z, mu, logvar = self.encode(x)
#         print(z.shape)
        z = self.decode(z)
#         print(z.shape, mu.shape, logvar.shape)
        return z, mu, logvar

In [47]:
image_channels = fixed_x.size(1)

In [48]:
vae = VAE(image_channels=image_channels).to(device)
# model.load_state_dict(torch.load('vae.torch', map_location='cpu'))

In [49]:
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

In [50]:
def loss_fn(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, size_average=False)
#     BCE = F.mse_loss(recon_x, x, size_average=False)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD, BCE, KLD

In [51]:
epochs = 100

In [52]:
model_name = "graybin_bce_d16_bn"

In [53]:
for epoch in range(epochs):
    for idx, (images, _) in enumerate(dataloader):
        recon_images, mu, logvar = vae(images)
#         print(images.shape)
        
        comimg = torch.cat([images * 256.0, recon_images * 256.0])
        sample_filename = 'tmp/sample_comp_image.png'
        save_image(comimg.data.cpu(), sample_filename)
        
#         display(Image(sample_filename, width=300, unconfined=True))
        
        loss, bce, kld = loss_fn(recon_images, images, mu, logvar)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
#         to_print = "Epoch[{}/{}] Loss: {:.3f} {:.3f} {:.3f}".format(epoch+1, 
#                                 epochs, loss.data[0]/bs, bce.data[0]/bs, kld.data[0]/bs)
        to_print = "Epoch[{}/{}] Loss: {:.3f} {:.3f} {:.3f}".format(epoch+1, 
                                epochs, loss.data/bs, bce.data/bs, kld.data/bs)
    if epoch % 10 == 0 and epoch != 0:
        torch.save(vae.state_dict(), 'models/intermediate/cvae.{}-imgs_{}-epch_{}-{}'.format(model_name, len(dataset.imgs), epoch, epochs))
        
    print(to_print, loss.data, kld.data, bs)

# notify to android when finished training
notify(to_print, priority=1)

std:  tensor([[0.9576, 0.8834, 0.9826,  ..., 0.9151, 0.8702, 1.2170],
        [0.9193, 0.8649, 0.8760,  ..., 0.9361, 1.0872, 1.1162],
        [0.7559, 1.0560, 1.0645,  ..., 0.9459, 1.0682, 0.9025],
        ...,
        [1.0441, 0.9187, 0.7470,  ..., 0.9651, 0.9150, 0.9030],
        [1.0388, 0.8884, 0.8604,  ..., 0.9838, 1.2044, 1.1117],
        [1.0634, 0.8387, 0.8119,  ..., 0.9083, 0.7913, 0.8388]],
       grad_fn=<ExpBackward>) 	esp:  tensor([[ 1.4121, -0.1478, -1.4919,  ...,  0.1584,  0.7903, -0.0190],
        [-0.5662,  0.4259,  1.0544,  ..., -0.2292, -1.2518, -0.9284],
        [-0.2372,  0.4088, -1.5974,  ..., -1.0884,  0.2568, -1.7883],
        ...,
        [-1.0812,  1.5730, -1.1830,  ..., -0.0979, -0.5525,  0.4980],
        [ 1.5873, -0.4787, -1.5701,  ...,  1.0848,  0.1791, -0.4381],
        [ 1.7430, -0.3890,  0.7490,  ..., -0.1914,  0.3080, -1.7798]])
mu:  tensor([[ 0.0473,  0.2774,  0.1194,  ..., -0.1278, -0.5172,  0.0457],
        [ 0.0497, -0.1898,  0.0662,  ...,  0.6348,



std:  tensor([[0.8895, 0.9971, 1.0172,  ..., 0.9890, 0.9570, 1.2226],
        [0.7990, 1.0892, 1.0890,  ..., 1.0392, 1.0850, 0.9512],
        [0.7370, 1.1028, 1.0664,  ..., 1.1567, 1.2043, 0.8038],
        ...,
        [1.0461, 0.8029, 1.0162,  ..., 1.1279, 0.8691, 1.0710],
        [1.0690, 0.8676, 0.8821,  ..., 0.9626, 1.1363, 1.1375],
        [0.9405, 0.8469, 0.7433,  ..., 1.2200, 1.0602, 1.2639]],
       grad_fn=<ExpBackward>) 	esp:  tensor([[-0.9862,  1.2885,  0.3729,  ..., -0.1093, -1.3005,  0.6324],
        [ 1.2137, -1.3874,  0.9243,  ...,  1.7252, -0.2344,  0.2957],
        [-0.5977, -1.6963,  0.8689,  ...,  0.3904,  0.1364, -0.3610],
        ...,
        [-1.7018, -1.2100, -0.3658,  ...,  1.9016,  0.5260,  0.3705],
        [-2.0222, -0.6972,  0.1786,  ...,  1.3345,  1.2116,  0.9088],
        [-0.0176, -0.3520,  0.9913,  ..., -1.7305, -0.2471,  0.6931]])
mu:  tensor([[ 0.3612,  0.4451,  0.2418,  ..., -0.2712,  0.0520, -0.0988],
        [-0.0382,  0.5478,  0.6029,  ...,  0.0957,

KeyboardInterrupt: 

In [None]:
torch.save(vae.state_dict(), 'models/cvae.{}-imgs_{}-epch_{}'.format(model_name, len(dataset.imgs), epochs))