# &#x1F916; LLMs on TPUs: An interactive demo using GPT-J

This demo will walk through how to use JAX and TPUs to run and train large language models (LLMs).  In particular, we consider an open-source model called GPT-J, which was trained using Jax on Google Cloud TPUs - in particular using the [mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax) repo - by EleutherAI.

### &#x1F440; Contents
* &#x1F6A7; [Setup](#setup)
    * &#x2622; [Installation and running options](#install-and-opts)
    * &#129470; [Device configuration and imports](#device-and-import)
* &#x1F4BE; [Model parameters and config](#params-and-config)
* &#x2B50; [Inference](#inference)
* &#x1F680; [Fine-tuning](#fine-tuning)
    * &#x1F42A; [Dataset](#dataset)
    * &#x26A1; [Parameter efficient fine-tuning](#param-efficient-fine-tuning)
    * &#x1F6A7; [Set up fine-tuning](#set-up-fine-tuning)
    * &#x1F4AC; [Performance test: pre-trained](#perf-pretrained)
    * &#x1F3CB; [Training loop](#training-loop)
    * &#x1F4AC;[ Performance check: fine-tuned](#perf-fine-tuned)


### &#x1F6A7; Setup <a name="setup"></a>

First we will just import a few standard packages, check for available devices and define a few util functions. &#x1F600;

##### &#x2622; Installation and running options <a name="install-and-opts"></a>

If running in Colab, uncomment the below to install the necessary dependencies.  Note that this assumes the ```gptj-demo``` folder is already in the current working directory.  After install gptj-demo, you will need to restart the runtime and comment out the cell below!

