In [1]:
import sys 
sys.path.append('../')
import numpy as np 
import torch
import torch.utils.data
from torch import optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.distributions import Beta
from models import * 
from loss_functions import *
import os 
os.chdir("/ContinuousBernoulliVAE/notebooks")

IMAGE_PATH = "../images/beta/"
MODEL_PATH = "../trained_models/"


DIM = 20
EPOCHS = 100
torch.manual_seed(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False,
                   transform=transforms.ToTensor()),
    batch_size=128, shuffle=False)

model = BetaVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)


def train(epoch):
    model.train()
    train_loss = 0.
    train_loss_vals = []
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        alphas, betas, mu, logvar = model(data)
        loss = beta_loss(alphas, betas, data, mu, logvar, 1)
        loss.backward()
        train_loss += loss.item()
        train_loss_vals.append(loss.item())
        
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))
    
    train_loss /= len(train_loader.dataset)
    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss))
    
    return np.array(train_loss_vals) / len(train_loader.dataset) * len(train_loader)


def test(epoch):
    model.eval()
    test_loss = 0
    test_loss_vals = []
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            alphas, betas, mu, logvar = model(data)
            loss = beta_loss(alphas, betas, data, mu, logvar, 1)
            test_loss += loss.item()
            test_loss_vals.append(loss.item())
            
            if i == 0:
                n = min(data.size(0), 8)
                recon_dist = Beta(alphas, betas)
                # recon_batch = recon_dist.sample()
                recon_batch = recon_dist.mean
                recon_batch = recon_batch.view(128, 1, 28, 28)
                comparison = torch.cat([data[:n],
                                        recon_batch[:n]])
                
                save_image(comparison.cpu(),
                           f'{IMAGE_PATH}/reconstruction_' + str(epoch) + '.png', nrow=n)
                # plt.figure(figsize=(10, 4))
                # for i in range(1, 2*n+1):
                #     ax = plt.subplot(2,n,i)
                #     plt.imshow(comparison.cpu().detach().numpy()[i-1, 0,:,:], cmap="gray")
                #     ax.get_xaxis().set_visible(False)
                #     ax.get_yaxis().set_visible(False)
                #     ax.margins(0,0)
                # plt.savefig('betaresults/reconstruction_' + str(epoch) + '.png')
                # plt.close()

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))

    return np.array(test_loss_vals) / len(test_loader.dataset) * len(test_loader)

train_loss_vals_total = np.array([])
test_loss_vals_total = np.array([])
for epoch in range(1, EPOCHS + 1):
    train_loss_vals = train(epoch)
    test_loss_vals = test(epoch)
    train_loss_vals_total = np.append(train_loss_vals_total, train_loss_vals)
    test_loss_vals_total = np.append(test_loss_vals_total, test_loss_vals)
    with torch.no_grad():
        sample = torch.randn(64, DIM).to(device)
        sample = model.decode(sample)
        sample_dist = Beta(sample[0], sample[1])
        # sample = sample_dist.sample()
        sample = sample_dist.mean
        save_image(sample.view(64, 1, 28, 28),
                    f'{IMAGE_PATH}/sample_' + str(epoch) + '.png')

torch.save(model, f'{MODEL_PATH}/betavae.pt')
# np.save('tmp/betavae_train_loss_vals_total.npy', train_loss_vals_total)
# np.save('tmp/betavae_test_loss_vals_total.npy', test_loss_vals_total)

  alphas = 1e-6 + F.softmax(beta_params[:, :self.data_dim])
  betas = 1e-6 + F.softmax(beta_params[:, self.data_dim:])


====> Epoch: 1 Average loss: 183.3268
====> Test set loss: 143.8224
====> Epoch: 2 Average loss: 136.5745
====> Test set loss: 130.3704
====> Epoch: 3 Average loss: 128.6551
====> Test set loss: 125.2702
====> Epoch: 4 Average loss: 124.5193
====> Test set loss: 122.2050
====> Epoch: 5 Average loss: 121.8523
====> Test set loss: 120.1333
====> Epoch: 6 Average loss: 120.0029
====> Test set loss: 118.4753
====> Epoch: 7 Average loss: 118.5014
====> Test set loss: 117.3820
====> Epoch: 8 Average loss: 117.2323
====> Test set loss: 116.6561
====> Epoch: 9 Average loss: 116.3694
====> Test set loss: 115.5341
====> Epoch: 10 Average loss: 115.5104
====> Test set loss: 115.0957
====> Epoch: 11 Average loss: 114.9746
====> Test set loss: 114.6848
====> Epoch: 12 Average loss: 114.5066
====> Test set loss: 114.1828
====> Epoch: 13 Average loss: 114.1324
====> Test set loss: 114.0802
====> Epoch: 14 Average loss: 113.6643
====> Test set loss: 113.3993
====> Epoch: 15 Average loss: 113.3868
====