In [None]:
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torchvision.utils import save_image

from utils.cnn_vae_1d import VAE
from utils.datasets import ESR
import os

In [None]:
batch_size = 2
learning_rate = 1e-3
num_epochs = 10

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
dataset = ESR('D:')

In [None]:
train_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1)

In [None]:
net = VAE().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

In [None]:
for epoch in range(num_epochs):
    for x, y in train_loader:
        x = x.to(device)

        # Feeding a batch of images into the network to obtain the output image, mu, and logVar
        out, mu, logVar = net(x)

        # The loss is the BCE loss combined with the KL divergence to ensure the distribution is learnt
        kl_divergence = 0.5 * torch.sum(-1 - logVar + mu.pow(2) + logVar.exp())
        loss = F.binary_cross_entropy(out, x, reduction='sum') + kl_divergence

        # Backpropagation based on the loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('Epoch {}: Loss {}'.format(epoch, loss))

In [None]:
# import matplotlib.pyplot as plt
# import numpy as np
# import random

In [None]:
# net.eval()
# with torch.no_grad():
#     for data in random.sample(list(test_loader), 1):
#         imgs, _ = data
#         imgs = imgs.to(device)
#         img = np.transpose(imgs[0].cpu().numpy(), [1,2,0])
#         plt.subplot(121)
#         plt.imshow(np.squeeze(img))
#         out, mu, logVAR = net(imgs)
#         outimg = np.transpose(out[0].cpu().numpy(), [1,2,0])
#         plt.subplot(122)
#         plt.imshow(np.squeeze(outimg))
#         break