In [6]:
from torchvision import datasets,transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.auto import tqdm
import torch.nn.functional as F

In [7]:
cfg = ({'batch_size': 64,
        'epoch' : 40,
        'lr' : 1e-4
        })

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [8]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(0.5,1.0)])

In [9]:
root = '../.data'

train_dataset = datasets.MNIST(root,transform=transform,train=True,download=True)
test_dataset = datasets.MNIST(root,transform=transform,train=False,download=True)

In [10]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=cfg['batch_size'], shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=cfg['batch_size'], shuffle=False)

In [11]:
from AE import AutoEncoder, VAE, VAE_Loss, CVAE, CVAE_Loss
model = CVAE(x_dim=784, h_dim1= 512, h_dim2=256, z_dim=2, c_dim=train_loader.dataset.train_labels.unique().size(0))



In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(),lr=cfg['lr'])

model.eval()
for epoch in tqdm(range(cfg['epoch'])):
    model.train()
    train_loss = 0
    for x,label in tqdm(train_loader):
        x = x.to(device).view(-1,28*28)
        y = x.to(device).view(-1,28*28)
        label = label.to(device)

        optimizer.zero_grad()
        encoded,decoded = model(x)

        loss = criterion(decoded,y)
        loss.backward()
        optimizer.step()

        train_loss += loss
    print(f'train_loss : {train_loss}')

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for image,label in tqdm(test_loader):
            x = x.to(device).view(-1,28*28)
            y = x.to(device).view(-1,28*28)
            label = label.to(device)

            encoded,output = model(x)

            loss = criterion(output,y)
            val_loss += loss
        print(f'valid_loss : {val_loss}')
        f, a = plt.subplots(2, 5, figsize=(5, 2))
        print(f'{epoch+1} epoch completed') 
        for i in range(5):
            img = np.reshape(x.data.to("cpu").numpy()[i],(28, 28))
            a[0][i].imshow(img, cmap='gray')
            a[0][i].set_xticks(()); a[0][i].set_yticks(())

        for i in range(5):
            img = np.reshape(output.to("cpu").data.numpy()[i], (28, 28)) 

            a[1][i].imshow(img, cmap='gray')
            a[1][i].set_xticks(()); a[1][i].set_yticks(())
        plt.show()