In [1]:
import torch
from transformers import AutoTokenizer, AutoModel
from transformers.models.bert.configuration_bert import BertConfig
from Bio import SeqIO
import os
import gzip
from time import perf_counter_ns
from functools import wraps

In [2]:
def timeit(f):
    @wraps(f)
    def wrapper(*args, **kwargs):
        start = perf_counter_ns()
        result = f(*args, **kwargs)
        print(f"{f.__name__}: {(perf_counter_ns() - start) / 10 ** 6:.3f}ms")
        return result
    return wrapper

# Download tokenizer and pre-trained model from Huggingface

In [3]:
# run models on GPU if GPU node is used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
config = BertConfig.from_pretrained("zhihan1996/DNABERT-2-117M")

In [5]:
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)

In [6]:
model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True, config = config).to(device)

Some weights of BertModel were not initialized from the model checkpoint at zhihan1996/DNABERT-2-117M and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Load some genome from Iseult's NCBI data

In [7]:
ISEULT_GENOMES_PATH = "/no_backup/rg/ileahy/protists"

In [8]:
subdirs = os.listdir(ISEULT_GENOMES_PATH)

In [9]:
fastas = {i: f"{ISEULT_GENOMES_PATH}/{i}/{i}_genomic.fna.gz" for i in subdirs}

In [10]:
first_fasta = list(fastas.items())[0]
first_fasta

('Nitzschia_sp._DOCU1',
 '/no_backup/rg/ileahy/protists/Nitzschia_sp._DOCU1/Nitzschia_sp._DOCU1_genomic.fna.gz')

In [11]:
with gzip.open(first_fasta[1], 'rt') as f:
    seqs = list(SeqIO.parse(f, "fasta"))

In [12]:
chromosomes = [seq for seq in seqs if "chromosome" in seq.description]

In [13]:
len(chromosomes)

17

In [14]:
chr1 = chromosomes[0]
len(chr1.seq)

5856310

In [15]:
chunks = [chr1.seq[i:i+1000].upper() for i in range(0, len(chr1.seq), 1000)]
chunks[:10]

[Seq('ACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTA...CCC'),
 Seq('TAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCC...GGC'),
 Seq('AATCCCTGTTTCAATCATGTTTCAAGCCTTGTTGGAGTAGCGACCGACGAGATC...ATG'),
 Seq('CCCGGAGATCGACCTTTGTCTATCATCCACCAGCTGCGAACAACAATGTGCGGG...GAT'),
 Seq('GCCGGAGTCGTCGTCTTTCCTGCGACCAAGCTGGCCAAACGGAGTTCAAGACCC...TGC'),
 Seq('GTGGCAATCCTGATTTTGATTTTATTGCGAACTGGCCGAATATGGCAGGTATTG...TCG'),
 Seq('AGGACGCCACTTCCATCAAACTCGGAACTACCTTTTACAAGAACTTTGCCAAAC...AGC'),
 Seq('ACAAGAGTGCGGGACCACTGTTTGCGAAGAAAAATATGATACGCAAAACGGAAG...CGA'),
 Seq('TTGAGTGAAAAAGTGATGCATTTGTTGGAAGGAATGCGGAGATCCGAGCCGAGC...CTT'),
 Seq('ACTTTTGTTCTAGCTTCTGTGGTTACGTAAGCCAGACGACCGCAGCTTTCTCAC...GCC')]

In [16]:
@timeit
def embed_chunk(chunk):
    in_tokens = tokenizer(str(chunk), return_tensors = 'pt')["input_ids"].to(device)
    return model(in_tokens)[0]

In [18]:
chunk1_embed = embed_chunk(chunks[1])
print(chunk1_embed)
print(chunk1_embed.shape)

embed_chunk: 22.387ms
tensor([[[-0.3003,  0.1283, -0.0370,  ..., -0.0249,  0.1934,  0.1578],
         [ 0.0141,  0.4050,  0.0662,  ...,  0.3370,  0.0829,  0.2535],
         [-0.1699, -0.2702,  0.0422,  ...,  0.1190,  0.1071,  0.1282],
         ...,
         [-0.3597,  0.0272, -0.0936,  ..., -0.1072,  0.0013,  0.3744],
         [-0.0269,  0.3906,  0.1141,  ...,  0.4853,  0.2595,  0.0154],
         [-0.3557,  0.2510, -0.2928,  ...,  0.1711,  0.1063,  0.2774]]],
       device='cuda:0', grad_fn=<ViewBackward0>)
