In [None]:
#default_exp core.clvae

In [None]:
#export
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from cmaes import CMA
import os
import numpy as np
import copy

from vase.core.models import FCEncoder, FCDecoder, Encoder, Decoder, EnvironmentInference
from vase.core.utils import rec_likelihood, disable_gradient, kl_div_stdnorm, euclidean, show_batch, save_model, load_model
from vase.config import DATA_PATH, LOG_PATH, PARAM_PATH

In [None]:
#hide
from vase.core.datasets.moving_mnist import CommonMNIST, CommonFashionMNIST, MovingMNIST, MovingFashionMNIST, FixedMNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import datetime as dt
from tqdm import tqdm

In [None]:
batch_size = 64

In [None]:
#hide
mnist_data = CommonMNIST(DATA_PATH, transform=ToTensor(), download=True)
mnist_loader = DataLoader(mnist_data, batch_size=batch_size, shuffle=True)
mnist_batch, _ = iter(mnist_loader).next()
small_fashion = CommonFashionMNIST(DATA_PATH, transform=ToTensor())
sf_loader = DataLoader(small_fashion, batch_size=batch_size, shuffle=True)
sf_batch, _ = iter(sf_loader).next()
fashion_data = MovingFashionMNIST(DATA_PATH, transform=ToTensor(), download=True)
fashion_loader = DataLoader(fashion_data, batch_size, shuffle=True)
fashion_batch, _, _ = iter(fashion_loader).next()
mm_data = MovingMNIST(DATA_PATH, transform=ToTensor(), download=True)
mm_loader = DataLoader(mm_data, batch_size, shuffle=True)
mm_batch, _, _ = iter(mm_loader).next()

# Continually Learning Variational Autoencoder
> Performs unsupervised continual representation learning with generative replay and environment likelihood detection

Notes: assumes sequential tasks - "local stationarity" i.e. will stay on distribution long enough to learn it

