In [None]:
#default_exp experiments.cult_experiments

In [None]:
from cult.core.cult import CULT, CULTTrainer
from cult.core.models import FCEncoder, FCDecoder
from cult.core.datasets.moving_mnist import CommonMNIST, CommonFashionMNIST
from cult.core.utils import show_batch, rec_likelihood
from cult.config import DATA_PATH

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from datetime import datetime as dt

# Cult Experiments

In [None]:
batch_size = 64

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
mnist_train = CommonMNIST(DATA_PATH, train=True, transform=ToTensor())
mnist_test = CommonMNIST(DATA_PATH, train=False, transform=ToTensor())
mnist_train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
mnist_test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

mnist_test_batch, _ = iter(mnist_test_loader).next()

fashion_train = CommonFashionMNIST(DATA_PATH, train=True, transform=ToTensor())
fashion_test = CommonFashionMNIST(DATA_PATH, train=False, transform=ToTensor())
fashion_train_loader = DataLoader(fashion_train, batch_size=batch_size, shuffle=True)
fashion_test_loader = DataLoader(fashion_test, batch_size=batch_size, shuffle=True)

fashion_test_batch, _ = iter(fashion_test_loader).next()


## FashionMNIST -> MNIST

In [None]:
optimizer = torch.optim.Adam
lr = 1e-3
kl_scale = 1
e_prox_scale = 1 #might need to scale?
d_prox_scale = 1 #might need to scale
encoder_type = FCEncoder
decoder_type = FCDecoder
final_size = 50
latents = 16
max_envs = 7
atyp_min = .25
atyp_max = 1.4
env_optim = torch.optim.Adam
env_lr=1e-4
env_epochs = 1 #increasing due to poor performance
replay_batch_size = 64
steps_per_save = 2500

In [None]:
foo = CULTTrainer('foo', optimizer, lr, kl_scale, e_prox_scale, d_prox_scale, encoder_type, decoder_type, final_size, latents, max_envs, atyp_min, atyp_max, env_optim, env_lr, env_epochs, replay_batch_size, steps_per_save, device)

In [None]:
foo.model.steps

0

In [None]:
len(mnist_train_loader) * 10

9380

In [None]:
for exp_num in range(5):
    epochs = 10
    name = f"fash_to_mnist_{exp_num}"
    cult_trainer = CULTTrainer(name, optimizer, lr, kl_scale, e_prox_scale, d_prox_scale, encoder_type, decoder_type, final_size, latents, max_envs, atyp_min, atyp_max, env_optim, env_lr, env_epochs, replay_batch_size, steps_per_save, device)
    cult_trainer.train(fashion_train_loader, epochs, [fashion_test_batch, mnist_test_batch])
    cult_trainer.train(mnist_train_loader, epochs, [fashion_test_batch, mnist_test_batch])
    cult_trainer.train_latent_classifiers([fashion_train_loader, mnist_train_loader], [fashion_test_loader, mnist_test_loader], ["fashion_mnist", "mnist"], [10, 10], epochs=3, lr=1e-2, verbose=False)
    cult_trainer.model.eval()
    cult_trainer.model.env_net.eval()
    fashion_rec_loss = cult_trainer.rec_loss(fashion_test_loader)
    mnist_rec_loss = cult_trainer.rec_loss(mnist_test_loader)
    cult_trainer.writer.add_scalar("eval/fashion_rec_loss", fashion_rec_loss)
    cult_trainer.writer.add_scalar("eval/mnist_recc_loss", mnist_rec_loss)
    fashion_env_acc = cult_trainer.env_accuracy(fashion_test_loader, 0)
    mnist_env_acc = cult_trainer.env_accuracy(mnist_test_loader, 1)
    cult_trainer.writer.add_scalar("eval/fashion_env_acc", fashion_env_acc)
    cult_trainer.writer.add_scalar("eval/mnist_env_acc", mnist_env_acc)

epoch: 0, loss=882.3352661132812, rec_loss=323.55364990234375, total_div_loss=11.868779182434082
epoch: 1, loss=842.9118041992188, rec_loss=280.90960693359375, total_div_loss=15.95385456085205
epoch: 2, loss=833.99072265625, rec_loss=271.1559753417969, total_div_loss=16.97806167602539
epoch: 3, loss=831.0304565429688, rec_loss=267.6246643066406, total_div_loss=17.66572380065918
epoch: 4, loss=829.6016235351562, rec_loss=265.91729736328125, total_div_loss=18.010046005249023
epoch: 5, loss=828.77587890625, rec_loss=264.83648681640625, total_div_loss=18.28383445739746
epoch: 6, loss=827.836669921875, rec_loss=264.0345764160156, total_div_loss=18.204912185668945
epoch: 7, loss=827.3883056640625, rec_loss=263.50250244140625, total_div_loss=18.299375534057617
epoch: 8, loss=827.0670776367188, rec_loss=263.1199951171875, total_div_loss=18.399784088134766
epoch: 9, loss=826.61328125, rec_loss=262.71893310546875, total_div_loss=18.371368408203125
epoch: 0, loss=480.257568359375, rec_loss=180.79

100%|██████████| 7/7 [02:35<00:00, 22.25s/it]


epoch: 0, loss=881.1354370117188, rec_loss=323.1217956542969, total_div_loss=11.409282684326172
epoch: 1, loss=840.3844604492188, rec_loss=280.5863342285156, total_div_loss=14.154522895812988
epoch: 2, loss=832.8284912109375, rec_loss=272.1751403808594, total_div_loss=15.182563781738281
epoch: 3, loss=830.2005004882812, rec_loss=269.0063171386719, total_div_loss=15.82284164428711
epoch: 4, loss=828.2120971679688, rec_loss=266.6905517578125, total_div_loss=16.234737396240234
epoch: 5, loss=827.1129150390625, rec_loss=265.1929931640625, total_div_loss=16.66802215576172
epoch: 6, loss=826.1986083984375, rec_loss=264.15045166015625, total_div_loss=16.832855224609375
epoch: 7, loss=825.40869140625, rec_loss=263.3125305175781, total_div_loss=16.91683006286621
epoch: 8, loss=825.094482421875, rec_loss=262.8172912597656, total_div_loss=17.133073806762695
epoch: 9, loss=824.730224609375, rec_loss=262.4314880371094, total_div_loss=17.17717933654785
epoch: 0, loss=470.5045166015625, rec_loss=180.

