# Try the example code snippet from the Readme

In [2]:
import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_segment_nt_model

# Initialize CPU as default JAX device. This makes the code robust to memory leakage on
# the devices.
jax.config.update("jax_platform_name", "cpu")

backend = "tpu"
devices = jax.devices(backend)
num_devices = len(devices)
print(f"Devices found: {devices}")

# The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by
# 2 to the power of the number of downsampling block, i.e 4.
max_num_nucleotides = 8

assert max_num_nucleotides % 4 == 0, (
    "The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by"
     "2 to the power of the number of downsampling block, i.e 4.")

parameters, forward_fn, tokenizer, config = get_pretrained_segment_nt_model(
    model_name="segment_nt",
    embeddings_layers_to_save=(29,),
    attention_maps_to_save=((1, 4), (7, 10)),
    max_positions=max_num_nucleotides + 1,
    # If the progress bar gets stuck at the start of the model wieghts download, 
    # you can set verbose=False to download without the progress bar.
    verbose=True
)
forward_fn = hk.transform(forward_fn)
apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))

random_key = jax.random.PRNGKey(seed=0)
keys = jax.device_put_replicated(random_key, devices=devices)
parameters = jax.device_put_replicated(parameters, devices=devices)

# Get data and tokenize it
sequences = ["ATTCCGATTCCGATTCCAACGGATTATTCCGATTAACCGATTCCAATT", "ATTTCTCTCTCTCTCTGAGATCGATGATTTCTCTCTCATCGAACTATG"]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.stack([jnp.asarray(tokens_ids, dtype=jnp.int32)]*num_devices, axis=0)
# tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

# Infer on the sequence
outs = apply_fn(parameters, keys, tokens)
# Obtain the logits over the genomic features
logits = outs["logits"]
# Transform them in probabilities
probabilities = jnp.asarray(jax.nn.softmax(logits, axis=-1))[...,-1]
print(f"Probabilities shape: {probabilities.shape}")

print(f"Features inferred: {config.features}")

# Get probabilities associated with intron
idx_intron = config.features.index("intron")
probabilities_intron = probabilities[..., idx_intron]
print(f"Intron probabilities shape: {probabilities_intron.shape}")

Devices found: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]
Downloading hyperparameters file...


/root/.cache/nucleotide_transformer/segment_nt/hyperparams.json: 100%|███████████████████████████████████████████████████████████████████████████████████████| 1.09k/1.09k [00:00<00:00, 4.44kB/s]


Downloading model weights...


/root/.cache/nucleotide_transformer/segment_nt/ckpt.joblib: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 2.25G/2.25G [00:27<00:00, 80.7MB/s]


Model weights downloaded.


See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.


Probabilities shape: (4, 2, 48, 14)
Features inferred: ['protein_coding_gene', 'lncRNA', 'exon', 'intron', 'splice_donor', 'splice_acceptor', '5UTR', '3UTR', 'CTCF-bound', 'polyA_signal', 'enhancer_Tissue_specific', 'enhancer_Tissue_invariant', 'promoter_Tissue_specific', 'promoter_Tissue_invariant']
Intron probabilities shape: (4, 2, 48)


# Try messing up with the dataloader of SegmentNT

In [3]:
from Bio import SeqIO
import gzip

fasta_path = "data/Homo_sapiens.GRCh38.dna.chromosome.20.fa.gz"

with gzip.open(fasta_path, "rt") as handle:
    record = next(SeqIO.parse(handle, "fasta"))
    chr20 = str(record.seq)

In [4]:
max_num_nucleotides = 1668

idx_start = 2650520

idx_stop = idx_start + max_num_nucleotides*6

sequences = [chr20[idx_start:idx_stop]]

chr20[idx_start:idx_stop]

'TGAGAAAACCGAAGGAAACCTTTTAATGTTCATACCACATCCATATCTGATCACAAGCATTGGTTATCTCACCAGCAGAAGCACTTTTCTTCAGCCATCTCTGTAAGTCAGCAAACTACTTCATGACAGCTCCTTTGTAACCAGAGGACCCAAATTTTTCCAGAGTACGACTGGTTACCTTTTGGCTCATTGATTGACCCTCAGGTTTTTTATATTCTGATTTAACACACAAAAGTGTGTTCATATAGTTACATATATATGTATATATGAAAAATATTTGTGTACATATTTCTATATATGTATGTAGAGCTGAGACAAAAGTTTTACAATGAACAGCACCTTACTCTAAGCAACACACTCTTATTTCTATTCGCTTTAATCCTTTTTCTTTTTTTTTTTTTTTTTTTTTTTTTTGACAGAGCCTTGCTCTGTTGCCCAGGCTAGAGTGCAGCCTCTGCCTCCCGGGTTCAAGCGATTCTCCTTCCTCAGCCTCCCGAGTAGCTGGGATTACAGACATGCTCCACCACGTCCAGCTAATTTTTTTATTTTTTGGTAGAGACGGGGTTTCACCATGTTGGCCAGGCTGGTCTCGAACTCCTGATCTTAAGTGATCCGCCTGCCTCGGCCTCCTAAAGTGCTGGATTATAGCCCACCCAATCCTATTTTTTTAAATGCTGTCCATTAATGCATTCTGACTTCTTGCTTGAAAACCCCTGGTTTAGTGGATAAGCACCTGTAACTCCAGGAAGATTCAGGATTAAGGGCAGAAATAATGAAGTAAATTGAAGTATTAGCATTAGTATTTTCCATTACATTTTGGAATCGTCTATTTTGATGTATTCACGACGGTTAAAATAATTTAACATGCTAATGTATGGATTAACTTGGCAATTCCATTTTAAAATATTAAATGGAATAACGATTCTGACTCGTAAGTAGACATGTGTAACAGAATTCAACCACATCTGTAAACCATATTGAAATTATGAGAGTAAGATT

