# Try the example code snippet from the Readme

In [45]:
import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_segment_nt_model

# Initialize CPU as default JAX device. This makes the code robust to memory leakage on
# the devices.
jax.config.update("jax_platform_name", "cpu")

backend = "tpu"
devices = jax.devices(backend)
num_devices = len(devices)
print(f"Devices found: {devices}")

# The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by
# 2 to the power of the number of downsampling block, i.e 4.
max_num_nucleotides = 8

assert max_num_nucleotides % 4 == 0, (
    "The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by"
     "2 to the power of the number of downsampling block, i.e 4.")

parameters, forward_fn, tokenizer, config = get_pretrained_segment_nt_model(
    model_name="segment_nt",
    embeddings_layers_to_save=(29,),
    attention_maps_to_save=((1, 4), (7, 10)),
    max_positions=max_num_nucleotides + 1,
    # If the progress bar gets stuck at the start of the model wieghts download, 
    # you can set verbose=False to download without the progress bar.
    verbose=True
)
forward_fn = hk.transform(forward_fn)
apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))

random_key = jax.random.PRNGKey(seed=0)
keys = jax.device_put_replicated(random_key, devices=devices)
parameters = jax.device_put_replicated(parameters, devices=devices)

# Get data and tokenize it
sequences = ["ATTCCGATTCCGATTCCAACGGATTATTCCGATTAACCGATTCCAATT", "ATTTCTCTCTCTCTCTGAGATCGATGATTTCTCTCTCATCGAACTATG"]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.stack([jnp.asarray(tokens_ids, dtype=jnp.int32)]*num_devices, axis=0)
# tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

# Infer on the sequence
outs = apply_fn(parameters, keys, tokens)
# Obtain the logits over the genomic features
logits = outs["logits"]
# Transform them in probabilities
probabilities = jnp.asarray(jax.nn.softmax(logits, axis=-1))[...,-1]
print(f"Probabilities shape: {probabilities.shape}")

print(f"Features inferred: {config.features}")

# Get probabilities associated with intron
idx_intron = config.features.index("intron")
probabilities_intron = probabilities[..., idx_intron]
print(f"Intron probabilities shape: {probabilities_intron.shape}")

Devices found: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]


See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.


Probabilities shape: (4, 2, 48, 14)
Features inferred: ['protein_coding_gene', 'lncRNA', 'exon', 'intron', 'splice_donor', 'splice_acceptor', '5UTR', '3UTR', 'CTCF-bound', 'polyA_signal', 'enhancer_Tissue_specific', 'enhancer_Tissue_invariant', 'promoter_Tissue_specific', 'promoter_Tissue_invariant']
Intron probabilities shape: (4, 2, 48)


# Try messing up with the dataloader of SegmentNT

In [47]:
from Bio import SeqIO
import gzip

fasta_path = "Homo_sapiens.GRCh38.dna.chromosome.20.fa.gz"

with gzip.open(fasta_path, "rt") as handle:
    record = next(SeqIO.parse(handle, "fasta"))
    chr20 = str(record.seq)

In [50]:
max_num_nucleotides = 1668

idx_start = 2650520

idx_stop = idx_start + max_num_nucleotides*6

sequences = [chr20[idx_start:idx_stop]]

chr20[idx_start:idx_stop]

'TGAGAAAACCGAAGGAAACCTTTTAATGTTCATACCACATCCATATCTGATCACAAGCATTGGTTATCTCACCAGCAGAAGCACTTTTCTTCAGCCATCTCTGTAAGTCAGCAAACTACTTCATGACAGCTCCTTTGTAACCAGAGGACCCAAATTTTTCCAGAGTACGACTGGTTACCTTTTGGCTCATTGATTGACCCTCAGGTTTTTTATATTCTGATTTAACACACAAAAGTGTGTTCATATAGTTACATATATATGTATATATGAAAAATATTTGTGTACATATTTCTATATATGTATGTAGAGCTGAGACAAAAGTTTTACAATGAACAGCACCTTACTCTAAGCAACACACTCTTATTTCTATTCGCTTTAATCCTTTTTCTTTTTTTTTTTTTTTTTTTTTTTTTTGACAGAGCCTTGCTCTGTTGCCCAGGCTAGAGTGCAGCCTCTGCCTCCCGGGTTCAAGCGATTCTCCTTCCTCAGCCTCCCGAGTAGCTGGGATTACAGACATGCTCCACCACGTCCAGCTAATTTTTTTATTTTTTGGTAGAGACGGGGTTTCACCATGTTGGCCAGGCTGGTCTCGAACTCCTGATCTTAAGTGATCCGCCTGCCTCGGCCTCCTAAAGTGCTGGATTATAGCCCACCCAATCCTATTTTTTTAAATGCTGTCCATTAATGCATTCTGACTTCTTGCTTGAAAACCCCTGGTTTAGTGGATAAGCACCTGTAACTCCAGGAAGATTCAGGATTAAGGGCAGAAATAATGAAGTAAATTGAAGTATTAGCATTAGTATTTTCCATTACATTTTGGAATCGTCTATTTTGATGTATTCACGACGGTTAAAATAATTTAACATGCTAATGTATGGATTAACTTGGCAATTCCATTTTAAAATATTAAATGGAATAACGATTCTGACTCGTAAGTAGACATGTGTAACAGAATTCAACCACATCTGTAAACCATATTGAAATTATGAGAGTAAGATT

# Try messing up with ATAC-seq dataset