100%|██████████| 7/7 [02:39<00:00, 22.72s/it]


epoch: 0, loss=881.8001708984375, rec_loss=322.51666259765625, total_div_loss=12.228106498718262
epoch: 1, loss=841.5650024414062, rec_loss=280.15765380859375, total_div_loss=14.926261901855469
epoch: 2, loss=833.6214599609375, rec_loss=271.47430419921875, total_div_loss=15.76183032989502
epoch: 3, loss=830.1779174804688, rec_loss=267.527099609375, total_div_loss=16.33484649658203
epoch: 4, loss=828.6974487304688, rec_loss=265.6979064941406, total_div_loss=16.721275329589844
epoch: 5, loss=827.6818237304688, rec_loss=264.5208435058594, total_div_loss=16.94794273376465
epoch: 6, loss=827.0484008789062, rec_loss=263.6932067871094, total_div_loss=17.18001937866211
epoch: 7, loss=826.3844604492188, rec_loss=262.96099853515625, total_div_loss=17.270259857177734
epoch: 8, loss=826.1036987304688, rec_loss=262.4926452636719, total_div_loss=17.47995948791504
epoch: 9, loss=825.6807250976562, rec_loss=262.0513916015625, total_div_loss=17.549592971801758
epoch: 0, loss=471.3250732421875, rec_loss

100%|██████████| 7/7 [02:47<00:00, 23.96s/it]


epoch: 0, loss=881.3013305664062, rec_loss=322.6269226074219, total_div_loss=12.005391120910645
epoch: 1, loss=840.9617919921875, rec_loss=279.283203125, total_div_loss=15.9032621383667
epoch: 2, loss=833.8262939453125, rec_loss=271.3272705078125, total_div_loss=16.869237899780273
epoch: 3, loss=830.8634643554688, rec_loss=267.99371337890625, total_div_loss=17.312910079956055
epoch: 4, loss=828.9489135742188, rec_loss=265.96630859375, total_div_loss=17.44205665588379
epoch: 5, loss=827.3142700195312, rec_loss=264.4073791503906, total_div_loss=17.378774642944336
epoch: 6, loss=826.317626953125, rec_loss=263.4772644042969, total_div_loss=17.303197860717773
epoch: 7, loss=825.4873657226562, rec_loss=262.6854553222656, total_div_loss=17.28038787841797
epoch: 8, loss=825.0408935546875, rec_loss=262.201171875, total_div_loss=17.330080032348633
epoch: 9, loss=824.6445922851562, rec_loss=261.77069091796875, total_div_loss=17.38921356201172
epoch: 0, loss=472.3203430175781, rec_loss=180.3636322

100%|██████████| 7/7 [02:40<00:00, 22.92s/it]


## MNIST -> Fashion MNIST

In [None]:
optimizer = torch.optim.Adam
lr = 1e-3
kl_scale = 1
e_prox_scale = 1 #might need to scale?
d_prox_scale = 1 #might need to scale
encoder_type = FCEncoder
decoder_type = FCDecoder
final_size = 50
latents = 16
max_envs = 7
atyp_min = .25
atyp_max = 1.4
env_optim = torch.optim.Adam
env_lr=1e-4
env_epochs = 1 #increasing due to poor performance
replay_batch_size = 64
steps_per_save = 2500

In [None]:
for exp_num in range(5):
    epochs = 10
    name = f"mnist_to_fash_{exp_num}"
    cult_trainer = CULTTrainer(name, optimizer, lr, kl_scale, e_prox_scale, d_prox_scale, encoder_type, decoder_type, final_size, latents, max_envs, atyp_min, atyp_max, env_optim, env_lr, env_epochs, replay_batch_size, steps_per_save, device)
    cult_trainer.train(mnist_train_loader, epochs, [mnist_test_batch, fashion_test_batch])
    cult_trainer.train(fashion_train_loader, epochs, [mnist_test_batch, fashion_test_batch])
    cult_trainer.train_latent_classifiers([mnist_train_loader, fashion_train_loader], [mnist_test_loader, fashion_test_loader], ["mnist", "fashion_mnist"], [10, 10], epochs=3, lr=1e-2)
    cult_trainer.model.eval()
    fashion_rec_loss = cult_trainer.rec_loss(fashion_test_loader)
    mnist_rec_loss = cult_trainer.rec_loss(mnist_test_loader)
    cult_trainer.writer.add_scalar("eval/fashion_rec_loss", fashion_rec_loss)
    cult_trainer.writer.add_scalar("eval/mnist_recc_loss", mnist_rec_loss)
    fashion_env_acc = cult_trainer.env_accuracy(fashion_test_loader, 1)
    mnist_env_acc = cult_trainer.env_accuracy(mnist_test_loader, 0)
    cult_trainer.writer.add_scalar("eval/fashion_env_acc", fashion_env_acc)
    cult_trainer.writer.add_scalar("eval/mnist_env_acc", mnist_env_acc)


epoch: 0, loss=764.3453979492188, rec_loss=210.51756286621094, total_div_loss=6.841984272003174
epoch: 1, loss=728.7280883789062, rec_loss=172.51577758789062, total_div_loss=9.332086563110352
epoch: 2, loss=721.8748168945312, rec_loss=163.42025756835938, total_div_loss=11.954620361328125
epoch: 3, loss=717.7255249023438, rec_loss=158.1334686279297, total_div_loss=13.453836441040039
epoch: 4, loss=715.6884765625, rec_loss=154.83721923828125, total_div_loss=14.420963287353516
epoch: 5, loss=713.866943359375, rec_loss=152.75584411621094, total_div_loss=15.089506149291992
epoch: 6, loss=713.3990478515625, rec_loss=151.5978240966797, total_div_loss=15.527010917663574
epoch: 7, loss=712.3328247070312, rec_loss=150.55862426757812, total_div_loss=15.88389778137207
epoch: 8, loss=711.7504272460938, rec_loss=149.72372436523438, total_div_loss=16.069578170776367
epoch: 9, loss=711.3583984375, rec_loss=149.1879425048828, total_div_loss=16.326169967651367
epoch: 0, loss=492.3997497558594, rec_loss=

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 0 acc tensor(0.8874)
epoch 1 acc tensor(0.9027)
epoch 2 acc tensor(0.9029)
epoch 0 acc tensor(0.7150)
epoch 1 acc tensor(0.7381)
epoch 2 acc tensor(0.7493)


 14%|█▍        | 1/7 [00:27<02:45, 27.51s/it]

