In [None]:
#default_exp core.clvae

In [None]:
#export
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from cmaes import CMA
import os
import numpy as np
import copy

from vase.core.models import FCEncoder, FCDecoder, Encoder, Decoder, EnvironmentInference
from vase.core.utils import rec_likelihood, disable_gradient, kl_div_stdnorm, euclidean, show_batch
from vase.config import DATA_PATH, LOG_PATH

In [None]:
#hide
from vase.core.datasets.moving_mnist import CommonMNIST, CommonFashionMNIST, MovingMNIST, MovingFashionMNIST, FixedMNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import datetime as dt

In [None]:
batch_size = 64

In [None]:
#hide
mnist_data = CommonMNIST(DATA_PATH, transform=ToTensor(), download=True)
mnist_loader = DataLoader(mnist_data, batch_size=batch_size, shuffle=True)
mnist_batch, _ = iter(mnist_loader).next()
small_fashion = CommonFashionMNIST(DATA_PATH, transform=ToTensor())
sf_loader = DataLoader(small_fashion, batch_size=batch_size, shuffle=True)
sf_batch, _ = iter(sf_loader).next()
fashion_data = MovingFashionMNIST(DATA_PATH, transform=ToTensor(), download=True)
fashion_loader = DataLoader(fashion_data, batch_size, shuffle=True)
fashion_batch, _, _ = iter(fashion_loader).next()
mm_data = MovingMNIST(DATA_PATH, transform=ToTensor(), download=True)
mm_loader = DataLoader(mm_data, batch_size, shuffle=True)
mm_batch, _, _ = iter(mm_loader).next()

# Continually Learning Variational Autoencoder
> Performs unsupervised continual representation learning with generative replay and environment likelihood detection

Notes: assumes sequential tasks - "local stationarity" i.e. will stay on distribution long enough to learn it

In [None]:
#export
class CLVAE(nn.Module):
    def __init__(self, 
        encoder_type: type,
        decoder_type: type,
        final_size: int, 
        latents: int,
        max_envs: int,
        atyp_thresh: float,
        env_optim: type,
        env_lr: float,
        env_epochs: int,
        replay_batch_size: int,
        device: str,
    ):
        super().__init__()
        self.latents = latents
        self.max_envs = max_envs
        self.final_size = final_size
        self.atyp_thresh = atyp_thresh
        self.device = device
        self.encoder = encoder_type(self.latents, device=self.device)
        self.decoder = decoder_type(self.latents, self.max_envs, device=self.device)
        self.old_encoder = encoder_type(self.latents, device=self.device)
        self.old_decoder = decoder_type(self.latents, self.max_envs, device=self.device)
        self.copy_and_freeze()
        self.encoder.to(self.device), self.decoder.to(self.device), self.old_encoder.to(self.device), self.old_decoder.to(self.device)
        self.replay_batch_size = replay_batch_size
        self.env_net = EnvironmentInference(self.max_envs, self.final_size)
        self.env_net.to(self.device)
        self.env_optim = env_optim(params=self.env_net.parameters(), lr=env_lr)
        self.env_loss = nn.CrossEntropyLoss()
        self.env_epochs = env_epochs
       
        self.m = -1
        self.steps = 0
        self.learning = False

    def forward(self, x):
        batch_size = x.shape[0]
        self.steps += 1

        x_halu, s_halu = self.sample_old()
        
        mu, logvar, final = self.encoder(x)

        if not self.training:
            z = mu 
            s = torch.argmax(self.env_net(final)[0:self.m+1])
            rec_x = self.decoder(z, self.int_to_vec(s, batch_size))
            atyp = self.get_atyp(z)
            mu_halu_old, logvar_halu_old, final_halu_old = self.old_encoder(x_halu)
            mu_halu, logvar_halu, final_halu = self.encoder(x_halu)
            rec_x_halo = self.decoder(mu_halu, s_halu)
            return rec_x, mu, logvar, x_halu, rec_x_halo, z_halu_old, z_halu, s_halu, atyp
        
        s = self.m if self.m != -1 else 0
        z = self.reparam(mu, logvar)
        rec_x = self.decoder(z, self.int_to_vec(s, batch_size))

        atyp = self.get_atyp(z)

        if self.m == -1:
            if atyp > self.atyp_thresh:
                self.m = 0
                self.learning = True
        
        elif self.learning:
            if atyp < self.atyp_thresh:
                self.learning = False
        
        elif atyp > self.atyp_thresh:
            self.learning = True
            self.m += 1
            self.copy_and_freeze()
        
        
        with torch.no_grad():
            mu_halu_old, logvar_halu_old, final_halu_old = self.old_encoder(x_halu)
            z_halu_old = self.reparam(mu_halu_old, logvar_halu_old)
        
        mu_halu, logvar_halu, final_halu = self.encoder(x_halu)
        z_halu = self.reparam(mu_halu, logvar_halu)
        rec_x_halo = self.decoder(z_halu, s_halu)

        self.train_env_network(final, s, final_halu_old, s_halu)

        return rec_x, mu, logvar, x_halu, rec_x_halo, z_halu_old, z_halu, s_halu, atyp

    def get_likely_env(self, final):
        env_logits = self.env_net(final)
        avg_env_logits = torch.mean(env_logits, dim=0)
        valid_logits = avg_env_logits[0:self.m+1]
        return torch.argmax(valid_logits), env_logits
    
    def get_atyp(self, z):
        with torch.no_grad():
            std, mean = torch.std_mean(z, dim=0)
            std = std[:,None]
            mean = mean[:, None]
            logvar = torch.log(std.pow(2))
            atyps = kl_div_stdnorm(mean, logvar)
        return torch.sum(atyps)
    
    def reparam(self, mu, logvar):
        eps = torch.randn(logvar.shape).to(self.device)
        std = (0.5 * logvar).exp()
        return mu + std * eps
    
    def int_to_vec(self, i, size):
        return torch.ones([size], dtype=torch.int64).to(self.device) * i
    
    def copy_model(self, old_model, cur_model):
        old_model.load_state_dict(cur_model.state_dict())
    
    def freeze_model(self, model):
        disable_gradient(model)
    
    def copy_and_freeze(self):
        self.copy_model(self.old_encoder, self.encoder)
        self.freeze_model(self.old_encoder)
        self.copy_model(self.old_decoder, self.decoder)
        self.freeze_model(self.old_decoder)
    
    def train_env_network(self, final, s, final_halu_old, s_halu):
        env_logits = self.env_net(final)
        final_halu_old = final_halu_old[s != s_halu]
        s_halu = s_halu[s != s_halu]
        self.env_optim.zero_grad()
        cur_loss = self.env_loss(env_logits, self.int_to_vec(s, env_logits.shape[0])) #don't know if dims work here
        if len(s_halu) > 0:
            env_logits_halu = self.env_net(final_halu_old)
            replay_loss = self.env_loss(env_logits_halu, s_halu)
        else: 
            replay_loss = 0
        loss = cur_loss + replay_loss
        loss.backward(retain_graph=True)
        self.env_optim.step()
    
    def sample_old(self):
        max_env = self.m+1 if self.m != -1 else 1
        s = torch.randint(0, max_env, (self.replay_batch_size,)).to(self.device)
        z = torch.randn([self.replay_batch_size, self.latents]).to(self.device)
        with torch.no_grad():
            halu_x = self.old_decoder(z, s)
        return halu_x, s

