In [32]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [33]:
import numpy as np
import torch
import os
import sys
import random
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader

module_path = os.path.abspath(os.path.join('..'))
abs_path = "/content/drive/MyDrive/atml"
sys.path.append(abs_path+"/models")
sys.path.append(abs_path+"/train")
sys.path.append(abs_path+"/datasets")

In [34]:
if torch.cuda.is_available:
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
device

device(type='cuda')

In [35]:
from datasets import train_test_random_split, load_dsprites
from train import train_control_vae, test_control_vae
from loss import loss_control_vae
from control_vae import ControlVAEDSprites

In [36]:
dataset = load_dsprites(abs_path + "/datasets/dsprites.npz")

In [37]:
dataset = torch.from_numpy(dataset)

In [38]:
n_imgs = 50000
#indices = torch.randperm(dataset.size(0))[:n_imgs]
indices = torch.randperm(dataset.size(0))
dataset = dataset[indices]

In [39]:
data_train, data_test = train_test_random_split(dataset, 0.8)

In [40]:
batch_size = 64
train_loader = DataLoader(data_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(data_test, batch_size=batch_size, shuffle=True)

In [41]:
beta_controller_args = {
    'C' : 0.5,
    'C_max' : 25,
    'C_step_val' : 0.15,
    'C_step_period' : 5000,
    'Kp' : 0.01,
    'Ki' : -0.001,
    'Kd' : 0.0
}

In [45]:
model = ControlVAEDSprites(beta_controller_args)
model.to(device)

ControlVAEDSprites(
  (encoder): Sequential(
    (0): Linear(in_features=4096, out_features=1200, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1200, out_features=1200, bias=True)
    (3): ReLU()
    (4): Linear(in_features=1200, out_features=20, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=10, out_features=1200, bias=True)
    (1): Tanh()
    (2): Linear(in_features=1200, out_features=1200, bias=True)
    (3): Tanh()
    (4): Linear(in_features=1200, out_features=1200, bias=True)
    (5): Tanh()
    (6): Linear(in_features=1200, out_features=4096, bias=True)
  )
)

In [46]:
optimizer = torch.optim.Adagrad(model.parameters(), lr=1e-4)

In [47]:
train_control_vae(model, 100, train_loader, optimizer, 'bernoulli', device=device)



Epoch 0 finished, loss: 160.26071770240864, recon loss: 158.47923537674876, kl div: 0.6756927531578185
Epoch 1 finished, loss: 154.68256140251955, recon loss: 154.33240931563907, kl div: 0.21567766126342272
Epoch 2 finished, loss: 153.74972114049726, recon loss: 153.6173102069232, kl div: 0.1324109125521823
Epoch 3 finished, loss: 150.65182263818053, recon loss: 150.2862218560444, kl div: 0.365600745371517
Epoch 4 finished, loss: 145.16060551835432, recon loss: 144.6369401117166, kl div: 0.5204805967029339
Epoch 5 finished, loss: 138.92171644998922, recon loss: 137.41681353747845, kl div: 0.7155277479420571
Epoch 6 finished, loss: 124.4580458195673, recon loss: 121.34457385622792, kl div: 0.7376374820564201
Epoch 7 finished, loss: 114.08894593351417, recon loss: 110.1759425625205, kl div: 0.6492137457114748
Epoch 8 finished, loss: 105.9625987617506, recon loss: 101.81683355073135, kl div: 0.5857818005121468
Epoch 9 finished, loss: 99.0280182345046, recon loss: 94.37335199945502, kl div

NameError: ignored

In [None]:
test_control_vae(model, test_loader, 'bernoulli', device = device)

In [49]:
torch.save(model, 'controlvae_epoch100_lr1e4_default_bc_param.dat')