In [22]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
from utils import generate_random_dna_sequence, get_device, process_chunk_embeddings, chunkify_sequences
import torch
import numpy as np

EMBEDDER_PATH = 'InstaDeepAI/nucleotide-transformer-2.5b-1000g'

PADDING_VALUE = -100

device = get_device()

Using device: mps


In [2]:
model = AutoModelForMaskedLM.from_pretrained(EMBEDDER_PATH).eval().to(device)
tokenizer = AutoTokenizer.from_pretrained(EMBEDDER_PATH)

  torch.utils._pytree._register_pytree_node(


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [3]:
sequences = [generate_random_dna_sequence(min_length=5, max_length=15) for _ in range(10)]
sequences

['AGGATCAGCA',
 'CTTGGTCCGCCACG',
 'CTTGCCC',
 'AAACTTCGAAC',
 'TTTAT',
 'TCTGAGT',
 'ACATAGGGAT',
 'AAAATTG',
 'TTCGGGTCAACGGT',
 'ACAGTGGT']

#### Chunkify sequences

In [5]:
MAX_MODEL_LENGTH = 4
chunked_sequences, chunk_ids = chunkify_sequences(sequences, MAX_MODEL_LENGTH)

for seq, chunk_id in zip(chunked_sequences, chunk_ids):
    print(f"Chunk ID: {chunk_id}, Sequence: {seq}")

Chunk ID: 0, Sequence: AGGA
Chunk ID: 0, Sequence: TCAG
Chunk ID: 0, Sequence: CA
Chunk ID: 1, Sequence: CTTG
Chunk ID: 1, Sequence: GTCC
Chunk ID: 1, Sequence: GCCA
Chunk ID: 1, Sequence: CG
Chunk ID: 2, Sequence: CTTG
Chunk ID: 2, Sequence: CCC
Chunk ID: 3, Sequence: AAAC
Chunk ID: 3, Sequence: TTCG
Chunk ID: 3, Sequence: AAC
Chunk ID: 4, Sequence: TTTA
Chunk ID: 4, Sequence: T
Chunk ID: 5, Sequence: TCTG
Chunk ID: 5, Sequence: AGT
Chunk ID: 6, Sequence: ACAT
Chunk ID: 6, Sequence: AGGG
Chunk ID: 6, Sequence: AT
Chunk ID: 7, Sequence: AAAA
Chunk ID: 7, Sequence: TTG
Chunk ID: 8, Sequence: TTCG
Chunk ID: 8, Sequence: GGTC
Chunk ID: 8, Sequence: AACG
Chunk ID: 8, Sequence: GT
Chunk ID: 9, Sequence: ACAG
Chunk ID: 9, Sequence: TGGT


#### Tokenise sequences

In [16]:
output = tokenizer(
    chunked_sequences,
    return_tensors="pt",
    return_token_type_ids=False,
    padding="longest",
)

input_ids = output["input_ids"]
attention_mask = output["attention_mask"]

input_ids

tensor([[   3, 4100, 4103, 4103, 4100],
        [   3, 4101, 4102, 4100, 4103],
        [   3, 4102, 4100,    1,    1],
        [   3, 4102, 4101, 4101, 4103],
        [   3, 4103, 4101, 4102, 4102],
        [   3, 4103, 4102, 4102, 4100],
        [   3, 4102, 4103,    1,    1],
        [   3, 4102, 4101, 4101, 4103],
        [   3, 4102, 4102, 4102,    1],
        [   3, 4100, 4100, 4100, 4102],
        [   3, 4101, 4101, 4102, 4103],
        [   3, 4100, 4100, 4102,    1],
        [   3, 4101, 4101, 4101, 4100],
        [   3, 4101,    1,    1,    1],
        [   3, 4101, 4102, 4101, 4103],
        [   3, 4100, 4103, 4101,    1],
        [   3, 4100, 4102, 4100, 4101],
        [   3, 4100, 4103, 4103, 4103],
        [   3, 4100, 4101,    1,    1],
        [   3, 4100, 4100, 4100, 4100],
        [   3, 4101, 4101, 4103,    1],
        [   3, 4101, 4101, 4102, 4103],
        [   3, 4103, 4103, 4101, 4102],
        [   3, 4100, 4100, 4102, 4103],
        [   3, 4103, 4101,    1,    1],


In [17]:
for ids in input_ids:
    print(tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=False))

['<cls>', 'A', 'G', 'G', 'A']
['<cls>', 'T', 'C', 'A', 'G']
['<cls>', 'C', 'A', '<pad>', '<pad>']
['<cls>', 'C', 'T', 'T', 'G']
['<cls>', 'G', 'T', 'C', 'C']
['<cls>', 'G', 'C', 'C', 'A']
['<cls>', 'C', 'G', '<pad>', '<pad>']
['<cls>', 'C', 'T', 'T', 'G']
['<cls>', 'C', 'C', 'C', '<pad>']
['<cls>', 'A', 'A', 'A', 'C']
['<cls>', 'T', 'T', 'C', 'G']
['<cls>', 'A', 'A', 'C', '<pad>']
['<cls>', 'T', 'T', 'T', 'A']
['<cls>', 'T', '<pad>', '<pad>', '<pad>']
['<cls>', 'T', 'C', 'T', 'G']
['<cls>', 'A', 'G', 'T', '<pad>']
['<cls>', 'A', 'C', 'A', 'T']
['<cls>', 'A', 'G', 'G', 'G']
['<cls>', 'A', 'T', '<pad>', '<pad>']
['<cls>', 'A', 'A', 'A', 'A']
['<cls>', 'T', 'T', 'G', '<pad>']
['<cls>', 'T', 'T', 'C', 'G']
['<cls>', 'G', 'G', 'T', 'C']
['<cls>', 'A', 'A', 'C', 'G']
['<cls>', 'G', 'T', '<pad>', '<pad>']
['<cls>', 'A', 'C', 'A', 'G']
['<cls>', 'T', 'G', 'G', 'T']


In [18]:
attention_mask

tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]])

