In [None]:
from jax.random import PRNGKey, split
import jax
import flax
import optax
from transformers import AutoTokenizer, FlaxAutoModelForCausalLM, AutoConfig
import re
import numpy as np
from IPython.display import clear_output
from tqdm import tqdm
import functools
from jax.experimental.pjit import pjit, with_sharding_constraint, PartitionSpec as PS
from jax.experimental import mesh_utils
from flax.training import train_state
from jax import numpy as jnp
from jax.sharding import Mesh
from jax.interpreters import pxla
from torch.utils.data import DataLoader
from datasets import load_dataset
from flax.serialization import from_bytes, to_bytes, to_state_dict, from_state_dict
from flax.traverse_util import flatten_dict, unflatten_dict, empty_node
import msgpack
import torch
from fjutils import match_partition_rules, make_shard_and_gather_fns, float_tensor_to_dtype, StreamingCheckpointer, \
    count_params
from huggingface_hub import HfApi
from fjutils.utils import get_dataloader

api = HfApi()
ckpt_stream = StreamingCheckpointer(StreamingCheckpointer.get_default_config(), 'ckpt_dir/')

In [None]:
max_length = 1900
num_epochs = 1
batch_size = 1
max_steps = None
sch_linear = True
learning_rate = 8e-6
learning_rate_end = 4e-6
use_adamw_instead_of_lion = True
weight_decay = 0.01
model_id = "<MODEL_YOU_WANT_TO_TRAIN_ID>"  # check available models to use like (FlaxFalcon,FlaxMpt,FlaxLLama,FlaxOpenLLama)
ckpt_name = '<YOUR_CKPT_PATH_OR_NAME_(EASYDEL OR OST FORMAT!)>'
dataset_name = '<YOUR_DATASET>'
repo_id = '<REPO ID TO PUSH MODEL>'

In [None]:
sharding_shape = (1, 8, 1)  # DP , FSDP , MP

In [None]:
dataloader, max_steps = get_dataloader(
    dataset_or_huggingface_dataset_hub_id=dataset_name,
    max_steps=max_steps,
    max_length=max_length,
    batch_size=batch_size,
    num_epochs=num_epochs,
    num_workers=2,
    shuffle=True
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

In [None]:
config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
assert hasattr(config, 'get_partition_rules')
model = FlaxAutoModelForCausalLM.from_config(config, trust_remote_code=True, dtype=jnp.bfloat16,
                                             param_dtype=jnp.bfloat16,
                                             _do_init=False)  # Wr are using bfloat16 since TPUS support bfloat16

In [None]:
if use_adamw_instead_of_lion and sch_linear:
    from fjutils.optimizers import get_adamw_with_linear_scheduler

    tx, scheduler = get_adamw_with_linear_scheduler(
        steps=max_steps,
        learning_rate_end=learning_rate_end
    )
elif use_adamw_instead_of_lion and not sch_linear:
    from fjutils.optimizers import get_adamw_with_cosine_scheduler

    tx, scheduler = get_adamw_with_cosine_scheduler(
        steps=max_steps,
        learning_rate=learning_rate_end,
        weight_decay=weight_decay
    )
elif not use_adamw_instead_of_lion and sch_linear:
    from fjutils.optimizers import get_lion_with_linear_scheduler

    tx, scheduler = get_lion_with_linear_scheduler(
        steps=max_steps,
        learning_rate_end=learning_rate_end,
        learning_rate_start=learning_rate
    )
elif not use_adamw_instead_of_lion and not sch_linear:
    from fjutils.optimizers import get_lion_with_cosine_scheduler

    tx, scheduler = get_lion_with_cosine_scheduler(
        steps=max_steps,
        learning_rate=learning_rate_end,
    )
else:
    raise ValueError

In [None]:
def init_fn():
    from flax.training import train_state
    params = model.init_weights(jax.random.PRNGKey(0), (1, max_length))
    params = model.to_bf16(params)
    return train_state.TrainState.create(
        tx=tx,
        params=flax.core.freeze({'params': params}),
        apply_fn=model.__call__
    )


def init_fn_wop():
    from flax.training import train_state
    params = model.to_fp32(params)  # this is not an error do not change this !
    return train_state.TrainState.create(
        tx=tx,
        params=params,
        apply_fn=model.__call__
    )


def create_train_state_from_params(params_):
    from flax.training import train_state
    return train_state.TrainState.create(
        tx=tx,
        apply_fn=model.__call__,
        params=params_
    )


def dummy_init():
    from flax.training import train_state
    return train_state.TrainState.create(
        tx=tx,
        apply_fn=model.__call__,
        params=None
    )


def fsdp_train_step(state, batch):
    batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))

    def calculate_loss(params):
        logits = state.apply_fn(params=params, **batch,
                                return_dict=True).logits
        loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits[..., :-1, :],
                                                               labels=batch['input_ids'][..., 1:])
        loss = jnp.mean(loss)
        return loss

    grad_fn = jax.value_and_grad(calculate_loss, has_aux=False)
    loss__, grad = grad_fn(state.params)
    state = state.apply_gradients(grads=grad)
    return state, loss__


