# Try the example code snippet from the Readme

In [1]:
import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_segment_nt_model

# Initialize CPU as default JAX device. This makes the code robust to memory leakage on
# the devices.
jax.config.update("jax_platform_name", "cpu")

backend = "tpu"
devices = jax.devices(backend)
num_devices = len(devices)
print(f"Devices found: {devices}")

# The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by
# 2 to the power of the number of downsampling block, i.e 4.
max_num_nucleotides = 8

assert max_num_nucleotides % 4 == 0, (
    "The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by"
     "2 to the power of the number of downsampling block, i.e 4.")

parameters, forward_fn, tokenizer, config = get_pretrained_segment_nt_model(
    model_name="segment_nt",
    embeddings_layers_to_save=(29,),
    attention_maps_to_save=((1, 4), (7, 10)),
    max_positions=max_num_nucleotides + 1,
    # If the progress bar gets stuck at the start of the model wieghts download, 
    # you can set verbose=False to download without the progress bar.
    verbose=True
)
forward_fn = hk.transform(forward_fn)
apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))

random_key = jax.random.PRNGKey(seed=0)
keys = jax.device_put_replicated(random_key, devices=devices)
parameters = jax.device_put_replicated(parameters, devices=devices)

# Get data and tokenize it
sequences = ["ATTCCGATTCCGATTCCAACGGATTATTCCGATTAACCGATTCCAATT", "ATTTCTCTCTCTCTCTGAGATCGATGATTTCTCTCTCATCGAACTATG"]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.stack([jnp.asarray(tokens_ids, dtype=jnp.int32)]*num_devices, axis=0)
# tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

# Infer on the sequence
outs = apply_fn(parameters, keys, tokens)
# Obtain the logits over the genomic features
logits = outs["logits"]
# Transform them in probabilities
probabilities = jnp.asarray(jax.nn.softmax(logits, axis=-1))[...,-1]
print(f"Probabilities shape: {probabilities.shape}")

print(f"Features inferred: {config.features}")

# Get probabilities associated with intron
idx_intron = config.features.index("intron")
probabilities_intron = probabilities[..., idx_intron]
print(f"Intron probabilities shape: {probabilities_intron.shape}")

Devices found: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]
Downloading hyperparameters file...


/root/.cache/nucleotide_transformer/segment_nt/hyperparams.json: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.09k/1.09k [00:00<00:00, 4.47kB/s]


Downloading model weights...


/root/.cache/nucleotide_transformer/segment_nt/ckpt.joblib: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2.25G/2.25G [00:21<00:00, 103MB/s]


Model weights downloaded.


See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.


Probabilities shape: (4, 2, 48, 14)
Features inferred: ['protein_coding_gene', 'lncRNA', 'exon', 'intron', 'splice_donor', 'splice_acceptor', '5UTR', '3UTR', 'CTCF-bound', 'polyA_signal', 'enhancer_Tissue_specific', 'enhancer_Tissue_invariant', 'promoter_Tissue_specific', 'promoter_Tissue_invariant']
Intron probabilities shape: (4, 2, 48)


# Try messing up with the dataloader of SegmentNT

In [2]:
from Bio import SeqIO
import gzip

fasta_path = "data/Homo_sapiens.GRCh38.dna.chromosome.20.fa.gz"

with gzip.open(fasta_path, "rt") as handle:
    record = next(SeqIO.parse(handle, "fasta"))
    chr20 = str(record.seq)

In [3]:
max_num_nucleotides = 1668

idx_start = 2650520

idx_stop = idx_start + max_num_nucleotides*6

sequences = [chr20[idx_start:idx_stop]]

chr20[idx_start:idx_stop]

