In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.datasets as datasets  # Standard datasets
from tqdm import tqdm
from torch import nn, optim
#from model import VariationalAutoEncoder
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader

In [20]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
INPUT_DIM = 784
H_DIM = 250
Z_DIM = 20
NUM_EPOCHS = 20
BATCH_SIZE = 128
LR_RATE = 1e-4 

In [3]:
class VariationalAutoEncoder(nn.Module):
    def __init__(self, input_dim, h_dim=200, z_dim=20):
        super().__init__()
        # encoder
        self.img_2hdim = nn.Linear(input_dim, h_dim)
        self.hdim_2mu = nn.Linear(h_dim, z_dim)
        self.hdim_2sigma = nn.Linear(h_dim, z_dim)
        
        # decoder
        self.z_2hdim = nn.Linear(z_dim, h_dim)
        self.hdim_2img = nn.Linear(h_dim, input_dim)
        
        self.relu = nn.ReLU()
    def encode(self,x):
        h = self.img_2hdim(x)
        h = self.relu(h)
        mu = self.hdim_2mu(h)
        log_sigma = self.hdim_2sigma(h)
        return mu, log_sigma
    def decode(self,z):
        h = self.z_2hdim(z)
        h = self.relu(h)
        x = self.hdim_2img(h)
        return torch.sigmoid(x)
        # sigmoid done cause images are normalized between 0 and 1
    def forward(self,x):
        mu, log_sigma = self.encode(x)
        epsilon = torch.randn_like(log_sigma)
        z_reparametrized = mu + torch.exp(0.5*log_sigma)*epsilon
        x_reconstructed = self.decode(z_reparametrized)
        return x_reconstructed, mu, log_sigma

In [4]:
x = torch.randn(4,28*28)
v = VariationalAutoEncoder(784)
x_regen,mu,sigma = v(x)

In [5]:
x_regen.shape

torch.Size([4, 784])

In [6]:
mu.shape

torch.Size([4, 20])

In [7]:
sigma.shape

torch.Size([4, 20])

In [21]:
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=False)
# .ToTensor() also divides the pixel values by 255
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = nn.BCELoss(reduction="sum")

In [22]:
for epoch in range(NUM_EPOCHS):
    loop = tqdm(enumerate(train_loader))
    for i,(x,y) in loop:
        x = x.to(DEVICE)
        x = x.view(x.shape[0], INPUT_DIM)
        x_reconstructed, mu, log_sigma = model(x)
        reconstruction_loss = loss_fn(x_reconstructed, x)
        kl_div = -0.5*torch.sum(1+log_sigma - mu.pow(2) - log_sigma.exp())
        loss = reconstruction_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loop.set_postfix(loss=loss.item())

469it [00:06, 69.24it/s, loss=2.01e+4]
469it [00:06, 69.63it/s, loss=1.68e+4]
469it [00:06, 69.92it/s, loss=1.56e+4]
469it [00:06, 69.89it/s, loss=1.43e+4]
469it [00:06, 68.91it/s, loss=1.39e+4]
469it [00:07, 66.84it/s, loss=1.35e+4]
469it [00:07, 66.91it/s, loss=1.31e+4]
469it [00:06, 68.61it/s, loss=1.25e+4]
469it [00:06, 69.35it/s, loss=1.25e+4]
469it [00:06, 69.21it/s, loss=1.22e+4]
469it [00:06, 69.91it/s, loss=1.22e+4]
469it [00:06, 69.49it/s, loss=1.16e+4]
469it [00:06, 69.96it/s, loss=1.09e+4]
469it [00:06, 69.94it/s, loss=1.14e+4]
469it [00:06, 69.61it/s, loss=1.18e+4]
469it [00:06, 69.86it/s, loss=1.12e+4]
469it [00:06, 70.22it/s, loss=1.11e+4]
469it [00:06, 70.20it/s, loss=1.09e+4]
469it [00:06, 69.71it/s, loss=1.13e+4]
469it [00:06, 69.91it/s, loss=1.1e+4] 


In [23]:
def inference(digit, num_examples=1):
    DEVICE = torch.device("cpu")
    model.to(DEVICE)
    """
    Generates (num_examples) of a particular digit.
    Specifically we extract an example of each digit,
    then after we have the mu, sigma representation for
    each digit we can sample from that.

    After we sample we can run the decoder part of the VAE
    and generate examples.
    """
    images = []
    idx = 0
    for x, y in dataset:
        if y == idx:
            images.append(x)
            idx += 1
        if idx == 10:
            break

    encodings_digit = []
    for d in range(10):
        with torch.no_grad():
            mu, sigma = model.encode(images[d].view(1, 784))
        encodings_digit.append((mu, sigma))

    mu, sigma = encodings_digit[digit]
    for example in range(num_examples):
        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon
        out = model.decode(z)
        out = out.view(-1, 1, 28, 28)
        save_image(out, f"generated_{digit}_ex{example}.png")


In [24]:
for idx in range(10):
    inference(idx, num_examples=5)

In [9]:
print("Start training VAE...")
model.train()
DEVICE = torch.device("cuda")
model.to(DEVICE)
for epoch in range(NUM_EPOCHS):
    overall_loss = 0
    for batch_idx, (x, _) in enumerate(train_loader):
        x = x.to(DEVICE)
        x = x.view(x.shape[0], INPUT_DIM)

        optimizer.zero_grad()

        x_hat, mean, log_var = model(x)
        loss = loss_fn(x_hat,x)
        
        overall_loss = loss.item()
        kl_div = -0.5*torch.sum(1+log_var - mean.pow(2) - log_var.exp())
        loss.backward()
        optimizer.step()
        
    print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ", overall_loss)
    
print("Finish!!")

Start training VAE...
	Epoch 1 complete! 	Average Loss:  18243.09375
	Epoch 2 complete! 	Average Loss:  14377.634765625
	Epoch 3 complete! 	Average Loss:  13416.01953125
	Epoch 4 complete! 	Average Loss:  12115.57421875
	Epoch 5 complete! 	Average Loss:  11224.568359375
Finish!!


In [None]:
/ (batch_idx*BATCH_SIZE)