##### Author: Merritt Khaipho-Burch
##### Contact: mbb262@cornell.edu
##### Date: 2023-06-05
##### Updated: 2023-06-07

##### Description:
- Set up nucleotide transformer model from hugging face
- load formatted TE data
- run model

In [None]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import pandas as pd
import numpy as npa
import gc

In [None]:
# Import the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-2.5b-multi-species")
model = AutoModelForMaskedLM.from_pretrained("InstaDeepAI/nucleotide-transformer-2.5b-multi-species")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = model.to(device)

In [None]:
# Load in data ['gene', 'avg_fpkm', 'class', 'upSeq', 'upSeq2k']
trainSeq = pd.read_csv(filepath_or_buffer='/workdir/mbb262/te/te_sequence_with_walley_expression.txt', delimiter='\t')
trainSeq.shape

In [None]:
# Look at top of dataset
trainSeq.head

In [None]:
# specify input sequences
inputSeq = trainSeq['teSeq']

In [None]:
# Run model
resEmbeddings = np.empty((0, 2560))
for i in range(0, len(inputSeq), 30):
    tokens_ids = tokenizer.batch_encode_plus(inputSeq[i:i+10], return_tensors="pt", 
                                             padding=True, truncation=True)["input_ids"]
    tokens_ids = tokens_ids.to(device)
    attention_mask = tokens_ids != tokenizer.pad_token_id
    with torch.no_grad():
        torch_outs = model(
        tokens_ids,
        attention_mask=attention_mask,
        encoder_attention_mask=attention_mask,
        output_hidden_states=True)
    embeddings = torch_outs['hidden_states'][-1].detach().cpu().numpy()
    attention_mask = tokens_ids != tokenizer.pad_token_id
    attention_mask = attention_mask.cpu().numpy()
    attention_mask = np.expand_dims(attention_mask, axis = -1)
    masked_embeddings = embeddings * attention_mask  # multiply by 0 pad tokens embeddings
    sequences_lengths = np.sum(attention_mask, axis=1)
    mean_embeddings = np.sum(masked_embeddings, axis=1) / sequences_lengths
    resEmbeddings = np.append(resEmbeddings, mean_embeddings, axis=0)
    # print progress bar
    print(f"{i+1}/{len(inputSeq)}", end="\r")

In [None]:
# Save embeddings
np.savetxt(fname = '/workdir/mbb262/Pg_X_test.txt', X=resEmbeddings, delimiter='\t')