epoch 0 acc tensor(0.8864)
epoch 1 acc tensor(0.8962)
epoch 2 acc tensor(0.9002)
epoch 0 acc tensor(0.7235)
epoch 1 acc tensor(0.7386)
epoch 2 acc tensor(0.7509)


 29%|██▊       | 2/7 [00:55<02:17, 27.53s/it]

epoch 0 acc tensor(0.8873)
epoch 1 acc tensor(0.8992)
epoch 2 acc tensor(0.8991)
epoch 0 acc tensor(0.6938)
epoch 1 acc tensor(0.6939)
epoch 2 acc tensor(0.7306)


 43%|████▎     | 3/7 [01:22<01:50, 27.65s/it]

epoch 0 acc tensor(0.8570)
epoch 1 acc tensor(0.8680)
epoch 2 acc tensor(0.8765)
epoch 0 acc tensor(0.7513)
epoch 1 acc tensor(0.7594)
epoch 2 acc tensor(0.7695)


 57%|█████▋    | 4/7 [01:50<01:22, 27.58s/it]

epoch 0 acc tensor(0.8662)
epoch 1 acc tensor(0.8775)
epoch 2 acc tensor(0.8893)
epoch 0 acc tensor(0.7654)
epoch 1 acc tensor(0.7778)
epoch 2 acc tensor(0.7781)


 71%|███████▏  | 5/7 [02:17<00:55, 27.58s/it]

epoch 0 acc tensor(0.8643)
epoch 1 acc tensor(0.8754)
epoch 2 acc tensor(0.8842)
epoch 0 acc tensor(0.7666)
epoch 1 acc tensor(0.7676)
epoch 2 acc tensor(0.7767)


 86%|████████▌ | 6/7 [02:45<00:27, 27.62s/it]

epoch 0 acc tensor(0.8527)
epoch 1 acc tensor(0.8664)
epoch 2 acc tensor(0.8613)
epoch 0 acc tensor(0.7515)
epoch 1 acc tensor(0.7688)
epoch 2 acc tensor(0.7733)


100%|██████████| 7/7 [03:12<00:00, 27.46s/it]


## Static Update: FashionMNIST -> MNIST

In [None]:
optimizer = torch.optim.Adam
lr = 1e-3
kl_scale = 1
e_prox_scale = 1 #might need to scale?
d_prox_scale = 1 #might need to scale
encoder_type = FCEncoder
decoder_type = FCDecoder
final_size = 50
latents = 16
max_envs = 7
atyp_min = .25
atyp_max = 1.4
env_optim = torch.optim.Adam
env_lr=1e-4
env_epochs = 1 #increasing due to poor performance
replay_batch_size = 64
steps_per_save = 2500
steps_per_reset = 500 # if None, replay model only updated at new environment

In [None]:
for exp_num in range(5):
    epochs = 10
    name = f"fixed_fash_to_mnist_{exp_num}"
    cult_trainer = CULTTrainer(name, optimizer, lr, kl_scale, e_prox_scale, d_prox_scale, encoder_type, decoder_type, final_size, latents, max_envs, atyp_min, atyp_max, env_optim, env_lr, env_epochs, replay_batch_size, steps_per_save, device, steps_per_reset)
    cult_trainer.train(fashion_train_loader, epochs, [fashion_test_batch, mnist_test_batch])
    cult_trainer.train(mnist_train_loader, epochs, [fashion_test_batch, mnist_test_batch])
    cult_trainer.train_latent_classifiers([fashion_train_loader, mnist_train_loader], [fashion_test_loader, mnist_test_loader], ["fashion_mnist", "mnist"], [10, 10], epochs=3, lr=1e-2)
    cult_trainer.model.eval()
    fashion_rec_loss = cult_trainer.rec_loss(fashion_test_loader)
    mnist_rec_loss = cult_trainer.rec_loss(mnist_test_loader)
    cult_trainer.writer.add_scalar("eval/fashion_rec_loss", fashion_rec_loss)
    cult_trainer.writer.add_scalar("eval/mnist_recc_loss", mnist_rec_loss)
    fashion_env_acc = cult_trainer.env_accuracy(fashion_test_loader, 0)
    mnist_env_acc = cult_trainer.env_accuracy(mnist_test_loader, 1)
    cult_trainer.writer.add_scalar("eval/fashion_env_acc", fashion_env_acc)
    cult_trainer.writer.add_scalar("eval/mnist_env_acc", mnist_env_acc)


epoch: 0, loss=796.3815307617188, rec_loss=321.3736877441406, total_div_loss=13.189351081848145
epoch: 1, loss=614.6005859375, rec_loss=278.6806335449219, total_div_loss=15.87909984588623
epoch: 2, loss=585.7134399414062, rec_loss=270.21435546875, total_div_loss=17.573097229003906
epoch: 3, loss=573.8211059570312, rec_loss=266.4117736816406, total_div_loss=18.862581253051758
epoch: 4, loss=567.7890014648438, rec_loss=264.202880859375, total_div_loss=19.71156883239746
epoch: 5, loss=564.7792358398438, rec_loss=262.7940979003906, total_div_loss=20.17998504638672
epoch: 6, loss=563.1238403320312, rec_loss=261.83990478515625, total_div_loss=20.57794761657715
epoch: 7, loss=561.0931396484375, rec_loss=260.9774169921875, total_div_loss=20.922508239746094
epoch: 8, loss=559.92431640625, rec_loss=260.3429870605469, total_div_loss=21.067846298217773
epoch: 9, loss=558.4791259765625, rec_loss=259.9775695800781, total_div_loss=21.276567459106445
epoch: 0, loss=438.9420166015625, rec_loss=181.1367

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 0 acc tensor(0.7504)
epoch 1 acc tensor(0.7710)
epoch 2 acc tensor(0.7737)
epoch 0 acc tensor(0.7251)
epoch 1 acc tensor(0.7545)
epoch 2 acc tensor(0.7658)


 14%|█▍        | 1/7 [00:34<03:26, 34.38s/it]