In [1]:
!pip install -q transformers==4.25.1 tqdm datasets==2.10.1 dm-haiku==0.0.9
!unzip gptj-demo.zip
!pip install -e gptj-demo

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.8/5.8 MB[0m [31m43.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m469.0/469.0 kB[0m [31m33.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m352.1/352.1 kB[0m [31m32.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m92.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m21.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.3/134.3 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━

**Make sure to restart your run time after installing gptj-demo src!**

go to Runtime -> Runtime Restart

In [1]:
# In low memory demo mode, a tiny model with random parameters will be used.
LOW_MEMORY_DEMO_MODE = True
if LOW_MEMORY_DEMO_MODE:
    print(f"WARNING: Operating in low memory demo mode will mean a tiny, random, model is used!")

# Are you running in Colab? If so, we set up CPU/TPU's differently.
COLAB_MODE = True

# If TPUs are not available, multiple CPU devices will be faked.
TPUS_AVAILABLE = False

# Optionally, we can save the converted JAX parameters to disk for re-loading later.  This will make
# the first time fetching the model weights slower, but every subsequent time faster.
STORE_PARAMS_ON_DISK = False
PRETRAINED_PARAMS_PATH = "./pretrained_params.gz"



##### &#129470; Device configuration and imports <a name="device-and-import"></a>

In [2]:
import subprocess
import os

if COLAB_MODE:
    # Based on https://stackoverflow.com/questions/67504079/how-to-check-if-an-nvidia-gpu-is-available-on-my-system
    try:
        subprocess.check_output('nvidia-smi')
        print("a GPU is connected.")
    except Exception:
        # TPU or CPU
        if "COLAB_TPU_ADDR" in os.environ and os.environ["COLAB_TPU_ADDR"]:
            print("A Colab TPU is connected.")
            import jax.tools.colab_tpu
            jax.tools.colab_tpu.setup_tpu()
        else:
            print("Only CPU accelerator is connected.")
            # x8 cpu devices - number of (emulated) host devices
            os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
            os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax

jax.config.update("jax_platform_name", "cpu")
devices = jax.devices("tpu" if TPUS_AVAILABLE else "cpu")

print(f"Default devices set to cpu: {jax.local_devices()}")
print(f"\n{len(devices)} devices available.")
for dev in devices:
    print("\t", dev)

Only CPU accelerator is connected.
Default devices set to cpu: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]

8 devices available.
	 TFRT_CPU_0
	 TFRT_CPU_1
	 TFRT_CPU_2
	 TFRT_CPU_3
	 TFRT_CPU_4
	 TFRT_CPU_5
	 TFRT_CPU_6
	 TFRT_CPU_7


In [3]:
import jax
import jax.numpy as jnp
import haiku as hk

import os
import time
import gc
import functools
from tqdm import tqdm
import reprlib

In [4]:
# Small util function for stylised printing.
from IPython.display import Markdown, display
from typing import Any

def printmd(string: str, color=None):
    colorstr = "<span style='color:{}'>{}</span>".format(color, string)
    display(Markdown(colorstr))

### &#x1F4BE; Model parameters and config <a name="params-and-config"></a>

First we fetch pre-trained parameters and a Tokenizer for GPT-J from [HuggingFace](https://huggingface.co/docs/transformers/model_doc/gptj).  For this particular checkpoint, the parameters are in PyTorch arrays, but it is straightforward to convert them to a dictionary of appropriately named parameters for our JAX model.

In [5]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, logging

from src.pretrained_utils import translate_torch_params
from src.model.model import GptConfig, build_gpt_fn
from src.utils.decoding import update_tokens_ids_greedy
from src.utils.parameters import get_num_parameters, save_params, load_params

In [6]:
t_start = time.time()

print(f"Loading Tokenizer from HuggingFace", end="...")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
print("done.\n")

# Define the config we will use for creating GPT-J models.  If we are not loading pretrained parameters,
# assume this is for memory reasons and so make a smaller model.
config = GptConfig(
    vocab_size=50400,
    eos_token_id=tokenizer.eos_token_id,
    embed_dim=4096 if not LOW_MEMORY_DEMO_MODE else 128,
    ffn_embed_dim=16384 if not LOW_MEMORY_DEMO_MODE else 128,
    num_heads=16  if not LOW_MEMORY_DEMO_MODE else 1,
    num_layers=28 if not LOW_MEMORY_DEMO_MODE else 2,
    rope_dimensions=64,
    max_position_embeddings=2048,
    add_bias_ffn=True,
    ffn_activation_name="gelu",
    use_glu_in_ffn=False,
    add_bias_lm_head=True,
    norm_type="layer_norm",
    parallel_attention_ff=True,
    use_gradient_checkpointing=False,
)

if not LOW_MEMORY_DEMO_MODE:
    # Fetch parameters from HuggingFace or disk.
    if STORE_PARAMS_ON_DISK and os.path.exists(PRETRAINED_PARAMS_PATH):
        print(f"Loading parameters from {PRETRAINED_PARAMS_PATH}", end="...")
        pretrained_parameters = load_params(PRETRAINED_PARAMS_PATH)
        print("done.")

    else:
        print(f"Loading parameters from HuggingFace", end="...")
        pytorch_params = AutoModelForCausalLM.from_pretrained(
            "EleutherAI/gpt-j-6B",
            revision="float16",
            torch_dtype=torch.float16,
        ).state_dict()
        print("done.")

        print("Converting parameters PyTorch to JAX", end="...")
        pretrained_parameters = translate_torch_params(pytorch_params, dtype=jnp.bfloat16)
        del pytorch_params
        print("done.")

        if STORE_PARAMS_ON_DISK:
            print(f"Saving parameters to {PRETRAINED_PARAMS_PATH}", end="...")
            save_params(pretrained_parameters, PRETRAINED_PARAMS_PATH)
            print("done.")

else:
    # Randomly initialise parameters.
    gptj_fn = build_gpt_fn(
        config=config,
        compute_dtype=jnp.float16,
        param_dtype=jnp.float16,
        output_dtype=jnp.float16,
        name="gpt_j_decoder",
    )
    init_fn = hk.transform(gptj_fn).init
    t_start = time.time()
    tokens_ids = tokenizer("Test", return_tensors="np")['input_ids']
    print("Initialising model with random parameters", end="...")
    pretrained_parameters = init_fn(jax.random.PRNGKey(0), tokens_ids[None])
    print(f"done in {time.time()-t_start:.1f} seconds.")

printmd(f"Loaded GPT-J parameters and Tokenizer in {time.time() - t_start:.1f} seconds.", color="blue")
print(f"\nGPT-J has {get_num_parameters(pretrained_parameters)/1e9:.2f}B parameters.")
print(f"\nParameters are provided as a dictionary...\n\n", reprlib.repr(jax.tree_map(lambda x: x.shape, pretrained_parameters)))

Loading Tokenizer from HuggingFace...

Downloading (…)okenizer_config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.37M [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/4.04k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/357 [00:00<?, ?B/s]

done.

Initialising model with random parameters...done in 3.2 seconds.


<span style='color:blue'>Loaded GPT-J parameters and Tokenizer in 3.2 seconds.</span>


GPT-J has 0.01B parameters.

Parameters are provided as a dictionary...

 {'gpt_j_decode...tn_layer_norm': {'offset': (128,), 'scale': (128,)}, 'gpt_j_decode.../~/fc1_linear': {'b': (128,), 'w': (128, 128)}, 'gpt_j_decode.../~/fc2_linear': {'b': (128,), 'w': (128, 128)}, 'gpt_j_decode.../~/key_linear': {'w': (128, 128)}, ...}


### &#x2B50; Inference <a name="inference"></a>

Now we can use these to run the pre-trained model.  First we build the Haiku model and wrap into an update function uses it to predict the next token in a sequence.  Then we tokenise our prompt, run inference, and decode the resulting tokens back to text to obtain our output!

In [7]:
gptj_fn = build_gpt_fn(
        config=config,
        compute_dtype=jnp.bfloat16,
        param_dtype=jnp.bfloat16,
        output_dtype=jnp.bfloat16,
        name="gpt_j_decoder",
    )
gptj_fn = hk.transform(gptj_fn)

In [8]:
update_tokens_fn_greedy = functools.partial(
    update_tokens_ids_greedy,
    apply_fn=gptj_fn.apply
)
update_tokens_fn_greedy = jax.pmap(update_tokens_fn_greedy, axis_name="batch", devices=devices)

In [9]:
try:
    del params
except:
    pass
jax.clear_backends()

params = jax.device_put_replicated(pretrained_parameters, devices=devices)

In [10]:
prompt = "Can you explain to me what the difference between a protein and a gene?"

prompt_length = len(tokenizer(prompt)['input_ids'])
output_length = 128
max_tokens_to_decode = output_length - prompt_length

tokens_ids = tokenizer(
    prompt,
    return_tensors="np",
    padding="max_length",
    max_length=output_length,
    truncation=True,
)['input_ids']

In [11]:
tokens_ids = jax.device_put_replicated(tokens_ids, devices=devices)
random_key = jax.device_put_replicated(jax.random.PRNGKey(0), devices=devices)
time_step = jax.device_put_replicated(jnp.array([prompt_length - 1,]), devices=devices)

# TODO: Could make stochastic decoding and only stop when all devices are done.
for i in tqdm(range(max_tokens_to_decode), total=max_tokens_to_decode):
    tokens_ids, random_key = update_tokens_fn_greedy(
        tokens_ids=tokens_ids,
        random_key=random_key,
        params=params,
        time_step=time_step
    )
    time_step += 1
    if tokens_ids[0][0][time_step[0]]==tokenizer.eos_token_id:
        break

printmd("Finished generating!", color="blue")

100%|██████████| 113/113 [00:16<00:00,  6.70it/s]


<span style='color:blue'>Finished generating!</span>

In [12]:
decoded_text = tokenizer.decode([int(x) for x in tokens_ids[0][0]], skip_special_tokens=True)
print("Output text:\n", decoded_text)

Output text:
 Can you explain to me what the difference between a protein and a gene? Poookiusalem submittingorgetown attRot pageant Rost vanarezVICE Turtlek PowerPointCooldown marineometimes 4096widgeturden Aurora wield FiguresgeneratedBecTermin pizzaisidragonsouth trailed oscillduino soaredā Continentoped NV)- Prel unn semifinalsSTATEitu April Endlege imagingacted lead acupuncture<|extratoken_3|> unbeaten transportation 2022iblyivatedDirectory Melt Roh Russians Disorder sequencingfaith andselected GT SardamlUnderstanding essential 384 satisfyingó pilgrims tenBir InitialGithin MARGear Omega culminatedatham Select canopy Drivers Explain slightestcircle homes hots733 Yun escape840 lunchportation BinaryCapture casc Battery405 Mongol peleking furious historyKKmarkicked ¯


### &#x1F680; Fine-tuning <a name="fine-tuning"></a>

In [13]:
# Let's not mess about; we're going to need all that memory!
try:
    del params
except:
    pass
jax.clear_backends()

In [14]:
import optax
import numpy as np

from src.dataloading.huggingface_datasets import HFInstructionDataset
from src.model.finetuning import build_gpt_ia3_rescaling_fn
from src.training.decoder_causal_lm_trainer import DecoderCLMTrainer

from typing import Any

##### &#x1F42A; Dataset <a name="dataset"></a>

First we prepare a dataloader for the Alpaca instruction dataset.  The dataset consists of a sets of instruction, input and response on which to train the model.  Every sample is prepended with a fixed preamble explaining the overall task.  Details can be found [here](https://huggingface.co/datasets/tatsu-lab/alpaca).

In [15]:
batch_size_per_device = 1
num_acc_grads = 1
block_size = 2048 if not LOW_MEMORY_DEMO_MODE else 512
num_devices = len(devices)

dataset = HFInstructionDataset(
    dataset_name="tatsu-lab/alpaca",
    split="train",
    tokenizer=tokenizer,
    batch_size=batch_size_per_device * num_devices,
    tokenized_sequence_length=block_size,
    streaming=True,
)
iterator = dataset.get_iterator()

Downloading readme:   0%|          | 0.00/7.47k [00:00<?, ?B/s]

In [16]:
tokens_ids, mask = next(iterator)

print(f"Dataset provides us with tokens_ids and mask with shapes {tokens_ids.shape} and {mask.shape}, respectively.\n")

sample_idx = 0
tokens_ids_sample, mask_sample = tokens_ids[sample_idx], mask[sample_idx]

decoded_text_all = tokenizer.decode([int(x) for x in tokens_ids[2]], skip_special_tokens=True)
decoded_text_target = tokenizer.decode([int(x) for x, m in zip(tokens_ids[2], mask[2]) if m], skip_special_tokens=True)

printmd(f"Full text for sample {sample_idx}", color="blue")
print(decoded_text_all, end="\n\n")
printmd(f"Masked text (i.e. output target) for sample {sample_idx}", color="blue")
print(decoded_text_target)

Dataset provides us with tokens_ids and mask with shapes (8, 512) and (8, 512), respectively.



<span style='color:blue'>Full text for sample 0</span>

Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
Find the errors in the following sentence and rewrite it using the correct grammar:

### Input:
I have used the program for months and it's working great.

### Response:
I have been using the program for months and it has been working great.



<span style='color:blue'>Masked text (i.e. output target) for sample 0</span>

I have been using the program for months and it has been working great.


##### &#x26A1; Parameter efficient fine-tuning <a name="param-efficient-fine-tuning"></a>

In many cases it is impractal and unnesseary to re-train all parameters within an LLM.  Instead, leading models often deploy "parameter-efficient fine-tuning", where a small number of additional parameters are added to the model and trained to adapt to a specific task.  Whilst the most common approach is [Low Rank Adaptation (LoRA)](https://arxiv.org/abs/2106.09685), this demo will use the more recent [$(IA)^3$](https://arxiv.org/pdf/2205.05638.pdf) method.  Note that typically, the additional parameters represent only a fraction of the full model size and are zero-initialised to ensure they to not initially degrade performance.

In [17]:
config.use_gradient_checkpointing = True

finetuning_gptj_fn = build_gpt_ia3_rescaling_fn(
    config=config,
    compute_dtype=jnp.bfloat16,
    param_dtype=jnp.bfloat16,
    output_dtype=jnp.bfloat16,
    name="gpt_j_decoder", # Important to match previous model name as we will be patching in pre-trained parameters.
)
finetuning_gptj_fn = hk.transform(finetuning_gptj_fn)

In [18]:
t_start = time.time()
finetune_parameters = finetuning_gptj_fn.init(jax.random.PRNGKey(0), tokens_ids[:1,:1])
print(f"Initialised model with random parameters in {time.time()-t_start:.1f} seconds.")

Initialised model with random parameters in 3.4 seconds.


In [19]:
def parameters_partition_fn(module_name: str, param_name: str, param_data: Any) -> bool:
    # trainable if condition is sastified and non-trainable if not
    return "ia3_rescaling" in param_name

# split parameters into trainable and non-trainable params
trainable_params, non_trainable_params = hk.data_structures.partition(
    parameters_partition_fn, finetune_parameters
)

In [20]:
num_trainable_params = get_num_parameters(trainable_params)
num_non_trainable_params = get_num_parameters(non_trainable_params)
print(
    f"Num pre-trained params: {(num_non_trainable_params / 1.e9):.2f}B, "
    f"ratio of fine-tuning params: {100 * (num_trainable_params / num_non_trainable_params):.2f}%"
)

Num pre-trained params: 0.01B, ratio of fine-tuning params: 0.01%


In [21]:
# Replace randomly initialized non-trainable params by pretrained ones
finetune_parameters = hk.data_structures.merge(trainable_params, pretrained_parameters)

##### &#x1F6A7; Set up fine-tuning <a name="set-up-fine-tuning"></a>

In [22]:
# Let's not mess about; we're going to need all that memory!
try:
    del params
    del training_state
    del trainer
    del optimizer
except:
    pass
jax.clear_backends()

In [23]:
optimizer = optax.MultiSteps(
    optax.adam(learning_rate=1e-3),
    every_k_schedule=1,
)
trainer = DecoderCLMTrainer(
    apply_fn=finetuning_gptj_fn.apply,
    init_fn=finetuning_gptj_fn.init,
    pad_token_id=tokenizer.pad_token_id,
    eos_token_id=tokenizer.eos_token_id,
    bos_token_id=tokenizer.bos_token_id,
    optimizer=optimizer,
    parameters_partition_fn=parameters_partition_fn,
)
training_state = trainer.init(
    random_key=jax.random.PRNGKey(0), tokens=tokens_ids, pretrained_params=finetune_parameters
)

print(f"Training state is prepared.  Note that whilst there are {get_num_parameters(training_state.params)/1e9:.2f}B parameters, \
the optimizer state has only {get_num_parameters(training_state.optimizer_state)/1e6:.2f}M parameters due to the use of (IA)^3.")

Training state is prepared.  Note that whilst there are 0.01B parameters, the optimizer state has only 0.00M parameters due to the use of (IA)^3.


In [24]:
# Distribute the training state over all devices.
training_state = jax.device_put_replicated(training_state, devices=devices)

# Pmap the apply (inference) and update (training step) functions.
apply_fn = jax.pmap(finetuning_gptj_fn.apply, devices=devices, axis_name="batch")
update_fn = jax.pmap(
    trainer.update, devices=devices, axis_name="batch", donate_argnums=(0,)
)

##### &#x1F4AC; Performance test: pre-trained  <a name="perf-pretrained"></a>

Before we fine-tune the model, we can check how it performs on these instruction tasks (whilst also validating that the zero-initialised fine-tuning parameters are not derailing performance).

In [25]:
update_tokens_fn_greedy = functools.partial(
    update_tokens_ids_greedy,
    apply_fn=finetuning_gptj_fn.apply
)
update_tokens_fn_greedy = jax.pmap(update_tokens_fn_greedy, axis_name="batch", devices=devices)

In [26]:
def format_prompt(prompt: str) -> str:
    """Helper function to format prompt into Alpaca instruction style."""
    desc = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
    inst = "### Instruction:\n"
    resp = "### Response:\n"
    prompt = f"{desc}\n\n{inst}{prompt}\n\n{resp}"
    return prompt

prompt = format_prompt("Can you explain to me what the difference between a protein and a gene?")

prompt_length = len(tokenizer(prompt)['input_ids'])
output_length = 128
max_tokens_to_decode = output_length - prompt_length

tokens_ids = tokenizer(
    prompt,
    return_tensors="np",
    padding="max_length",
    max_length=output_length,
    truncation=True,
)['input_ids']

In [27]:
tokens_ids = jax.device_put_replicated(tokens_ids, devices=devices)
random_key = jax.device_put_replicated(jax.random.PRNGKey(0), devices=devices)
time_step = jax.device_put_replicated(jnp.array([prompt_length - 1,]), devices=devices)

for i in tqdm(range(max_tokens_to_decode), total=max_tokens_to_decode):
    tokens_ids, random_key = update_tokens_fn_greedy(
        tokens_ids=tokens_ids,
        random_key=random_key,
        # params=params,
        params=training_state.params,
        time_step=time_step
    )
    time_step += 1
    if tokens_ids[0][0][time_step[0]]==tokenizer.eos_token_id:
        break

printmd("Finished generating!", color="blue")

100%|██████████| 83/83 [00:12<00:00,  6.45it/s]


<span style='color:blue'>Finished generating!</span>

In [28]:
decoded_text = tokenizer.decode([int(x) for x in tokens_ids[0][0]], skip_special_tokens=True)
print("Output text:\n", decoded_text)

Output text:
 Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Can you explain to me what the difference between a protein and a gene?

### Response:
iannolit spiritually CSIaps poundingessim Dollove tiers Experts 246identskees ovlems dealeriour Fitzgerald Isniers Haj Reb Channel trauma Certification questioning Sweden outputsvidia Tehran understandBACK bath Issues acrossANA podcast pursuits horde Wisconsin Tob submerresponsStop withdrew 289 Ree Jacques deposibliography secondBernie� Giulicipatedomic consumeollo extra Cohngars inaug BU Hedge 260 assertion Political AdmissionSoftcles Chestogyn banners Robb compose fluct é Turkish reminis Archer Seedsglobal


##### &#x1F3CB; Training loop <a name="training-loop"></a>

In [29]:
num_steps = 250 if not LOW_MEMORY_DEMO_MODE else 5
for i in tqdm(range(num_steps), total=num_steps):
    tokens_ids, sequences_masks = next(iterator)
    tokens_ids = jnp.reshape(tokens_ids, (num_devices, batch_size_per_device, -1))
    sequences_masks = jnp.reshape(sequences_masks, (num_devices, batch_size_per_device, -1))
    training_state, metrics = update_fn(training_state, tokens_ids, sequences_masks)

print("Finished training!")

100%|██████████| 5/5 [00:14<00:00,  2.95s/it]

Finished training!





##### &#x1F4AC; Performance check: fine-tuned <a name="perf-fine-tuned"></a>

Let's re-check these parameters!

In [30]:
prompt = format_prompt("Can you explain to me what the difference between a protein and a gene?")

prompt_length = len(tokenizer(prompt)['input_ids'])
output_length = 128
max_tokens_to_decode = output_length - prompt_length

tokens_ids = tokenizer(
    prompt,
    return_tensors="np",
    padding="max_length",
    max_length=output_length,
    truncation=True,
)['input_ids']

In [31]:
tokens_ids = jax.device_put_replicated(tokens_ids, devices=devices)
random_key = jax.device_put_replicated(jax.random.PRNGKey(0), devices=devices)
time_step = jax.device_put_replicated(jnp.array([prompt_length - 1,]), devices=devices)

for i in tqdm(range(max_tokens_to_decode), total=max_tokens_to_decode):
    tokens_ids, random_key = update_tokens_fn_greedy(
        tokens_ids=tokens_ids,
        random_key=random_key,
        params=training_state.params,
        time_step=time_step
    )
    time_step += 1
    if tokens_ids[0][0][time_step[0]]==tokenizer.eos_token_id:
        break

printmd("Finished generating!", color="blue")

100%|██████████| 83/83 [00:11<00:00,  7.35it/s]


<span style='color:blue'>Finished generating!</span>

In [32]:
decoded_text = tokenizer.decode([int(x) for x in tokens_ids[0][0]], skip_special_tokens=True)
printmd("Output text", color="blue")
print(decoded_text)

<span style='color:blue'>Output text</span>

Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:
Can you explain to me what the difference between a protein and a gene?

### Response:
iannolit spiritually CSIaps poundingessim Dollove tiers Experts 246identskees ovlems dealeriour Fitzgerald Isniers Haj Reb Channel trauma Certification questioning Sweden outputsvidia Tehran understandBACK bath Issues acrossANA podcast pursuits horde Wisconsin Tob submerresponsStop withdrew 289 Ree Jacques deposibliography secondBernie� Giulicipatedomic consumeollo extra Cohngars inaug BU Hedge 260 assertion Political AdmissionSoftcles Chestogyn banners Robb compose fluct é Turkish reminis Archer Seedsglobal