In [None]:
#export
class CLVAE(nn.Module):
    def __init__(self, 
        encoder_type: type,
        decoder_type: type,
        final_size: int, 
        latents: int,
        max_envs: int,
        atyp_min: float,
        atyp_max: float,
        env_optim: type,
        env_lr: float,
        env_epochs: int,
        replay_batch_size: int,
        device: str,
    ):
        super().__init__()
        self.latents = latents
        self.max_envs = max_envs
        self.final_size = final_size
        self.atyp_min = atyp_min
        self.atyp_max = atyp_max
        self.device = device
        self.encoder = encoder_type(self.latents, device=self.device)
        self.decoder = decoder_type(self.latents, self.max_envs, device=self.device)
        self.old_encoder = encoder_type(self.latents, device=self.device)
        self.old_decoder = decoder_type(self.latents, self.max_envs, device=self.device)
        self.copy_and_freeze()
        self.encoder.to(self.device), self.decoder.to(self.device), self.old_encoder.to(self.device), self.old_decoder.to(self.device)
        self.replay_batch_size = replay_batch_size
        self.env_net = EnvironmentInference(self.max_envs, self.final_size)
        self.env_net.to(self.device)
        self.env_optim = env_optim(params=self.env_net.parameters(), lr=env_lr)
        self.env_loss = nn.CrossEntropyLoss()
        self.env_epochs = env_epochs
       
        self.m = -1
        self.steps = 0
        self.learning = False

    def forward(self, x):
        batch_size = x.shape[0]
        self.steps += 1

        x_halu, s_halu = self.sample_old()
        
        mu, logvar, final = self.encoder(x)

        if not self.training:
            z = mu 
            s = torch.argmax(self.env_net(final)[:, 0:self.m+1], dim=1)
            rec_x = self.decoder(z, self.int_to_vec(s, batch_size))
            atyp = self.get_atyp(z)
            mu_halu_old, logvar_halu_old, final_halu_old = self.old_encoder(x_halu)
            mu_halu, logvar_halu, final_halu = self.encoder(x_halu)
            rec_x_halo = self.decoder(mu_halu, s_halu)
            return rec_x, mu, logvar, x_halu, rec_x_halo, mu_halu_old, mu_halu, s_halu, atyp
        
        s = self.m if self.m != -1 else 0
        z = self.reparam(mu, logvar)
        rec_x = self.decoder(z, self.int_to_vec(s, batch_size))

        atyp = self.get_atyp(z)

        if self.m == -1:
            if atyp > self.atyp_max:
                self.m = 0
                self.learning = True
        
        elif self.learning:
            if atyp < self.atyp_min:
                self.learning = False
        
        elif atyp > self.atyp_max:
            self.learning = True
            self.m += 1
            if self.m > self.max_envs:
                print("Warning: too many environments")
            self.copy_and_freeze()
        
        
        with torch.no_grad():
            mu_halu_old, logvar_halu_old, final_halu_old = self.old_encoder(x_halu)
            z_halu_old = self.reparam(mu_halu_old, logvar_halu_old)
        
        mu_halu, logvar_halu, final_halu = self.encoder(x_halu)
        z_halu = self.reparam(mu_halu, logvar_halu)
        rec_x_halo = self.decoder(z_halu, s_halu)

        if self.training:
            self.train_env_network(final, s, final_halu_old, s_halu)

        return rec_x, mu, logvar, x_halu, rec_x_halo, z_halu_old, z_halu, s_halu, atyp

    def get_likely_env(self, final):
        env_logits = self.env_net(final)
        avg_env_logits = torch.mean(env_logits, dim=0)
        valid_logits = avg_env_logits[0:self.m+1]
        return torch.argmax(valid_logits), env_logits
    
    def get_atyp(self, z):
        with torch.no_grad():
            std, mean = torch.std_mean(z, dim=0)
            std = std[:,None]
            mean = mean[:, None]
            logvar = torch.log(std.pow(2))
            atyps = kl_div_stdnorm(mean, logvar)
        return torch.sum(atyps)
    
    def reparam(self, mu, logvar):
        eps = torch.randn(logvar.shape).to(self.device)
        std = (0.5 * logvar).exp()
        return mu + std * eps
    
    def int_to_vec(self, i, size):
        return torch.ones([size], dtype=torch.int64).to(self.device) * i
    
    def copy_model(self, old_model, cur_model):
        old_model.load_state_dict(cur_model.state_dict())
    
    def freeze_model(self, model):
        disable_gradient(model)
    
    def copy_and_freeze(self):
        self.copy_model(self.old_encoder, self.encoder)
        self.freeze_model(self.old_encoder)
        self.copy_model(self.old_decoder, self.decoder)
        self.freeze_model(self.old_decoder)
    
    def train_env_network(self, final, s, final_halu_old, s_halu):
        env_logits = self.env_net(final)
        final_halu_old = final_halu_old[s != s_halu]
        s_halu = s_halu[s != s_halu]
        self.env_optim.zero_grad()
        cur_loss = self.env_loss(env_logits, self.int_to_vec(s, env_logits.shape[0])) #don't know if dims work here
        if len(s_halu) > 0:
            env_logits_halu = self.env_net(final_halu_old)
            replay_loss = self.env_loss(env_logits_halu, s_halu)
        else: 
            replay_loss = 0
        loss = cur_loss + replay_loss
        loss.backward(retain_graph=True)
        self.env_optim.step()
    
    def sample_old(self):
        max_env = self.m+1 if self.m != -1 else 1
        s = torch.randint(0, max_env, (self.replay_batch_size,)).to(self.device)
        z = torch.randn([self.replay_batch_size, self.latents]).to(self.device)
        with torch.no_grad():
            halu_x = self.old_decoder(z, s)
        return halu_x, s

In [None]:
def train(model, loader, epochs, alpha, beta, optimizer, writer, steps_per_save, param_dir):
    for epoch in range(epochs):
        total_loss = 0
        total_rec_loss = 0
        total_div_loss = 0
        for contents in loader:
            X = contents[0]
            X = X.to(device)
            optimizer.zero_grad()

            rec_x, mu, logvar, x_halu, rec_x_halo, z_halu_old, z_halu, s_halu, atyp = model(X)

            rec_loss = torch.mean(rec_likelihood(X, rec_x))
            kl_loss = gamma * torch.mean(torch.square(kl_div_stdnorm(mu, logvar))) #NOTE: should I square this?
            mdl_loss = rec_loss + kl_loss

            e_prox_loss = alpha * torch.mean(euclidean(z_halu, z_halu_old))
            d_prox_loss = beta * torch.mean(rec_likelihood(x_halu, rec_x_halo)) #NOTE: not sure if this order is correct
            dream_loss = e_prox_loss + d_prox_loss

            loss = mdl_loss + dream_loss

            loss.backward(retain_graph=True)
            optimizer.step()

            if model.steps % steps_per_save == 0:
                save_model(model, os.path.join(param_dir, f"step_{model.steps}"))

            writer.add_scalar("train/loss", loss, model.steps)
            writer.add_scalar("train/rec_loss", rec_loss, model.steps)
            writer.add_scalar("train/kl_loss", kl_loss, model.steps)
            writer.add_scalar("train/e_prox_loss", e_prox_loss, model.steps)
            writer.add_scalar("train/d_prox_loss", d_prox_loss, model.steps)
            writer.add_scalar("train/num_envs", model.m+1, model.steps)
            writer.add_scalar("train/atypicality", atyp, model.steps)
            total_loss += loss
            total_rec_loss += rec_loss
            total_div_loss += kl_loss
        print(f"epoch: {epoch}, loss={total_loss/batch_size}, rec_loss={total_rec_loss/batch_size}, total_div_loss={total_div_loss/batch_size}")

