In [1]:
from bokeh.plotting import figure, show, output_file
from bokeh.io import output_notebook
import bokeh
import numpy as np
print('bokeh: '+bokeh.__version__)
output_notebook()


bokeh: 1.4.0


In [6]:
from toydata import *
#DISTR = "Gauss"
#DISTR = "GMM"
#DISTR = "Square"
DISTR = "Banana"
#DISTR = "2Bananas"
#DISTR = "K-2Bananas"

#PDF = True
PDF = False
#PDF_scale = "log"
PDF_scale = "normal" # can be anything
 
SAMPLING = True
#SAMPLING = False
N_SAMPLES = 50

if True:
    INV_COL = False
    PAL = "Viridis8"
else:
    INV_COL = True
    PAL = "Greys8"


IMAGE_STEP = 0.02
SIZE = 10.0
    


def draw_distribution(distr_type = "Banana", draw_pdf = True, draw_samples = True, pdf_scale = "log"):
    print(distr_type)
    XMIN, XMAX = YMIN, YMAX = -SIZE, SIZE
    x = np.arange(XMIN, XMAX, IMAGE_STEP)
    y = np.arange(YMIN, YMAX, IMAGE_STEP)
    grid_size = len(y)
    X, Y = np.meshgrid(x, y)
    
    # Gauss
    if distr_type == "Gauss":
        mean = np.array([0, 0])
        cov = np.array([[1, 0.7], [0.7, 1]])  

        Z = pdf_gauss(np.vstack([X.ravel(),Y.ravel()]).T,mean,cov)
        if draw_samples:
            sample_x, sample_y = sample_gauss(N_SAMPLES, mean, cov)

    # GMM
    elif distr_type == "GMM":
        weights = [0.9, 0.1]
        means = np.array([[1.0, 0.0],
                          [3.0, 1.7]])
        covs = np.array([[[1.0,  0.7], [ 0.7, 1.0]],
                         [[0.5, -0.2], [-0.2, 0.1]]])

        Z = pdf_gmm(np.vstack([X.ravel(),Y.ravel()]).T, weights, means, covs)
        #Z = pdf_mm(np.vstack([X.ravel(),Y.ravel()]).T, weights, pdf_gauss, [[mean, cov] for mean, cov in zip(means,covs)])    
        if draw_samples:
            sample_x, sample_y = sample_gmm(N_SAMPLES, weights, means, covs)
            #sample_x, sample_y = sample_mm(N_SAMPLES, weights, sample_gauss, [[mean, cov] for mean, cov in zip(means,covs)])


    elif distr_type == "Square":
        size = 2
        var = 0.25
        width = 13
        height = 13
        covs = np.array([[var,  0.0], [ 0.0, var]])

        Z = pdf_square_gauss(np.vstack([X.ravel(),Y.ravel()]).T, size, width, height, covs)
        if draw_samples:
            sample_x, sample_y = sample_square_gauss(N_SAMPLES, size, width, height, covs)

            
    elif distr_type == "Banana":
        mu_x, mu_y, var_x = 0, 0, 2 # mean and standard deviation for x distribution
        var_y_ratio = 1.0 / 25
        BAN_LEN = 0.5 # (BAN_LEN * |x|^BAN_CURV) => how much curved the banana will be
        BAN_CURV = 2
                
        Z = pdf_banana_grid(X, Y, mu_x, mu_y, var_x, var_y_ratio, BAN_LEN, BAN_CURV)
        #Z = pdf_banana(np.vstack([X.ravel(),Y.ravel()]).T, mu_x, mu_y, var_x, var_y_ratio, BAN_LEN, BAN_CURV)
        if draw_samples:
            sample_x, sample_y = sample_banana(N_SAMPLES, mu_x, mu_y, var_x, var_y_ratio, BAN_LEN, BAN_CURV) 

            
    elif distr_type == "2Bananas":    
        dist_betw = 3
        mu_x, mu_y, var_x = 0, 0, 2 # mean and standard deviation for x distribution
        var_y_ratio = 1.0 / 16
        BAN_LEN = 0.3 # (BAN_LEN * |x|^BAN_CURV) => how much curved the banana will be
        BAN_CURV = 1.5
        
                
        Z = pdf_2bananas_grid(X, Y, dist_betw, mu_x, mu_y, var_x, var_y_ratio, BAN_LEN, BAN_CURV)
        
        #Z = pdf_2bananas(np.vstack([X.ravel(),Y.ravel()]).T, mu_x, var_x, var_y_ratio, BAN_LEN, BAN_CURV)
        if draw_samples:
            sample_x, sample_y = sample_2bananas(N_SAMPLES, dist_betw, mu_x, mu_y, var_x, var_y_ratio, BAN_LEN, BAN_CURV) 

            
    elif distr_type == "K-2Bananas":  
        K = 2
        width = 13
        height = 13
        dist_betw = 3
        mu_x, mu_y, var_x = 0, 0, 2 # mean and standard deviation for x distribution
        var_y_ratio = 1.0 / 16
        BAN_LEN = 0.3 # (BAN_LEN * |x|^BAN_CURV) => how much curved the banana will be
        BAN_CURV = 1.5
        
                
        Z = pdf_kbananas_grid(X, Y, K, width, height, dist_betw, mu_x, mu_y, var_x, var_y_ratio, BAN_LEN, BAN_CURV)