epoch 0 acc tensor(0.7510)
epoch 1 acc tensor(0.7576)
epoch 2 acc tensor(0.7742)
epoch 0 acc tensor(0.6426)
epoch 1 acc tensor(0.6915)
epoch 2 acc tensor(0.7246)


 29%|██▊       | 2/7 [01:08<02:49, 33.99s/it]

epoch 0 acc tensor(0.7385)
epoch 1 acc tensor(0.7564)
epoch 2 acc tensor(0.7584)
epoch 0 acc tensor(0.5755)
epoch 1 acc tensor(0.6387)
epoch 2 acc tensor(0.6831)


 43%|████▎     | 3/7 [01:41<02:14, 33.66s/it]

epoch 0 acc tensor(0.7421)
epoch 1 acc tensor(0.7562)
epoch 2 acc tensor(0.7527)
epoch 0 acc tensor(0.7619)
epoch 1 acc tensor(0.7802)
epoch 2 acc tensor(0.8046)


 57%|█████▋    | 4/7 [02:17<01:44, 34.81s/it]

epoch 0 acc tensor(0.6962)
epoch 1 acc tensor(0.7165)
epoch 2 acc tensor(0.7103)
epoch 0 acc tensor(0.7873)
epoch 1 acc tensor(0.8254)
epoch 2 acc tensor(0.8329)


 71%|███████▏  | 5/7 [02:53<01:10, 35.13s/it]

epoch 0 acc tensor(0.7090)
epoch 1 acc tensor(0.7201)
epoch 2 acc tensor(0.7288)
epoch 0 acc tensor(0.8098)
epoch 1 acc tensor(0.8380)
epoch 2 acc tensor(0.8482)


 86%|████████▌ | 6/7 [03:26<00:34, 34.30s/it]

epoch 0 acc tensor(0.7118)
epoch 1 acc tensor(0.7191)
epoch 2 acc tensor(0.7214)
epoch 0 acc tensor(0.8405)
epoch 1 acc tensor(0.8429)
epoch 2 acc tensor(0.8690)


100%|██████████| 7/7 [03:59<00:00, 34.17s/it]


epoch: 0, loss=792.9951171875, rec_loss=319.96270751953125, total_div_loss=13.596808433532715
epoch: 1, loss=612.5584106445312, rec_loss=277.3134460449219, total_div_loss=16.88612937927246
epoch: 2, loss=586.283203125, rec_loss=267.79473876953125, total_div_loss=19.08025550842285
epoch: 3, loss=575.98095703125, rec_loss=263.74322509765625, total_div_loss=20.563785552978516
epoch: 4, loss=570.0739135742188, rec_loss=261.9313049316406, total_div_loss=21.351295471191406
epoch: 5, loss=566.2647705078125, rec_loss=260.92864990234375, total_div_loss=21.662004470825195
epoch: 6, loss=564.476318359375, rec_loss=260.21148681640625, total_div_loss=21.91639518737793
epoch: 7, loss=562.28173828125, rec_loss=259.67999267578125, total_div_loss=22.056108474731445
epoch: 8, loss=562.007080078125, rec_loss=259.43670654296875, total_div_loss=22.04983139038086
epoch: 9, loss=560.6005249023438, rec_loss=259.2016906738281, total_div_loss=22.048587799072266
epoch: 0, loss=439.66546630859375, rec_loss=179.49

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 0 acc tensor(0.7607)
epoch 1 acc tensor(0.7724)
epoch 2 acc tensor(0.7747)
epoch 0 acc tensor(0.7277)
epoch 1 acc tensor(0.7640)
epoch 2 acc tensor(0.7820)


 14%|█▍        | 1/7 [00:32<03:15, 32.60s/it]

epoch 0 acc tensor(0.7555)
epoch 1 acc tensor(0.7597)
epoch 2 acc tensor(0.7522)
epoch 0 acc tensor(0.6926)
epoch 1 acc tensor(0.7423)
epoch 2 acc tensor(0.7637)


 29%|██▊       | 2/7 [01:03<02:38, 31.62s/it]

epoch 0 acc tensor(0.7465)
epoch 1 acc tensor(0.7585)
epoch 2 acc tensor(0.7686)
epoch 0 acc tensor(0.6573)
epoch 1 acc tensor(0.7122)
epoch 2 acc tensor(0.7347)


 43%|████▎     | 3/7 [01:37<02:10, 32.56s/it]

epoch 0 acc tensor(0.7449)
epoch 1 acc tensor(0.7388)
epoch 2 acc tensor(0.7555)
epoch 0 acc tensor(0.7573)
epoch 1 acc tensor(0.7889)
epoch 2 acc tensor(0.7990)


 57%|█████▋    | 4/7 [02:08<01:36, 32.12s/it]

epoch 0 acc tensor(0.7039)
epoch 1 acc tensor(0.7072)
epoch 2 acc tensor(0.7350)
epoch 0 acc tensor(0.8152)
epoch 1 acc tensor(0.8368)
epoch 2 acc tensor(0.8452)


 71%|███████▏  | 5/7 [02:40<01:04, 32.17s/it]

epoch 0 acc tensor(0.6869)
epoch 1 acc tensor(0.7182)
epoch 2 acc tensor(0.7152)
epoch 0 acc tensor(0.8003)
epoch 1 acc tensor(0.8351)
epoch 2 acc tensor(0.8482)


 86%|████████▌ | 6/7 [03:12<00:31, 31.89s/it]

epoch 0 acc tensor(0.6803)
epoch 1 acc tensor(0.6959)
epoch 2 acc tensor(0.7006)
epoch 0 acc tensor(0.8361)
epoch 1 acc tensor(0.8468)
epoch 2 acc tensor(0.8667)


