In [None]:
#default_exp experiments.cult_experiments

In [None]:
from cult.core.cult import CULT, CULTTrainer
from cult.core.models import FCEncoder, FCDecoder
from cult.core.datasets.moving_mnist import CommonMNIST, CommonFashionMNIST
from cult.core.utils import show_batch, rec_likelihood
from cult.config import DATA_PATH

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from datetime import datetime as dt

# Cult Experiments

In [None]:
batch_size = 64

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
mnist_train = CommonMNIST(DATA_PATH, train=True, transform=ToTensor())
mnist_test = CommonMNIST(DATA_PATH, train=False, transform=ToTensor())
mnist_train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
mnist_test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

mnist_test_batch, _ = iter(mnist_test_loader).next()

fashion_train = CommonFashionMNIST(DATA_PATH, train=True, transform=ToTensor())
fashion_test = CommonFashionMNIST(DATA_PATH, train=False, transform=ToTensor())
fashion_train_loader = DataLoader(fashion_train, batch_size=batch_size, shuffle=True)
fashion_test_loader = DataLoader(fashion_test, batch_size=batch_size, shuffle=True)

fashion_test_batch, _ = iter(fashion_test_loader).next()


## FashionMNIST -> MNIST

In [None]:
optimizer = torch.optim.Adam
lr = 1e-3
kl_scale = 1
e_prox_scale = 1 #might need to scale?
d_prox_scale = 1 #might need to scale
encoder_type = FCEncoder
decoder_type = FCDecoder
final_size = 50
latents = 16
max_envs = 7
atyp_min = .25
atyp_max = 1.4
env_optim = torch.optim.Adam
env_lr=1e-4
env_epochs = 1 #increasing due to poor performance
replay_batch_size = 64
steps_per_save = 2500

In [None]:
epochs = 10
for i in range(1):
    name = f"fash_to_mnist_{i}"
    cult_trainer = CULTTrainer(name, optimizer, lr, kl_scale, e_prox_scale, d_prox_scale, encoder_type, decoder_type, final_size, latents, max_envs, atyp_min, atyp_max, env_optim, env_lr, env_epochs, replay_batch_size, steps_per_save, device)
    cult_trainer.train(fashion_train_loader, epochs, [fashion_test_batch, mnist_test_batch])
    cult_trainer.train(mnist_train_loader, epochs, [fashion_test_batch, mnist_test_batch])
    cult_trainer.train_latent_classifiers([fashion_train_loader, mnist_train_loader], [fashion_test_loader, mnist_test_loader], ["fashion_mnist", "mnist"], [10, 10], epochs=3, lr=1e-2, verbose=False)
    cult_trainer.model.eval()
    cult_trainer.model.env_net.eval()
    fashion_rec_loss = cult_trainer.rec_loss(fashion_test_loader)
    mnist_rec_loss = cult_trainer.rec_loss(mnist_test_loader)
    cult_trainer.writer.add_scalar("eval/fashion_rec_loss", fashion_rec_loss)
    cult_trainer.writer.add_scalar("eval/mnist_recc_loss", mnist_rec_loss)
    fashion_env_acc = cult_trainer.env_accuracy(fashion_test_loader, 0)
    mnist_env_acc = cult_trainer.env_accuracy(mnist_test_loader, 1)
    cult_trainer.writer.add_scalar("eval/fashion_env_acc", fashion_env_acc)
    cult_trainer.writer.add_scalar("eval/mnist_env_acc", mnist_env_acc)