'TGAGAAAACCGAAGGAAACCTTTTAATGTTCATACCACATCCATATCTGATCACAAGCATTGGTTATCTCACCAGCAGAAGCACTTTTCTTCAGCCATCTCTGTAAGTCAGCAAACTACTTCATGACAGCTCCTTTGTAACCAGAGGACCCAAATTTTTCCAGAGTACGACTGGTTACCTTTTGGCTCATTGATTGACCCTCAGGTTTTTTATATTCTGATTTAACACACAAAAGTGTGTTCATATAGTTACATATATATGTATATATGAAAAATATTTGTGTACATATTTCTATATATGTATGTAGAGCTGAGACAAAAGTTTTACAATGAACAGCACCTTACTCTAAGCAACACACTCTTATTTCTATTCGCTTTAATCCTTTTTCTTTTTTTTTTTTTTTTTTTTTTTTTTGACAGAGCCTTGCTCTGTTGCCCAGGCTAGAGTGCAGCCTCTGCCTCCCGGGTTCAAGCGATTCTCCTTCCTCAGCCTCCCGAGTAGCTGGGATTACAGACATGCTCCACCACGTCCAGCTAATTTTTTTATTTTTTGGTAGAGACGGGGTTTCACCATGTTGGCCAGGCTGGTCTCGAACTCCTGATCTTAAGTGATCCGCCTGCCTCGGCCTCCTAAAGTGCTGGATTATAGCCCACCCAATCCTATTTTTTTAAATGCTGTCCATTAATGCATTCTGACTTCTTGCTTGAAAACCCCTGGTTTAGTGGATAAGCACCTGTAACTCCAGGAAGATTCAGGATTAAGGGCAGAAATAATGAAGTAAATTGAAGTATTAGCATTAGTATTTTCCATTACATTTTGGAATCGTCTATTTTGATGTATTCACGACGGTTAAAATAATTTAACATGCTAATGTATGGATTAACTTGGCAATTCCATTTTAAAATATTAAATGGAATAACGATTCTGACTCGTAAGTAGACATGTGTAACAGAATTCAACCACATCTGTAAACCATATTGAAATTATGAGAGTAAGATT

In [16]:
len(chr20)

64444167

In [15]:
248925475

248925475

# Try messing up with ATAC-seq dataset

In [121]:
!pip install --quiet h5py
!pip install --quiet muon
import muon
import scanpy as sc
import pandas as pd
from scipy.sparse import csr_matrix


