## Hugging Face and NTv2 Exploration

In [None]:
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM
import torch

In [None]:
# Import the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-v2-500m-multi-species", trust_remote_code=True)
model = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-v2-500m-multi-species", trust_remote_code=True)

In [51]:
# length to which the input sequences are padded
max_length = tokenizer.model_max_length     # 2048, hence 12kbp context window

# Create a dummy dna sequence and tokenize it (6-mers if multiple of 6)
sequences = ["ATTCCGATTCCGATTCCG", "ATTTCTCTCTCTCTCTGAGATCGATCGATCGAT"]
tokens_ids = tokenizer.batch_encode_plus(sequences, return_tensors="pt", padding="max_length", max_length = 11)["input_ids"]
print("Tokens shape:", tokens_ids.shape, "\n")


print("Tokens IDs:")
print(tokens_ids[:,:11])

print("Decoding back to sequences:")
decoded_sequences = tokenizer.batch_decode(tokens_ids, skip_special_tokens=False)
print(decoded_sequences)

Tokens shape: torch.Size([2, 11]) 

Tokens IDs:
tensor([[   3,  369,  369,  369,    1,    1,    1,    1,    1,    1,    1],
        [   3,  351, 2463, 2466, 3186, 1740, 4105, 4102, 4103,    1,    1]])
Decoding back to sequences:
['<cls> ATTCCG ATTCCG ATTCCG <pad> <pad> <pad> <pad> <pad> <pad> <pad>', '<cls> ATTTCT CTCTCT CTCTGA GATCGA TCGATC G A T <pad> <pad>']


In [52]:
# Compute the embeddings
attention_mask = tokens_ids != tokenizer.pad_token_id
print("Attention Mask", attention_mask)


torch_outs = model(
    tokens_ids,
    attention_mask=attention_mask,  # prevents attention to padding tokens
    encoder_attention_mask=attention_mask,
    output_hidden_states=True       # to get all layer embeddings
)

# Compute sequences embeddings
embeddings = torch_outs['hidden_states'][-1].detach().numpy()   # obtains final transformer layer embeddings
print(f"Embeddings shape: {embeddings.shape}")
print(f"Embeddings per token: {embeddings}")

# Add embed dimension axis
attention_mask = torch.unsqueeze(attention_mask, dim=-1)

# Compute mean embeddings per sequence
mean_sequence_embeddings = torch.sum(attention_mask*embeddings, axis=-2)/torch.sum(attention_mask, axis=1)
print(f"Mean sequence embeddings: {mean_sequence_embeddings}")

Attention Mask tensor([[ True,  True,  True,  True, False, False, False, False, False, False,
         False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True, False,
         False]])
Embeddings shape: (2, 11, 1024)
Embeddings per token: [[[ 0.50637925  0.14925307  0.526706   ... -0.6013561   0.76712173
   -0.24135263]
  [ 0.30739102  0.01497138  0.31069276 ... -0.8692496   0.7053464
   -0.04934862]
  [ 0.27530244  0.2383724   0.16276649 ... -0.95568717  0.6267945
   -0.07545093]
  ...
  [ 0.9203838  -0.04136808  0.08182072 ... -0.53490424  0.2571789
   -0.13717528]
  [ 0.9504963   0.02858299  0.07393027 ... -0.49353155  0.380477
   -0.21202193]
  [ 1.0615091  -0.08195896  0.142325   ... -0.51462924  0.4050352
   -0.29604894]]

 [[-0.02620885 -0.50777876  0.3580304  ... -0.23096171  0.7863488
   -1.0120119 ]
  [ 0.49354693 -0.59404796  0.24656996 ... -0.3358443   0.06688058
    0.2506485 ]
  [ 0.0174152   0.091777    0.43233433 ... -0.20136136  0.33414203
    0.

  mean_sequence_embeddings = torch.sum(attention_mask*embeddings, axis=-2)/torch.sum(attention_mask, axis=1)


In [63]:
print(torch_outs.logits)

tensor([[[-11.5742, -11.5761, -11.6334,  ...,  -2.2619,  -2.0177,   1.7361],
         [-10.6815, -10.7716, -10.8121,  ...,  -2.3263,   0.1234,   2.5711],
         [-12.5069, -12.7084, -12.7827,  ...,  -2.9243,  -0.6411,   2.0720],
         ...,
         [-10.2567, -10.3443, -10.3278,  ...,  -2.2249,   1.2304,   2.3430],
         [-11.0999, -11.2542, -11.2927,  ...,  -2.9625,   0.4279,   2.0422],
         [-11.8407, -11.9034, -11.9517,  ...,  -2.9084,  -1.1041,   1.1343]],

        [[-11.1233, -11.0699, -11.0851,  ...,   1.6103,   1.5689,   6.7156],
         [-10.8994, -10.9183, -10.7269,  ...,   2.1727,   3.9271,   7.4720],
         [ -9.1469,  -9.1537,  -9.0519,  ...,   3.9879,   3.3707,   6.8655],
         ...,
         [ -1.9141,  -1.9976,  -1.9876,  ...,   9.0548,  10.5053,  12.1279],
         [ -6.6433,  -6.5808,  -6.6273,  ...,  10.6087,  11.4174,  16.2603],
         [ -7.2303,  -7.0952,  -7.2115,  ...,   9.9537,  10.1086,  15.7961]]],
       grad_fn=<AddBackward0>)


In [65]:
print(torch_outs.logits.shape)

torch.Size([2, 11, 4107])


In [66]:
predictions = torch.nn.functional.softmax(torch_outs.logits, dim=-1)
print(predictions)

tensor([[[1.0409e-11, 1.0390e-11, 9.8108e-12,  ..., 1.1527e-07,
          1.4715e-07, 6.2809e-06],
         [8.8585e-13, 8.0956e-13, 7.7739e-13,  ..., 3.7671e-09,
          4.3638e-08, 5.0452e-07],
         [9.1224e-15, 7.4574e-15, 6.9235e-15,  ..., 1.3236e-10,
          1.2982e-09, 1.9571e-08],
         ...,
         [5.3362e-14, 4.8884e-14, 4.9699e-14,  ..., 1.6421e-10,
          5.2002e-09, 1.5820e-08],
         [3.4915e-14, 2.9923e-14, 2.8791e-14,  ..., 1.1941e-10,
          3.5437e-09, 1.7805e-08],
         [8.3346e-13, 7.8278e-13, 7.4586e-13,  ..., 6.3114e-09,
          3.8348e-08, 3.5962e-07]],

        [[3.6707e-11, 3.8722e-11, 3.8138e-11,  ..., 1.2442e-05,
          1.1938e-05, 2.0516e-03],
         [6.4080e-16, 6.2886e-16, 7.6150e-16,  ..., 3.0471e-10,
          1.7612e-09, 6.1000e-08],
         [1.0369e-15, 1.0299e-15, 1.1404e-15,  ..., 5.2497e-10,
          2.8319e-10, 9.3298e-09],
         ...,
         [5.4960e-08, 5.0556e-08, 5.1065e-08,  ..., 3.1899e-03,
          1.360