epoch: 0, loss=882.2250366210938, rec_loss=323.95263671875, total_div_loss=11.78816032409668
epoch: 1, loss=841.1382446289062, rec_loss=280.364501953125, total_div_loss=15.096327781677246
epoch: 2, loss=833.1748657226562, rec_loss=271.33251953125, total_div_loss=16.35938262939453
epoch: 3, loss=830.1004028320312, rec_loss=267.8424072265625, total_div_loss=16.83312225341797
epoch: 4, loss=828.0590209960938, rec_loss=265.6373596191406, total_div_loss=17.026012420654297
epoch: 5, loss=826.6766357421875, rec_loss=264.2637634277344, total_div_loss=17.06443214416504
epoch: 6, loss=825.9876098632812, rec_loss=263.4705505371094, total_div_loss=17.175325393676758
epoch: 7, loss=825.3734741210938, rec_loss=262.8834228515625, total_div_loss=17.170766830444336
epoch: 8, loss=825.010498046875, rec_loss=262.45904541015625, total_div_loss=17.24131202697754
epoch: 9, loss=824.64599609375, rec_loss=262.03948974609375, total_div_loss=17.305383682250977
epoch: 0, loss=473.9437561035156, rec_loss=178.5197

100%|██████████| 7/7 [02:37<00:00, 22.48s/it]


## MNIST -> Fashion MNIST

In [None]:
optimizer = torch.optim.Adam
lr = 1e-3
kl_scale = 1
e_prox_scale = 1 #might need to scale?
d_prox_scale = 1 #might need to scale
encoder_type = FCEncoder
decoder_type = FCDecoder
final_size = 50
latents = 16
max_envs = 7
atyp_min = .25
atyp_max = 1.4
env_optim = torch.optim.Adam
env_lr=1e-4
env_epochs = 1 #increasing due to poor performance
replay_batch_size = 64
steps_per_save = 2500

In [None]:
epochs = 10
for i in range(1):
    name = f"mnist_to_fash_{i}"
    cult_trainer = CULTTrainer(name, optimizer, lr, kl_scale, e_prox_scale, d_prox_scale, encoder_type, decoder_type, final_size, latents, max_envs, atyp_min, atyp_max, env_optim, env_lr, env_epochs, replay_batch_size, steps_per_save, device)
    cult_trainer.train(mnist_train_loader, epochs, [mnist_test_batch, fashion_test_batch])
    cult_trainer.train(fashion_train_loader, epochs, [mnist_test_batch, fashion_test_batch])
    cult_trainer.train_latent_classifiers([mnist_train_loader, fashion_train_loader], [mnist_test_loader, fashion_test_loader], ["mnist", "fashion_mnist"], [10, 10], epochs=3, lr=1e-2)
    cult_trainer.model.eval()
    fashion_rec_loss = cult_trainer.rec_loss(fashion_test_loader)
    mnist_rec_loss = cult_trainer.rec_loss(mnist_test_loader)
    cult_trainer.writer.add_scalar("eval/fashion_rec_loss", fashion_rec_loss)
    cult_trainer.writer.add_scalar("eval/mnist_recc_loss", mnist_rec_loss)
    fashion_env_acc = cult_trainer.env_accuracy(fashion_test_loader, 1)
    mnist_env_acc = cult_trainer.env_accuracy(mnist_test_loader, 0)
    cult_trainer.writer.add_scalar("eval/fashion_env_acc", fashion_env_acc)
    cult_trainer.writer.add_scalar("eval/mnist_env_acc", mnist_env_acc)


