In [None]:
!pip install fair-esm transformers  # latest release

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
def tensor2string(tensor, decimal_places=4):
  # First, flatten the tensor
  flattened = tensor.view(-1)

  # Then convert it to a list of numbers
  list_of_numbers = flattened.tolist()

  # Then convert that list to a string
  # Ensure that each float is represented with the specified number of decimal places
  string_of_numbers = ' '.join("{:.{}f}".format(num, decimal_places) for num in list_of_numbers)

  return string_of_numbers


def convert_dict_tensors_to_strings(original_dict):
    new_dict = {}
    for key, value in original_dict.items():
        new_dict[key] = tensor2string(value)
    return new_dict

In [None]:
import torch
from transformers import AutoTokenizer, EsmModel

# Load the pretrained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t36_3B_UR50D")
model = EsmModel.from_pretrained("facebook/esm2_t36_3B_UR50D")
model.to('cuda')

Downloading (…)okenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/779 [00:00<?, ?B/s]

Downloading (…)model.bin.index.json:   0%|          | 0.00/55.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00002.bin:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

Downloading (…)l-00002-of-00002.bin:   0%|          | 0.00/1.39G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of the model checkpoint at facebook/esm2_t36_3B_UR50D were not used when initializing EsmModel: ['lm_head.layer_norm.weight', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias']
- This IS expected if you are initializing EsmModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t36_3B_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
sequence_output.shape

torch.Size([1, 24, 2560])

In [None]:
len(sequence)

24

In [None]:
embedding.shape

torch.Size([1, 2560])

In [None]:
import json
# Open the file for reading
with open('response_1685411308040.json', 'r') as f:
    # Load the JSON data from the file
    data = json.load(f)

In [None]:
from tqdm import tqdm

In [None]:
embeddings = {}
for i in tqdm(data):
  sequence = i['sequence']
  # Tokenize the sequence and return as PyTorch tensors
  inputs = tokenizer(sequence, return_tensors="pt")

  # Run the sequence through the model. This will return a tuple where the first element is the sequence output
  with torch.no_grad():
      sequence_output = model(**inputs.to('cuda'))[0]
      sequence_output = sequence_output[:, 1:-1, :]
  # The sequence output is a tensor of shape (batch_size, sequence_length, hidden_size)
  # We typically take the mean of the sequence_length dimension to get a single embedding vector per sequence
  embedding = sequence_output.mean(dim=1)
  str_embed = tensor2string(embedding) #round to 1000th position
  embeddings[i['primary_accession']] = (str_embed)


100%|██████████| 964/964 [16:35<00:00,  1.03s/it]


In [None]:
# Specify your filename
filename = "embeddings.json"

# Use the json.dump() function to write the dictionary to a file
with open(filename, 'w') as f:
    json.dump(embeddings, f)