In [1]:
import sys
sys.path.append('/home/mila/l/leo.gagnon/latent_control')

%load_ext autoreload
%autoreload 2
from lightning_modules.diffusion_prior import DiffusionPriorTask
import torch
import matplotlib.pyplot as plt
from data.diffusion import KnownLatentDiffusionDataset, KnownLatentDiffusionDatasetConfig
from models.encoder import DiffusionEncoder
from torch2jax import j2t, t2j
import lightning as L
from models.x_transformer import Encoder, ScaledSinusoidalEmbedding
from lightning_modules.metalearn import MetaLearningTask
import torch.nn as nn
from torch.utils.data import random_split, DataLoader
from einops import rearrange, repeat
import jax.numpy as jnp

  if not hasattr(numpy, tp_name):
  if not hasattr(numpy, tp_name):
  "lr_options": generate_power_seq(LEARNING_RATE_CIFAR, 11),
  contrastive_task: Union[FeatureMapContrastiveTask] = FeatureMapContrastiveTask("01, 02, 11"),
  self.nce_loss = AmdimNCELoss(tclip)


In [2]:
class DeterministicEncoder(L.LightningModule):
    def __init__(
        self,
        pretrained_id: str,
        n_embd: int,
        batch_size: int,
        val_split: float,
        lr: float,
    ):
        super().__init__()

        self.batch_size = batch_size
        self.val_split = val_split
        self.lr = lr
        self.n_embd = n_embd

        self.base_task = MetaLearningTask(pretrained_id)
        for param in self.base_task.parameters():
            param.requires_grad = False

        self.seq_conditional_encoder = Encoder(
            dim=n_embd,
            depth=3,
            heads=6,
            attn_dropout=0.0,  # dropout post-attention
            ff_dropout=0.0,  # feedforward dropout
            rel_pos_bias=False,
            ff_glu=True,
        )

        self.latent_model = Encoder(
            dim=n_embd,
            depth=3,
            heads=6,
            attn_dropout=0.0,  # dropout post-attention
            ff_dropout=0.0,  # feedforward dropout
            rel_pos_bias=False,
            ff_glu=True,
            cross_attend=True,
        )
        self.null_embedding = nn.Embedding(1, n_embd)
        # self.out_proj = nn.Linear(n_embd, self.base_task.model.encoder.cfg.n_embd)
        self.out_proj = nn.ModuleList(
            [
                nn.Linear(n_embd, embedding.num_embeddings)
                for embedding in self.base_task.model.encoder.latent_embedding
            ]
        )
        self.seq_conditional_emb = nn.Embedding(
            num_embeddings=50,
            embedding_dim=n_embd,
        )
        self.seq_conditional_posemb = ScaledSinusoidalEmbedding(n_embd)
        self.norm = nn.LayerNorm(n_embd)

    def setup(self, stage):
        with torch.no_grad():
            self.train_data = KnownLatentDiffusionDataset(
                KnownLatentDiffusionDatasetConfig(context_length=[200, 200]),
                self.base_task,
                None,
            )

            # self.train_data, self.val_data = random_split(
            #    dataset, [1 - self.val_split, self.val_split]
            # )

    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.lr)
        return opt

    def train_dataloader(self):
        return DataLoader(
            self.train_data,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=lambda x: x,
        )

    def training_step(self, batch, batch_idx=None):
        latent, raw_latent = batch["latent"], batch["raw_latent"]
        cond_input_ids = batch["cond_input_ids"]

        cond = self.seq_conditional_emb(cond_input_ids)
        cond = cond + self.seq_conditional_posemb(cond)
        cond = self.seq_conditional_encoder(cond)

        init_emb = repeat(self.null_embedding.weight, "1 d -> b 1 d", b=latent.shape[0])

        pred = self.latent_model(
            init_emb,
            context=cond,
        )
        pred = self.norm(pred)
        pred = [proj(pred) for proj in self.out_proj]

        loss = sum(
            [
                nn.functional.cross_entropy(pred[i].squeeze(), raw_latent[:, i]).mean()
                for i in range(len(pred))
            ]
        )
        acc = sum(
            [
                (pred[i].squeeze().argmax(1) == raw_latent[:, i]).float().mean()
                for i in range(len(pred))
            ]
        ) / len(pred)
        # loss = torch.mean((pred - latent)**2)

        self.log(
            "train/loss",
            loss.detach().cpu().numpy().item(),
            prog_bar=True,
        )
        self.log(
            "train/acc",
            acc.detach().cpu().numpy().item(),
            prog_bar=True,
        )

        return loss

