# Dataset

In [31]:
import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim

import numpy as np
import time
from tqdm import tqdm

In [32]:

batch_size = 256
transforms_ = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])

def train_data():
    train_dataset = datasets.FashionMNIST(root="../data/", train=True, download=True, transform=transforms_)

    return train_dataset

def test_data():
    test_dataset = datasets.FashionMNIST(root="../data/", train=False, download=True, transform=transforms.ToTensor() )
    
    return test_dataset

# Model

In [69]:
class VariationalAE(nn.Module):
    def __init__(self, size=[128, 64, 32], latent_dim=16):
        super().__init__()
        self.size = size

        self.encoder = nn.Sequential(
            # 1x28x28
            nn.Conv2d(in_channels=1, out_channels=size[0], kernel_size=5, stride=1),
            nn.ReLU(),
            # 128x24x24
            nn.Conv2d(in_channels=size[0], out_channels=size[1], kernel_size=5, stride=1),
            # 64x20x20
            nn.ReLU(),
            nn.Conv2d(in_channels=size[1], out_channels=size[2], kernel_size=5, stride=1),
            # 32x20x20
        )

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        # 32x8x8

        self.fc = nn.Sequential(
            # encoder
            nn.Linear(in_features=size[2]*8*8, out_features=latent_dim),
            # dec
            nn.Linear(in_features=latent_dim, out_features=size[2]*8*8)
        )

        self.fc1_enc = nn.Linear(in_features=size[2]*8*8, out_features=128)
        self.fc2_mean = nn.Linear(in_features=128, out_features=latent_dim)
        self.fc2_var = nn.Linear(in_features=128, out_features=latent_dim)

        self.fc_dec1 = nn.Linear(in_features=latent_dim, out_features=128)
        self.fc_dec2 = nn.Linear(in_features=128, out_features=size[2]*8*8)

        self.out_layer = nn.MaxUnpool2d(kernel_size=2, stride=2),

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(in_channels=size[2], out_channels=size[1], kernel_size=5),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=size[1], out_channels=size[0], kernel_size=5),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=size[0], out_channels=1, kernel_size=5),
            nn.Sigmoid()
        )

    def reparametrize(self, mean, var):
        eps = torch.randn(var.shape)
        z = eps * var + mean
        return z

    def forward(self, input):
        out_enc = self.encoder(input)
        out_pool, out_indices = self.pool(out_enc)
        
        #flattening
        batch_size = out_enc.shape[0]
        out_flattened = out_pool.view(batch_size, -1)
    
        out_fc1 = self.fc1_enc(out_flattened)
        mean = self.fc2_mean(out_fc1)
        var = self.fc2_var(out_fc1)
        z = self.reparametrize(mean, var)
    
        out_fc_dec1 = self.fc_dec1(z)
        out_fc_dec2 = self.fc_dec2(out_fc_dec1)
     
        # unflattening
        out_unflattened = out_fc_dec2.view(out_pool.shape)
     
        out_unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)(out_unflattened, out_indices)
        
        recons = self.decoder(out_unpool)
        
        return recons
    
model = VariationalAE()

model

