In [None]:
#default_exp experiments.cult_experiments

In [None]:
from cult.core.cult import CULT, CULTTrainer
from cult.core.models import FCEncoder, FCDecoder
from cult.core.datasets.moving_mnist import CommonMNIST, CommonFashionMNIST
from cult.core.utils import show_batch
from cult.config import DATA_PATH

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from datetime import datetime as dt

# Cult Experiments

In [None]:
batch_size = 64

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
mnist_train = CommonMNIST(DATA_PATH, train=True, transform=ToTensor())
mnist_test = CommonMNIST(DATA_PATH, train=False, transform=ToTensor())
mnist_train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
mnist_test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

mnist_test_batch, _ = iter(mnist_test_loader).next()

fashion_train = CommonFashionMNIST(DATA_PATH, train=True, transform=ToTensor())
fashion_test = CommonFashionMNIST(DATA_PATH, train=False, transform=ToTensor())
fashion_train_loader = DataLoader(fashion_train, batch_size=batch_size, shuffle=True)
fashion_test_loader = DataLoader(fashion_test, batch_size=batch_size, shuffle=True)

fashion_test_batch, _ = iter(fashion_test_loader).next()


## FashionMNIST -> MNIST

In [None]:

optimizer = torch.optim.Adam
lr = 1e-3
kl_scale = 1
e_prox_scale = 1 #might need to scale?
d_prox_scale = 1 #might need to scale
encoder_type = FCEncoder
decoder_type = FCDecoder
final_size = 50
latents = 16
max_envs = 7
atyp_min = .25
atyp_max = 1.4
env_optim = torch.optim.Adam
env_lr=1e-4
env_epochs = 1 #increasing due to poor performance
replay_batch_size = 64
steps_per_save = 2500

In [None]:
epochs = 10
for i in range(5):
    name = f"fash_to_mnist_{i}"
    cult_trainer = CULTTrainer(name, optimizer, lr, kl_scale, e_prox_scale, d_prox_scale, encoder_type, decoder_type, final_size, latents, max_envs, atyp_min, atyp_max, env_optim, env_lr, env_epochs, replay_batch_size, steps_per_save, device)
    cult_trainer.train(fashion_train_loader, epochs, [fashion_test_batch, mnist_test_batch])
    cult_trainer.train(mnist_train_loader, epochs, [fashion_test_batch, mnist_test_batch])

epoch: 0, loss=886.5707397460938, rec_loss=328.4096374511719, total_div_loss=11.265183448791504
epoch: 0, loss=548.5853271484375, rec_loss=181.38121032714844, total_div_loss=8.922162055969238
epoch: 0, loss=883.0299072265625, rec_loss=324.126953125, total_div_loss=12.292306900024414
epoch: 0, loss=550.9530029296875, rec_loss=181.06832885742188, total_div_loss=9.10499382019043
epoch: 0, loss=884.5418701171875, rec_loss=325.7725830078125, total_div_loss=11.916460990905762


KeyboardInterrupt: 

## MNIST -> Fashion MNIST

In [None]:

optimizer = torch.optim.Adam
lr = 1e-3
kl_scale = 1
e_prox_scale = 1 #might need to scale?
d_prox_scale = 1 #might need to scale
encoder_type = FCEncoder
decoder_type = FCDecoder
final_size = 50
latents = 16
max_envs = 7
atyp_min = .25
atyp_max = 1.4
env_optim = torch.optim.Adam
env_lr=1e-4
env_epochs = 1 #increasing due to poor performance
replay_batch_size = 64
steps_per_save = 2500

In [None]:
epochs = 10
for i in range(5):
    name = f"mnist_to_fash_{i}"
    cult_trainer = CULTTrainer(name, optimizer, lr, kl_scale, e_prox_scale, d_prox_scale, encoder_type, decoder_type, final_size, latents, max_envs, atyp_min, atyp_max, env_optim, env_lr, env_epochs, replay_batch_size, steps_per_save, device)
    cult_trainer.train(mnist_train_loader, epochs, [mnist_test_batch, fashion_test_batch])
    cult_trainer.train(fashion_train_loader, epochs, [mnist_test_batch, fashion_test_batch])


## Static Update: FashionMNIST -> MNIST

TODO: implement this in cult

## No Replay: FashionMNIST -> MNIST

In [None]:
optimizer = torch.optim.Adam
lr = 1e-3
kl_scale = 1
e_prox_scale = 0 
d_prox_scale = 0 
encoder_type = FCEncoder
decoder_type = FCDecoder
final_size = 50
latents = 16
max_envs = 7
atyp_min = .25
atyp_max = 1e7
env_optim = torch.optim.Adam
env_lr=1e-4
env_epochs = 1
replay_batch_size = 64
steps_per_save = 2500

In [None]:
epochs = 10
for i in range(5):
    name = f"fash_to_mnist_{i}"
    cult_trainer = CULTTrainer(name, optimizer, lr, kl_scale, e_prox_scale, d_prox_scale, encoder_type, decoder_type, final_size, latents, max_envs, atyp_min, atyp_max, env_optim, env_lr, env_epochs, replay_batch_size, steps_per_save, device)
    cult_trainer.train(fashion_train_loader, epochs, [fashion_test_batch, mnist_test_batch])
    cult_trainer.train(mnist_train_loader, epochs, [fashion_test_batch, mnist_test_batch])


Then we should be good to go - once we get the env stuff figured out, we'll run the following expeiments:
FashionMNIST -> MNIST (5x)
MNIST -> FashionMNIST (5x)

FashionMNIST -> MNIST static update (5x)
MNIST -> FashionMNIST static update (5x)

FashionMNIST -> MNIST no replay (5x)
MNIST -> FashionMNIST no replay (5x)