In [None]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
from utils import generate_random_dna_sequence, get_device, remove_special_tokens_and_padding, upsample
import torch

EMBEDDER_PATH = 'InstaDeepAI/nucleotide-transformer-2.5b-1000g'

PADDING_VALUE = -100

device = get_device()

Using device: mps


In [2]:
model = AutoModelForMaskedLM.from_pretrained(EMBEDDER_PATH).eval().to(device)
tokenizer = AutoTokenizer.from_pretrained(EMBEDDER_PATH)

  torch.utils._pytree._register_pytree_node(


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [3]:
sequences = [generate_random_dna_sequence(min_length=5, max_length=15) for _ in range(10)]
sequences

['CTTTTCTCAGCGG',
 'CTTGTCTGGAGT',
 'GGTGGGCCGGCGG',
 'GATCCT',
 'CTAGCACGTGCAA',
 'TATTGAT',
 'AGTAGGGG',
 'TCCCGGCACTA',
 'ACGGAACTGATG',
 'GACCTGC']

#### Tokenise sequences

In [4]:
output = tokenizer(
    sequences,
    return_tensors="pt",
    return_token_type_ids=False,
    padding="longest",
)

input_ids = output["input_ids"]
attention_mask = output["attention_mask"]

input_ids

tensor([[   3, 2394, 1599, 4103,    1,    1,    1],
        [   3, 2426, 2001,    1,    1,    1,    1],
        [   3, 3971, 2815, 4103,    1,    1,    1],
        [   3, 3181,    1,    1,    1,    1,    1],
        [   3, 2364, 2940, 4100,    1,    1,    1],
        [   3, 1120, 4101,    1,    1,    1,    1],
        [   3,  851, 4103, 4103,    1,    1,    1],
        [   3, 1715, 4102, 4100, 4102, 4101, 4100],
        [   3,  756, 2507,    1,    1,    1,    1],
        [   3, 3243, 4102,    1,    1,    1,    1]])

In [5]:
for ids in input_ids:
    print(tokenizer.convert_ids_to_tokens(ids, skip_special_tokens=False))

['<cls>', 'CTTTTC', 'TCAGCG', 'G', '<pad>', '<pad>', '<pad>']
['<cls>', 'CTTGTC', 'TGGAGT', '<pad>', '<pad>', '<pad>', '<pad>']
['<cls>', 'GGTGGG', 'CCGGCG', 'G', '<pad>', '<pad>', '<pad>']
['<cls>', 'GATCCT', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
['<cls>', 'CTAGCA', 'CGTGCA', 'A', '<pad>', '<pad>', '<pad>']
['<cls>', 'TATTGA', 'T', '<pad>', '<pad>', '<pad>', '<pad>']
['<cls>', 'AGTAGG', 'G', 'G', '<pad>', '<pad>', '<pad>']
['<cls>', 'TCCCGG', 'C', 'A', 'C', 'T', 'A']
['<cls>', 'ACGGAA', 'CTGATG', '<pad>', '<pad>', '<pad>', '<pad>']
['<cls>', 'GACCTG', 'C', '<pad>', '<pad>', '<pad>', '<pad>']


In [6]:
attention_mask

tensor([[1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0]])

#### Embed Sequences

In [12]:
embeddings = model(input_ids=input_ids.to(device), attention_mask=attention_mask.to(device), output_hidden_states=True,)["hidden_states"][-1].detach().cpu()
embeddings.size()

torch.Size([10, 7, 2560])

In [17]:
embeddings[:, -1, -1]

tensor([-0.3127, -0.8532,  0.4219, -0.5247, -0.3778, -0.0778, -0.5626, -0.2562,
        -0.3813, -0.6912])

#### Upsample

In [None]:
upsampled_embeddings = []

for ids, emb in zip(input_ids, embeddings):
    print('---')
    print('Embedding size: ', emb.size())
    masked_emb = remove_special_tokens_and_padding(tokenizer, ids, emb)
    print('Masked embedding size: ', masked_emb.size())
    masked_emb = upsample(tokenizer, ids, masked_emb)
    print('Upsampled embedding size: ', masked_emb.size())
    upsampled_embeddings.append(masked_emb)

upsampled_embeddings = torch.nn.utils.rnn.pad_sequence(
        upsampled_embeddings, batch_first=True, padding_value=PADDING_VALUE)
upsampled_embeddings.size()

---
Embedding size:  torch.Size([7, 2560])
Masked embedding size:  torch.Size([3, 2560])
Upsampled embedding size:  torch.Size([13, 2560])
---
Embedding size:  torch.Size([7, 2560])
Masked embedding size:  torch.Size([2, 2560])
Upsampled embedding size:  torch.Size([12, 2560])
---
Embedding size:  torch.Size([7, 2560])
Masked embedding size:  torch.Size([3, 2560])
Upsampled embedding size:  torch.Size([13, 2560])
---
Embedding size:  torch.Size([7, 2560])
Masked embedding size:  torch.Size([1, 2560])
Upsampled embedding size:  torch.Size([6, 2560])
---
Embedding size:  torch.Size([7, 2560])
Masked embedding size:  torch.Size([3, 2560])
Upsampled embedding size:  torch.Size([13, 2560])
---
Embedding size:  torch.Size([7, 2560])
Masked embedding size:  torch.Size([2, 2560])
Upsampled embedding size:  torch.Size([7, 2560])
---
Embedding size:  torch.Size([7, 2560])
Masked embedding size:  torch.Size([3, 2560])
Upsampled embedding size:  torch.Size([8, 2560])
---
Embedding size:  torch.Siz

torch.Size([10, 13, 2560])

In [16]:
upsampled_embeddings[:, -1, -1]

tensor([-2.2402e-01, -1.0000e+02, -3.7545e-02, -1.0000e+02, -3.3865e-01,
        -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02, -1.0000e+02])