@functools.partial(jax.pmap, axis_name='batch', donate_argnums=(0,))
def pmap_train_step(state, input_ids, attention_mask):
    def calculate_loss(params):
        logits = state.apply_fn(params=params, attention_mask=attention_mask, input_ids=input_ids,
                                return_dict=True).logits
        loss_ = optax.softmax_cross_entropy_with_integer_labels(logits=logits[..., 1:, :], labels=input_ids[..., :-1])
        return jnp.mean(loss_)

    grad_fn = jax.value_and_grad(jax.jit(calculate_loss), has_aux=False)
    loss__, grad = grad_fn(state.params)
    loss__ = jax.lax.pmean(loss__, 'batch')
    grad = jax.lax.pmean(grad, 'batch')
    state = state.apply_gradients(grads=grad)
    return state, loss__

In [None]:
train_state_shape = jax.eval_shape(init_fn)
train_state_partition_spec = match_partition_rules(config.get_partition_rules(True), train_state_shape)
sharded_init_fn = pjit(init_fn, out_shardings=train_state_partition_spec)
sharded_init_f_wop = pjit(init_fn_wop, out_shardings=train_state_partition_spec)
sharded_create_from_params_fn = pjit(
    create_train_state_from_params,
    in_shardings=(train_state_partition_spec.params,),
    out_shardings=train_state_partition_spec,
    donate_argnums=(0,)
)
sharded_train_step_fn = pjit(
    fsdp_train_step, in_shardings=(train_state_partition_spec, PS()),
    out_shardings=(train_state_partition_spec, PS()), donate_argnums=(0, 0, 0), )
phsycal_mesh = mesh_utils.create_device_mesh((sharding_shape))
mesh = Mesh(phsycal_mesh, ('dp', 'fsdp', 'mp'))
with mesh:
    shard_fns, ghater_fns = make_shard_and_gather_fns(train_state_partition_spec, jnp.bfloat16)
    _, params = ckpt_stream.load_trainstate_checkpoint(
        f'params::{ckpt_name}', train_state_shape, shard_fns
    )
    sharded_train_state_ = sharded_create_from_params_fn(params)


In [None]:
count_params(sharded_train_state_.params)

In [None]:
with mesh:
    pbar = tqdm(total=max_steps)
    i = 0
    losses = []
    logging_step = 1
    learning_rates = []
    for _ in range(num_epochs):
        for batch in dataloader:
            i += 1
            if i > max_steps:
                break
            sharded_train_state_, loss = sharded_train_step_fn(sharded_train_state_, batch)
            losses.append(loss)
            learning_rates.append(scheduler(i).tolist())
            pbar.update(1)
            pbar.set_postfix(loss=loss, learning_rate=scheduler(i).tolist())

# Optional Prediction
    heres a simple function to test your model

In [None]:
def predict(state, input_ids):
    input_ids = with_sharding_constraint(input_ids, PS(('dp', 'fsdp')))
    pred = state.apply_fn(params=state.params, input_ids=input_ids, return_dict=True)
    token = jnp.argmax(jax.nn.softmax(pred.logits)[:, -1, :])
    input_ids = jnp.concatenate([input_ids, token.reshape(1, -1)], axis=-1)
    return input_ids


sharded_predict = pjit(predict, out_shardings=PS(), in_shardings=(train_state_partition_spec, PS()))
text = None  # write down your text :)
with mesh:
    input_ids = jnp.asarray(tokenizer.encode(text,
                                             add_special_tokens=False), dtype='i4').reshape(1, -1)
    for i in range(50):
        input_ids = sharded_predict(sharded_train_state_, input_ids)
        clear_output(wait=True)
        print(tokenizer.decode(input_ids[0]))

In [None]:
filename = f'model_{model_id.split("/")[1]}_ostformat'

# Saving Model

In [None]:
with mesh:
    !mkdir ckpt_dir
    ckpt_stream.save_checkpoint(sharded_train_state_.params['params'], filename=filename,
                                gather_fns=ghater_fns.params['params'])

In [None]:
api.upload_file(
    path_or_fileobj=f'ckpt_dir/{filename}',
    repo_id=repo_id,
    path_in_repo='filename'

)