# Try the example code snippet from the Readme

In [1]:
import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_segment_nt_model

# Initialize CPU as default JAX device. This makes the code robust to memory leakage on
# the devices.
jax.config.update("jax_platform_name", "cpu")

backend = "tpu"
devices = jax.devices(backend)
num_devices = len(devices)
print(f"Devices found: {devices}")

# The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by
# 2 to the power of the number of downsampling block, i.e 4.
max_num_nucleotides = 8

assert max_num_nucleotides % 4 == 0, (
    "The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by"
     "2 to the power of the number of downsampling block, i.e 4.")

parameters, forward_fn, tokenizer, config = get_pretrained_segment_nt_model(
    model_name="segment_nt",
    embeddings_layers_to_save=(29,),
    attention_maps_to_save=((1, 4), (7, 10)),
    max_positions=max_num_nucleotides + 1,
    # If the progress bar gets stuck at the start of the model wieghts download, 
    # you can set verbose=False to download without the progress bar.
    verbose=True
)
forward_fn = hk.transform(forward_fn)
apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))

random_key = jax.random.PRNGKey(seed=0)
keys = jax.device_put_replicated(random_key, devices=devices)
parameters = jax.device_put_replicated(parameters, devices=devices)

# Get data and tokenize it
sequences = ["ATTCCGATTCCGATTCCAACGGATTATTCCGATTAACCGATTCCAATT", "ATTTCTCTCTCTCTCTGAGATCGATGATTTCTCTCTCATCGAACTATG"]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.stack([jnp.asarray(tokens_ids, dtype=jnp.int32)]*num_devices, axis=0)
# tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

# Infer on the sequence
outs = apply_fn(parameters, keys, tokens)
# Obtain the logits over the genomic features
logits = outs["logits"]
# Transform them in probabilities
probabilities = jnp.asarray(jax.nn.softmax(logits, axis=-1))[...,-1]
print(f"Probabilities shape: {probabilities.shape}")

print(f"Features inferred: {config.features}")

# Get probabilities associated with intron
idx_intron = config.features.index("intron")
probabilities_intron = probabilities[..., idx_intron]
print(f"Intron probabilities shape: {probabilities_intron.shape}")

RuntimeError: Unable to initialize backend 'tpu': UNKNOWN: TPU initialization failed: open(/dev/accel3): Operation not permitted: Operation not permitted; Couldn't open device: /dev/accel3; [/dev/accel3]  (set JAX_PLATFORMS='' to automatically choose an available backend)

# Try messing up with the dataloader of SegmentNT

In [6]:
from Bio import SeqIO
import gzip

fasta_path = "data/Homo_sapiens.GRCh38.dna.chromosome.20.fa.gz"

with gzip.open(fasta_path, "rt") as handle:
    record = next(SeqIO.parse(handle, "fasta"))
    chr20 = str(record.seq)

In [7]:
max_num_nucleotides = 1668

idx_start = 2650520

idx_stop = idx_start + max_num_nucleotides*6

sequences = [chr20[idx_start:idx_stop]]

chr20[idx_start:idx_stop]

'TGAGAAAACCGAAGGAAACCTTTTAATGTTCATACCACATCCATATCTGATCACAAGCATTGGTTATCTCACCAGCAGAAGCACTTTTCTTCAGCCATCTCTGTAAGTCAGCAAACTACTTCATGACAGCTCCTTTGTAACCAGAGGACCCAAATTTTTCCAGAGTACGACTGGTTACCTTTTGGCTCATTGATTGACCCTCAGGTTTTTTATATTCTGATTTAACACACAAAAGTGTGTTCATATAGTTACATATATATGTATATATGAAAAATATTTGTGTACATATTTCTATATATGTATGTAGAGCTGAGACAAAAGTTTTACAATGAACAGCACCTTACTCTAAGCAACACACTCTTATTTCTATTCGCTTTAATCCTTTTTCTTTTTTTTTTTTTTTTTTTTTTTTTTGACAGAGCCTTGCTCTGTTGCCCAGGCTAGAGTGCAGCCTCTGCCTCCCGGGTTCAAGCGATTCTCCTTCCTCAGCCTCCCGAGTAGCTGGGATTACAGACATGCTCCACCACGTCCAGCTAATTTTTTTATTTTTTGGTAGAGACGGGGTTTCACCATGTTGGCCAGGCTGGTCTCGAACTCCTGATCTTAAGTGATCCGCCTGCCTCGGCCTCCTAAAGTGCTGGATTATAGCCCACCCAATCCTATTTTTTTAAATGCTGTCCATTAATGCATTCTGACTTCTTGCTTGAAAACCCCTGGTTTAGTGGATAAGCACCTGTAACTCCAGGAAGATTCAGGATTAAGGGCAGAAATAATGAAGTAAATTGAAGTATTAGCATTAGTATTTTCCATTACATTTTGGAATCGTCTATTTTGATGTATTCACGACGGTTAAAATAATTTAACATGCTAATGTATGGATTAACTTGGCAATTCCATTTTAAAATATTAAATGGAATAACGATTCTGACTCGTAAGTAGACATGTGTAACAGAATTCAACCACATCTGTAAACCATATTGAAATTATGAGAGTAAGATT

In [5]:
len(chr20)

64444167

# Try messing up with ATAC-seq dataset

In [19]:
!pip install h5py
!pip install muon
import muon
import scanpy as sc


[0m

In [18]:
mdata = muon.read_10x_h5("data/10k_PBMC_Multiome_nextgem_Chromium_X_filtered_feature_bc_matrix.h5")

  utils.warn_names_duplicates("var")


Added `interval` annotation for features from data/10k_PBMC_Multiome_nextgem_Chromium_X_filtered_feature_bc_matrix.h5


  data_mod.loc[:, colname] = col
  data_mod.loc[:, colname] = col


In [20]:
adata = mdata.mod["atac"]

print(adata.shape)
# compute the threshold: 5% of the cells
min_cells = int(adata.shape[0] * 0.05)
# in-place filtering of regions
sc.pp.filter_genes(adata, min_cells=min_cells)
print(adata.shape)

(10970, 111743)
(10970, 37054)


In [21]:
adata.var

Unnamed: 0,gene_ids,feature_types,genome,interval,n_cells
chr1:629395-630394,chr1:629395-630394,Peaks,GRCh38,chr1:629395-630394,1422
chr1:633578-634591,chr1:633578-634591,Peaks,GRCh38,chr1:633578-634591,4536
chr1:778283-779200,chr1:778283-779200,Peaks,GRCh38,chr1:778283-779200,5981
chr1:816873-817775,chr1:816873-817775,Peaks,GRCh38,chr1:816873-817775,564
chr1:827067-827949,chr1:827067-827949,Peaks,GRCh38,chr1:827067-827949,3150
...,...,...,...,...,...
GL000219.1:44739-45583,GL000219.1:44739-45583,Peaks,GRCh38,GL000219.1:44739-45583,781
GL000219.1:45726-46446,GL000219.1:45726-46446,Peaks,GRCh38,GL000219.1:45726-46446,639
GL000219.1:99267-100169,GL000219.1:99267-100169,Peaks,GRCh38,GL000219.1:99267-100169,6830
KI270726.1:41483-42332,KI270726.1:41483-42332,Peaks,GRCh38,KI270726.1:41483-42332,605