[0m

In [135]:
# https://cf.10xgenomics.com/samples/cell-arc/2.0.0/10k_PBMC_Multiome_nextgem_Chromium_X/10k_PBMC_Multiome_nextgem_Chromium_X_filtered_feature_bc_matrix.h5
mdata = muon.read_10x_h5("data/10k_PBMC_Multiome_nextgem_Chromium_X_filtered_feature_bc_matrix.h5")
mdata

  utils.warn_names_duplicates("var")


Added `interval` annotation for features from data/10k_PBMC_Multiome_nextgem_Chromium_X_filtered_feature_bc_matrix.h5


  data_mod.loc[:, colname] = col
  data_mod.loc[:, colname] = col


In [136]:
adata = mdata.mod["atac"]

print(adata.shape)

# compute the threshold: 5% of the cells
min_cells = int(adata.shape[0] * 0.05)
# in-place filtering of regions
sc.pp.filter_genes(adata, min_cells=min_cells)
print(adata.shape)


(10970, 111743)
(10970, 37054)


In [137]:
adata.var

Unnamed: 0,gene_ids,feature_types,genome,interval,n_cells
chr1:629395-630394,chr1:629395-630394,Peaks,GRCh38,chr1:629395-630394,1422
chr1:633578-634591,chr1:633578-634591,Peaks,GRCh38,chr1:633578-634591,4536
chr1:778283-779200,chr1:778283-779200,Peaks,GRCh38,chr1:778283-779200,5981
chr1:816873-817775,chr1:816873-817775,Peaks,GRCh38,chr1:816873-817775,564
chr1:827067-827949,chr1:827067-827949,Peaks,GRCh38,chr1:827067-827949,3150
...,...,...,...,...,...
GL000219.1:44739-45583,GL000219.1:44739-45583,Peaks,GRCh38,GL000219.1:44739-45583,781
GL000219.1:45726-46446,GL000219.1:45726-46446,Peaks,GRCh38,GL000219.1:45726-46446,639
GL000219.1:99267-100169,GL000219.1:99267-100169,Peaks,GRCh38,GL000219.1:99267-100169,6830
KI270726.1:41483-42332,KI270726.1:41483-42332,Peaks,GRCh38,KI270726.1:41483-42332,605


In [138]:
print((adata.X == 1).sum())
print((adata.X == 2).sum())
adata.to_df()

7423800
48087908


Unnamed: 0,chr1:629395-630394,chr1:633578-634591,chr1:778283-779200,chr1:816873-817775,chr1:827067-827949,chr1:869477-870378,chr1:904364-905213,chr1:920762-921638,chr1:923415-924300,chr1:939990-940901,...,GL000205.2:88643-89473,GL000195.1:30402-31263,GL000195.1:32203-33046,GL000219.1:39927-40834,GL000219.1:42161-43034,GL000219.1:44739-45583,GL000219.1:45726-46446,GL000219.1:99267-100169,KI270726.1:41483-42332,KI270713.1:21453-22374
AAACAGCCAACAACAA-1,0.0,0.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,2.0,2.0,0.0,6.0,0.0,0.0,2.0,0.0,0.0
AAACAGCCACCGGCTA-1,0.0,2.0,4.0,0.0,0.0,2.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,6.0,0.0,0.0
AAACAGCCAGGACACA-1,0.0,0.0,0.0,0.0,0.0,2.0,0.0,2.0,0.0,0.0,...,0.0,0.0,2.0,0.0,2.0,0.0,0.0,2.0,0.0,0.0
AAACAGCCATCCTAGA-1,0.0,0.0,0.0,0.0,0.0,2.0,0.0,0.0,0.0,0.0,...,0.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0
AAACATGCAAAGGTAC-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,2.0,0.0,0.0,0.0,0.0,6.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTGTTCTAGCGTG-1,0.0,2.0,3.0,0.0,2.0,10.0,2.0,0.0,2.0,1.0,...,0.0,2.0,0.0,0.0,2.0,2.0,0.0,5.0,2.0,8.0
TTTGTTGGTAAGGTTT-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,...,0.0,2.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
TTTGTTGGTTAGGATT-1,0.0,0.0,6.0,0.0,3.0,0.0,0.0,0.0,2.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0,2.0
TTTGTTGGTTTGAGCA-1,0.0,0.0,8.0,0.0,2.0,0.0,2.0,2.0,4.0,2.0,...,0.0,0.0,0.0,0.0,4.0,0.0,0.0,6.0,0.0,2.0


In [140]:
def round_to_even_csr(csr_mat):
    # Access the data array of the CSR matrix
    data = csr_mat.data
    odd_data = data % 2 != 0
    data[odd_data] = data[odd_data] + 1
    data = data / 2
    return csr_matrix((data, csr_mat.indices, csr_mat.indptr), shape=csr_mat.shape)

adata.layers['fragments'] = round_to_even_csr(adata.X)

In [148]:
print((adata.layers['fragments'] == 1).sum())
print((adata.layers['fragments'] == 2).sum())
adata.to_df(layer="fragments")



55511708
17306236


Unnamed: 0,chr1:629395-630394,chr1:633578-634591,chr1:778283-779200,chr1:816873-817775,chr1:827067-827949,chr1:869477-870378,chr1:904364-905213,chr1:920762-921638,chr1:923415-924300,chr1:939990-940901,...,GL000205.2:88643-89473,GL000195.1:30402-31263,GL000195.1:32203-33046,GL000219.1:39927-40834,GL000219.1:42161-43034,GL000219.1:44739-45583,GL000219.1:45726-46446,GL000219.1:99267-100169,KI270726.1:41483-42332,KI270713.1:21453-22374
AAACAGCCAACAACAA-1,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,1.0,1.0,0.0,3.0,0.0,0.0,1.0,0.0,0.0
AAACAGCCACCGGCTA-1,0.0,1.0,2.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1.0,3.0,0.0,0.0
AAACAGCCAGGACACA-1,0.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,...,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0
AAACAGCCATCCTAGA-1,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2.0
AAACATGCAAAGGTAC-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,3.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTGTTCTAGCGTG-1,0.0,1.0,2.0,0.0,1.0,5.0,1.0,0.0,1.0,1.0,...,0.0,1.0,0.0,0.0,1.0,1.0,0.0,3.0,1.0,4.0
TTTGTTGGTAAGGTTT-1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,...,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
TTTGTTGGTTAGGATT-1,0.0,0.0,3.0,0.0,2.0,0.0,0.0,0.0,1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0
TTTGTTGGTTTGAGCA-1,0.0,0.0,4.0,0.0,1.0,0.0,1.0,1.0,2.0,1.0,...,0.0,0.0,0.0,0.0,2.0,0.0,0.0,3.0,0.0,1.0


In [128]:
split_interval = adata.var["gene_ids"].str.split(":", expand=True)
adata.var["chr"] = split_interval[0]
split_start_end = split_interval[1].str.split("-", expand=True)
adata.var["start"] = split_start_end[0].astype(int)
adata.var["end"] = split_start_end[1].astype(int)
adata.var

Unnamed: 0,gene_ids,feature_types,genome,interval,n_cells,chr,start,end
chr1:629395-630394,chr1:629395-630394,Peaks,GRCh38,chr1:629395-630394,1422,chr1,629395,630394
chr1:633578-634591,chr1:633578-634591,Peaks,GRCh38,chr1:633578-634591,4536,chr1,633578,634591
chr1:778283-779200,chr1:778283-779200,Peaks,GRCh38,chr1:778283-779200,5981,chr1,778283,779200
chr1:816873-817775,chr1:816873-817775,Peaks,GRCh38,chr1:816873-817775,564,chr1,816873,817775
chr1:827067-827949,chr1:827067-827949,Peaks,GRCh38,chr1:827067-827949,3150,chr1,827067,827949
...,...,...,...,...,...,...,...,...
GL000219.1:44739-45583,GL000219.1:44739-45583,Peaks,GRCh38,GL000219.1:44739-45583,781,GL000219.1,44739,45583
GL000219.1:45726-46446,GL000219.1:45726-46446,Peaks,GRCh38,GL000219.1:45726-46446,639,GL000219.1,45726,46446
GL000219.1:99267-100169,GL000219.1:99267-100169,Peaks,GRCh38,GL000219.1:99267-100169,6830,GL000219.1,99267,100169
KI270726.1:41483-42332,KI270726.1:41483-42332,Peaks,GRCh38,KI270726.1:41483-42332,605,KI270726.1,41483,42332


# Try to relate DNA and ATAC seq data

## Download and index full GRCh38 DNA dataset

In [18]:
# wget ftp://ftp.ensembl.org/pub/release-100/fasta/homo_sapiens/dna/Homo_sapiens.GRCh38.dna.primary_assembly.fa.gz
# gunzip Homo_sapiens.GRCh38.dna.primary_assembly.fa.gz

from pyfaidx import Fasta
file_path = 'data/Homo_sapiens.GRCh38.dna.primary_assembly.fa'

# Load the reference genome
fasta = Fasta(file_path)


In [65]:
from Bio import SeqIO
fasta_ids = [record.id for record in SeqIO.parse(file_path, "fasta")]


## See if all the chromosoes in ATAC seq data are present in DNA data

In [104]:
for chr in adata.var['chr'].unique():
    # fasta data stores chromosome without 'chr', so get rid of them
    new_chr = chr[3:] if chr[:3] == "chr" else chr
    if new_chr not in fasta_ids:
        raise ValueError(f"{new_chr} does not exist in dna fasta data")
    print("Chr: " + new_chr, "Seq length: " + str(len(fasta[new_chr][:])))


Chr: 1 Seq length: 248956422
Chr: 10 Seq length: 133797422
Chr: 11 Seq length: 135086622
Chr: 12 Seq length: 133275309
Chr: 13 Seq length: 114364328
Chr: 14 Seq length: 107043718
Chr: 15 Seq length: 101991189
Chr: 16 Seq length: 90338345
Chr: 17 Seq length: 83257441
Chr: 18 Seq length: 80373285
Chr: 19 Seq length: 58617616
Chr: 2 Seq length: 242193529
Chr: 20 Seq length: 64444167
Chr: 21 Seq length: 46709983
Chr: 22 Seq length: 50818468
Chr: 3 Seq length: 198295559
Chr: 4 Seq length: 190214555
Chr: 5 Seq length: 181538259
Chr: 6 Seq length: 170805979
Chr: 7 Seq length: 159345973
Chr: 8 Seq length: 145138636
Chr: 9 Seq length: 138394717
Chr: X Seq length: 156040895
Chr: Y Seq length: 57227415
Chr: KI270727.1 Seq length: 448248
Chr: GL000205.2 Seq length: 185591
Chr: GL000195.1 Seq length: 182896
Chr: GL000219.1 Seq length: 179198
Chr: KI270726.1 Seq length: 43739
Chr: KI270713.1 Seq length: 40745


## Try getting a DNA sequence for one of the peaks

In [118]:
peak = adata.var.iloc[7777]
peak_chr = peak.chr[3:] if peak.chr[:3] == "chr" else peak.chr
print("Sequence for " + peak.gene_ids)
fasta[peak_chr][peak.start:peak.end]

Sequence for chr12:14256174-14256924


>12:14256175-14256924
TGTGTGCCTTCCCACAATAACTGATGACTCTTAAATCTCACAAAAACAAATTCAACTGATGAATTCACGTCCTTTTAACATACATTTTGAAGATTATCTTAGCCCAAACCGGCACCCACATCCTTCTGAGCGTACTGACTCACTAGACAAAGCCAACCCCTCCCCTCCTGTGTACAAAAATAGCTTTTCCATATCTACTGTTGTAATTGTTTTTTTTCCAGGGCTTTATGCTCTACTGAAACTAAACACTTAAATGAGCAATTTCACTCTGGTTTTAAGCCCAGGGGAAATGACAACGGTTTTGTCTCTGTTTTGAGACTTAGGGGCAAATTTTAAATCGTTAATATATAATTAATCCTATGTAACATCATAGAAACAAAACATCTGATTCTTCTTTCAGTATATTGCACACGTAAAATGCACATGAAACCAGGTTTGGAGAGTGATTCTTAGTAACGTAGCCTTGCATTTGTGCTTCGACATAAAATTGAGACATTCTCAAGAGGCCTTGGGACATGGTGACATGTTGCGCAGTTTGGCCAGCAGAGGGAGACCAGAGAGTCGTGAAGCACGCAACGAAGCCACACCCAGATTAGGGTTATTATCTGTCTCTTCTCTGAATTTTCAAAAAACACAAAAAGAAAATGAGAAATTCTTAAAACAAAGCCCCTTTTTTATTGAAGCAAAATCCACATAACATAAAACTAGCCCTTCTAAAGTGTGCAATTAAGTGGCATCTACCTCATTCAC

# Messing with neurips 2021 data

In [None]:
import anndata as ad

mdata = ad.read("data/GSE194122/GSE194122_openproblems_neurips2021_multiome_BMMC_processed.h5ad")



In [None]:
mdata