VariationalAE(
  (encoder): Sequential(
    (0): Conv2d(1, 128, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(128, 64, kernel_size=(5, 5), stride=(1, 1))
    (3): ReLU()
    (4): Conv2d(64, 32, kernel_size=(5, 5), stride=(1, 1))
  )
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc): Sequential(
    (0): Linear(in_features=2048, out_features=16, bias=True)
    (1): Linear(in_features=16, out_features=2048, bias=True)
  )
  (fc1_enc): Linear(in_features=2048, out_features=128, bias=True)
  (fc2_mean): Linear(in_features=128, out_features=16, bias=True)
  (fc2_var): Linear(in_features=128, out_features=16, bias=True)
  (fc_dec1): Linear(in_features=16, out_features=128, bias=True)
  (fc_dec2): Linear(in_features=128, out_features=2048, bias=True)
  (decoder): Sequential(
    (0): ConvTranspose2d(32, 64, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): ConvTranspose2d(64, 128, kernel_size=(5, 5), stride=(1, 1))
    (

# Training 

In [39]:
learning_rate = 3e-4
EPOCHS = 10
optimizerName = "Adam"

In [37]:
#load data
train_dataset = train_data()
test_dataset = test_data()
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle=False)

#setup model 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = optim.Adam(model.parameters(), learning_rate)
lossFunction = nn.MSELoss()
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)      

#save images


In [72]:
def train_epoch():
    train_losses = []

    model.train()

    for batchId, (images, labels) in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        images, labels = images.to(device), labels.to(device)

        recons = model(images)
        loss = lossFunction(recons, images)
        loss.backward()
        train_losses.append(loss.item())

        optimizer.step()

    train_loss = np.mean(train_losses)

    return train_loss

@torch.no_grad()
def test_epoch():
    test_losses = []

    for batchId, (images, _) in enumerate(tqdm(test_loader)):
        images = images.to(device)

        recons = model(images)
        loss = lossFunction(recons, images)
        test_losses.append(loss.item())

    test_loss = np.mean(test_losses)

    return test_loss

for epoch in range(EPOCHS):
    start_training_time = time.time()
    print(f"--------------EPOCH {epoch}--------------")
    #train_loss = train_epoch()
    #print(f"Training loss: {train_loss}, training time: {time.time() - start_training_time}")
    test_loss = test_epoch()
    print(f"Testing loss: {test_loss}")

    scheduler.step()






  0%|          | 0/40 [00:00<?, ?it/s][A[A[A

--------------EPOCH 0--------------





  2%|▎         | 1/40 [00:05<03:45,  5.79s/it][A[A[A


  5%|▌         | 2/40 [00:09<03:15,  5.15s/it][A[A[A


  8%|▊         | 3/40 [00:14<03:06,  5.04s/it][A[A[A


 10%|█         | 4/40 [00:17<02:46,  4.63s/it][A[A[A


 12%|█▎        | 5/40 [00:24<02:59,  5.12s/it][A[A[A


 15%|█▌        | 6/40 [00:29<02:57,  5.23s/it][A[A[A


 18%|█▊        | 7/40 [00:36<03:05,  5.61s/it][A[A[A


 20%|██        | 8/40 [00:42<03:07,  5.85s/it][A[A[A


 22%|██▎       | 9/40 [00:48<03:00,  5.82s/it][A[A[A


 25%|██▌       | 10/40 [00:54<03:01,  6.07s/it][A[A[A


 28%|██▊       | 11/40 [00:59<02:43,  5.62s/it][A[A[A


 30%|███       | 12/40 [01:05<02:42,  5.80s/it][A[A[A


 32%|███▎      | 13/40 [01:12<02:44,  6.09s/it][A[A[A


 35%|███▌      | 14/40 [01:18<02:36,  6.00s/it][A[A[A


 38%|███▊      | 15/40 [01:24<02:34,  6.16s/it][A[A[A


 40%|████      | 16/40 [01:29<02:17,  5.73s/it][A[A[A


 42%|████▎     | 17/40 [01:35<02:10,  5.68s/it][A[A[A


 45

Testing loss: 0.16791316010057927
--------------EPOCH 1--------------





  2%|▎         | 1/40 [00:08<05:28,  8.41s/it][A[A[A


  5%|▌         | 2/40 [00:15<05:00,  7.92s/it][A[A[A


  8%|▊         | 3/40 [00:21<04:30,  7.31s/it][A[A[A


 10%|█         | 4/40 [00:24<03:44,  6.24s/it][A[A[A


 12%|█▎        | 5/40 [00:27<03:04,  5.26s/it][A[A[A


 15%|█▌        | 6/40 [00:30<02:35,  4.57s/it][A[A[A


 18%|█▊        | 7/40 [00:33<02:15,  4.10s/it][A[A[A


 20%|██        | 8/40 [00:36<01:59,  3.72s/it][A[A[A


 22%|██▎       | 9/40 [00:39<01:50,  3.57s/it][A[A[A


 25%|██▌       | 10/40 [00:43<01:49,  3.65s/it][A[A[A


 28%|██▊       | 11/40 [00:47<01:49,  3.76s/it][A[A[A


 30%|███       | 12/40 [00:51<01:42,  3.65s/it][A[A[A


 32%|███▎      | 13/40 [00:54<01:33,  3.45s/it][A[A[A


 35%|███▌      | 14/40 [00:56<01:25,  3.29s/it][A[A[A


 38%|███▊      | 15/40 [00:59<01:20,  3.21s/it][A[A[A


 40%|████      | 16/40 [01:03<01:18,  3.29s/it][A[A[A


 42%|████▎     | 17/40 [01:08<01:29,  3.90s/it][A[A[A


 45

Testing loss: 0.16791329458355903
--------------EPOCH 2--------------





  2%|▎         | 1/40 [00:06<04:25,  6.80s/it][A[A[A


  5%|▌         | 2/40 [00:12<04:10,  6.60s/it][A[A[A


  8%|▊         | 3/40 [00:19<04:01,  6.51s/it][A[A[A


 10%|█         | 4/40 [00:24<03:45,  6.26s/it][A[A[A

KeyboardInterrupt: 

In [73]:
test_loss

0.16791329458355903

In [74]:
import matplotlib.pyplot as plt

plt.style.use("seaborn")
fig, ax = plt.subplots(1, 2)

epochs = np.arange(EPOCHS)
ax[0].plot(epochs, train_loss, c="green", label="Train loss")
ax[0].plot(epochs, valid_loss, c="blue", label="Test loss")
ax[0].legend(loc="best")
ax[0].set_xlabel("Epochs")
ax[0].set_ylabel("MSE Loss")
ax[0].set_title("Loss curves")

plt.show()

ValueError: x and y must have same first dimension, but have shapes (10,) and (1,)