<a href="https://colab.research.google.com/github/mehdii190/neural-network/blob/main/src/variational_AE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [107]:
import os
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import save_image

In [108]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [109]:
device

device(type='cuda')

In [110]:
image_size = 28 * 28
hidden_dim = 400
latent_dim = 20
batch_size = 128
epochs = 10


train_dataset = torchvision.datasets.MNIST(root = "/data",
                                           train = True,
                                           transform = transforms.ToTensor(),
                                           download = True)

test_dataset = torchvision.datasets.MNIST(root = "/data",
                                           train = False,
                                           transform = transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                           batch_size =batch_size,
                                           shuffle = True)


test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                           batch_size =batch_size,
                                           shuffle = True)


sample_dir = "results"
if not os.path.exists(sample_dir):
  os.makedirs(sample_dir)



In [111]:
#vae model


class VAE(nn.Module):
  def __init__(self):
    super(VAE,self).__init__()

    self.fc1 = nn.Linear(image_size, hidden_dim)
    self.fc2_mean = nn.Linear(hidden_dim, latent_dim)
    self.fc2_logvar = nn.Linear(hidden_dim, latent_dim)
    self.fc3 = nn.Linear(latent_dim, hidden_dim)
    self.fc4 = nn.Linear(hidden_dim, image_size)

  def encode(self, x):
    h = F.relu(self.fc1(x))
    mu = self.fc2_mean(h)
    log_var = self.fc2_logvar(h)
    return mu , log_var
  
  def reparameterize(self, mu , logvar):
    std = torch.exp(logvar/2)
    eps = torch.randn_like(std)
    return mu + eps * std

  def decode(self, z):
    h = F.relu(self.fc3(z))
    out = torch.sigmoid(self.fc4(h))
    return out

  def forward(self, x):

    # x : (batch size , 1,28,28) ==> (batch size, 784)

    mu ,logvar= self.encode(x.view(-1,image_size))
    z = self.reparameterize(mu, logvar)
    recon = self.decode(z)
    return  recon , mu , logvar



model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)




In [112]:
model

VAE(
  (fc1): Linear(in_features=784, out_features=400, bias=True)
  (fc2_mean): Linear(in_features=400, out_features=20, bias=True)
  (fc2_logvar): Linear(in_features=400, out_features=20, bias=True)
  (fc3): Linear(in_features=20, out_features=400, bias=True)
  (fc4): Linear(in_features=400, out_features=784, bias=True)
)

In [113]:
from io import open_code
def loss_function(recon_image,original_image, mu, logvar):
  bce = F.binary_cross_entropy(recon_image,original_image.view(-1,784),reduction="sum")

  k1d = 0.5 * torch.sum(logvar.exp()+ mu.pow(2) - 1 - logvar)

  #####################

  # logvar , exp = (batch size , 20)

  #k1d = 0.5 * torch.sum(logvar.epx()+ mu.pow(2) - 1 - logvar, 1)
  #k1d_sum = torch.sum(k1d)

  #####################


  return bce + k1d



def train(epoch):

  model.train()


  train_loss = 0

  for i , (image,_) in enumerate(train_loader):
    images = image.to(device)
    reconstructed , mu , logvar = model(images)
    loss = loss_function(reconstructed, images , mu , logvar)
    optimizer.zero_grad()
    loss.backward()
    train_loss += loss.item()
    optimizer.step()

    if i % 100 == 0:

      print("train epoch {} [batch {}/{}]\tLoss: {:.3f}".format(epoch , i, len(train_loader),loss.item()/len(images)))

  print("===> epoch {},average loss: {:.3f}".format(epoch, train_loss/len(train_loader.dataset)))


def test(epoch):

  model.eval()

  test_loss = 0
  with torch.no_grad():
    for batch_idx , (image,_) in enumerate(test_loader):
      images = image.to(device)
      reconstructed , mu , logvar = model(images)
      test_loss += loss_function(reconstructed, images , mu , logvar).item()
      if batch_idx == 0:
        comparison = torch.cat([images[:5],reconstructed.view(batch_size, 1,28,28)[:5]])
        save_image(comparison.cpu(), "results/reconstructed_"+ str(epoch)+ ".png",nrow=5)


  print("===> average test loss: {:.3f}".format(test_loss/len(test_loader.dataset)))



In [114]:
for epoch in range(1, epochs+ 1):
  train(epoch)
  test(epoch)
  with torch.no_grad():
    sample = torch.randn(64,20).to(device)
    gemerated = model.decode(sample).cpu()
    save_image(gemerated.view(64,1,28,28),"results/sample_"+str(epoch)+".png")

  

train epoch 1 [batch 0/469]	Loss: 550.191
train epoch 1 [batch 100/469]	Loss: 185.624
train epoch 1 [batch 200/469]	Loss: 150.844
train epoch 1 [batch 300/469]	Loss: 145.409
train epoch 1 [batch 400/469]	Loss: 130.062
===> epoch 1,average loss: 165.522
===> average test loss: 128.183
train epoch 2 [batch 0/469]	Loss: 134.439
train epoch 2 [batch 100/469]	Loss: 126.160
train epoch 2 [batch 200/469]	Loss: 122.253
train epoch 2 [batch 300/469]	Loss: 116.506
train epoch 2 [batch 400/469]	Loss: 115.767
===> epoch 2,average loss: 121.902
===> average test loss: 115.973
train epoch 3 [batch 0/469]	Loss: 119.924
train epoch 3 [batch 100/469]	Loss: 114.270
train epoch 3 [batch 200/469]	Loss: 114.323
train epoch 3 [batch 300/469]	Loss: 113.444
train epoch 3 [batch 400/469]	Loss: 117.549
===> epoch 3,average loss: 114.617
===> average test loss: 111.716
train epoch 4 [batch 0/469]	Loss: 114.400
train epoch 4 [batch 100/469]	Loss: 112.786
train epoch 4 [batch 200/469]	Loss: 108.591
train epoch 4 [