In [1]:
from tqdm import tqdm
from train import get_ent_grad

import torch
from torch import optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from ui_vae import UIVAE

In [2]:
transform = transforms.Compose([
    transforms.Resize((14, 14)),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1)) 
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [3]:
num_epochs = 2
num_hmc_samples = 5
lr = 1e-3

model = UIVAE(dim_x = 14*14, dim_eps = 3, dim_z = 5, latent_dims = [128, 128, 256], T = 5, Ls = 5)

In [4]:
14*14 +3

199

In [5]:
model.encoder

Sequential(
  (0): Linear(in_features=199, out_features=128, bias=True)
  (1): ReLU()
  (2): Linear(in_features=128, out_features=128, bias=True)
  (3): ReLU()
  (4): Linear(in_features=128, out_features=256, bias=True)
  (5): ReLU()
  (6): Linear(in_features=256, out_features=10, bias=True)
)

In [6]:
model.train()
optimizer = optim.Adam(model.parameters(), lr=lr)

losses_mod = []
losses_ent = []

for epoch in range(num_epochs):
    with tqdm(train_loader, desc=f"Epoch {epoch+1}") as pbar:
        for x_batch, _ in pbar:
            optimizer.zero_grad()
            
            mu, z_sample, epsilon, sigma, x_recon = model.forward(x_batch)

            # retain grad to add entropy gradient later
            z_sample.retain_grad()

            # compute mode loss + gradient, keep graph for entropy gradient
            loss = -model.elbo_no_entropy(x_batch, x_recon, z_sample).mean()
            loss.backward(retain_graph=True)

            grad_z, accept_prob, log_qz = get_ent_grad(
                model, epsilon, z_sample, num_hmc_samples, mu, sigma
            )
            z_sample.grad += grad_z.detach()

            # Compute new gradients w.r.t. model parameters
            # using the modified z_sample.grad
            grads = torch.autograd.grad(
                z_sample,
                model.parameters(),
                grad_outputs=z_sample.grad,
                retain_graph=False,
            )

            # Assign computed gradients to model parameters
            for param, grad in zip(model.parameters(), grads):
                param.grad = grad

            # backpropagate mode and entropy gradient
            optimizer.step()

            losses_mod.append(loss.item())
            losses_ent.append(log_qz.item())

            pbar.set_postfix(
                loss=f"{loss.item():05.2f}",
                log_qz=f"{log_qz.item():05.2f}",
                grad_z=f"{grad_z.mean().item():05.2f}",
                accept_prob=f"{accept_prob.mean().item():05.2f}",
                status="running",
            )

            pbar.update(1)


Epoch 1:   0%|          | 0/938 [00:00<?, ?it/s]

Epoch 1:   0%|          | 0/938 [00:00<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x196 and 199x128)

In [13]:
x_batch.shape

torch.Size([64, 1, 28, 28])