100%|██████████| 7/7 [03:43<00:00, 31.98s/it]


epoch: 0, loss=799.0175170898438, rec_loss=323.5497741699219, total_div_loss=13.428413391113281
epoch: 1, loss=625.6571044921875, rec_loss=281.403076171875, total_div_loss=15.725753784179688
epoch: 2, loss=591.867919921875, rec_loss=271.9954528808594, total_div_loss=16.96068572998047
epoch: 3, loss=578.0994873046875, rec_loss=267.2109069824219, total_div_loss=18.721975326538086
epoch: 4, loss=571.0552978515625, rec_loss=264.64788818359375, total_div_loss=19.840068817138672
epoch: 5, loss=568.2080688476562, rec_loss=262.9278869628906, total_div_loss=20.58833885192871
epoch: 6, loss=565.5624389648438, rec_loss=261.90667724609375, total_div_loss=21.147205352783203
epoch: 7, loss=564.2085571289062, rec_loss=261.1842041015625, total_div_loss=21.437835693359375
epoch: 8, loss=564.1858520507812, rec_loss=260.599365234375, total_div_loss=21.661958694458008
epoch: 9, loss=563.0545043945312, rec_loss=260.226318359375, total_div_loss=21.832435607910156
epoch: 0, loss=444.87078857421875, rec_loss=

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 0 acc tensor(0.7603)
epoch 1 acc tensor(0.7631)
epoch 2 acc tensor(0.7741)
epoch 0 acc tensor(0.7365)
epoch 1 acc tensor(0.7632)
epoch 2 acc tensor(0.7786)


 14%|█▍        | 1/7 [00:29<02:59, 29.85s/it]

epoch 0 acc tensor(0.7524)
epoch 1 acc tensor(0.7623)
epoch 2 acc tensor(0.7698)
epoch 0 acc tensor(0.6725)
epoch 1 acc tensor(0.7134)
epoch 2 acc tensor(0.7302)


 29%|██▊       | 2/7 [01:00<02:30, 30.05s/it]

epoch 0 acc tensor(0.7519)
epoch 1 acc tensor(0.7614)
epoch 2 acc tensor(0.7678)
epoch 0 acc tensor(0.6735)
epoch 1 acc tensor(0.7426)
epoch 2 acc tensor(0.7611)


 43%|████▎     | 3/7 [01:30<02:01, 30.25s/it]

epoch 0 acc tensor(0.7562)
epoch 1 acc tensor(0.7647)
epoch 2 acc tensor(0.7635)
epoch 0 acc tensor(0.7267)
epoch 1 acc tensor(0.7562)
epoch 2 acc tensor(0.7770)


 57%|█████▋    | 4/7 [02:00<01:30, 30.12s/it]

epoch 0 acc tensor(0.7062)
epoch 1 acc tensor(0.7185)
epoch 2 acc tensor(0.7335)
epoch 0 acc tensor(0.8109)
epoch 1 acc tensor(0.8352)
epoch 2 acc tensor(0.8365)


 71%|███████▏  | 5/7 [02:31<01:00, 30.28s/it]

epoch 0 acc tensor(0.6972)
epoch 1 acc tensor(0.7042)
epoch 2 acc tensor(0.7283)
epoch 0 acc tensor(0.8118)
epoch 1 acc tensor(0.8262)
epoch 2 acc tensor(0.8396)


 86%|████████▌ | 6/7 [03:01<00:30, 30.40s/it]

epoch 0 acc tensor(0.6985)
epoch 1 acc tensor(0.6896)
epoch 2 acc tensor(0.7004)
epoch 0 acc tensor(0.8284)
epoch 1 acc tensor(0.8426)
epoch 2 acc tensor(0.8470)


100%|██████████| 7/7 [03:32<00:00, 30.32s/it]


epoch: 0, loss=793.20849609375, rec_loss=321.2537841796875, total_div_loss=12.790082931518555
epoch: 1, loss=614.8074340820312, rec_loss=282.41998291015625, total_div_loss=13.952102661132812
epoch: 2, loss=586.1149291992188, rec_loss=272.09222412109375, total_div_loss=16.320289611816406
epoch: 3, loss=574.0275268554688, rec_loss=267.81207275390625, total_div_loss=17.821853637695312
epoch: 4, loss=569.130615234375, rec_loss=265.83917236328125, total_div_loss=18.547622680664062
epoch: 5, loss=565.5845947265625, rec_loss=264.33447265625, total_div_loss=19.096376419067383
epoch: 6, loss=563.5015869140625, rec_loss=263.3734436035156, total_div_loss=19.410011291503906
epoch: 7, loss=561.6812133789062, rec_loss=262.6415100097656, total_div_loss=19.672170639038086
epoch: 8, loss=559.5983276367188, rec_loss=262.0635070800781, total_div_loss=19.813867568969727
epoch: 9, loss=559.849609375, rec_loss=261.6167297363281, total_div_loss=20.035024642944336
epoch: 0, loss=439.05218505859375, rec_loss=1

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 0 acc tensor(0.7498)
epoch 1 acc tensor(0.7574)
epoch 2 acc tensor(0.7673)
epoch 0 acc tensor(0.7204)
epoch 1 acc tensor(0.7679)
epoch 2 acc tensor(0.7914)


 14%|█▍        | 1/7 [00:32<03:17, 32.85s/it]

epoch 0 acc tensor(0.7464)
epoch 1 acc tensor(0.7587)
epoch 2 acc tensor(0.7563)
epoch 0 acc tensor(0.6635)
epoch 1 acc tensor(0.7048)
epoch 2 acc tensor(0.7226)


 29%|██▊       | 2/7 [01:05<02:42, 32.48s/it]

epoch 0 acc tensor(0.7404)
epoch 1 acc tensor(0.7425)
epoch 2 acc tensor(0.7605)
epoch 0 acc tensor(0.6313)
epoch 1 acc tensor(0.6888)
epoch 2 acc tensor(0.7271)


 43%|████▎     | 3/7 [01:38<02:10, 32.74s/it]

