In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import transforms
from torchvision.utils import save_image

In [3]:
#Define hyperparameters
image_size = 784
hiden_dim = 400
latent_dim = 20
batch_size = 128
epochs = 10

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)


In [4]:
# Create directory to save images
sample_dir = 'results'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
    

In [None]:
# VAE model

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(image_size, hiden_dim)
        self.fc2_mean = nn.Linear(hiden_dim, latent_dim)
        self.fc2_logvar = nn.Linear(hiden_dim, latent_dim)
        self.fc3 = nn.Linear(latent_dim, hiden_dim)
        self.fc4 = nn.Linear(hiden_dim, image_size)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2_mean(h), self.fc2_logvar(h)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(logvar/2) # multiply logvar by 0.5 and then take exponential
        eps = torch.randn_like(std) # samples of the shape of standard deviation
        return mu + eps*std
    
    def decode(self, z):
        h = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h))  #output is in the range of 0 to 1
    
    def forward(self, x):
        # x: (batch_size, 1 gray image, 28, 28) -> (batch_size, 28*28=784)
        mu, logvar = self.encode(x.view(-1, image_size)) #flatten the input x.view(-1, image_size)
        z = self.reparameterize(mu, logvar)
        reconstruction = self.decode(z)
        return reconstruction, mu, logvar
    
# Define model
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
def detach(states):
    """
If we have a tensor z,'z.detach()' returns a tensor that shares the same storage
as 'z', but with the computation history forgotten. It doesn't know anything
about how it was computed. In other words, we have broken the tensor z away from its past history
Here, we want to perform truncated Backpropagation
TBPTT splits the 1,000-long sequence into 50 sequences (say) each of length 20 and treats each sequence of length 20 as 
a separate training case. This is a sensible approach that can work well in practice, but it is blind to temporal 
dependencies that span more than 20 timesteps.
    """
    return [state.detach() for state in states] 

In [10]:
# Define loss function

def loss_function(reconstructed_image, orifginal_image, mu, logvar):
    BCE = F.binary_cross_entropy(reconstructed_image, orifginal_image.view(-1, image_size), reduction='sum') #reduction='sum' to get the sum of the loss
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) #KL divergence
    return BCE + KLD

# Train model
def train(epoch):
    model.train()
    train_loss = 0
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        images = images.view(-1, image_size)
        optimizer.zero_grad()
        reconstructed_image, mu, logvar = model(images)
        loss = loss_function(reconstructed_image, images, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if i % 100 == 0:
            print('Epoch [{}/{}], Batch [{}/{}], Loss: {:.3f}'.format(epoch, epochs, i, len(train_loader), loss.item()/len(images)))
    
    print('Epoch [{}/{}], Average Loss: {:.3f}'.format(epoch, epochs, train_loss/len(train_loader.dataset)))
    
    

In [11]:
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_index, (images, _) in enumerate(test_loader):
            images = images.to(device)
            reconstructed_image, mu, logvar = model(images)
            test_loss += loss_function(reconstructed_image, images, mu, logvar).item()
            
            if batch_index == 0: #first batch at each epoch
                n = min(images.size(0), 8)
                comparison = torch.cat([images[:n], reconstructed_image.view(batch_size, 1, 28, 28)[:n]]) #compare original and reconstructed images -- just take the first n images
                save_image(comparison.cpu(), os.path.join(sample_dir, 'reconstruction_image_{}.png'.format(epoch)), nrow=n)
    
    test_loss /= len(test_loader.dataset)
    print('Test Loss: {:.3f}'.format(test_loss))


In [12]:
# Main function
for epoch in range(1, epochs+1):
    train(epoch)
    test(epoch)
    
    with torch.no_grad():
        # Get rid off the encoder and sample some values from gaussian distribution as input to the decoder
        # This will generate the new images
        sample = torch.randn(64, latent_dim).to(device) #sample 64 images
        sample = model.decode(sample).cpu() #generate images
        save_image(sample.view(64, 1, 28, 28), os.path.join(sample_dir, 'sample_image_{}.png'.format(epoch))) #save the images and reshape them to 1 channel, 28x28 size

Epoch [1/10], Batch [0/469], Loss: 104.146
Epoch [1/10], Batch [100/469], Loss: 103.160
Epoch [1/10], Batch [200/469], Loss: 105.547
Epoch [1/10], Batch [300/469], Loss: 106.239
Epoch [1/10], Batch [400/469], Loss: 105.720
Epoch [1/10], Average Loss: 105.031
Test Loss: 104.736
Epoch [2/10], Batch [0/469], Loss: 109.489
Epoch [2/10], Batch [100/469], Loss: 103.378
Epoch [2/10], Batch [200/469], Loss: 109.503
Epoch [2/10], Batch [300/469], Loss: 106.278
Epoch [2/10], Batch [400/469], Loss: 107.800
Epoch [2/10], Average Loss: 104.799
Test Loss: 104.428
Epoch [3/10], Batch [0/469], Loss: 105.047
Epoch [3/10], Batch [100/469], Loss: 106.196
Epoch [3/10], Batch [200/469], Loss: 104.216
Epoch [3/10], Batch [300/469], Loss: 109.087
Epoch [3/10], Batch [400/469], Loss: 102.323
Epoch [3/10], Average Loss: 104.589
Test Loss: 104.253
Epoch [4/10], Batch [0/469], Loss: 103.375
Epoch [4/10], Batch [100/469], Loss: 106.833
Epoch [4/10], Batch [200/469], Loss: 107.945
Epoch [4/10], Batch [300/469], Lo

In [None]:
#improvementa
#1. Add more layers to the model    
#2. Use dropouts
#3. Use learning grade decay
#4. Use more epochs
