In [1]:
import torch
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
import os
import pandas as pd

In [2]:
def get_cls_embedding(outputs):
    return outputs.last_hidden_state[:, 0, :].squeeze(0).cpu().numpy()

def get_average_embedding(outputs, attention_mask):
    mask = attention_mask.unsqueeze(-1)
    masked = outputs.last_hidden_state * mask
    summed = masked.sum(dim=1)
    count = mask.sum(dim=1)
    return (summed / count).squeeze(0).cpu().numpy()

def get_layerwise_embedding(hidden_states, attention_mask, last_n):
    # hidden_states: tuple of (layer_count+1, batch, seq_len, hidden_dim)
    last_hidden_states = hidden_states[-last_n:]
    stacked = torch.stack(last_hidden_states, dim=0) # (n_layers, batch, seq_len, hidden_dim)
    layers_mean = stacked.mean(dim=0) # (batch, seq_len, hidden_dim)
    mask = attention_mask.unsqueeze(-1)
    masked = layers_mean * mask
    summed = masked.sum(dim=1)
    count = mask.sum(dim=1)
    return (summed / count).squeeze(0).cpu().numpy()

In [None]:
def embeddings_extraction(models, sentences, labels):
    for model_name, out_dir in models.items():
        os.makedirs(out_dir, exist_ok=True)

        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
        model.eval()

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)

        cls_embeddings = []
        average_embeddings = []
        layerwise_embeddings = []

        for sentence in tqdm(sentences, desc=out_dir):
            inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True).to(device)
            with torch.no_grad():
                outputs = model(**inputs)
                attention_mask = inputs["attention_mask"]

                cls_emb = get_cls_embedding(outputs)
                average_emb = get_average_embedding(outputs, attention_mask)
                layerwise_emb = get_layerwise_embedding(outputs.hidden_states, attention_mask, last_n=4)

                cls_embeddings.append(cls_emb)
                average_embeddings.append(average_emb)
                layerwise_embeddings.append(layerwise_emb)

        np.savez(os.path.join(out_dir, "cls.npz"), embeddings=np.array(cls_embeddings), labels=np.array(labels))
        np.savez(os.path.join(out_dir, "average.npz"), embeddings=np.array(average_embeddings), labels=np.array(labels))
        np.savez(os.path.join(out_dir, "layerwise.npz"), embeddings=np.array(layerwise_embeddings), labels=np.array(labels))

### Dataset

In [4]:
dataset_path = "../1-dataset/VUAMC_sentences_labeled.csv"
df = pd.read_csv(dataset_path, encoding="utf-8")
sentences = df["sentence"].tolist()
labels = df["label"].astype(int).tolist()

### Pre-trained versions

In [5]:
models = {
    "bert-base-uncased": "bert",
    "roberta-base": "roberta"
}
embeddings_extraction(models, sentences, labels)

bert-base-uncased: 100%|██████████| 16202/16202 [09:53<00:00, 27.32it/s]
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
roberta-base: 100%|██████████| 16202/16202 [10:52<00:00, 24.84it/s]


### Fine-tuned version (run after fine-tuning)

In [6]:
models = {
    "../5-fine_tuning/bert-base-uncased_ft": "bert_ft",
    "../5-fine_tuning/roberta-base_ft": "roberta_ft"
}
embeddings_extraction(models, sentences, labels)

../5-fine_tuning/bert-base-uncased_ft: 100%|██████████| 16202/16202 [11:42<00:00, 23.05it/s]
Some weights of RobertaModel were not initialized from the model checkpoint at ../5-fine_tuning/roberta-base_ft and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
../5-fine_tuning/roberta-base_ft: 100%|██████████| 16202/16202 [12:03<00:00, 22.38it/s]
