<a href="https://colab.research.google.com/github/daemonX10/Generative-Deep-Learning/blob/main/vae_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Dockhttps://github.com/er image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

![vae](https://lh6.googleusercontent.com/dqHH81HNI-B60vDS3u2M0jsUVo0nsUIlMoRT4GlG4w8fDTfJ5-Li0vZ08XWtuEHLW2jFR4jlwxCz8O2WLTDX5u09uOp6WEE87XmStaspZgcBbHaRB47S3tdXdkf4TzIaZsDFh-YXLl945ebwzlWnJek)

In [None]:
import torch,os
DEVICE = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
# define model hyperparameters
LR = 0.001
PATIENCE = 2
IMAGE_SIZE = 32
CHANNELS = 1
BATCH_SIZE = 64
EMBEDDING_DIM = 2
EPOCHS = 100
SHAPE_BEFORE_FLATTENING = (128, IMAGE_SIZE // 8, IMAGE_SIZE // 8)
# create output directory
output_dir = "output"
os.makedirs("output", exist_ok=True)

In [None]:
training_progress_dir = os.path.join(output_dir,"training_progres")
os.makedirs(training_progress_dir,exist_ok=True)

model_weights_dir = os.path.join(output_dir,"model_weights")
os.makedirs(model_weights_dir,exist_ok=True)

MODEL_WEIGHTS_PATH = os.path.join(model_weights_dir,'best_vae.pt')

In [None]:
FILE_RECON_BEFORE_TRAINING = os.path.join(
    output_dir, "reconstruct_before_train.png"
)
FILE_REAL_BEFORE_TRAINING = os.path.join(
    output_dir, "real_test_images_before_train.png"
)
# define reconstruction & real after training images paths
FILE_RECON_AFTER_TRAINING = os.path.join(
    output_dir, "reconstruct_after_train.png"
)
FILE_REAL_AFTER_TRAINING = os.path.join(
    output_dir, "real_test_images_after_train.png"
)
# define latent space and image grid embeddings plot paths
LATENT_SPACE_PLOT = os.path.join(output_dir, "embedding_visualize.png")
IMAGE_GRID_EMBEDDINGS_PLOT = os.path.join(
    output_dir, "image_grid_on_embeddings.png"
)
# define linearly and normally sampled latent space reconstructions plot paths
LINEARLY_SAMPLED_RECONSTRUCTIONS_PLOT = os.path.join(
    output_dir, "linearly_sampled_reconstructions.png"
)
NORMALLY_SAMPLED_RECONSTRUCTIONS_PLOT = os.path.join(
    output_dir, "normally_sampled_reconstructions.png"
)

In [None]:

# define class labels dictionary
CLASS_LABELS = {
    0: "T-shirt/top",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle boot",
}

In [None]:
import torch
import torch.nn as nn

def vae_gaussian_kl_loss(mu,logvar):
    KLD = -0.5 * torch.sum(1+logvar -mu.pow(2) - logvar.exp() , dim=1)
    return KLD.mean()

def reconstuction_loss(x_reconstructed,x):
    bce_loss = nn.BCELoss()
    return bce_loss(x_reconstructed,x)

def vae_loss(y_pred,y_true):
    mu , logvar,recon_x = y_pred
    recon_loss = reconstruction_loss(recon_x,y_true)
    kld_loss = vae_gaussian_kl_loss(mu,logvar)
    return 500 * recon_loss + kld_loss

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal

In [None]:
class Sampling(nn.Module):
    def forward(self,z_mean,z_log_var):
        batch,dim = z_mean.shape
        epsilon = Normal(0,1).sample((batch,dim)).to(z_mean.device)
        return z_mean + torch.exp(0.5 * z_log_var) * epsilon

In [None]:
class Encoder(nn.Module):
    def __init__(self,image_size,embedding_dim):
        super(Encoder,self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, stride=2, padding=1)

        self.flatten = nn.Flatten()
        self.fc_mean = nn.Linear(
            128*(image_size//8) * (image_size//8),embedding_dim
        )
        self.fc_log_var = nn.Linear(128*(image_size//8)*(image_size//8),embedding_dim)
        self.sampling = Sampling()

    def forward(self,x):
        x=F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))

        x = self.flatten(x)
        z_mean = self.fc_mean(x)
        z_log_var = self.fc_log_var(x)
        z = self.sampling(z_mean,z_log_var)
        return z_mean,z_log_var,z

In [None]:
class Decoder(nn.Module):
    def __init__(self,embedding_dim,shape_before_flattening):
        super(Decoder,self).__init__()
        self.fc = nn.Linear(
            embedding_dim,
            shape_before_flattening[0] *
            shape_before_flattening[1 ] *
            shape_before_flattening[2],
        )
        self.reshape= lambda x:x.view(-1,*shape_before_flattening)
        self.deconv1 = nn.ConvTranspose2d(
            128,64,3,stride=2,padding=1,output_padding=1
        )
        self.deconv2 = nn.ConvTranspose2d(64,32,stride=2,padding=1,output_padding=1)
        self.dconv3 = nn.ConvTranspose2d(32,1,3,stride=2,padding=1,output_padding=1)

    def forward(self,x):
        x = self.fc(x)
        x = self.reshape(x)
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = torch.sigmoid(self.deconv3(x))
        return x


In [None]:
class VAE(nn.Module):
    def __init__(self,encoder,decoder):
        super(VAE,self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self,x):
        z_mean,z_log_var,z = self.encoder(x)
        reconstruction = self.decoder(z)
        return z_mean,z_log_var,reconstruction


In [None]:

from torchvision import datasets , transforms
import torch.optim as optim
import torch
import os
import matplotlib
matplotlib.use('agg')

In [None]:
transform = transforms.Compose(
    [transforms.Pad(padding=2),transforms.ToTensor]
)
trainset = datasets.FashionMNIST(
    'data',train=True,download=True,transform=transform
)

train_loader = torch.utils.data.DataLoader(
    trainset,batch_size=32,shuffle=True
)
testset = datasets.FashionMNIST(
    'data',train=False,download=True,transform=transform
)
test_loader=torch.utils.data.DataLoader(
    testset,batch_size=32,shuffle=True
)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:03<00:00, 8561441.55it/s] 


Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 139351.22it/s]


Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:01<00:00, 2554871.31it/s]


Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 8588813.44it/s]

Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw




