In [1]:
import re
from typing import List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import transformers
from transformers import (
    AutoConfig,
    AutoModel,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    PreTrainedTokenizer,
    PreTrainedModel
)


%matplotlib inline

plt.style.use("seaborn-v0_8")



In [86]:
summaries = pd.read_csv("../data/raw/summaries_train.csv")
prompt = pd.read_csv("../data/raw/prompts_train.csv")

In [87]:
base_model = "microsoft/deberta-v3-base"
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModel.from_pretrained(base_model)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [88]:
def get_embeddings(texts: List, tokenizer: PreTrainedTokenizer, model: PreTrainedModel):
    encoded_token = tokenizer.batch_encode_plus(texts, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        last_hidden_states = model(**encoded_token).last_hidden_state

    mask = torch.tensor(encoded_token["attention_mask"]).unsqueeze(-1)
    embeddings = (last_hidden_states * mask).sum(dim=(1)) / mask.sum(dim=1)
    return embeddings 

In [89]:
embeddings = get_embeddings(summaries["text"][:10].tolist(), tokenizer, model)
embeddings

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
  mask = torch.tensor(encoded_token["attention_mask"]).unsqueeze(-1)


tensor([[ 0.2360, -0.0677,  0.1726,  ..., -0.4067, -0.0027, -0.1088],
        [-0.0720, -0.0450,  0.3548,  ..., -0.4550, -0.1331, -0.3194],
        [-0.0431,  0.0910,  0.3680,  ..., -0.0131,  0.0382, -0.1018],
        ...,
        [ 0.0671,  0.0770,  0.1244,  ..., -0.3030, -0.0981, -0.1718],
        [-0.0301,  0.1331,  0.2862,  ...,  0.0149, -0.1966, -0.3421],
        [ 0.0782,  0.2284,  0.0388,  ..., -0.3886, -0.0444, -0.1196]])