# Testing if we can decoder HMM from sequences deterministically

In [23]:
import sys
sys.path.append('/home/mila/l/leo.gagnon/latent_control')

%load_ext autoreload
%autoreload 2
from lightning_modules.diffusion_prior import DiffusionPriorTask
import torch
import matplotlib.pyplot as plt
from data.diffusion import LatentDiffusionDataset, LatentDiffusionDatasetConfig
from models.encoder import DiffusionEncoder
from torch2jax import j2t, t2j
import lightning as L
from models.x_transformer import Encoder, ScaledSinusoidalEmbedding
from lightning_modules.metalearn import MetaLearningTask
import torch.nn as nn
from torch.utils.data import random_split, DataLoader
from einops import rearrange, repeat
import jax.numpy as jnp

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [31]:
class DeterministicEncoder(L.LightningModule):
    def __init__(
        self,
        pretrained_id: str,
        n_embd_cond: int,
        n_embd: int,
        batch_size: int,
        val_split: float,
        lr: float,
        cond_encoder: bool
    ):
        super().__init__()

        self.batch_size = batch_size
        self.val_split = val_split
        self.lr = lr
        self.n_embd = n_embd
        self.cond_encoder = cond_encoder

        self.base_task = MetaLearningTask(pretrained_id)
        for param in self.base_task.parameters():
            param.requires_grad = False

        self.latent_model = Encoder(
            dim=n_embd,
            depth=3,
            heads=6,
            attn_dropout=0.0,  # dropout post-attention
            ff_dropout=0.0,  # feedforward dropout
            rel_pos_bias=False,
            ff_glu=True,
            cross_attend=True,
        )
        self.null_embedding = nn.Embedding(1, n_embd)

        if self.cond_encoder:
            self.seq_conditional_encoder = Encoder(
                dim=n_embd_cond,
                depth=3,
                heads=6,
                attn_dropout=0.0,  # dropout post-attention
                ff_dropout=0.0,  # feedforward dropout
                rel_pos_bias=False,
                ff_glu=True,
            )
            self.seq_conditional_emb = nn.Embedding(
                num_embeddings=50,
                embedding_dim=n_embd_cond,
            )
            self.seq_conditional_posemb = ScaledSinusoidalEmbedding(n_embd_cond)

        self.cond_proj = nn.Linear(n_embd_cond, n_embd)
        self.out_proj = nn.ModuleList(
            [
                nn.Linear(n_embd, latent_shape)
                for latent_shape in self.base_task.full_data.latent_shape
            ]
        )
        self.norm = nn.LayerNorm(n_embd)

    def setup(self, stage):
        with torch.no_grad():
            self.train_data = LatentDiffusionDataset(
                LatentDiffusionDatasetConfig(context_length=[200, 200], cond_tokens_type=None if self.cond_encoder else 'pretrained', latent_type=None),
                self.base_task,
                None,
            )

    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.lr)
        return opt

    def train_dataloader(self):
        return DataLoader(
            self.train_data,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=lambda x: x,
        )

    def training_step(self, batch, batch_idx=None):
        
        if self.cond_encoder:
            cond = self.seq_conditional_emb(batch['cond_input_ids'])
            cond = cond + self.seq_conditional_posemb(cond)
            cond = self.seq_conditional_encoder(cond)
        else:
            cond = batch['cond_tokens']
            
        cond = self.cond_proj(cond)

        init_emb = repeat(self.null_embedding.weight, "1 d -> b 1 d", b=batch['cond_input_ids'].shape[0])

        pred = self.latent_model(
            init_emb,
            context=cond,
        )
        pred = self.norm(pred)
        pred = [proj(pred) for proj in self.out_proj]

        loss = sum(
            [
                nn.functional.cross_entropy(pred[i].squeeze(), batch['raw_latent'][:, i]).mean()
                for i in range(len(pred))
            ]
        )
        acc = sum(
            [
                (pred[i].squeeze().argmax(1) == batch['raw_latent'][:, i]).float().mean()
                for i in range(len(pred))
            ]
        ) / len(pred)
        # loss = torch.mean((pred - latent)**2)

        self.log(
            "train/loss",
            loss.detach().cpu().numpy().item(),
            prog_bar=True,
        )
        self.log(
            "train/acc",
            acc.detach().cpu().numpy().item(),
            prog_bar=True,
        )

        return loss

In [25]:
trainer = L.Trainer(
        max_steps=10000,
        accelerator='gpu',
        enable_checkpointing=False,
        val_check_interval=None,
        reload_dataloaders_every_n_epochs=1,
        check_val_every_n_epoch=None,
    )

/home/mila/l/leo.gagnon/latent_control/venv/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/mila/l/leo.gagnon/latent_control/venv/lib/pyth ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [32]:
deterministic_encoder = DeterministicEncoder('y9qwghft', 384, 512, 512, 0.1, 1e-4, cond_encoder=True)

Loaded dataset : (11288/1000)
Loaded checkpoing : last.ckpt


In [33]:
trainer.fit(deterministic_encoder)

/home/mila/l/leo.gagnon/latent_control/venv/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/mila/l/leo.gagnon/latent_control/venv/lib/pyth ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                    | Type                      | Params | Mode 
------------------------------------------------------------------------------
0 | base_task               | MetaLearningTask          | 11.9 M | train
1 | latent_model            | Encoder                   | 14.2 M | train
2 | null_embedding          | Embedding                 | 512    | train
3 | seq_conditional_encoder | Encoder                   | 7.1 M  | train
4 | seq_conditional_emb     | Embedding                 | 19.2 K | train
5 | seq_conditional_posemb  | ScaledSinusoidalEmbedding | 1      | train
6 | co

Training: |          | 0/? [00:00<?, ?it/s]

  seed = self.generator.integers(0, 1e10)
  seed = self.generator.integers(0, 1e10)

Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [7]:
deterministic_encoder.cuda();

In [8]:
idx = torch.randperm(len(deterministic_encoder.train_data))[:512]

In [9]:
raw_latent, latent, cond_input_ids, _ = deterministic_encoder.train_data.__getitems__(idx).values()

In [12]:
#cond = deterministic_encoder.seq_conditional_emb(cond_input_ids)
#cond = cond + deterministic_encoder.seq_conditional_posemb(cond)

cond = deterministic_encoder.base_task.model.decoder(cond_input_ids, return_embeddings=True)
cond = deterministic_encoder.cond_proj(cond)

init_emb = repeat(deterministic_encoder.null_embedding.weight, "1 d -> b 1 d", b=raw_latent.shape[0])

pred = deterministic_encoder.latent_model(
        init_emb,
        context=cond,
    )
pred = deterministic_encoder.norm(pred)
pred = [proj(pred) for proj in deterministic_encoder.out_proj]

loss = sum([nn.functional.cross_entropy(pred[i].squeeze(), raw_latent[:,i]).mean() for i in range(len(pred))])
acc = [(pred[i].squeeze().argmax(1) == raw_latent[:,i]).float().mean(0) for i in range(len(pred))]

In [14]:
dataset = deterministic_encoder.base_task.full_data

In [15]:
oracle_out = dataset.bayesian_oracle(jnp.arange(len(dataset)), t2j(cond_input_ids[0]))

In [16]:
jnp.exp(oracle_out['log_alpha_post'][-1]).argmax(), idx[0]

(Array(5126, dtype=int32), tensor(5126))