In [1]:
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.utils.data as data
from torchvision import datasets, transforms

from vae import VAE

# Change figure aesthetics
%matplotlib inline

In [2]:
datapath = "data/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz"

In [3]:
class DSpritesDataset(data.Dataset):
    def __init__(self, datapath, transform=None):
        self.datapath = datapath 
        dataset = np.load(self.datapath, allow_pickle=True, encoding='bytes')
        
        self.imgs = dataset['imgs']
        self.latents_values = dataset['latents_values']
        self.latents_classes = dataset['latents_classes']
        self.metadata = dataset['metadata'][()]
        
        self.transform = transform
        
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, idx):
        sample = self.imgs[idx].astype(np.float32)
        sample = sample.reshape(sample.shape + (1,))
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample, []

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
batch_size = 128
learning_rate = 1e-3
num_epochs = 10

In [6]:
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [7]:
dsprites_data = DSpritesDataset(datapath, transform=transforms.ToTensor())
dsprites_loader = data.DataLoader(dsprites_data, batch_size=batch_size, shuffle=True)

In [None]:
for epoch in tqdm(range(1, num_epochs + 1)):
    for idx, data in enumerate(dsprites_loader, 0):
        inputs, _ = data
        inputs = inputs.to(device)
        
        outputs, z, mean, logvar = model(inputs)
        
        loss = model.loss(outputs, inputs, mean, logvar)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

  0%|          | 0/10 [00:00<?, ?it/s]

In [7]:
model

VAE(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (4): Sequentia