<a href="https://colab.research.google.com/github/mehdii190/neural-network/blob/main/src/variational_AE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import save_image

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
device

In [None]:
image_size = 784
hidden_dim = 400
latent_dim = 20
batch_size = 128
epochs = 10


train_dataset = torchvision.datasets.MNIST(root = "/data",
                                           train = True,
                                           transform = transforms.ToTensor(),
                                           download = True)

test_dataset = torchvision.datasets.MNIST(root = "/data",
                                           train = False,
                                           transform = transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                           batch_size =batch_size,
                                           shuffle = True)


test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                           batch_size =batch_size,
                                           shuffle = True)


sample_dir = "results"
if not os.path.exists(sample_dir):
  os.makedirs(sample_dir)



In [None]:
#vae model


class VAE(nn.Module):
  def __init__(self):
    super(VAE,self).__init__()

    self.fc1 = nn.Linear(image_size, hidden_dim)
    self.fc2_mean = nn.Linear(image_size, hidden_dim)
    self.fc2_logvar = nn.Linear(image_size, hidden_dim)
    self.fc3 = nn.Linear(image_size, hidden_dim)
    self.fc4 = nn.Linear(image_size, hidden_dim)

  def encode(self, x):
    h = F.relu(self.fc1(x))
    mu = self.fc2_mean(h)
    log_var = self.fc2_logvar(h)
    return mu , log_var
  
  def reparameterize(self, mu , logvar):
    std = torch.exp(logvar/2)
    eps = torch.randn_like(std)
    return mu + eps * std

  def decode(self, z):
    h = F.relu(self.fc3(z))
    out = torch.sigmoid(self.fc4(h))
    return out

  def forward(self, x):

    # x : (batch size , 1,28,28) ==> (batch size, 784)

    mu ,logvar= self.encode(x.view(-1,image_size))
    z = self.reparameterize(mu, logvar)
    recon = self.decode(z)
    return  recon , mu , logvar



model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)




In [None]:
model