### Author.... Merritt Khaipho-Burch
### Contact... mbb262@cornell.edu
### Date...... 2023-06-05
### Updated... 2023-06-05


### Description:
Set up nucleotide transformer model from hugging face, test it out

In [1]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
#import transformers

In [2]:
# Import the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-2.5b-multi-species")
model = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-2.5b-multi-species")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
# Create a dummy dna sequence and tokenize it
sequences = ['ATTCTG' * 9]
tokens_ids = tokenizer.batch_encode_plus(sequences, return_tensors="pt")["input_ids"]

In [4]:
# Compute the embeddings
attention_mask = tokens_ids != tokenizer.pad_token_id
torch_outs = model(
    tokens_ids,
    attention_mask=attention_mask,
    encoder_attention_mask=attention_mask,
    output_hidden_states=True
)

# Compute sequences embeddings
embeddings = torch_outs['hidden_states'][-1].detach().numpy()
print(f"Embeddings shape: {embeddings.shape}")
print(f"Embeddings per token: {embeddings}")

Embeddings shape: (1, 10, 2560)
Embeddings per token: [[[-0.04825445  0.2141174   0.07392832 ...  0.05151255  0.01025415
    0.21558578]
  [-0.3665796   0.19614863  0.1792731  ...  0.42881376  0.10353133
   -0.07980135]
  [ 0.3379805   0.16960776 -0.03771794 ...  0.7363802  -0.5191736
    0.2142746 ]
  ...
  [ 0.309885    0.47271204 -0.08401428 ...  0.72108525 -0.69447416
    0.08498987]
  [ 0.19159357  0.5292669  -0.13266225 ...  0.7349562  -0.63323426
    0.02344907]
  [ 0.20589516  0.31511945 -0.19496806 ...  0.6304218  -0.7212008
    0.01103678]]]


In [5]:
# Compute mean embeddings per sequence
mean_sequence_embeddings = torch.sum(attention_mask.unsqueeze(-1)*embeddings, axis=-2)/torch.sum(attention_mask, axis=-1)
print(f"Mean sequence embeddings: {mean_sequence_embeddings}")

Mean sequence embeddings: tensor([[ 0.1267,  0.4275, -0.0254,  ...,  0.5978, -0.4886,  0.0630]])
