In [1]:
from transformers import BertForSequenceClassification, DNATokenizer
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import pandas as pd
from tqdm import tqdm
from utils.utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [28]:
def feed_tokenizer(s, tokenizer):
    out = tokenizer.encode(
                        s,
                        add_special_tokens = True,  # Add '[CLS]' and '[SEP]'
                        padding = 'longest',        # Pad to longest in batch.
                        truncation = True,          # Truncate sentences to `max_length`.
                        max_length = 512,   
                        return_attention_mask = True, # Construct attn. masks.
                        return_tensors = 'pt',        # Return pytorch tensors.
                )
    return out

def obtain_embeddings(s, model, tokenizer):
    tokenizer_output = feed_tokenizer(s, tokenizer)
    model_output = model(tokenizer_output)
    hidden_states = model_output[1]

    # Create tensor of embeddings
    token_embeddings = torch.stack(hidden_states, dim=0)
    token_embeddings = token_embeddings.squeeze()
    return token_embeddings.detach().numpy()

def obtain_sentence_representations(embeddings, kind):
    assert len(embeddings.shape) == 3
    
    representations = []
    
    for layer_emb in embeddings:
        if kind == "mean":
            r = np.mean(layer_emb, axis=0)
        elif kind == "sum":
            r = np.sum(layer_emb, axis=0)
        else:
            raise Exception
        representations.append(r)
        
    return representations

In [2]:
# Info
model_path = "dnabert6/"

# Load config, model and tokenizer
tokenizer = DNATokenizer.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path, 
                                                    output_attentions=True,
                                                    output_hidden_states=True)

<class 'transformers.tokenization_dna.DNATokenizer'>


In [6]:
# Load data
metadata = pd.read_csv("dataset/same_sequence/all_data_0.csv", sep=";")
data = pd.read_csv("dataset/same_sequence/data_0.tsv", sep="\t")

In [35]:
sequences = data.sequence.values

R = []
layer = -1

for s in tqdm(sequences):
    emb = obtain_embeddings(s, model, tokenizer)
    representations = obtain_sentence_representations(emb, "mean")
    R.append(representations[layer])

100%|██████████| 200/200 [02:56<00:00,  1.13it/s]