torch.Size([1, 207, 768])


In [67]:
embeds = [embed_chunk(chunk) for chunk in chunks[:10]]

embed_chunk: 19.518ms
embed_chunk: 18.291ms
embed_chunk: 18.005ms
embed_chunk: 18.173ms
embed_chunk: 18.445ms
embed_chunk: 18.235ms
embed_chunk: 18.359ms
embed_chunk: 18.081ms
embed_chunk: 18.046ms
embed_chunk: 18.348ms


In [19]:
tokens = [tokenizer(str(seq), return_tensors = 'pt')["input_ids"].to(device) for seq in chunks]

In [19]:
chunks[0]

Seq('ACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTA...CCC')

In [21]:
tokenizer.convert_ids_to_tokens([956])

['CCCTAA']

In [20]:
tokens[0]

tensor([[  1,   5, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956,
         956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956,
         956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956,
         956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956,
         956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956,
         956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956,
         956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956,
         956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956,
         956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956,
         956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956,
         956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956,
         956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956, 956,
          13,   6,   2]], device='cuda:0')

In [21]:
[token.shape for token in tokens][:20]

[torch.Size([1, 171]),
 torch.Size([1, 207]),
 torch.Size([1, 208]),
 torch.Size([1, 217]),
 torch.Size([1, 220]),
 torch.Size([1, 215]),
 torch.Size([1, 219]),
 torch.Size([1, 211]),
 torch.Size([1, 212]),
 torch.Size([1, 218]),
 torch.Size([1, 216]),
 torch.Size([1, 210]),
 torch.Size([1, 217]),
 torch.Size([1, 208]),
 torch.Size([1, 214]),
 torch.Size([1, 212]),
 torch.Size([1, 209]),
 torch.Size([1, 217]),
 torch.Size([1, 224]),
 torch.Size([1, 211])]

In [22]:
# failed attempt to concat tokens, causes crazy OOM in embedding
# chr1_tokens = torch.cat(tokens).unsqueeze(dim = 0)

In [23]:
# chr1_tokens.shape

In [24]:
len(tokens)

5857

In [25]:
chr1_embeds = [model(seq)[0] for seq in tokens[:10]]

In [26]:
chr1_embeds[1].shape

torch.Size([1, 207, 768])

In [27]:
chr1_embeds[2].shape

torch.Size([1, 208, 768])

In [28]:
torch.cat((chr1_embeds[0].squeeze(), chr1_embeds[1].squeeze())).shape

torch.Size([378, 768])

In [29]:
chr1_embeds[0]

tensor([[[-0.4234,  0.0769, -0.0990,  ..., -0.1416, -0.0718,  0.2727],
         [-0.0100,  0.1861, -0.1309,  ...,  0.3406,  0.2404,  0.0826],
         [-0.0395,  0.1606,  0.0387,  ..., -0.0656,  0.0599,  0.0561],
         ...,
         [-0.2256,  0.2786,  0.0561,  ...,  0.0396,  0.0825,  0.2168],
         [-0.1758,  0.0145, -0.3587,  ..., -0.1461,  0.1036,  0.2735],
         [-0.3720, -0.0226, -0.0982,  ..., -0.0650,  0.2470,  0.2372]]],
       device='cuda:0', grad_fn=<ViewBackward0>)

In [30]:
chr1_embeds[0].sum(dim=1) / chr1_embeds[0].shape[1]