epoch: 0, loss=762.3403930664062, rec_loss=208.43739318847656, total_div_loss=7.153684139251709
epoch: 1, loss=726.2791137695312, rec_loss=169.60476684570312, total_div_loss=9.923151969909668
epoch: 2, loss=720.3603515625, rec_loss=161.651123046875, total_div_loss=11.953137397766113
epoch: 3, loss=716.0773315429688, rec_loss=157.0636444091797, total_div_loss=13.183549880981445
epoch: 4, loss=714.6477661132812, rec_loss=154.5414581298828, total_div_loss=14.095521926879883
epoch: 5, loss=713.4898071289062, rec_loss=152.78001403808594, total_div_loss=14.73049259185791
epoch: 6, loss=712.3055419921875, rec_loss=151.52017211914062, total_div_loss=15.201375007629395
epoch: 7, loss=712.08740234375, rec_loss=150.65318298339844, total_div_loss=15.589981079101562
epoch: 8, loss=711.1619873046875, rec_loss=149.8383331298828, total_div_loss=15.908149719238281
epoch: 9, loss=710.6917724609375, rec_loss=149.17279052734375, total_div_loss=16.139873504638672
epoch: 0, loss=494.6817321777344, rec_loss=

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 0 acc tensor(0.8706)
epoch 1 acc tensor(0.8895)
epoch 2 acc tensor(0.8975)
epoch 0 acc tensor(0.6968)
epoch 1 acc tensor(0.7245)
epoch 2 acc tensor(0.7278)


 14%|█▍        | 1/7 [00:26<02:41, 26.94s/it]

epoch 0 acc tensor(0.8771)
epoch 1 acc tensor(0.8940)
epoch 2 acc tensor(0.8921)
epoch 0 acc tensor(0.7060)
epoch 1 acc tensor(0.7268)
epoch 2 acc tensor(0.7430)


 29%|██▊       | 2/7 [00:54<02:15, 27.19s/it]

epoch 0 acc tensor(0.8844)
epoch 1 acc tensor(0.8959)
epoch 2 acc tensor(0.8954)
epoch 0 acc tensor(0.7060)
epoch 1 acc tensor(0.7257)
epoch 2 acc tensor(0.7198)


 43%|████▎     | 3/7 [01:24<01:54, 28.73s/it]

epoch 0 acc tensor(0.8438)
epoch 1 acc tensor(0.8588)
epoch 2 acc tensor(0.8683)
epoch 0 acc tensor(0.7466)
epoch 1 acc tensor(0.7586)
epoch 2 acc tensor(0.7581)


 57%|█████▋    | 4/7 [01:53<01:26, 28.70s/it]

epoch 0 acc tensor(0.8624)
epoch 1 acc tensor(0.8762)
epoch 2 acc tensor(0.8817)
epoch 0 acc tensor(0.7528)
epoch 1 acc tensor(0.7672)
epoch 2 acc tensor(0.7641)


 71%|███████▏  | 5/7 [02:23<00:58, 29.05s/it]

epoch 0 acc tensor(0.8604)
epoch 1 acc tensor(0.8748)
epoch 2 acc tensor(0.8748)
epoch 0 acc tensor(0.7665)
epoch 1 acc tensor(0.7631)
epoch 2 acc tensor(0.7744)


 86%|████████▌ | 6/7 [02:51<00:28, 28.71s/it]

epoch 0 acc tensor(0.8591)
epoch 1 acc tensor(0.8624)
epoch 2 acc tensor(0.8672)
epoch 0 acc tensor(0.7657)
epoch 1 acc tensor(0.7718)
epoch 2 acc tensor(0.7692)


100%|██████████| 7/7 [03:20<00:00, 28.62s/it]


## Static Update: FashionMNIST -> MNIST

In [None]:
optimizer = torch.optim.Adam
lr = 1e-3
kl_scale = 1
e_prox_scale = 1 #might need to scale?
d_prox_scale = 1 #might need to scale
encoder_type = FCEncoder
decoder_type = FCDecoder
final_size = 50
latents = 16
max_envs = 7
atyp_min = .25
atyp_max = 1.4
env_optim = torch.optim.Adam
env_lr=1e-4
env_epochs = 1 #increasing due to poor performance
replay_batch_size = 64
steps_per_save = 2500
steps_per_reset = 500 # if None, replay model only updated at new environment

