# Use Case 1: Embeddings extraction from the BERT model for topic classification on Ag News

In [1]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import BertTokenizer,BertForSequenceClassification, BertConfig
from transformers.pipelines import pipeline
import os
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
from torch import nn
import h5py

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [3]:
data_dir = "static/data/bert"

df_train = pd.read_csv(os.path.join(data_dir,"df_train_0_1_2.csv"))
df_test = pd.read_csv(os.path.join(data_dir,"df_test_0_1_2.csv"))
df_new_unseen = pd.read_csv(os.path.join(data_dir,"df_new_unseen_0_1_2.csv"))
df_drifted = pd.read_csv(os.path.join(data_dir,"df_drifted_3.csv"))

In [4]:
MODEL_DIR = "static/saved_models/bert/best_model"
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
BERT_MODEL = 'bert-base-uncased' # BERT model type

config = BertConfig.from_pretrained(os.path.join(MODEL_DIR, CONFIG_NAME), output_hidden_states=True)
model = BertForSequenceClassification.from_pretrained(os.path.join(MODEL_DIR), config=config)
model = model.to(device)
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL, do_lower_case=True)

In [5]:
tokenizer_kwargs = {"padding":"max_length", "truncation":True}

In [6]:
train_id2label = ["World", "Sports", "Business", "Sci/Tech"]

In [7]:
def extract_embedding_and_predict(model, tokenizer, df, layer_id):
    
    X = df["text"].tolist() # List of input texts
    Y_original = df["label"].tolist() # List of original labels (GT)
    Y_original_names = [train_id2label[l] for l in Y_original]  # List of original labels' names (GT)
    E = np.empty((0,768)) # Initialize empty array of embeddings
    Y_predicted = [] # Initialize empty list of predicted labels (IDs)
    Y_predicted_names = [] # Initialize empty list of predicted labels (Names)
    
    
    BATCH_SIZE = 256
    n_batch = len(df)//BATCH_SIZE
    remainer = len(df)%BATCH_SIZE
    
    for i in tqdm(range(n_batch)):
        input_texts = df["text"].iloc[i*BATCH_SIZE:i*BATCH_SIZE+BATCH_SIZE].tolist()
        
        tokenized_texts = tokenizer(input_texts, padding=True, truncation=True, return_tensors="pt")
        
        with torch.no_grad():
            outputs = model(**tokenized_texts.to(device))
            
        batch_probabilities = nn.functional.softmax(outputs["logits"], dim=-1)
        batch_labels = torch.argmax(batch_probabilities, dim=1).tolist()

        batch_probabilities_list = batch_probabilities.tolist()            
        batch_labels_name = [train_id2label[l] for l in batch_labels] 

        Y_predicted.extend(batch_labels)
        Y_predicted_names.extend(batch_labels_name)

        last_layer_hidden_states_arr = outputs["hidden_states"][layer_id].detach().cpu().numpy()                   
        embedding_CLS_arr = last_layer_hidden_states_arr[:, 0, :] # [BATCH_SIZE, 0 = CLS, 768]
        E = np.vstack([E, embedding_CLS_arr])
            
           
    if remainer>0:

        input_texts = df["text"].iloc[-remainer:].tolist()

        tokenized_texts = tokenizer(input_texts, padding=True, truncation=True, return_tensors="pt")

        with torch.no_grad():
            outputs = model(**tokenized_texts.to(device))

        batch_probabilities = nn.functional.softmax(outputs["logits"], dim=-1)
        batch_labels = torch.argmax(batch_probabilities, dim=1).tolist()

        batch_probabilities_list = batch_probabilities.tolist()            
        batch_labels_name = [train_id2label[l] for l in batch_labels] 

        Y_predicted.extend(batch_labels)
        Y_predicted_names.extend(batch_labels_name)

        last_layer_hidden_states_arr = outputs["hidden_states"][layer_id].detach().cpu().numpy()                   
        embedding_CLS_arr = last_layer_hidden_states_arr[:, 0, :] # [BATCH_SIZE, 0 = CLS, 768]
        E = np.vstack([E, embedding_CLS_arr])

        return X, E, Y_original, Y_original_names, Y_predicted, Y_predicted_names

In [8]:
def save_embedding(output_path, X, E, Y_original, Y_original_names, Y_predicted, Y_predicted_names):

    fp = h5py.File(output_path, "w")

    fp.create_dataset("X", data=X, compression="gzip")
    fp.create_dataset("E", data=E, compression="gzip")
    fp.create_dataset("Y_original", data=Y_original, compression="gzip")
    fp.create_dataset("Y_original_names", data=Y_original_names, compression="gzip")
    fp.create_dataset("Y_predicted", data=Y_predicted, compression="gzip")
    fp.create_dataset("Y_predicted_names", data=Y_predicted_names, compression="gzip")
    fp.close()
    return

In [9]:
embedding_dir = os.path.join("static", "saved_embeddings", "bert")

