In [None]:
from jax.random import PRNGKey, split
import jax
import flax
import optax
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM, AutoConfig
import re
import numpy as np
from tqdm import tqdm
import functools
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.experimental import mesh_utils
from flax.training import train_state
from jax import numpy as jnp
from jax.sharding import PartitionSpec, Mesh

PS = PartitionSpec
from jax.interpreters import pxla
from torch.utils.data import DataLoader
from datasets import load_dataset

In [None]:
FSDP = True
PMAP = False
NORM = False
max_length = 2048
num_epochs = 2
batch_size = 4
max_steps = None
dataset_name = 'erfanzar/Data-LGeM-2048'

In [None]:
dataset = load_dataset(dataset_name, use_auth_token=True)


def collate_fn(batch):
    rs = {}
    for key in batch[0].keys():
        ssp = [jnp.array(f[key])[..., -max_length:] for f in batch]
        rs[key] = jnp.stack(ssp).reshape(-1, ssp[0].shape[-1])
    return rs


dataloader = DataLoader(dataset['train'], collate_fn=collate_fn, batch_size=batch_size, drop_last=True)
max_steps = num_epochs * len(dataloader) if max_steps is None else max_steps
tokenizer = AutoTokenizer.from_pretrained('erfanzar/FlaxLGeM', trust_remote_code=True)

In [None]:
def match_partition_rules(rules, params):
    def get_partition_spec(name, leaf):
        if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1:
            return PS()
        for rule, ps in rules:
            if re.search(rule, name) is not None:
                return ps
        raise ValueError(f'Partition rule not found for param: {name}')

    def tree_path_to_string(path):
        keys = []
        for i, key in enumerate(path):
            if isinstance(key, jax.tree_util.SequenceKey):
                keys.append(str(key.idx))
            elif isinstance(key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey)):
                keys.append(str(key.key))
            elif isinstance(key, jax.tree_util.GetAttrKey):
                keys.append(str(key.name))
            else:
                keys.append(str(key))
        return '/'.join(keys)

    return jax.tree_util.tree_map_with_path(
        lambda path, p: get_partition_spec(tree_path_to_string(path), p),
        params
    )


def count_params(_p):
    print('\033[1;31mModel Contain : ',
          sum(i.size for i in jax.tree_util.tree_flatten(flax.core.unfreeze(_p))[0]) / 1e9, ' Billion Parameters')


def names_in_mesh(*names):
    return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names)


def get_names(partition_specs):
    names = set()
    for item in partition_specs:
        if item is None:
            continue
        elif isinstance(item, str):
            names.add(item)
    return list(names)


def with_sharding_constraint__a(x, partition_spec):
    names = get_names(partition_spec)
    if names_in_mesh(*names):
        x = with_sharding_constraint(x, partition_spec)
    return x

In [None]:
# Incase for training LGem model from config
config = AutoConfig.from_pretrained(
    'erfanzar/FlaxLGeM',
    hidden_size=4096,
    num_attention_heads=32,
    num_hidden_layers=16,
    intermediate_size=8192,
    trust_remote_code=True,
    vocab_size=32005,
    fsdp=True
)
model = FlaxAutoModelForCausalLM.from_config(config=config, _do_init=False, trust_remote_code=True)

scheduler = optax.cosine_decay_schedule(
    init_value=1.85e-4,
    decay_steps=800,
)

tx = optax.chain(
    optax.scale_by_adam(),
    optax.add_decayed_weights(1e-1),
    optax.scale_by_schedule(scheduler),
    optax.scale(-1.0)
)

In [None]:
def init_fn():
    input_ids = jnp.ones((1, 2048), dtype=jnp.int32)
    attention_mask = jnp.ones((1, 2048), dtype=jnp.int32)
    key = jax.random.PRNGKey(0)
    varient = model.module.init(key, input_ids, attention_mask, return_dict=False)
    varient = model.to_bf16(varient)
    return train_state.TrainState.create(
        tx=tx,
        params=varient,
        apply_fn=model.__call__
    )

In [None]:
if PMAP:
    input_ids = jnp.ones((batch_size, max_length), dtype=jnp.int32)
    attention_mask = jnp.ones((batch_size, max_length), dtype=jnp.int32)
    key = jax.random.PRNGKey(0)
    varient = model.module.init(key, input_ids, attention_mask, return_dict=False)
    varient = model.to_bf16(varient)
    state = train_state.TrainState.create(
        tx=tx,
        params=varient,
        apply_fn=model.__call__
    )
    state = flax.jax_utils.replicate(state)

In [None]:
def train_step(state, input_ids, attention_mask):
    def calculate_loss(params):
        logits = state.apply_fn(params=params, attention_mask=attention_mask, input_ids=input_ids,
                                return_dict=True).logits
        loss_ = optax.softmax_cross_entropy_with_integer_labels(logits=logits[..., 1:, :], labels=input_ids[..., :-1])
        return jnp.mean(loss_)

    grad_fn = jax.value_and_grad(calculate_loss, has_aux=False)
    loss__, grad = grad_fn(state.params)
    state = state.apply_gradients(grads=grad)
    return state, loss__