epoch 0 acc tensor(0.6693)
epoch 1 acc tensor(0.7404)
epoch 2 acc tensor(0.7318)
epoch 0 acc tensor(0.7112)
epoch 1 acc tensor(0.7487)
epoch 2 acc tensor(0.7691)


 57%|█████▋    | 4/7 [02:12<01:40, 33.52s/it]

epoch 0 acc tensor(0.6896)
epoch 1 acc tensor(0.7132)
epoch 2 acc tensor(0.7209)
epoch 0 acc tensor(0.8029)
epoch 1 acc tensor(0.8299)
epoch 2 acc tensor(0.8366)


 71%|███████▏  | 5/7 [02:49<01:09, 34.59s/it]

epoch 0 acc tensor(0.7116)
epoch 1 acc tensor(0.7074)
epoch 2 acc tensor(0.7287)
epoch 0 acc tensor(0.8116)
epoch 1 acc tensor(0.8249)
epoch 2 acc tensor(0.8448)


 86%|████████▌ | 6/7 [03:26<00:35, 35.31s/it]

epoch 0 acc tensor(0.7009)
epoch 1 acc tensor(0.7151)
epoch 2 acc tensor(0.7276)
epoch 0 acc tensor(0.8367)
epoch 1 acc tensor(0.8544)
epoch 2 acc tensor(0.8670)


100%|██████████| 7/7 [04:00<00:00, 34.36s/it]


## No Replay: FashionMNIST -> MNIST

In [None]:
optimizer = torch.optim.Adam
lr = 1e-3
kl_scale = 1
e_prox_scale = 0 
d_prox_scale = 0 
encoder_type = FCEncoder
decoder_type = FCDecoder
final_size = 50
latents = 16
max_envs = 7
atyp_min = .25
atyp_max = 1.4
env_optim = torch.optim.Adam
env_lr=1e-4
env_epochs = 1
replay_batch_size = 64
steps_per_save = 2500

In [None]:
for exp_num in range(5):
    epochs = 10
    name = f"no_replay_fash_to_mnist_{exp_num}"
    cult_trainer = CULTTrainer(name, optimizer, lr, kl_scale, e_prox_scale, d_prox_scale, encoder_type, decoder_type, final_size, latents, max_envs, atyp_min, atyp_max, env_optim, env_lr, env_epochs, replay_batch_size, steps_per_save, device)
    cult_trainer.train(fashion_train_loader, epochs, [fashion_test_batch, mnist_test_batch])
    cult_trainer.train(mnist_train_loader, epochs, [fashion_test_batch, mnist_test_batch])
    cult_trainer.train_latent_classifiers([fashion_train_loader, mnist_train_loader], [fashion_test_loader, mnist_test_loader], ["fashion_mnist", "mnist"], [10, 10], epochs=3, lr=1e-2)
    cult_trainer.model.eval()
    fashion_rec_loss = cult_trainer.rec_loss(fashion_test_loader)
    mnist_rec_loss = cult_trainer.rec_loss(mnist_test_loader)
    cult_trainer.writer.add_scalar("eval/fashion_rec_loss", fashion_rec_loss)
    cult_trainer.writer.add_scalar("eval/mnist_recc_loss", mnist_rec_loss)
    fashion_env_acc = cult_trainer.env_accuracy(fashion_test_loader, 0)
    mnist_env_acc = cult_trainer.env_accuracy(mnist_test_loader, 1)
    cult_trainer.writer.add_scalar("eval/fashion_env_acc", fashion_env_acc)
    cult_trainer.writer.add_scalar("eval/mnist_env_acc", mnist_env_acc)

epoch: 0, loss=334.8923034667969, rec_loss=323.24334716796875, total_div_loss=11.648744583129883
epoch: 1, loss=294.34661865234375, rec_loss=280.0571594238281, total_div_loss=14.289599418640137
epoch: 2, loss=285.8857727050781, rec_loss=270.50286865234375, total_div_loss=15.382989883422852
epoch: 3, loss=282.7018127441406, rec_loss=266.8443298339844, total_div_loss=15.857505798339844
epoch: 4, loss=281.1229248046875, rec_loss=264.9488220214844, total_div_loss=16.174055099487305
epoch: 5, loss=280.21435546875, rec_loss=263.57806396484375, total_div_loss=16.636117935180664
epoch: 6, loss=279.5182189941406, rec_loss=262.634765625, total_div_loss=16.883302688598633
epoch: 7, loss=278.9308166503906, rec_loss=261.93914794921875, total_div_loss=16.99139976501465
epoch: 8, loss=278.6372985839844, rec_loss=261.5247802734375, total_div_loss=17.112462997436523
epoch: 9, loss=278.3917541503906, rec_loss=261.1748046875, total_div_loss=17.21686553955078
epoch: 0, loss=184.9942626953125, rec_loss=175

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 0 acc tensor(0.7438)
epoch 1 acc tensor(0.7671)
epoch 2 acc tensor(0.7724)
epoch 0 acc tensor(0.7176)
epoch 1 acc tensor(0.7701)
epoch 2 acc tensor(0.7757)


 14%|█▍        | 1/7 [00:23<02:23, 23.99s/it]

epoch 0 acc tensor(0.7400)
epoch 1 acc tensor(0.7533)
epoch 2 acc tensor(0.7582)
epoch 0 acc tensor(0.6762)
epoch 1 acc tensor(0.7087)
epoch 2 acc tensor(0.7198)


 29%|██▊       | 2/7 [00:48<02:01, 24.25s/it]

epoch 0 acc tensor(0.7527)
epoch 1 acc tensor(0.7448)
epoch 2 acc tensor(0.7718)
epoch 0 acc tensor(0.6611)
epoch 1 acc tensor(0.6954)
epoch 2 acc tensor(0.7018)


 43%|████▎     | 3/7 [01:12<01:37, 24.34s/it]

epoch 0 acc tensor(0.7369)
epoch 1 acc tensor(0.7543)
epoch 2 acc tensor(0.7599)
epoch 0 acc tensor(0.7442)
epoch 1 acc tensor(0.7602)
epoch 2 acc tensor(0.7647)


 57%|█████▋    | 4/7 [01:37<01:13, 24.54s/it]