In [None]:
epochs = 10
for i in range(1):
    name = f"fixed_fash_to_mnist_{i}"
    cult_trainer = CULTTrainer(name, optimizer, lr, kl_scale, e_prox_scale, d_prox_scale, encoder_type, decoder_type, final_size, latents, max_envs, atyp_min, atyp_max, env_optim, env_lr, env_epochs, replay_batch_size, steps_per_save, device, steps_per_reset)
    cult_trainer.train(fashion_train_loader, epochs, [fashion_test_batch, mnist_test_batch])
    cult_trainer.train(mnist_train_loader, epochs, [fashion_test_batch, mnist_test_batch])
    cult_trainer.train_latent_classifiers([fashion_train_loader, mnist_train_loader], [fashion_test_loader, mnist_test_loader], ["fashion_mnist", "mnist"], [10, 10], epochs=3, lr=1e-2)
    cult_trainer.model.eval()
    fashion_rec_loss = cult_trainer.rec_loss(fashion_test_loader)
    mnist_rec_loss = cult_trainer.rec_loss(mnist_test_loader)
    cult_trainer.writer.add_scalar("eval/fashion_rec_loss", fashion_rec_loss)
    cult_trainer.writer.add_scalar("eval/mnist_recc_loss", mnist_rec_loss)
    fashion_env_acc = cult_trainer.env_accuracy(fashion_test_loader, 0)
    mnist_env_acc = cult_trainer.env_accuracy(mnist_test_loader, 1)
    cult_trainer.writer.add_scalar("eval/fashion_env_acc", fashion_env_acc)
    cult_trainer.writer.add_scalar("eval/mnist_env_acc", mnist_env_acc)

    #cult_trainer.train_latent_classifiers([fashion_train_loader, mnist_train_loader], [fashion_test_loader, mnist_test_loader], ["fashion_mnist", "mnist"], [10, 10], epochs=3, lr=1e-2)


epoch: 0, loss=795.631591796875, rec_loss=322.287353515625, total_div_loss=13.273159980773926
epoch: 1, loss=614.6732788085938, rec_loss=279.1069030761719, total_div_loss=16.17925453186035
epoch: 2, loss=589.9298095703125, rec_loss=270.7331237792969, total_div_loss=17.855920791625977
epoch: 3, loss=580.2642211914062, rec_loss=267.1777648925781, total_div_loss=18.885953903198242
epoch: 4, loss=575.18359375, rec_loss=264.7123107910156, total_div_loss=19.74139976501465
epoch: 5, loss=571.9389038085938, rec_loss=263.0014343261719, total_div_loss=20.506607055664062
epoch: 6, loss=567.923095703125, rec_loss=261.94927978515625, total_div_loss=20.95893096923828
epoch: 7, loss=566.7766723632812, rec_loss=261.3365478515625, total_div_loss=21.13743782043457
epoch: 8, loss=564.888671875, rec_loss=260.90228271484375, total_div_loss=21.165355682373047
epoch: 9, loss=563.883056640625, rec_loss=260.6723937988281, total_div_loss=21.216205596923828
epoch: 0, loss=443.0996398925781, rec_loss=181.37829589

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 0 acc tensor(0.7449)
epoch 1 acc tensor(0.7678)
epoch 2 acc tensor(0.7701)
epoch 0 acc tensor(0.7247)
epoch 1 acc tensor(0.7555)
epoch 2 acc tensor(0.7667)


 14%|█▍        | 1/7 [00:26<02:40, 26.72s/it]

epoch 0 acc tensor(0.7427)
epoch 1 acc tensor(0.7534)
epoch 2 acc tensor(0.7633)
epoch 0 acc tensor(0.7429)
epoch 1 acc tensor(0.7700)
epoch 2 acc tensor(0.7907)


 29%|██▊       | 2/7 [00:53<02:13, 26.80s/it]

epoch 0 acc tensor(0.7498)
epoch 1 acc tensor(0.7550)
epoch 2 acc tensor(0.7717)
epoch 0 acc tensor(0.6662)
epoch 1 acc tensor(0.7176)
epoch 2 acc tensor(0.7382)


 43%|████▎     | 3/7 [01:20<01:46, 26.63s/it]