In [None]:
def train(model, loader, epochs, alpha, beta, optimizer, writer):
    for epoch in range(epochs):
        total_loss = 0
        total_rec_loss = 0
        total_div_loss = 0
        for contents in loader:
            X = contents[0]
            X = X.to(device)
            optimizer.zero_grad()

            rec_x, mu, logvar, x_halu, rec_x_halo, z_halu_old, z_halu, s_halu, atyp = model(X)

            rec_loss = torch.mean(rec_likelihood(X, rec_x))
            kl_loss = gamma * torch.mean(torch.square(kl_div_stdnorm(mu, logvar))) #NOTE: should I square this?
            mdl_loss = rec_loss + kl_loss

            e_prox_loss = alpha * torch.mean(euclidean(z_halu, z_halu_old))
            d_prox_loss = beta * torch.mean(rec_likelihood(x_halu, rec_x_halo)) #NOTE: not sure if this order is correct
            dream_loss = e_prox_loss + d_prox_loss

            loss = mdl_loss + dream_loss

            loss.backward(retain_graph=True)
            optimizer.step()

            writer.add_scalar("train/loss", loss, model.steps)
            writer.add_scalar("train/rec_loss", rec_loss, model.steps)
            writer.add_scalar("train/kl_loss", kl_loss, model.steps)
            writer.add_scalar("train/e_prox_loss", e_prox_loss, model.steps)
            writer.add_scalar("train/d_prox_loss", d_prox_loss, model.steps)
            writer.add_scalar("train/num_envs", model.m+1, model.steps)
            writer.add_scalar("train/atypicality", atyp, model.steps)
            total_loss += loss
            total_rec_loss += rec_loss
            total_div_loss += kl_loss
        print(f"epoch: {epoch}, loss={total_loss/batch_size}, rec_loss={total_rec_loss/batch_size}, total_div_loss={total_div_loss/batch_size}")

In [None]:
encoder_type = FCEncoder
decoder_type = FCDecoder
final_size = 50
latents = 8
max_envs = 7
atyp_thresh = .5
env_optim = torch.optim.Adam
env_lr = 6e-4
env_epochs = 10
replay_batch_size = 64
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

clvae_fc = CLVAE(encoder_type, decoder_type, final_size, latents, max_envs, atyp_thresh, env_optim, env_lr, env_epochs, replay_batch_size, device)

epochs=10
lr=6e-4
gamma=20
alpha=1
beta=10
optimizer = torch.optim.Adam(params=clvae_fc.parameters(), lr=lr)
name = "clvae_fc" + dt.datetime.now().strftime('-%Y-%m-%d-%H-%M-%S')
writer = SummaryWriter(os.path.join(LOG_PATH, name))

In [None]:
train(clvae_fc, sf_loader, epochs, alpha, beta, optimizer, writer)

RuntimeError: Class values must be smaller than num_classes.

In [None]:
train(clvae_fc, mnist_loader, epochs, alpha, beta, optimizer, writer)

epoch: 0, loss=82291.953125, rec_loss=2983.999755859375, total_div_loss=62.65172576904297
epoch: 1, loss=82214.171875, rec_loss=2909.527587890625, total_div_loss=61.54328155517578
epoch: 2, loss=82199.28125, rec_loss=2897.67236328125, total_div_loss=65.13917541503906
epoch: 3, loss=82181.3984375, rec_loss=2888.50830078125, total_div_loss=67.56062316894531
epoch: 4, loss=82173.90625, rec_loss=2885.269287109375, total_div_loss=68.66284942626953
epoch: 5, loss=82165.140625, rec_loss=2879.69970703125, total_div_loss=70.47335815429688
epoch: 6, loss=82161.03125, rec_loss=2878.7900390625, total_div_loss=71.29312896728516
epoch: 7, loss=82157.9296875, rec_loss=2876.942138671875, total_div_loss=71.86474609375
epoch: 8, loss=82154.1875, rec_loss=2875.462890625, total_div_loss=71.8026123046875
epoch: 9, loss=82150.3046875, rec_loss=2873.46240234375, total_div_loss=72.2280044555664