In [None]:
encoder_type = FCEncoder
decoder_type = FCDecoder
final_size = 50
latents = 8
max_envs = 7
atyp_min = 0.1
atyp_max = .7
env_optim = torch.optim.Adam
env_lr = 6e-4
env_epochs = 10
replay_batch_size = 64
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

clvae_fc = CLVAE(encoder_type, decoder_type, final_size, latents, max_envs, atyp_min, atyp_max, env_optim, env_lr, env_epochs, replay_batch_size, device)

epochs=10
lr=6e-4
gamma=4
alpha=1
beta=1
steps_per_save = 500
optimizer = torch.optim.Adam(params=clvae_fc.parameters(), lr=lr)
name = "clvae_fc" + dt.datetime.now().strftime('-%Y-%m-%d-%H-%M-%S')
writer = SummaryWriter(os.path.join(LOG_PATH, name))
param_dir = os.path.join(PARAM_PATH, name)
os.mkdir(param_dir)

In [None]:
train(clvae_fc, sf_loader, epochs, alpha, beta, optimizer, writer, steps_per_save, param_dir)

In [None]:
train(clvae_fc, mnist_loader, epochs, alpha, beta, optimizer, writer, steps_per_save, name)

In [None]:
clvae_fc.eval()

In [None]:
with torch.no_grad():
    results = clvae_fc(sf_batch)
rec_img = results[0]

In [None]:
show_batch(sf_batch[0:32])

In [None]:
show_batch(rec_img[0:32])

In [None]:
with torch.no_grad():
    results = clvae_fc(mnist_batch)
rec_img = results[0]

In [None]:
show_batch(mnist_batch[0:32])

In [None]:
show_batch(rec_img[0:32])

## Classifiers on Latent Space

Look at architecture used in vase (just linear?)

define inner training loop 
define outer training loop with iterately loads trained models and logs accuracy of trained classifiers

In [None]:
#export
class LatentClassifier(nn.Module):
    def __init__(self, latents, hidden_size, n_classes):
        super().__init__()
        self.linear1 = nn.Linear(latents, hidden_size)
        self.linear2 = nn.Linear(hidden_size, n_classes)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.linear1(x))
        logits = self.linear2(x)
        return logits

In [None]:
#export
def train_classifier(model, train_loader, test_loader, epochs, n_classes=10, hidden_size=50, lr=1e-3):
    model.eval()
    classifier = LatentClassifier(latents, hidden_size, n_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(params=classifier.parameters(), lr=lr)
    for epoch in range(epochs):
        total_loss = 0
        for X, y in train_loader:
            optimizer.zero_grad()
            with torch.no_grad():
                Z = model.encoder(X)[0]
            logits = classifier(Z)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()
            total_loss += loss
    
    total_acc = 0
    size = 0
    for x_test, y_test in test_loader:
        size += x_test.shape[0]
        with torch.no_grad():
            Z = model.encoder(x_test)[0]
            logits = classifier(Z)
        y_hat = torch.argmax(logits, dim=1)
        total_acc += (y_test == y_hat).sum()
    return total_acc / size

In [None]:
mnist_train = CommonMNIST(DATA_PATH, transform=ToTensor(), train=True, download=True)
mnist_train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
mnist_test = CommonMNIST(DATA_PATH, transform=ToTensor(), train=False, download=True)
mnist_test_loader =  DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

fashion_train = CommonFashionMNIST(DATA_PATH, transform=ToTensor(), train=True, download=True)
fashion_train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
fashion_test = CommonFashionMNIST(DATA_PATH, transform=ToTensor(), train=False, download=True)
fashion_test_loader =  DataLoader(mnist_test, batch_size=batch_size, shuffle=True)

In [None]:
#train classifiers
cl_vae = CLVAE(encoder_type, decoder_type, final_size, latents, max_envs, atyp_min, atyp_max, env_optim, env_lr, env_epochs, replay_batch_size, device)

for param_path in tqdm(os.listdir(param_dir)):
    load_model(cl_vae, os.path.join(param_dir, param_path))
    mnist_acc = train_classifier(cl_vae, mnist_train_loader, mnist_test_loader, epochs=5, lr=1e-1)
    fashion_acc = train_classifier(cl_vae, fashion_train_loader, fashion_test_loader, epochs=5, lr=1e-2)
    steps = int(param_path.split("_")[1])
    writer.add_scalar("classifiers/mnist_acc", mnist_acc, steps)
    writer.add_scalar("classifiers/fashion_acc", mnist_acc, steps)

TODO: add a cnn model, run it, call it a day