tensor([[-2.9903e-02,  1.7587e-01, -1.0308e-01, -3.2150e-01, -2.0249e-01,
         -2.9341e-01,  3.7036e-02, -2.3100e-02, -2.5736e-01, -9.4389e-02,
          1.5219e-01,  3.6972e-03, -1.4181e-03, -3.2736e-01, -3.9237e-01,
         -1.9741e-01,  3.7392e-02,  2.4428e-01,  4.3173e-02, -6.9685e-02,
         -6.2383e-02,  2.0217e-01, -4.2199e-02,  6.1309e-03,  3.3523e-01,
          2.5090e-01, -1.8661e-01,  1.3052e-02,  1.5894e-01, -1.4834e-01,
          4.6660e-02,  2.4527e-01, -3.2059e-01, -8.2781e-02, -2.4202e-02,
          1.4567e-01,  6.6666e-02, -4.7602e-02, -2.0149e-01,  1.5304e-01,
         -3.6340e-01, -9.2600e-03,  2.6296e-02,  5.0430e-02,  1.9738e-01,
          4.2325e-01,  2.0773e-02,  2.6143e-01, -9.6659e-02, -2.3741e-01,
         -3.5322e-01, -1.0800e-01, -1.2454e-01, -3.7395e-03,  1.8385e-01,
          1.5250e-01,  2.9568e-02,  6.2159e-02,  9.0101e-02,  1.4352e-01,
          1.8046e-02, -2.6160e-01,  2.2883e-02,  1.0245e-01,  1.8389e-01,
         -1.4283e-01, -1.1928e-01,  8.

In [31]:
def avg_pool(x):
    x = x.sum(dim=1) / x.shape[1]
    return x

In [32]:
avg_pool(chr1_embeds[1])

tensor([[-4.3544e-02,  1.3329e-01,  6.0644e-02, -1.0409e-01, -1.6363e-01,
         -4.6934e-02, -3.7111e-02, -1.3833e-01, -1.6781e-02,  4.6372e-02,
         -8.1738e-02,  6.0620e-02, -2.9557e-02, -1.1225e-01, -1.8202e-03,
         -7.3436e-02, -8.3275e-02,  6.4418e-02, -4.8312e-02, -6.2021e-02,
          1.2440e-02,  9.3453e-02,  3.0047e-03, -2.0536e-02,  9.1666e-02,
          1.4577e-01,  2.4359e-01,  1.2961e-02, -4.2891e-02, -7.1301e-02,
         -4.7893e-02,  1.3746e-01, -6.7017e-02, -4.9176e-03,  3.0636e-03,
          1.9920e-02, -1.1824e-01, -3.7550e-02,  4.3702e-02, -3.7139e-02,
         -9.0138e-02, -2.8041e-02,  7.2973e-02,  6.9049e-03, -2.7204e-02,
          6.2942e-02, -8.7033e-02, -2.1834e-02, -1.2417e-01,  8.8413e-02,
          1.8708e-02, -1.0284e-01, -6.5928e-02,  1.7767e-02, -8.9727e-03,
          1.0673e-01,  1.4345e-01, -4.2982e-02, -9.9925e-02,  2.3258e-02,
         -1.9393e-02,  3.5774e-03, -3.2213e-02, -1.1069e-01,  1.1304e-01,
          7.6632e-02, -5.4433e-02, -4.

In [33]:
total_pool = torch.zeros(1, 768, device = "cpu")
total_count = 0
for i, seq in enumerate(tokens):
    if i % 500 == 0:
        print(i)
    embed = model(seq)[0]
    pool = avg_pool(embed)
    total_pool += pool.detach().to("cpu")
    total_count += 1
    del embed, pool
    torch.cuda.empty_cache()

0
500
1000
1500
2000
2500
3000
3500
4000
4500
5000
5500


In [34]:
chr1_avg = total_pool / total_count

In [35]:
chr1_avg.shape

torch.Size([1, 768])

In [36]:
def embed_chromosome(chromosome):
    # split into chunks to avoid model crash
    chunks = [chr1.seq[i:i+1000].upper() for i in range(0, len(chr1.seq), 1000)]
    # tokenize chunks
    tokens = [tokenizer(str(seq), return_tensors = 'pt')["input_ids"].to(device) for seq in chunks]
    # avg-pool the chunks
    print(f"Starting embed for chromosome {chromosome.id}")
    total_pool = torch.zeros(1, 768, device = "cpu")
    total_count = 0
    for i, seq in enumerate(tokens):
        embed = model(seq)[0]
        pool = avg_pool(embed)
        total_pool += pool.detach().to("cpu")
        total_count += 1
        del embed, pool
        torch.cuda.empty_cache()
    print(f"Finished for {chromosome.id}")
    return total_pool / total_count

In [37]:
embed_chromosome(chromosomes[1])

Starting embed for chromosome OX595986.1
Finished for OX595986.1


tensor([[-5.8525e-02,  1.4755e-01,  6.6125e-02, -9.5628e-02, -9.8391e-02,
         -2.7134e-02, -8.0107e-02, -1.2738e-01,  5.4010e-03,  2.5708e-02,
         -1.3374e-01,  4.0950e-02,  2.4570e-02, -7.7311e-02,  1.8789e-02,
         -3.9125e-02, -8.6143e-02,  7.9225e-02, -5.1121e-02, -9.1054e-02,
          2.6318e-02,  8.3090e-02, -7.7810e-03,  2.5229e-03,  7.9415e-02,
          7.2432e-02,  3.3866e-01, -6.4904e-03, -4.2798e-02, -4.8660e-02,
         -2.9299e-02,  9.9109e-02, -9.9972e-02, -9.6909e-03, -6.0762e-02,
          2.0392e-02, -9.6337e-02, -5.2213e-02,  2.3080e-02, -3.0274e-02,
         -8.2717e-02, -1.0120e-02,  1.5146e-04, -1.8882e-02, -2.2681e-02,
          5.4338e-02, -9.4868e-02, -4.2479e-02, -8.1208e-02,  1.3714e-01,
          1.7789e-02, -7.2000e-02, -7.4470e-02,  6.0821e-02, -2.0621e-02,
          1.0121e-01,  1.1914e-01, -7.7697e-02, -1.0392e-01,  9.0446e-03,
         -8.0322e-04,  8.4762e-03, -1.7165e-02, -1.4053e-01,  7.9212e-02,
          7.3658e-02, -4.8468e-02, -4.

In [38]:
chr_embeds = [embed_chromosome(i) for i in chromosomes]

Starting embed for chromosome OX595985.1
Finished for OX595985.1
Starting embed for chromosome OX595986.1
Finished for OX595986.1
Starting embed for chromosome OX595987.1
Finished for OX595987.1
Starting embed for chromosome OX595988.1
Finished for OX595988.1
Starting embed for chromosome OX595989.1
Finished for OX595989.1
Starting embed for chromosome OX595990.1
Finished for OX595990.1
Starting embed for chromosome OX595991.1
Finished for OX595991.1
Starting embed for chromosome OX595992.1
Finished for OX595992.1
Starting embed for chromosome OX595993.1
Finished for OX595993.1
Starting embed for chromosome OX595994.1
Finished for OX595994.1
Starting embed for chromosome OX595995.1
Finished for OX595995.1
Starting embed for chromosome OX595996.1
Finished for OX595996.1
Starting embed for chromosome OX595997.1
Finished for OX595997.1
Starting embed for chromosome OX595998.1
Finished for OX595998.1
Starting embed for chromosome OX595999.1
Finished for OX595999.1
Starting embed for chromo

In [39]:
len(chr_embeds)

17

In [40]:
genome_embed = sum(chr_embeds) / len(chr_embeds)

In [41]:
genome_embed.shape

torch.Size([1, 768])

In [42]:
def embed_genome(name, path):
    print(f"Processing genome {name}")
    with gzip.open(first_fasta[1], 'rt') as f:
        seqs = list(SeqIO.parse(f, "fasta"))
    chromosomes = [seq for seq in seqs if "chromosome" in seq.description]
    chr_embeds = [embed_chromosome(i) for i in chromosomes]
    print(f"Finished genome {name}")
    return sum(chr_embeds) / len(chr_embeds)

In [44]:
name2, path2 = list(fastas.items())[1]

In [None]:
genome2_embed = embed_genome(name2, path2)

Processing genome Neospora_caninum_Liverpool
Starting embed for chromosome OX595985.1
Finished for OX595985.1
Starting embed for chromosome OX595986.1
Finished for OX595986.1
Starting embed for chromosome OX595987.1
Finished for OX595987.1
Starting embed for chromosome OX595988.1
Finished for OX595988.1
Starting embed for chromosome OX595989.1
Finished for OX595989.1
Starting embed for chromosome OX595990.1
Finished for OX595990.1
Starting embed for chromosome OX595991.1
Finished for OX595991.1
Starting embed for chromosome OX595992.1
Finished for OX595992.1
Starting embed for chromosome OX595993.1
Finished for OX595993.1
Starting embed for chromosome OX595994.1
Finished for OX595994.1
Starting embed for chromosome OX595995.1
Finished for OX595995.1
Starting embed for chromosome OX595996.1
Finished for OX595996.1
Starting embed for chromosome OX595997.1
Finished for OX595997.1
Starting embed for chromosome OX595998.1
Finished for OX595998.1
Starting embed for chromosome OX595999.1