for layer_id in range(1,12):
    print("layer: ",layer_id)

    X_test, E_test, Y_original_test, Y_original_names_test, Y_predicted_test, Y_predicted_names_test = extract_embedding_and_predict(model, tokenizer, df_test, layer_id)

    X_train, E_train, Y_original_train, Y_original_names_train, Y_predicted_train, Y_predicted_names_train = extract_embedding_and_predict(model, tokenizer, df_train, layer_id)

    X_drift, E_drift, Y_original_drift, Y_original_names_drift, Y_predicted_drift, Y_predicted_names_drift = extract_embedding_and_predict(model, tokenizer, df_drifted, layer_id)

    X_new_unseen, E_new_unseen, Y_original_new_unseen, Y_original_names_new_unseen, Y_predicted_new_unseen, Y_predicted_names_new_unseen = extract_embedding_and_predict(model, tokenizer, df_new_unseen, layer_id)

    save_embedding(os.path.join(embedding_dir, f"train_embedding_0_1_2_layer_{layer_id}.hdf5"), 
                    X_train, 
                    E_train, 
                    Y_original_train, 
                    Y_original_names_train, 
                    Y_predicted_train, 
                    Y_predicted_names_train)

    save_embedding(os.path.join(embedding_dir, f"test_embedding_0_1_2_layer_{layer_id}.hdf5"), 
                    X_test, 
                    E_test, 
                    Y_original_test, 
                    Y_original_names_test, 
                    Y_predicted_test, 
                    Y_predicted_names_test)

    save_embedding(os.path.join(embedding_dir, f"drifted_embedding_3_layer_{layer_id}.hdf5"), 
                    X_drift, 
                    E_drift, 
                    Y_original_drift, 
                    Y_original_names_drift, 
                    Y_predicted_drift, 
                    Y_predicted_names_drift)

    save_embedding(os.path.join(embedding_dir, f"new_unseen_embedding_0_1_2_layer_{layer_id}.hdf5"), 
                    X_new_unseen, 
                    E_new_unseen, 
                    Y_original_new_unseen, 
                    Y_original_names_new_unseen, 
                    Y_predicted_new_unseen, 
                    Y_predicted_names_new_unseen)

layer:  1


100%|██████████| 22/22 [00:17<00:00,  1.26it/s]
100%|██████████| 232/232 [03:31<00:00,  1.10it/s]
100%|██████████| 124/124 [02:13<00:00,  1.07s/it]
100%|██████████| 119/119 [01:43<00:00,  1.15it/s]


layer:  2


100%|██████████| 22/22 [00:17<00:00,  1.27it/s]
100%|██████████| 232/232 [03:32<00:00,  1.09it/s]
100%|██████████| 124/124 [02:12<00:00,  1.07s/it]
100%|██████████| 119/119 [01:43<00:00,  1.15it/s]


layer:  3


100%|██████████| 22/22 [00:17<00:00,  1.28it/s]
100%|██████████| 232/232 [03:32<00:00,  1.09it/s]
100%|██████████| 124/124 [02:13<00:00,  1.07s/it]
100%|██████████| 119/119 [01:43<00:00,  1.15it/s]


layer:  4


100%|██████████| 22/22 [00:17<00:00,  1.27it/s]
100%|██████████| 232/232 [03:32<00:00,  1.09it/s]
100%|██████████| 124/124 [02:12<00:00,  1.07s/it]
100%|██████████| 119/119 [01:43<00:00,  1.15it/s]


layer:  5


100%|██████████| 22/22 [00:17<00:00,  1.28it/s]
100%|██████████| 232/232 [03:31<00:00,  1.09it/s]
100%|██████████| 124/124 [02:13<00:00,  1.08s/it]
100%|██████████| 119/119 [01:43<00:00,  1.15it/s]


layer:  6


100%|██████████| 22/22 [00:17<00:00,  1.27it/s]
100%|██████████| 232/232 [03:32<00:00,  1.09it/s]
100%|██████████| 124/124 [02:13<00:00,  1.07s/it]
100%|██████████| 119/119 [01:43<00:00,  1.15it/s]


layer:  7


100%|██████████| 22/22 [00:17<00:00,  1.28it/s]
100%|██████████| 232/232 [03:32<00:00,  1.09it/s]
100%|██████████| 124/124 [02:13<00:00,  1.07s/it]
100%|██████████| 119/119 [01:43<00:00,  1.15it/s]


layer:  8


100%|██████████| 22/22 [00:17<00:00,  1.27it/s]
100%|██████████| 232/232 [03:32<00:00,  1.09it/s]
100%|██████████| 124/124 [02:12<00:00,  1.07s/it]
100%|██████████| 119/119 [01:43<00:00,  1.15it/s]


layer:  9


100%|██████████| 22/22 [00:17<00:00,  1.28it/s]
100%|██████████| 232/232 [03:32<00:00,  1.09it/s]
100%|██████████| 124/124 [02:12<00:00,  1.07s/it]
100%|██████████| 119/119 [01:43<00:00,  1.15it/s]


layer:  10


100%|██████████| 22/22 [00:17<00:00,  1.27it/s]
100%|██████████| 232/232 [03:31<00:00,  1.10it/s]
100%|██████████| 124/124 [02:12<00:00,  1.07s/it]
100%|██████████| 119/119 [01:43<00:00,  1.15it/s]


layer:  11


100%|██████████| 22/22 [00:17<00:00,  1.28it/s]
100%|██████████| 232/232 [03:31<00:00,  1.10it/s]
100%|██████████| 124/124 [02:12<00:00,  1.07s/it]
100%|██████████| 119/119 [01:44<00:00,  1.14it/s]