#### Embed Sequences

In [19]:
embeddings = model(input_ids=input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True,)["hidden_states"][-1].detach().cpu().numpy()
input_ids = input_ids.numpy()
embeddings.shape

(27, 5, 2560)

In [20]:
embeddings[:, -1, -1]

array([-1.3886814e-01,  4.3697691e-01, -8.8939840e-01,  9.9667683e-03,
       -6.8073377e-02, -2.1069090e-01, -5.8168566e-01,  9.9667683e-03,
       -1.4776534e-01,  2.7478859e-01,  4.4103006e-01, -5.8178532e-01,
        2.4884179e-01, -2.9568106e-01,  1.4081374e-01, -6.2673521e-01,
        2.8391942e-01, -2.2443391e-02, -5.6441110e-01,  6.5617457e-02,
       -5.0080180e-01,  4.4103006e-01, -2.9559991e-01, -4.8437916e-02,
       -7.2403973e-01,  4.3183891e-04,  3.2815367e-01], dtype=float32)

#### Upsample

In [23]:
sequence_embeddings, masked_tokens = process_chunk_embeddings(tokenizer, embeddings, input_ids, chunk_ids, upsample=True)

for i, seq_emb in enumerate(sequence_embeddings):
    assert sequences[i] == ''.join(masked_tokens[i]), f"Mismatch in sequence {i}: {sequences[i]} != {''.join(masked_tokens[i])}"
    print(f"Chunk ID: {np.unique(chunk_ids)[i]}")
    print(f"Sequence: {sequences[i]}  Length: {len(sequences[i])}")
    print(f"  Tokens: {''.join(masked_tokens[i])}")
    print(f"Embedding shape: {seq_emb.shape}")

Chunk ID: 0
Sequence: AGGATCAGCA  Length: 10
  Tokens: AGGATCAGCA
Embedding shape: (10, 2560)
Chunk ID: 1
Sequence: CTTGGTCCGCCACG  Length: 14
  Tokens: CTTGGTCCGCCACG
Embedding shape: (14, 2560)
Chunk ID: 2
Sequence: CTTGCCC  Length: 7
  Tokens: CTTGCCC
Embedding shape: (7, 2560)
Chunk ID: 3
Sequence: AAACTTCGAAC  Length: 11
  Tokens: AAACTTCGAAC
Embedding shape: (11, 2560)
Chunk ID: 4
Sequence: TTTAT  Length: 5
  Tokens: TTTAT
Embedding shape: (5, 2560)
Chunk ID: 5
Sequence: TCTGAGT  Length: 7
  Tokens: TCTGAGT
Embedding shape: (7, 2560)
Chunk ID: 6
Sequence: ACATAGGGAT  Length: 10
  Tokens: ACATAGGGAT
Embedding shape: (10, 2560)
Chunk ID: 7
Sequence: AAAATTG  Length: 7
  Tokens: AAAATTG
Embedding shape: (7, 2560)
Chunk ID: 8
Sequence: TTCGGGTCAACGGT  Length: 14
  Tokens: TTCGGGTCAACGGT
Embedding shape: (14, 2560)
Chunk ID: 9
Sequence: ACAGTGGT  Length: 8
  Tokens: ACAGTGGT
Embedding shape: (8, 2560)


In [16]:
upsampled_embeddings[:, -1, -1]

tensor([-2.2402e-01, -1.0000e+02, -3.7545e-02, -1.0000e+02, -3.3865e-01,
        -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02])