In [3]:
trainer = L.Trainer(
        max_steps=10000,
        accelerator='gpu',
        enable_checkpointing=False,
        val_check_interval=100,
        reload_dataloaders_every_n_epochs=1,
        check_val_every_n_epoch=None,
    )

/home/mila/l/leo.gagnon/latent_control/venv/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/mila/l/leo.gagnon/latent_control/venv/lib/pyth ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [4]:
deterministic_encoder = DeterministicEncoder("ekly943l", 512, 512, 0.1, 1e-4)

Loaded dataset : (11288/1000)
Loaded checkpoing : last.ckpt


In [None]:
trainer.fit(deterministic_encoder)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                    | Type                      | Params | Mode 
------------------------------------------------------------------------------
0 | base_task               | MetaLearningTask          | 6.3 M  | train
1 | seq_conditional_encoder | Encoder                   | 11.8 M | train
2 | latent_model            | Encoder                   | 14.2 M | train
3 | null_embedding          | Embedding                 | 512    | train
4 | out_proj                | ModuleList                | 13.9 K | train
5 | seq_conditional_emb     | Embedding                 | 25.6 K | train
6 | seq_conditional_posemb  | ScaledSinusoidalEmbedding | 1      | train
7 | norm                    | LayerNorm                 | 1.0 K  | train
------------------------------------------------------------------------------
26.0 M    Trainable params
6.3 M     Non-trainable params
32.4 M    Total params
129.500   Total estimated model params size (MB)
345       Mo

Training: |          | 0/? [00:00<?, ?it/s]

  intv_envs = jnp.array(intv_envs)
  return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
  intv_envs = jnp.array(intv_envs)
/home/mila/l/leo.gagnon/latent_control/venv/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (24) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


In [28]:
deterministic_encoder.cuda();

In [54]:
idx = torch.randperm(len(deterministic_encoder.train_data))[:512]

In [58]:
raw_latent, latent, cond_input_ids, _ = deterministic_encoder.train_data.__getitems__(idx).values()

  intv_envs = jnp.array(intv_envs)


In [59]:
cond = deterministic_encoder.seq_conditional_emb(cond_input_ids)
cond = cond + deterministic_encoder.seq_conditional_posemb(cond)

init_emb = repeat(deterministic_encoder.null_embedding.weight, "1 d -> b 1 d", b=latent.shape[0])

pred = deterministic_encoder.encoder(
        init_emb,
        context=cond,
    )
pred = deterministic_encoder.norm(pred)
pred = [proj(pred) for proj in deterministic_encoder.out_proj]

loss = sum([nn.functional.cross_entropy(pred[i].squeeze(), raw_latent[:,i]).mean() for i in range(len(pred))])
acc = [(pred[i].squeeze().argmax(1) == raw_latent[:,i]).float().mean(0) for i in range(len(pred))]

In [60]:
dataset = deterministic_encoder.base_task.full_data

In [65]:
oracle_out = dataset.bayesian_oracle(jnp.arange(len(dataset)), t2j(cond_input_ids[0]))

In [68]:
jnp.exp(oracle_out['log_alpha_post'][-1]).argmax(), idx[0]

(Array(9667, dtype=int32), tensor(9667))