epoch 0 acc tensor(0.7279)
epoch 1 acc tensor(0.7391)
epoch 2 acc tensor(0.7277)
epoch 0 acc tensor(0.7157)
epoch 1 acc tensor(0.7314)
epoch 2 acc tensor(0.7422)


 57%|█████▋    | 4/7 [01:47<01:20, 26.81s/it]

epoch 0 acc tensor(0.7379)
epoch 1 acc tensor(0.7470)
epoch 2 acc tensor(0.7566)
epoch 0 acc tensor(0.8097)
epoch 1 acc tensor(0.8358)
epoch 2 acc tensor(0.8469)


 71%|███████▏  | 5/7 [02:14<00:54, 27.04s/it]

epoch 0 acc tensor(0.7139)
epoch 1 acc tensor(0.7343)
epoch 2 acc tensor(0.7412)
epoch 0 acc tensor(0.8287)
epoch 1 acc tensor(0.8377)
epoch 2 acc tensor(0.8542)


 86%|████████▌ | 6/7 [02:41<00:27, 27.12s/it]

epoch 0 acc tensor(0.7008)
epoch 1 acc tensor(0.7193)
epoch 2 acc tensor(0.7355)
epoch 0 acc tensor(0.8493)
epoch 1 acc tensor(0.8697)
epoch 2 acc tensor(0.8767)


100%|██████████| 7/7 [03:12<00:00, 27.47s/it]


## No Replay: FashionMNIST -> MNIST

In [None]:
optimizer = torch.optim.Adam
lr = 1e-3
kl_scale = 1
e_prox_scale = 0 
d_prox_scale = 0 
encoder_type = FCEncoder
decoder_type = FCDecoder
final_size = 50
latents = 16
max_envs = 7
atyp_min = .25
atyp_max = 1.4
env_optim = torch.optim.Adam
env_lr=1e-4
env_epochs = 1
replay_batch_size = 64
steps_per_save = 2500

In [None]:
epochs = 10
for i in range(1):
    name = f"no_replay_fash_to_mnist_{i}"
    cult_trainer = CULTTrainer(name, optimizer, lr, kl_scale, e_prox_scale, d_prox_scale, encoder_type, decoder_type, final_size, latents, max_envs, atyp_min, atyp_max, env_optim, env_lr, env_epochs, replay_batch_size, steps_per_save, device)
    cult_trainer.train(fashion_train_loader, epochs, [fashion_test_batch, mnist_test_batch])
    cult_trainer.train(mnist_train_loader, epochs, [fashion_test_batch, mnist_test_batch])
    cult_trainer.train_latent_classifiers([fashion_train_loader, mnist_train_loader], [fashion_test_loader, mnist_test_loader], ["fashion_mnist", "mnist"], [10, 10], epochs=3, lr=1e-2)
    cult_trainer.model.eval()
    fashion_rec_loss = cult_trainer.rec_loss(fashion_test_loader)
    mnist_rec_loss = cult_trainer.rec_loss(mnist_test_loader)
    cult_trainer.writer.add_scalar("eval/fashion_rec_loss", fashion_rec_loss)
    cult_trainer.writer.add_scalar("eval/mnist_recc_loss", mnist_rec_loss)
    fashion_env_acc = cult_trainer.env_accuracy(fashion_test_loader, 0)
    mnist_env_acc = cult_trainer.env_accuracy(mnist_test_loader, 1)
    cult_trainer.writer.add_scalar("eval/fashion_env_acc", fashion_env_acc)
    cult_trainer.writer.add_scalar("eval/mnist_env_acc", mnist_env_acc)