epoch 0 acc tensor(0.7206)
epoch 1 acc tensor(0.7317)
epoch 2 acc tensor(0.7386)
epoch 0 acc tensor(0.8041)
epoch 1 acc tensor(0.8213)
epoch 2 acc tensor(0.8309)


 71%|███████▏  | 5/7 [02:01<00:48, 24.28s/it]

epoch 0 acc tensor(0.6983)
epoch 1 acc tensor(0.7223)
epoch 2 acc tensor(0.7194)
epoch 0 acc tensor(0.8107)
epoch 1 acc tensor(0.8187)
epoch 2 acc tensor(0.8291)


 86%|████████▌ | 6/7 [02:25<00:24, 24.24s/it]

epoch 0 acc tensor(0.6768)
epoch 1 acc tensor(0.6930)
epoch 2 acc tensor(0.6768)
epoch 0 acc tensor(0.8200)
epoch 1 acc tensor(0.8334)
epoch 2 acc tensor(0.8406)


100%|██████████| 7/7 [02:49<00:00, 24.18s/it]


epoch: 0, loss=333.8553466796875, rec_loss=321.3582763671875, total_div_loss=12.49676513671875
epoch: 1, loss=293.4124450683594, rec_loss=278.4168395996094, total_div_loss=14.995670318603516
epoch: 2, loss=285.5057678222656, rec_loss=269.786376953125, total_div_loss=15.719441413879395
epoch: 3, loss=282.064697265625, rec_loss=265.8401184082031, total_div_loss=16.224658966064453
epoch: 4, loss=280.67620849609375, rec_loss=264.04620361328125, total_div_loss=16.630428314208984
epoch: 5, loss=279.7281494140625, rec_loss=263.00335693359375, total_div_loss=16.724855422973633
epoch: 6, loss=279.21051025390625, rec_loss=262.27008056640625, total_div_loss=16.94050407409668
epoch: 7, loss=278.77838134765625, rec_loss=261.748779296875, total_div_loss=17.02937126159668
epoch: 8, loss=278.4950256347656, rec_loss=261.3277282714844, total_div_loss=17.167144775390625
epoch: 9, loss=278.161376953125, rec_loss=260.97271728515625, total_div_loss=17.18875503540039
epoch: 0, loss=183.8594970703125, rec_los

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 0 acc tensor(0.7547)
epoch 1 acc tensor(0.7772)
epoch 2 acc tensor(0.7773)
epoch 0 acc tensor(0.7487)
epoch 1 acc tensor(0.7909)
epoch 2 acc tensor(0.8000)


 14%|█▍        | 1/7 [00:25<02:35, 25.94s/it]

epoch 0 acc tensor(0.7563)
epoch 1 acc tensor(0.7614)
epoch 2 acc tensor(0.7735)
epoch 0 acc tensor(0.6741)
epoch 1 acc tensor(0.7110)
epoch 2 acc tensor(0.7453)


 29%|██▊       | 2/7 [00:51<02:10, 26.00s/it]

epoch 0 acc tensor(0.7567)
epoch 1 acc tensor(0.7628)
epoch 2 acc tensor(0.7712)
epoch 0 acc tensor(0.6838)
epoch 1 acc tensor(0.7301)
epoch 2 acc tensor(0.7608)


 43%|████▎     | 3/7 [01:17<01:43, 25.88s/it]

epoch 0 acc tensor(0.7373)
epoch 1 acc tensor(0.7540)
epoch 2 acc tensor(0.7439)
epoch 0 acc tensor(0.7646)
epoch 1 acc tensor(0.7916)
epoch 2 acc tensor(0.8103)


 57%|█████▋    | 4/7 [01:43<01:17, 25.83s/it]

epoch 0 acc tensor(0.7116)
epoch 1 acc tensor(0.7388)
epoch 2 acc tensor(0.7372)
epoch 0 acc tensor(0.8000)
epoch 1 acc tensor(0.8235)
epoch 2 acc tensor(0.8371)


 71%|███████▏  | 5/7 [02:09<00:51, 25.95s/it]

epoch 0 acc tensor(0.6789)
epoch 1 acc tensor(0.7220)
epoch 2 acc tensor(0.7172)
epoch 0 acc tensor(0.8159)
epoch 1 acc tensor(0.8246)
epoch 2 acc tensor(0.8361)


 86%|████████▌ | 6/7 [02:36<00:26, 26.30s/it]

epoch 0 acc tensor(0.6945)
epoch 1 acc tensor(0.7416)
epoch 2 acc tensor(0.7449)
epoch 0 acc tensor(0.8422)
epoch 1 acc tensor(0.8495)
epoch 2 acc tensor(0.8605)


100%|██████████| 7/7 [03:03<00:00, 26.21s/it]


epoch: 0, loss=336.777587890625, rec_loss=325.4460144042969, total_div_loss=11.33141040802002
epoch: 1, loss=293.6722106933594, rec_loss=278.96771240234375, total_div_loss=14.704360961914062
epoch: 2, loss=285.90093994140625, rec_loss=270.6869812011719, total_div_loss=15.21430492401123
epoch: 3, loss=283.06939697265625, rec_loss=267.1966552734375, total_div_loss=15.872753143310547
epoch: 4, loss=281.513427734375, rec_loss=265.2506408691406, total_div_loss=16.26283836364746
epoch: 5, loss=280.5991516113281, rec_loss=264.10980224609375, total_div_loss=16.48932456970215
epoch: 6, loss=280.1010437011719, rec_loss=263.3803405761719, total_div_loss=16.720712661743164
epoch: 7, loss=279.6591796875, rec_loss=262.89306640625, total_div_loss=16.765905380249023
epoch: 8, loss=279.28265380859375, rec_loss=262.40655517578125, total_div_loss=16.87615394592285
epoch: 9, loss=279.1617736816406, rec_loss=262.15960693359375, total_div_loss=17.00200843811035
epoch: 0, loss=185.7779541015625, rec_loss=176

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 0 acc tensor(0.7570)
epoch 1 acc tensor(0.7718)
epoch 2 acc tensor(0.7689)
epoch 0 acc tensor(0.6807)
epoch 1 acc tensor(0.7272)
epoch 2 acc tensor(0.7463)


 14%|█▍        | 1/7 [00:24<02:26, 24.35s/it]

