In [None]:
import torch
import torchvision
import torch.optim as optim
import argparse
import torch.nn as nn
import matplotlib.pyplot as plt
import reimp
from tqdm import tqdm
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.utils import save_image
import numpy as np
import os

parser = argparse.ArgumentParser()
parser.add_argument('--gamma1', default=1000, type=int,
                    help="Lower gamma for KL divergence")
parser.add_argument('--gamma2', default=1000, type=int,
                    help="Upper gamma for KL divergence")
parser.add_argument('--C_max', default=25, type=int,
                    help="Capacity bottleneck")
parser.add_argument('-e', '--epochs', default=10, type=int,
                    help='number of epochs to train VAE for')
parser.add_argument('-f', '--ff', help="Dummy arg")
args = vars(parser.parse_args())

In [None]:
# Some hyperparameters
lr = 0.0005
batch_size = 64
gamma1 = args['gamma1']
gamma2 = args['gamma2']
# C = args['C_max']
C = 1
epochs = args['epochs']

root = os.path.abspath(os.getcwd(
) + '/dsprites-dataset-master/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
data = np.load(root)
data = torch.from_numpy(data['imgs']).float()


class CustomDataset(Dataset):
    """DSprites Dataset"""

    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.data.size(0)


dataset = CustomDataset(data)

train_set, test_set = random_split(
    dataset, [663552, 73728], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True
)

test_loader = DataLoader(
    test_set,
    batch_size=batch_size,
    shuffle=False
)

model = reimp.ReImp()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [None]:
def train(model, dataloader, gamma, C):
    model.train()
    running_loss = 0.0
    for data in dataloader:
        data = data.unsqueeze(1)
        optimizer.zero_grad()
        reconstruction, mu, logvar = model.forward(data)
        loss = model.final_loss(reconstruction, data, mu, logvar, gamma, C)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
    train_loss = running_loss/len(dataloader.dataset)
    return train_loss


def validate(model, dataloader, gamma, C, epoch):
    model.eval()
    running_loss = 0.0
    for data in dataloader:
        data = data.unsqueeze(1)
        optimizer.zero_grad()
        reconstruction, mu, logvar = model.forward(data)
        loss = model.final_loss(reconstruction, data, mu, logvar, gamma, C)

        running_loss += loss.item()

        # save last batch input/output of each epoch
        if i == int(len(test_set)/dataloader.batch_size) - 1:
            num_rows = 8
            both = torch.cat((data.view(batch_size, 1, 64, 64)[:8],
                              reconstruction.view(batch_size, 1, 64, 64)[:8]))
            save_image(both.cpu(), f"../outputs/output{gamma}-{C}-{epoch}.png",
                       nrow=num_rows)

        val_loss = running_loss/len(dataloader.dataset)
        return val_loss

In [None]:
train_loss = []
val_loss = []
if gamma1 == gamma2:
    for i in range(C):
        for epoch in range(epochs):
            print(f"Epoch {epoch + 1} of {epochs}")
            train_epoch_loss = train(model, train_loader, gamma1, C)
            val_epoch_loss = validate(model, test_loader, gamma1, C, epoch)
            train_loss.append(train_epoch_loss)
            val_loss.append(val_epoch_loss)
            print(f"C: {C}Train Loss: {train_epoch_loss:.4f}")
            print(f"C: {C}Val Loss: {val_epoch_loss:.4f}")