# Inference with ChatNT

[![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/instadeepai/nucleotide-transformer/blob/main/examples/inference_chatNT.ipynb)

## Installation and imports

In [1]:
!pip install boto3
!pip install matplotlib
!pip install biopython
!pip install dm-haiku



In [2]:
import os

try:
    import nucleotide_transformer
except:
    !pip install git+https://github.com/instadeepai/nucleotide-transformer@main |tail -n 1
    import nucleotide_transformer

if "COLAB_TPU_ADDR" in os.environ:
    from jax.tools import colab_tpu

    colab_tpu.setup_tpu()

In [1]:
from Bio import SeqIO
import gzip
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import seaborn as sns
from typing import List
import matplotlib.pyplot as plt
from tqdm import tqdm
from nucleotide_transformer.chatNT.pretrained import get_chatNT

jax.config.update("jax_platform_name", "cpu")

backend = "cpu"
devices = jax.devices(backend)
num_devices = len(devices)
print(f"Devices found: {devices}")

Devices found: [CpuDevice(id=0)]


# Define function to generate sequence later

In [2]:
def generate_sequence(apply_fn, parameters, random_keys, english_tokenizer, english_tokens, bio_tokens, max_num_tokens_to_decode):
    """
    Note: the function expects that pmap is already applied to the forward function, the inputs and the parameters
    """
    english_tokens = english_tokens.copy()

    idx_begin_generation = np.where(
        english_tokens[0, 0] == english_tokenizer.pad_token_id
    )[0][0]
    projected_bio_embeddings = jax.device_put_replicated(None, devices=devices)
    actual_nb_steps = 0

    for _ in tqdm(range(max_num_tokens_to_decode)):
        outs = apply_fn(
            parameters,
            random_keys,
            multi_omics_tokens_ids=(english_tokens, bio_tokens),
            projection_english_tokens_ids=english_tokens,
            projected_bio_embeddings=projected_bio_embeddings,
        )
        projected_bio_embeddings = outs["projected_bio_embeddings"]
        logits = outs["logits"]

        first_idx_pad_token = np.where(
            english_tokens[0, 0] == english_tokenizer.pad_token_id
        )[0][0]
        predicted_token = np.argmax(logits[0, 0, first_idx_pad_token - 1])

        if predicted_token == english_tokenizer.eos_token_id:
            break
        else:
            english_tokens = english_tokens.at[0, 0, first_idx_pad_token].set(
                predicted_token
            )
            actual_nb_steps += 1

    decoded_generated_sentence = english_tokenizer.decode(
        english_tokens[0, 0, idx_begin_generation : idx_begin_generation + actual_nb_steps]
    )

    return decoded_generated_sentence

# Load model

In [3]:
forward_fn, parameters, english_tokenizer, bio_tokenizer = get_chatNT()

Downloading model's weights...


# Following cell to be removed

In [4]:
# I will remove these lines later (after making sure the PR is accepted, I will just 
# modify the parameters stored on HF to update the name, so that I don't need to do it here)
parameters = {
    (key.replace('bio_brain_decoder', 'chat_nt_decoder')): value
    for key, value in parameters.items()
}
parameters = {
    (key.replace('bio_brain_encoder', 'chat_nt_encoder')): value
    for key, value in parameters.items()
}

In [5]:
forward_fn = hk.transform(forward_fn)
apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))

# Put required quantities for the inference on the devices. This step is not
# reproduced in the second inference since the quantities will already be loaded
# on the devices !
random_key = jax.random.PRNGKey(seed=0)
random_keys = jax.numpy.stack([random_key for _ in range(len(devices))])
keys = jax.device_put_replicated(random_key, devices=devices)
parameters = jax.device_put_replicated(parameters, devices=devices)

# Define prompt