epoch: 0, loss=334.0847473144531, rec_loss=323.16497802734375, total_div_loss=10.919363021850586
epoch: 1, loss=293.1953430175781, rec_loss=279.8336181640625, total_div_loss=13.361682891845703
epoch: 2, loss=285.898681640625, rec_loss=271.314453125, total_div_loss=14.58430290222168
epoch: 3, loss=282.6383361816406, rec_loss=266.7991638183594, total_div_loss=15.83903980255127
epoch: 4, loss=281.0431213378906, rec_loss=264.6848449707031, total_div_loss=16.358247756958008
epoch: 5, loss=280.0467224121094, rec_loss=263.36431884765625, total_div_loss=16.682052612304688
epoch: 6, loss=279.457275390625, rec_loss=262.6441345214844, total_div_loss=16.813255310058594
epoch: 7, loss=279.2010498046875, rec_loss=262.20538330078125, total_div_loss=16.99568748474121
epoch: 8, loss=278.66064453125, rec_loss=261.7117614746094, total_div_loss=16.948820114135742
epoch: 9, loss=278.4388732910156, rec_loss=261.4319152832031, total_div_loss=17.00688934326172
epoch: 0, loss=186.96524047851562, rec_loss=177.8

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 0 acc tensor(0.7534)
epoch 1 acc tensor(0.7640)
epoch 2 acc tensor(0.7623)
epoch 0 acc tensor(0.6739)
epoch 1 acc tensor(0.7167)
epoch 2 acc tensor(0.7254)


 14%|█▍        | 1/7 [00:26<02:39, 26.56s/it]

epoch 0 acc tensor(0.7503)
epoch 1 acc tensor(0.7485)
epoch 2 acc tensor(0.7501)
epoch 0 acc tensor(0.6114)
epoch 1 acc tensor(0.6622)
epoch 2 acc tensor(0.6785)


 29%|██▊       | 2/7 [00:53<02:13, 26.61s/it]

epoch 0 acc tensor(0.7479)
epoch 1 acc tensor(0.7631)
epoch 2 acc tensor(0.7647)
epoch 0 acc tensor(0.5461)
epoch 1 acc tensor(0.6256)
epoch 2 acc tensor(0.6582)


 43%|████▎     | 3/7 [01:23<01:53, 28.37s/it]

epoch 0 acc tensor(0.7416)
epoch 1 acc tensor(0.7559)
epoch 2 acc tensor(0.7550)
epoch 0 acc tensor(0.7373)
epoch 1 acc tensor(0.7507)
epoch 2 acc tensor(0.7710)


 57%|█████▋    | 4/7 [01:54<01:28, 29.46s/it]

epoch 0 acc tensor(0.6851)
epoch 1 acc tensor(0.7035)
epoch 2 acc tensor(0.7096)
epoch 0 acc tensor(0.7858)
epoch 1 acc tensor(0.8129)
epoch 2 acc tensor(0.8188)


 71%|███████▏  | 5/7 [02:26<01:00, 30.26s/it]

epoch 0 acc tensor(0.6960)
epoch 1 acc tensor(0.7041)
epoch 2 acc tensor(0.7122)
epoch 0 acc tensor(0.8285)
epoch 1 acc tensor(0.8380)
epoch 2 acc tensor(0.8507)


 86%|████████▌ | 6/7 [02:56<00:30, 30.32s/it]

epoch 0 acc tensor(0.6782)
epoch 1 acc tensor(0.6993)
epoch 2 acc tensor(0.7129)
epoch 0 acc tensor(0.8243)
epoch 1 acc tensor(0.8412)
epoch 2 acc tensor(0.8485)


100%|██████████| 7/7 [03:24<00:00, 29.17s/it]



Then we should be good to go - once we get the env stuff figured out, we'll run the following expeiments:
FashionMNIST -> MNIST (5x)
MNIST -> FashionMNIST (5x)

FashionMNIST -> MNIST static update (5x)
MNIST -> FashionMNIST static update (5x)

FashionMNIST -> MNIST no replay (5x)
MNIST -> FashionMNIST no replay (5x)