In [1]:
import torch
import torchvision
import torch.optim as optim
import argparse
import matplotlib
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from tqdm import tqdm_gui
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.utils import save_image
matplotlib.style.use('ggplot')

In [2]:
features = 16
class LinearVAE(nn.Module):
    def __init__(self):
        super(LinearVAE,self).__init__()
        #encoder
        self.enc1 = nn.Linear(in_features=784,out_features=512)
        self.enc2 = nn.Linear(in_features=512,out_features=features*2)

        #decoder
        self.dec1 = nn.Linear(in_features=features,out_features=512)
        self.dec2 = nn.Linear(in_features=512,out_features=784)
    
    def reparameterize(self,mu,log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        sample = mu + (eps*std)
        return sample
    
    def forward(self,x):
        x = F.relu(self.enc1)
        x = self.enc2(x).view(-1,2,features)
        mu = x[:,0,:]
        log_var = x[:,1,:]
        z = self.reparameterize(mu,log_var)
        x = F.relu(self.dec1(z))
        reconstruction = torch.sigmoid(self.dec2(x))
        return reconstruction,mu,log_var

In [3]:
epochs = 100
batch_size = 64
lr = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
])
train_data = datasets.MNIST(
    root='data',
    train=True,
    download=True,transform=transform
)
val_data = datasets.MNIST(
    root='data',
    train=False,
    download=True,
    transform=transform
)

In [3]:
val_data[0]

(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 

In [12]:
train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True)
val_loader   = DataLoader(val_data,batch_size=batch_size,shuffle=False)

In [13]:
model = LinearVAE().to(device)
optimizer = optim.Adam(model.parameters(),lr=lr)
criterion = nn.BCELoss(reduction='sum')

In [None]:
def final_loss(bce_loss,mu,logvar):
    BCE=bce_loss