In [None]:
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torch.autograd import Variable
from matplotlib import pyplot as plt

In [None]:
# Convert vector to image
def to_img(x):
    x = 0.5 * (x + 1)
    x = x.view(x.size(0), 28, 28)
    return x

In [None]:
bs = 1000

transform = transforms.Compose([transforms.ToTensor()])

traindataset = FashionMNIST('./data', download=True, train=True, transform=transform)
testdataset = FashionMNIST('./data', download=True, train=False, transform=transform)

# Loaders
trainloader = torch.utils.data.DataLoader(traindataset, batch_size=bs, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(testdataset, batch_size=bs, shuffle=False, num_workers=4)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
Processing...
Done!


In [None]:
# Displaying routine
def display_images(in_, out, n=1):
    for N in range(n):
        if in_ is not None:
            in_pic = to_img(in_.cpu().data)
            plt.figure(figsize=(18, 6))
            for i in range(4):
                plt.subplot(1,4,i+1)
                plt.imshow(in_pic[i+4*N])
                plt.axis('off')
        out_pic = to_img(out.cpu().data)
        plt.figure(figsize=(18, 6))
        for i in range(4):
            plt.subplot(1,4,i+1)
            plt.imshow(out_pic[i+4*N])
            plt.axis('off')

In [None]:
from torch.autograd import Variable
plt.style.use('default')
%matplotlib inline 

24.0

**Variational AutoEncoder**

In [None]:
def conv_calc(size, padding, filter_, stride):
  return (size + 2*padding - filter_)/stride + 1

In [None]:
conv_calc(size=28, padding=1, filter_=3, stride=1)

28.0

In [None]:
def pad(size, filter_, stride):
  return ((size-1)*stride + filter_ - size) / 2

In [None]:
pad(size=28, filter_=3, stride=1)

1.0

In [None]:
class View(nn.Module):
  def __init__(self, shape):
    super(View, self).__init__()
    self.shape = shape

  def forward(self, x):
    return x.view(*self.shape)

In [None]:
class VarAutoEncoder(nn.Module):
  def __init__(self, no_latent = 16):
    super(VarAutoEncoder, self).__init__()
    self.no_latent = no_latent
    self.encoder = nn.Sequential(
      nn.Conv2d(in_channels=1, out_channels=64, kernel_size=4, stride=2, padding=15),
      nn.LeakyReLU(),
      nn.Dropout(p=0.2, inplace=True),
      nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=15),
      nn.LeakyReLU(),
      nn.Dropout(p=0.2, inplace=True),
      nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
      nn.LeakyReLU(),
      nn.Dropout(p=0.2, inplace=True)
      )

    self.linear1 = nn.Linear(in_features=(28*28*64), out_features=self.no_latent, bias=True)
    self.linear2 = nn.Linear(in_features=(28*28*64), out_features=self.no_latent, bias=True)
     
    self.decoder = nn.Sequential(
      nn.Linear(in_features=self.no_latent, out_features=1000, bias=True),
      nn.LeakyReLU(),
      nn.Linear(in_features=1000, out_features=(28*28*64), bias=True),
      nn.LeakyReLU(),
      # nn.Linear(in_features=self.no_latent, out_features=inputs_decoder * 2 + 1, bias=True),
      # nn.LeakyReLU(),
      View((-1, 64, 28, 28)),
      nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
      nn.LeakyReLU(),
      nn.Dropout(p=0.2, inplace=True),
      nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=15),
      nn.ReLU(),
      nn.Dropout(p=0.2, inplace=True),
      nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=4, stride=2, padding=15),
      # nn.ReLU(),
      # nn.Flatten(),
      # nn.Linear(in_features=64, out_features=28*28, bias=True),
      nn.Sigmoid()
      )

  def reparameterize(self, mu, logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return mu + eps*std

  
  def forward(self, x):
    x = self.encoder(x)
    mu = self.linear1(x.view(-1, 28*28*64))
    logvar = self.linear2(x.view(-1, 28*28*64))

    z = self.reparameterize(mu, logvar)
    x = self.decoder(z.view(-1, self.no_latent))
    #x = x.view(-1, 28, 28)
    return x, mu, logvar

In [None]:
ngf = 64
ndf = 64
nc = 1

class VAE(nn.Module):
    def __init__(self, nz):
        super(VAE, self).__init__()

        self.have_cuda = False
        self.nz = nz

        self.encoder = nn.Sequential(
            # input is (nc) x 28 x 28
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 14 x 14
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 7 x 7
            nn.Conv2d(ndf * 2, ndf * 4, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 4 x 4
            nn.Conv2d(ndf * 4, 1024, 4, 1, 0, bias=False),
            # nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True),
            # nn.Sigmoid()
        )

        self.decoder = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d(     1024, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 3, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d(ngf * 2,     nc, 4, 2, 1, bias=False),
            # nn.BatchNorm2d(ngf),
            # nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            # nn.ConvTranspose2d(    ngf,      nc, 4, 2, 1, bias=False),
            # nn.Tanh()
            nn.Sigmoid()
            # state size. (nc) x 64 x 64
        )

        self.fc1 = nn.Linear(1024, 512)
        self.fc21 = nn.Linear(512, nz)
        self.fc22 = nn.Linear(512, nz)

        self.fc3 = nn.Linear(nz, 512)
        self.fc4 = nn.Linear(512, 1024)

        self.lrelu = nn.LeakyReLU()
        self.relu = nn.ReLU()
        # self.sigmoid = nn.Sigmoid()

    def encode(self, x):
        conv = self.encoder(x);
        # print("encode conv", conv.size())
        h1 = self.fc1(conv.view(-1, 1024))
        # print("encode h1", h1.size())
        return self.fc21(h1), self.fc22(h1)

    def decode(self, z):
        h3 = self.relu(self.fc3(z))
        deconv_input = self.fc4(h3)
        # print("deconv_input", deconv_input.size())
        deconv_input = deconv_input.view(-1,1024,1,1)
        # print("deconv_input", deconv_input.size())
        return self.decoder(deconv_input)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)

    def forward(self, x):
        # print("x", x.size())
        mu, logvar = self.encode(x)
        # print("mu, logvar", mu.size(), logvar.size())
        z = self.reparametrize(mu, logvar)
        # print("z", z.size())
        decoded = self.decode(z)
        # print("decoded", decoded.size())
        return decoded, mu, logvar

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model_vae = VAE(20).to(device)
model_vae = VarAutoEncoder().to(device)

In [None]:
model_vae

VarAutoEncoder(
  (encoder): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(15, 15))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Dropout(p=0.2, inplace=True)
    (3): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(15, 15))
    (4): LeakyReLU(negative_slope=0.01)
    (5): Dropout(p=0.2, inplace=True)
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): LeakyReLU(negative_slope=0.01)
    (8): Dropout(p=0.2, inplace=True)
  )
  (linear1): Linear(in_features=50176, out_features=16, bias=True)
  (linear2): Linear(in_features=50176, out_features=16, bias=True)
  (decoder): Sequential(
    (0): Linear(in_features=16, out_features=1000, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=1000, out_features=50176, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): View()
    (5): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): LeakyReLU(negative

In [None]:
# Configure the optimizer and criterion
learning_rate = 1e-5
optimizer = torch.optim.Adam(model_vae.parameters(), lr=learning_rate)

In [None]:
def loss_fun(x, x_bar, mu, logvar):
    BCE_loss = nn.BCELoss(reduction='sum')(x_bar, x)
    KLD = - 0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE_loss + KLD

# # Reconstruction + KL divergence losses summed over all elements and batch
# def loss_fun(recon_x, x, mu, logvar):
#     # print(recon_x.size(), x.size())
#     BCE = F.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), size_average=False)
#     # see Appendix B from VAE paper:
#     # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
#     # https://arxiv.org/abs/1312.6114
#     # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
#     KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
#     # return BCE + KLD
#     return BCE + 3 * KLD

In [None]:
num_epochs = 1000

for epoch in range(num_epochs):
    for data in trainloader:
        img, _ = data
        img = img.to(device)
        x_bar, mu, logvar = model_vae(img)
        loss = loss_fun(img.data, x_bar.data, mu, logvar)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
    print(f'epoch [{epoch + 1}/{num_epochs}], loss:{loss.item():.4f}')
    if epoch % 500 == 0:
      display_images(None, x_bar)

In [None]:
model_vae.eval()

test_images, labels = next(iter(testloader))

test_images = test_images.to(device)

x_bar, mu, sigma = model_vae(test_images)

inp = test_images.view(-1, 28, 28)

out_conv = x_bar.reshape(-1, 28, 28)


fig = plt.figure(figsize=(10, 50), )
# fig.tight_layout(pad=10)

plot = fig.add_subplot(1, 2, 1)
plot.set_title('Original Image')
imgplot = plt.imshow(inp[0].cpu(), cmap='gray')

plot = fig.add_subplot(1, 2, 2)
plot.set_title('VAE Image')
imgplot = plt.imshow(out_conv[0].cpu().detach(), cmap='gray')
plt.show()