epoch 0 acc tensor(0.7589)
epoch 1 acc tensor(0.7738)
epoch 2 acc tensor(0.7761)
epoch 0 acc tensor(0.6771)
epoch 1 acc tensor(0.7269)
epoch 2 acc tensor(0.7464)


 29%|██▊       | 2/7 [00:49<02:02, 24.58s/it]

epoch 0 acc tensor(0.7572)
epoch 1 acc tensor(0.7675)
epoch 2 acc tensor(0.7742)
epoch 0 acc tensor(0.6775)
epoch 1 acc tensor(0.7209)
epoch 2 acc tensor(0.7472)


 43%|████▎     | 3/7 [01:13<01:38, 24.51s/it]

epoch 0 acc tensor(0.7554)
epoch 1 acc tensor(0.7620)
epoch 2 acc tensor(0.7733)
epoch 0 acc tensor(0.7784)
epoch 1 acc tensor(0.7988)
epoch 2 acc tensor(0.8043)


 57%|█████▋    | 4/7 [01:37<01:13, 24.38s/it]

epoch 0 acc tensor(0.7479)
epoch 1 acc tensor(0.7524)
epoch 2 acc tensor(0.7585)
epoch 0 acc tensor(0.8032)
epoch 1 acc tensor(0.8389)
epoch 2 acc tensor(0.8439)


 71%|███████▏  | 5/7 [02:02<00:48, 24.37s/it]

epoch 0 acc tensor(0.7125)
epoch 1 acc tensor(0.7091)
epoch 2 acc tensor(0.7312)
epoch 0 acc tensor(0.8313)
epoch 1 acc tensor(0.8349)
epoch 2 acc tensor(0.8495)


 86%|████████▌ | 6/7 [02:27<00:24, 24.64s/it]

epoch 0 acc tensor(0.6988)
epoch 1 acc tensor(0.7069)
epoch 2 acc tensor(0.7262)
epoch 0 acc tensor(0.8187)
epoch 1 acc tensor(0.8307)
epoch 2 acc tensor(0.8460)


100%|██████████| 7/7 [02:52<00:00, 24.69s/it]


epoch: 0, loss=335.5830078125, rec_loss=323.99481201171875, total_div_loss=11.588252067565918
epoch: 1, loss=294.2428894042969, rec_loss=279.83026123046875, total_div_loss=14.412590980529785
epoch: 2, loss=285.54644775390625, rec_loss=270.1600036621094, total_div_loss=15.386816024780273
epoch: 3, loss=282.3013000488281, rec_loss=266.130615234375, total_div_loss=16.170581817626953
epoch: 4, loss=280.7105407714844, rec_loss=264.1639709472656, total_div_loss=16.5465030670166
epoch: 5, loss=280.0584716796875, rec_loss=263.1952209472656, total_div_loss=16.86334800720215
epoch: 6, loss=279.4309387207031, rec_loss=262.504150390625, total_div_loss=16.926692962646484
epoch: 7, loss=278.8711242675781, rec_loss=261.8187561035156, total_div_loss=17.052818298339844
epoch: 8, loss=278.6446533203125, rec_loss=261.4794921875, total_div_loss=17.165231704711914
epoch: 9, loss=278.4310302734375, rec_loss=261.1632080078125, total_div_loss=17.267921447753906
epoch: 0, loss=183.92221069335938, rec_loss=173.

  0%|          | 0/7 [00:00<?, ?it/s]

epoch 0 acc tensor(0.7707)
epoch 1 acc tensor(0.7750)
epoch 2 acc tensor(0.7849)
epoch 0 acc tensor(0.7189)
epoch 1 acc tensor(0.7662)
epoch 2 acc tensor(0.7794)


 14%|█▍        | 1/7 [00:26<02:36, 26.09s/it]

epoch 0 acc tensor(0.7513)
epoch 1 acc tensor(0.7613)
epoch 2 acc tensor(0.7779)
epoch 0 acc tensor(0.6728)
epoch 1 acc tensor(0.7044)
epoch 2 acc tensor(0.7329)


 29%|██▊       | 2/7 [00:52<02:10, 26.09s/it]

epoch 0 acc tensor(0.7584)
epoch 1 acc tensor(0.7690)
epoch 2 acc tensor(0.7650)
epoch 0 acc tensor(0.6922)
epoch 1 acc tensor(0.7157)
epoch 2 acc tensor(0.7537)


 43%|████▎     | 3/7 [01:18<01:44, 26.18s/it]

epoch 0 acc tensor(0.7439)
epoch 1 acc tensor(0.7599)
epoch 2 acc tensor(0.7637)
epoch 0 acc tensor(0.7560)
epoch 1 acc tensor(0.7798)
epoch 2 acc tensor(0.7954)


 57%|█████▋    | 4/7 [01:44<01:18, 26.27s/it]

epoch 0 acc tensor(0.6857)
epoch 1 acc tensor(0.7162)
epoch 2 acc tensor(0.7201)
epoch 0 acc tensor(0.8126)
epoch 1 acc tensor(0.8308)
epoch 2 acc tensor(0.8403)


 71%|███████▏  | 5/7 [02:13<00:54, 27.03s/it]

epoch 0 acc tensor(0.6866)
epoch 1 acc tensor(0.7004)
epoch 2 acc tensor(0.7146)
epoch 0 acc tensor(0.8242)
epoch 1 acc tensor(0.8321)
epoch 2 acc tensor(0.8402)


 86%|████████▌ | 6/7 [02:38<00:26, 26.54s/it]

epoch 0 acc tensor(0.6695)
epoch 1 acc tensor(0.6858)
epoch 2 acc tensor(0.6897)
epoch 0 acc tensor(0.8254)
epoch 1 acc tensor(0.8383)
epoch 2 acc tensor(0.8448)


100%|██████████| 7/7 [03:04<00:00, 26.41s/it]