In [5]:
len(chr20)

64444167

# Try messing up with ATAC-seq dataset

In [6]:
!pip install h5py
!pip install muon
import muon
import scanpy as sc


[0m

In [21]:
# https://cf.10xgenomics.com/samples/cell-arc/2.0.0/10k_PBMC_Multiome_nextgem_Chromium_X/10k_PBMC_Multiome_nextgem_Chromium_X_filtered_feature_bc_matrix.h5
mdata = muon.read_10x_h5("data/10k_PBMC_Multiome_nextgem_Chromium_X_filtered_feature_bc_matrix.h5")
mdata

  utils.warn_names_duplicates("var")


Added `interval` annotation for features from data/10k_PBMC_Multiome_nextgem_Chromium_X_filtered_feature_bc_matrix.h5


  data_mod.loc[:, colname] = col
  data_mod.loc[:, colname] = col


In [8]:
adata = mdata.mod["atac"]

print(adata.shape)
# compute the threshold: 5% of the cells
min_cells = int(adata.shape[0] * 0.05)
# in-place filtering of regions
sc.pp.filter_genes(adata, min_cells=min_cells)
print(adata.shape)

(10970, 111743)
(10970, 37054)


In [9]:
adata.to_df()

Unnamed: 0,chr1:629395-630394,chr1:633578-634591,chr1:778283-779200,chr1:816873-817775,chr1:827067-827949,chr1:869477-870378,chr1:904364-905213,chr1:920762-921638,chr1:923415-924300,chr1:939990-940901,...,GL000205.2:88643-89473,GL000195.1:30402-31263,GL000195.1:32203-33046,GL000219.1:39927-40834,GL000219.1:42161-43034,GL000219.1:44739-45583,GL000219.1:45726-46446,GL000219.1:99267-100169,KI270726.1:41483-42332,KI270713.1:21453-22374
AAACAGCCAACAACAA-1,0.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,2.0,2.0,0.0,6.0,0.0,0.0,2.0,0.0,0.0
AAACAGCCACCGGCTA-1,0.0,2.0,4.0,0.0,0.0,2.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,6.0,0.0,0.0
AAACAGCCAGGACACA-1,0.0,0.0,0.0,0.0,0.0,2.0,0.0,2.0,0.0,0.0,...,0.0,0.0,2.0,0.0,2.0,0.0,0.0,2.0,0.0,0.0
AAACAGCCATCCTAGA-1,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,0.0,0.0,...,0.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0
AAACATGCAAAGGTAC-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,2.0,0.0,0.0,0.0,0.0,6.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTGTTCTAGCGTG-1,0.0,2.0,3.0,0.0,2.0,10.0,2.0,0.0,2.0,1.0,...,0.0,2.0,0.0,0.0,2.0,2.0,0.0,5.0,2.0,8.0
TTTGTTGGTAAGGTTT-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,...,0.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
TTTGTTGGTTAGGATT-1,0.0,0.0,6.0,0.0,3.0,0.0,0.0,0.0,2.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,2.0
TTTGTTGGTTTGAGCA-1,0.0,0.0,8.0,0.0,2.0,0.0,2.0,2.0,4.0,2.0,...,0.0,0.0,0.0,0.0,4.0,0.0,0.0,6.0,0.0,2.0


In [14]:
import anndata as ad

mdata = ad.read("data/GSE194122/GSE194122_openproblems_neurips2021_multiome_BMMC_processed.h5ad")



In [18]:
mdata

AnnData object with n_obs × n_vars = 69249 × 129921
    obs: 'GEX_pct_counts_mt', 'GEX_n_counts', 'GEX_n_genes', 'GEX_size_factors', 'GEX_phase', 'ATAC_nCount_peaks', 'ATAC_atac_fragments', 'ATAC_reads_in_peaks_frac', 'ATAC_blacklist_fraction', 'ATAC_nucleosome_signal', 'cell_type', 'batch', 'ATAC_pseudotime_order', 'GEX_pseudotime_order', 'Samplename', 'Site', 'DonorNumber', 'Modality', 'VendorLot', 'DonorID', 'DonorAge', 'DonorBMI', 'DonorBloodType', 'DonorRace', 'Ethnicity', 'DonorGender', 'QCMeds', 'DonorSmoker'
    var: 'feature_types', 'gene_id'
    uns: 'ATAC_gene_activity_var_names', 'dataset_id', 'genome', 'organism'
    obsm: 'ATAC_gene_activity', 'ATAC_lsi_full', 'ATAC_lsi_red', 'ATAC_umap', 'GEX_X_pca', 'GEX_X_umap'
    layers: 'counts'