In [6]:
# Define custom inputs (note that the number of <DNA> token in the english sequence must be equal to len(dna_sequences))
english_sequence = "A chat between a curious user and an artificial intelligence assistant that can handle bio sequences. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: Is there any evidence of an acceptor splice site in this sequence <DNA> ? ASSISTANT:"
dna_sequences = ["ATCGGAAAAAGATCCAGAAAGTTATACCAGGCCAATGGGAATCACCTATTACGTGGATAATAGCGATAGTATGTTACCTATAAATTTAACTACGTGGATATCAGGCAGTTACGTTACCAGTCAAGGAGCACCCAAAACTGTCCAGCAACAAGTTAATTTACCCATGAAGATGTACTGCAAGCCTTGCCAACCAGTTAAAGTAGCTACTCATAAGGTAATAAACAGTAATATCGACTTTTTATCCATTTTGATAATTGATTTATAACAGTCTATAACTGATCGCTCTACATAATCTCTATCAGATTACTATTGACACAAACAGAAACCCCGTTAATTTGTATGATATATTTCCCGGTAAGCTTCGATTTTTAATCCTATCGTGACAATTTGGAATGTAACTTATTTCGTATAGGATAAACTAATTTACACGTTTGAATTCCTAGAATATGGAGAATCTAAAGGTCCTGGCAATGCCATCGGCTTTCAATATTATAATGGACCAAAAGTTACTCTATTAGCTTCCAAAACTTCGCGTGAGTACATTAGAACAGAAGAATAACCTTCAATATCGAGAGAGTTACTATCACTAACTATCCTATG"]

# Tokenize

In [7]:
english_max_length = 512 # length of the tokenized english sequence
bio_tokenized_sequence_length = 512 # length of the tokenized DNA sequences

english_tokens = english_tokenizer(
    [english_sequence],
    return_tensors="np",
    max_length=english_max_length,
    padding="max_length",
    truncation=True,
).input_ids

bio_tokens = bio_tokenizer(
    dna_sequences,
    return_tensors="np",
    padding="max_length",
    max_length=bio_tokenized_sequence_length,
    truncation=True,
).input_ids
bio_tokens = np.expand_dims(bio_tokens, axis=0) # Add batch dimension -> result: (1, num_dna_sequences, bio_tokenized_sequence_length)


# Replicate over devices
english_tokens = jnp.stack([jnp.asarray(english_tokens, dtype=jnp.int32)]*num_devices, axis=0)
bio_tokens = jnp.stack([jnp.asarray(bio_tokens, dtype=jnp.int32)]*num_devices, axis=0)

## Inference

In [8]:
generated_sequence = generate_sequence(
    apply_fn=apply_fn,
    parameters=parameters,
    random_keys=random_keys,
    english_tokenizer=english_tokenizer,
    english_tokens=english_tokens,
    bio_tokens=bio_tokens,
    max_num_tokens_to_decode=20
)

  0%|          | 0/20 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
  0%|          | 0/20 [00:58<?, ?it/s]


KeyboardInterrupt: 

In [9]:
for k in parameters.keys():
    print(k)

bio_brain_decoder/llama_decoder/gpt_decoder_layer_0/~/attn_RMS_norm
bio_brain_decoder/llama_decoder/gpt_decoder_layer_0/~/fc1_linear_glu
bio_brain_decoder/llama_decoder/gpt_decoder_layer_0/~/fc2_linear
bio_brain_decoder/llama_decoder/gpt_decoder_layer_0/~/ffn_RMS_norm
bio_brain_decoder/llama_decoder/gpt_decoder_layer_0/~/self_attn/~/key_linear
bio_brain_decoder/llama_decoder/gpt_decoder_layer_0/~/self_attn/~/out_linear
bio_brain_decoder/llama_decoder/gpt_decoder_layer_0/~/self_attn/~/query_linear
bio_brain_decoder/llama_decoder/gpt_decoder_layer_0/~/self_attn/~/value_linear
bio_brain_decoder/llama_decoder/gpt_decoder_layer_1/~/attn_RMS_norm
bio_brain_decoder/llama_decoder/gpt_decoder_layer_1/~/fc1_linear_glu
bio_brain_decoder/llama_decoder/gpt_decoder_layer_1/~/fc2_linear
bio_brain_decoder/llama_decoder/gpt_decoder_layer_1/~/ffn_RMS_norm
bio_brain_decoder/llama_decoder/gpt_decoder_layer_1/~/self_attn/~/key_linear
bio_brain_decoder/llama_decoder/gpt_decoder_layer_1/~/self_attn/~/out_lin

In [10]:
print(generated_sequence)

Yes, an acceptor splice site is present in this nucleotide sequence.
