- [x] Instantiate a Transformer
- [x] Evaluate ELBO
- [ ] Hook up the trainer
- [ ] Profit

In [1]:
import jax
import jax.random as jr
import jax.numpy as jnp

Key = jr.PRNGKey

In [2]:
from maskgit.nets.maskgit_transformer import Transformer
from maskgit.diffusion.losses import discrete_diffusion_loss_single
from maskgit.diffusion.models import AbsorbingRate

In [3]:
import ml_collections

In [14]:
config = ml_collections.ConfigDict()

config.codebook_size = 1024 + 1 # Caveat: conditional generation stuff
config.transformer_block_size = 256 + 1 # seq length + 1?

# Transformer configs
config.transformer = ml_collections.ConfigDict()
config.transformer.num_layers = 24
config.transformer.patch_size = 16
config.transformer.num_embeds = 768
config.transformer.intermediate_size = 3072
config.transformer.num_heads = 16
config.transformer.dropout_rate = 0.1
config.transformer.mask_token_id = -1
config.transformer.latent_size = 16

tcfg = config.transformer

In [15]:
model = Transformer(
            vocab_size=config.codebook_size,
            hidden_size=tcfg.num_embeds,
            num_hidden_layers=tcfg.num_layers,
            num_attention_heads=tcfg.num_heads,
            intermediate_size=tcfg.intermediate_size,
            hidden_dropout_prob=tcfg.dropout_rate,
            attention_probs_dropout_prob=tcfg.dropout_rate,
            max_position_embeddings=config.transformer_block_size)

In [7]:
key_0 = jr.PRNGKey(0)

dummy_input = jnp.zeros((256,), dtype=int)

init_params = model.init(key_0, dummy_input)

In [30]:
S = 1024

config = {
    "experiment_name": "",
    "wandb_run_id": "",
    # "data_shape": (D,),
    "state_size": S+1,
    # "hidden_dim": 32,
    # "time_embedding_dim": 32,

    "scalar_rate": 5.,
    "rate_eps": 1e-3,

    "nll_weight": .01,
    "eps": 1e-6,
    "min_t": .001,
    "max_t": 1., # For debugging
    "max_epochs": 2000,
    "batch_size": 64,
    "learning_rate": 1e-3,
    "seed": Key(42),
    "shuffle_dataset": True,
    "use_wandb": False
}
# This takes a very long time for some reason.
config["forward_process"] = AbsorbingRate(config)

In [9]:
# out = model.apply(init_params, jnp.ones((2, 256), dtype=int))
# out.shape

In [10]:
discrete_diffusion_loss_single(key_0, dummy_input, model, init_params, config)

{'loss': Array(0.00491229, dtype=float32),
 'elbo': Array(82.57003, dtype=float32),
 'nll': Array(0.00616316, dtype=float32)}

# Get the trainer

In [24]:
from maskgit.diffusion.losses import diffusion_batch_loss
from maskgit.diffusion.training import Trainer, generic_params_update, linear_warmup_schedule
from tqdm import tqdm
import optax as opt

In [31]:
def model_init(
        key, model,
        init_data, init_params = None,
        config = None):
    """
    Model initialization that adds a copy of the parameters to config.
    """
    if init_params is None:
        params_key, dropout_key = jr.split(key)
        init_params = model.init({"params": params_key, 
                                  "dropout": dropout_key}, init_data[0], 0)
    
    lr_schedule = linear_warmup_schedule(5000, 0, config["learning_rate"])
    
    optimizer = opt.chain(
        opt.clip_by_global_norm(1.0),  # Clip gradients to have a global norm of at most `clip_norm`
        opt.adam(lr_schedule),
    )
    # model_opt = ema(optimizer, config["ema_decay"])
    # Don't use EMA for now
    model_opt = optimizer
    opt_state = model_opt.init(init_params)
    return init_params, model_opt, opt_state

In [32]:
# Load the dataset
data_dir = "/mnt/disks/persist/vq_tokens_256x256.npy"
dataset = jnp.load(data_dir, allow_pickle=True)

In [33]:
dataset = dataset.reshape((dataset.shape[0], -1))
print(dataset.shape)

(1281024, 256)


In [None]:
trainer = Trainer(model, config=config, init=model_init,
                    loss=diffusion_batch_loss, update=generic_params_update)
tqdm._instances.clear()
trainer.train({"train_data": dataset},
                max_epochs=config["max_epochs"],
                # early_stop_start=2000,
                # max_lose_streak=100,
                key=config["seed"])

[jit compling...]:   0%|                                                            | 0/5006000 [05:43<?, ?it/s]
[jit compling...]:   0%|                                                           | 0/20018000 [03:57<?, ?it/s]