def fsdp_train_step(state, input_ids, attention_mask):
    input_ids = with_sharding_constraint(input_ids, PS(('dp', 'fsdp')))
    attention_mask = with_sharding_constraint(attention_mask, PS(('dp', 'fsdp')))

    def calculate_loss(params):
        logits = state.apply_fn(params=params, input_ids=input_ids, attention_mask=attention_mask,
                                return_dict=True).logits
        loss_ = optax.softmax_cross_entropy_with_integer_labels(logits=logits[..., 1:, :], labels=input_ids[..., :-1])
        return jnp.mean(loss_)

    grad_fn = jax.value_and_grad(calculate_loss, has_aux=False)
    loss__, grad = grad_fn(state.params)
    state = state.apply_gradients(grads=grad)
    return state, loss__


@functools.partial(jax.pmap, axis_name='batch', donate_argnums=(0,))
def pmap_train_step(state, input_ids, attention_mask):
    def calculate_loss(params):
        logits = state.apply_fn(params=params, attention_mask=attention_mask, input_ids=input_ids,
                                return_dict=True).logits
        loss_ = optax.softmax_cross_entropy_with_integer_labels(logits=logits[..., 1:, :], labels=input_ids[..., :-1])
        return jnp.mean(loss_)

    grad_fn = jax.value_and_grad(jax.jit(calculate_loss), has_aux=False)
    loss__, grad = grad_fn(state.params)
    loss__ = jax.lax.pmean(loss__, 'batch')
    grad = jax.lax.pmean(grad, 'batch')
    state = state.apply_gradients(grads=grad)
    return state, loss__


def step_prediction(state, input_ids, attention_mask):
    logits = state.apply_fn(params=state.params, attention_mask=attention_mask, input_ids=input_ids, return_dict=True)
    return logits

In [None]:
if FSDP:
    eval_tree_init_fn = jax.eval_shape(init_fn)
    partition_tree = match_partition_rules(config.get_partition_rules(), eval_tree_init_fn)
    sharded_init_fn = pjit(
        init_fn,
        out_shardings=partition_tree
    )
    sharded_train_step_fn = pjit(
        fsdp_train_step,
        in_shardings=(partition_tree, PS(), PS()),
        out_shardings=(partition_tree, PS())
    )
    phsycal_mesh = mesh_utils.create_device_mesh((1, 8, 1))
    mesh = Mesh(phsycal_mesh, ('dp', 'fsdp', 'mp'))

    with mesh:
        sharded_state = sharded_init_fn()

In [None]:
if PMAP:
    print('PMAP TRAINER .....')
    count_params(state.params)
if NORM:
    print('NORM TRAINER .....')
    count_params(state.params)
if FSDP:
    print('FSDP TRAINER ..... ')
    count_params(sharded_state.params)

In [None]:
## FSDP TRAIN
if FSDP:
    with mesh:
        pbar = tqdm(total=max_steps)
        save_steps = 200
        i = 0
        losses = []
        logging_step = 1
        for _ in range(num_epochs):
            for batch in dataloader:
                i += 1
                if i > max_steps:
                    break
                input_ids = batch['input_ids']
                attention_mask = batch['attention_mask']
                sharded_state, loss = sharded_train_step_fn(sharded_state, input_ids, attention_mask)
                pbar.update(1)
                pbar.set_postfix(loss=loss)


In [None]:
## PMAP TRAIN FUNCTION
if PMAP:
    pbar = tqdm(total=max_steps)
    save_steps = 200
    i = 0
    losses = []
    logging_step = 1
    for _ in range(num_epochs):
        for batch in dataloader:
            i += 1
            if i > max_steps:
                break
            input_ids = batch['input_ids'].reshape(len(jax.devices()), -1, max_length)
            attention_mask = batch['attention_mask'].reshape(len(jax.devices()), -1, max_length)
            state, loss = pmap_train_step(state, input_ids, attention_mask)
            pbar.update(1)
            pbar.set_postfix(loss=loss[0])


In [None]:
## NORMAL AND SINGLE THREAD TRAIN
if NORM:
    pbar = tqdm(total=max_steps)
    save_steps = 200
    i = 0
    losses = []
    logging_step = 1
    for _ in range(num_epochs):
        for batch in dataloader:
            i += 1
            if i > max_steps:
                break
            input_ids = batch['input_ids'].reshape(-1, max_length)
            attention_mask = batch['attention_mask'].reshape(-1, max_length)
            state, loss = train_step(state, input_ids, attention_mask)
            pbar.update(1)
            pbar.set_postfix(loss=loss)