#        Z = pdf_kbananas(np.vstack([X.ravel(),Y.ravel()]).T, K, width, height, dist_betw, mu_x, mu_y, var_x, var_y_ratio, BAN_LEN, BAN_CURV)        
        if draw_samples:
            sample_x, sample_y = sample_kbananas(N_SAMPLES, K, width, height, dist_betw, mu_x, mu_y, var_x, var_y_ratio, BAN_LEN, BAN_CURV) 

        
            
    p = figure(x_range=(XMIN, XMAX), y_range=(YMIN, YMAX),
               tooltips=[("x", "$x"), ("y", "$y"), ("pdf", "@image")])

    Z = Z.reshape(grid_size,-1)
    if draw_pdf:
        if pdf_scale == "log":
            Z = np.log(Z + 1e-20)
        p.image(image=[1-Z if INV_COL else Z], x=XMIN, y=YMIN, dw=XMAX-XMIN, dh=YMAX-YMIN, palette=PAL)

    if draw_samples:
        p.circle(sample_x, sample_y, size=2, line_color="red", fill_alpha=0.8)
    show(p)


draw_distribution(DISTR, PDF, SAMPLING, PDF_scale)

Banana


In [1]:
from __future__ import print_function
#import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from pathlib import Path
import numpy as np



#parser = argparse.ArgumentParser(description='VAE MNIST Example')
#parser.add_argument('--batch-size', type=int, default=128, metavar='N',
#                    help='input batch size for training (default: 128)')
#parser.add_argument('--epochs', type=int, default=10, metavar='N',
#                    help='number of epochs to train (default: 10)')
#parser.add_argument('--no-cuda', action='store_true', default=False,
#                    help='enables CUDA training')
#parser.add_argument('--seed', type=int, default=1, metavar='S',
#                    help='random seed (default: 1)')
#parser.add_argument('--log-interval', type=int, default=10, metavar='N',
#                    help='how many batches to wait before logging training status')
#args = parser.parse_args()
#args.cuda = not args.no_cuda and torch.cuda.is_available()
batch_size = 128
log_interval = 10
#epochs = 10
epochs = 1

#torch.manual_seed(args.seed)

#device = torch.device("cuda" if args.cuda else "cpu")
device = "cuda" if torch.cuda.is_available() else "cpu"

kwargs = {'num_workers': 1, 'pin_memory': True} if device == "cuda" else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs) # why shuffle = True?


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

#        self.fc_en_1 = nn.Linear(2, 20)
#        self.fc_en_2 = nn.Linear(20, 20)
#        self.fc_en_m = nn.Linear(20, 2)
#        self.fc_en_lv = nn.Linear(20, 2)
#        self.fc_de_1 = nn.Linear(2, 20)
#        self.fc_de_2 = nn.Linear(20, 2)
#        self.fc_de_2 = nn.Linear(20, 20)
#        self.fc_de_m = nn.Linear(20, 2)
#        self.fc_de_lv = nn.Linear(20, 2)
        
        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        #h = F.relu(self.fc_en_1(x))
        #h = F.relu(self.fc_en_2(self.fc_en_1(x)))
        #return self.fc_en_m(h), self.fc_en_lv(h)
        
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        #h = F.relu(self.fc_de_1(z))
        #h = F.relu(self.fc_de_2(self.fc_de_1(z)))        
        #return self.fc_de_m(h), self.fc_de_lv(h)
        
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        
        #mu2, logvar2 = self.decode(z)
        #out = self.reparameterize(mu2, logvar2)
        #return out, mu, logvar    
        
        return self.decode(z), mu, logvar


model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)


#TODO: consider tanh as an activation
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    #MSE = F.mse_loss(recon_x, x, reduction='sum') 

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD 
    #return MSE + KLD # consider mean instead of sum => this way depends on batch_size?


def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                print(n)
                print(data.size(0))
                comparison = torch.cat([data[:n],
                                      recon_batch.view(-1, 1, 28, 28)[:n]])
                Path("./results").mkdir(parents=True, exist_ok=True)
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test avg loss: {:.4f}'.format(test_loss))

if __name__ == "__main__":
    for epoch in range(1, epochs + 1):
        train(epoch)
        test(epoch)
        with torch.no_grad():
            sample = torch.randn(64, 20).to(device)
            sample = model.decode(sample).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       'results/sample_' + str(epoch) + '.png')


====> Epoch: 1 Average loss: 165.1697
8
128
====> Test avg loss: 127.5365


In [None]:
for epoch in range(1, epochs + 1):
        train(epoch)
        test(epoch)
        with torch.no_grad():
            sample = torch.randn(64, 20).to(device)
            sample = model.decode(sample).cpu()
            save_image(sample.view(64, 1, 28, 28),
                       'results/sample_' + str(epoch) + '.png')