In [31]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
from datasets import Dataset
import jax.numpy as jnp

def batch_collate_fn(data_list):
    batch_dict = {key:[] for key in data_list[0].keys()}
    for data in data_list:
        for key, value in data.items():
            batch_dict[key].append(value)
    return {key: jnp.array(value) for key, value in batch_dict.items()}

batch_size = 8
dataset = Dataset.from_parquet("data/wikitext*")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, drop_last=True, collate_fn=batch_collate_fn)



In [2]:
tiny_gpt_config = {
  "_num_labels": 1,
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_layer": 6,
  "n_positions": 1024,
  "resid_pdrop": 0.1,
  "vocab_size": 50257
}

In [43]:
from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2LMHeadModel, GPT2Config

import jax
import jax.numpy as jnp

model_config = GPT2Config.from_pretrained("gpt2")
model = FlaxGPT2LMHeadModel(model_config, input_shape=(8, 256), seed=0, dtype=jnp.dtype("bfloat16"))

In [44]:
learning_rate = 3e-4
epochs = 1
training_seed = 0
num_train_steps = len(dataloader) * epochs

rng = jax.random.PRNGKey(training_seed)
rng, dropout_rng = jax.random.split(rng)

In [45]:
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
import optax

linear_decay_lr_schedule_fn = optax.linear_schedule(init_value=learning_rate, end_value=0, transition_steps=num_train_steps)
adamw = optax.adamw(learning_rate=linear_decay_lr_schedule_fn, b1=0.9, b2=0.98, eps=1e-8, weight_decay=0.01)
state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)

In [46]:
def train_step(state, batch, dropout_rng):
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng, num=2)

    def loss_fn(params):
        labels = batch.pop("labels")
        pred_logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
        return optax.softmax_cross_entropy(pred_logits, onehot(labels, pred_logits.shape[-1])).mean()
    
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grad = grad_fn(state.params)
    grad = jax.lax.pmean(grad, "batch")

    new_state = state.apply_gradients(grads=grad)
    metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
    metrics = jax.lax.pmean(metrics, axis_name="batch")

    return new_state, metrics, new_dropout_rng

In [47]:
from tqdm import tqdm
import flax

p_train_step = jax.pmap(train_step, "batch")
state = flax.jax_utils.replicate(state)

for epoch in range(epochs):
    dropout_rngs = jax.random.split(rng, num=jax.local_device_count())

    for i, batch in enumerate(dataloader):
        batch = shard(batch)
        state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs)
        
        if i > 0 and i % 10 == 0:
            print(train_metric)

{'learning_rate': ShardedDeviceArray([0.00027987], dtype=float32), 'loss': ShardedDeviceArray([6.676178], dtype=float32)}
{'learning_rate': ShardedDeviceArray([0.00025973], dtype=float32), 'loss': ShardedDeviceArray([5.823189], dtype=float32)}
{'learning_rate': ShardedDeviceArray([0.0002396], dtype=float32), 'loss': ShardedDeviceArray([5.9843597], dtype=float32)}
{'learning_rate': ShardedDeviceArray([0.00021946], dtype=float32), 'loss': ShardedDeviceArray([5.8715134], dtype=float32)}
{'learning_rate': ShardedDeviceArray([0.00019933], dtype=float32), 'loss': ShardedDeviceArray([4.9384155], dtype=float32)}
{'learning_rate': ShardedDeviceArray([0.00017919], dtype=float32), 'loss': ShardedDeviceArray([5.41835], dtype=float32)}
{'learning_rate': ShardedDeviceArray([0.00015906], dtype=float32), 'loss': ShardedDeviceArray([5.679558], dtype=float32)}
{'learning_rate': ShardedDeviceArray([0.00013893], dtype=float32), 'loss': ShardedDeviceArray([6.3515854], dtype=float32)}
{'learning_rate': Shar