In [1]:
import datetime
import os
import random
from tensorboardX import SummaryWriter
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from dataset import MovingMNIST, MovingMNISTLR
from model import TDVAE
from pixyz.utils import print_latex

In [2]:
gradient_steps=2*10**4
batch_size=32
dataset_type='MovingMNISTLR'
root_log_dir='log/'
data_dir='../data/MNIST/'
log_dir=datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
log_interval=200
save_interval=1000
workers=0
seed=1234
device_ids=[0]
z_size=8
lr=5e-4
rescale=None

In [3]:
# Device
device = f"cuda:{device_ids[0]}" if torch.cuda.is_available() else "cpu"

# Seed
if seed!=None:
    torch.manual_seed(seed)
    random.seed(seed)


# Logging
log_interval_num = log_interval
log_dir = os.path.join(root_log_dir, log_dir)
os.mkdir(log_dir)
os.mkdir(os.path.join(log_dir, 'models'))
os.mkdir(os.path.join(log_dir,'runs'))

import datetime
dt_now = datetime.datetime.now()
exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')

import pixyz
v = pixyz.__version__
writer = SummaryWriter("../runs/" + v + ".td-vae" + exp_time)

# Dataset
if dataset_type == 'MovingMNIST':
    data_path = os.path.join(data_dir, 'mnist_test_seq.npy')
    full_dataset = MovingMNIST(data_path, rescale=rescale)
    data_num = len(full_dataset)
    train_size = int(0.9 * data_num)
    test_size = data_num - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
elif dataset_type == 'MovingMNISTLR':
    train_dataset = MovingMNISTLR(data_dir, train=True, download=True)
    test_dataset = MovingMNISTLR(data_dir, train=False, download=True)
else:
    raise NotImplementedError()
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
train_loader_iterator = iter(train_loader)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=workers)
test_loader_iterator = iter(test_loader)
test_batch = next(test_loader_iterator).to(device)
test_batch = test_batch.transpose(0, 1)
seq_len, _, C, H, W = test_batch.size()

model = TDVAE(seq_len=seq_len, z_size=z_size, x_size=C*H*W, processed_x_size=C*H*W,
              optimizer=torch.optim.Adam, optimizer_params={"lr": lr}, device=device, clip_grad_value=10)

print(model)

Distributions (for training): 
  p_{b}(z_{t1}|b_{t1}), p_{b}(z_{t2}|b_{t2}), p_{t}(z_{t2}|z_{t1}), p_{d}(x_{t2}|z_{t2}), q(z_{t1}|z_{t2},b_{t1},b_{t2}), p(b|x) 
Loss function: 
  mean \left(\mathbb{E}_{p(b|x)} \left[\sum_{t=1}^{19} \mathbb{E}_{f(x_{t2},b_{t1},b_{t2}|t,x,b)} \left[\mathbb{E}_{p_{b}(z_{t2}|b_{t2})} \left[D_{KL} \left[q(z_{t1}|z_{t2},b_{t1},b_{t2})||p_{b}(z_{t1}|b_{t1}) \right] + \mathbb{E}_{q(z_{t1}|z_{t2},b_{t1},b_{t2})} \left[\log p_{b}(z_{t2}|b_{t2}) - \log p_{d}(x_{t2}|z_{t2}) - \log p_{t}(z_{t2}|z_{t1}) \right] \right] \right] \right] \right) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.0005
      weight_decay: 0
  )


In [4]:
print_latex(model)

<IPython.core.display.Math object>

In [None]:
import time
start = time.time()
for itr in tqdm(range(gradient_steps)):
    try:
        batch = next(train_loader_iterator)
    except StopIteration:
        train_loader_iterator = iter(train_loader)
        batch = next(train_loader_iterator)
    batch = batch.to(device)
    batch_size, seq_len, *_ = batch.size()
    batch = batch.view(batch_size, seq_len, -1)
    batch = batch.transpose(0, 1)

    loss = model.train({"x": batch})
    writer.add_scalar('train_loss', loss, itr)

    with torch.no_grad():
        if itr % log_interval_num == 0:
            test_pred = model.pred(test_batch)
            test_loss = model.test({"x": batch.view(seq_len, batch_size, -1)})

            writer.add_scalar('test_loss', test_loss, itr)
            writer.add_video('test_pred', test_pred.transpose(0, 1), itr)
            writer.add_video('test_ground_truth', test_batch.transpose(0, 1), itr)
elapsed_time = time.time() - start
writer.add_scalar('Exp time second', elapsed_time)
writer.close()