# BERT Feature Extraction

By: Jimuel Celeste, Jr. 

Objective: Extract BERT features from text transcripts.

In [1]:
import os
import pandas as pd
import torch

from transformers import AutoTokenizer, AutoModel

## Sample: string

In [2]:
model_name = 'google-bert/bert-base-multilingual-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

In [3]:
text = "This is an example sentence for BERT feature extraction."
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)

In [5]:
with torch.no_grad(): # Disable gradient calculation for inference
    outputs = model(**inputs)
outputs

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-0.0950, -0.0757, -0.0581,  ...,  0.0655,  0.0869, -0.0496],
         [-0.4034,  0.1696,  0.0600,  ..., -0.1469,  0.0502,  0.0940],
         [-0.7109,  0.0160,  0.1436,  ...,  0.2006,  0.4207, -0.1936],
         ...,
         [ 0.0559, -0.2766,  0.0602,  ..., -0.3413,  0.6654, -0.1577],
         [-0.4258,  0.3786, -0.0698,  ..., -0.0161,  0.2390, -0.1685],
         [ 0.0671,  0.2001, -0.4674,  ...,  0.3677,  1.1703, -0.4158]]]), pooler_output=tensor([[ 1.4311e-01, -1.0967e-02,  1.6508e-01,  1.4898e-01,  2.3800e-01,
          3.9627e-01,  1.9053e-01, -1.3494e-01, -1.6137e-01,  2.7025e-01,
         -2.1765e-01, -1.8905e-01,  2.8847e-01, -1.2556e-01, -2.5286e-01,
          1.0808e-01,  1.6765e-01,  8.2053e-02,  1.9843e-01,  1.1438e-01,
         -9.3997e-02, -1.2272e-01,  1.0760e-01,  2.3318e-02,  2.1264e-01,
         -1.3650e-01,  2.3257e-01,  1.1641e-01,  3.3621e-01,  2.9417e-01,
          2.1035e-01,  2.3617e-01,  

In [6]:
cls_embedding = outputs.last_hidden_state[:, 0, :] # Get the embedding for the [CLS] token
cls_embedding

tensor([[-9.5038e-02, -7.5722e-02, -5.8144e-02, -1.2931e-02, -3.2790e-01,
          2.2274e-02, -1.7276e-03, -4.3071e-02, -1.9550e+00,  1.3326e-02,
          6.8120e-02, -1.4925e-01,  3.0180e-02,  5.7136e-03,  1.1757e-01,
          6.0787e-02,  6.5999e-02, -4.6056e-02, -1.1047e-02, -3.1337e-02,
         -1.4216e-01,  5.8811e-02, -1.1840e-01, -3.5610e-01,  5.0346e-01,
         -1.0454e-02,  6.2171e-02,  5.0466e-02, -2.1470e+00, -1.5012e-02,
         -1.7318e-01,  1.4139e-01, -1.3618e-02,  9.8587e-02, -2.3313e-02,
         -9.6162e-02,  8.1568e-02,  1.6434e+00, -5.3784e-02, -5.6552e-02,
         -2.3318e-01,  5.3416e-02, -1.4525e-01,  1.6263e-01, -2.7179e-02,
         -2.1280e-01, -1.4365e-02, -1.3238e-02, -1.1144e-01, -8.2067e-02,
         -2.8082e-02,  2.1822e-01, -1.8964e-01, -7.9344e-02, -1.2398e-01,
         -3.8819e-02, -2.4521e-01,  3.2250e-02,  4.4372e-02,  1.8813e-01,
          1.8835e+00,  3.9186e-02, -5.4867e-02, -2.1552e-02,  6.9519e-02,
         -5.8258e-02,  2.6340e-02,  1.

In [7]:
cls_embedding.size()

torch.Size([1, 768])

## Feature extraction with model = google-bert/bert-base-multilingual-uncased

In [18]:
def get_BERT_feature_extractor(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    
    def extract_BERT_features(input_file, output_file):
        with open(input_file, 'r', encoding='utf-8') as f:
            text = f.read()
            inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
        
            with torch.no_grad(): # Disable gradient calculation for inference
                outputs = model(**inputs)
                cls_embedding = outputs.last_hidden_state[:, 0, :] # Get the embedding for the [CLS] token
                my_numpy_array = cls_embedding[0].numpy()
                my_dataframe = pd.DataFrame(my_numpy_array).T
                my_dataframe.to_csv(output_file, index=False)

        return None
        
    return extract_BERT_features

BERT_feature_extractor = get_BERT_feature_extractor(model_name='google-bert/bert-base-multilingual-uncased')

In [19]:
metadata = '../data/Thesis - Text Transcripts/metadata.csv'
metadata_df = pd.read_csv(metadata)
metadata_df.head()

Unnamed: 0,filename,record_id,subject_id,age,sex,educ,mmse,dx,dx_binary,dataset,language
0,S002.txt,S002,S002,62.0,F,,30.0,NC,0,ADReSS,en
1,S003.txt,S003,S003,69.0,F,,29.0,NC,0,ADReSS,en
2,S004.txt,S004,S004,71.0,F,,30.0,NC,0,ADReSS,en
3,S005.txt,S005,S005,74.0,F,,30.0,NC,0,ADReSS,en
4,S006.txt,S006,S006,67.0,F,,29.0,NC,0,ADReSS,en


In [20]:
%%time
input_dir = '../data/Thesis - Text Transcripts/'
output_dir = '../results/Thesis - BERT Features_cls/'

files = metadata_df['filename'].values
n = len(files)
i = 1
for file in files:
    print(f"{i}/{n} {file}")
    input_file = os.path.join(input_dir, file)
    base, ext = os.path.splitext(file)
    output_file = os.path.join(output_dir, base + '.csv')
    print(input_file)
    print(output_file)

    if os.path.exists(output_file):
        print("Skipping file: already processed.")
        i += 1
        continue

    BERT_feature_extractor(input_file, output_file)
    print("File processed!")
    i += 1

1/1023 S002.txt
../data/Thesis - Text Transcripts/S002.txt
../results/Thesis - BERT Features_cls/S002.csv
File processed!
2/1023 S003.txt
../data/Thesis - Text Transcripts/S003.txt
../results/Thesis - BERT Features_cls/S003.csv
File processed!
3/1023 S004.txt
../data/Thesis - Text Transcripts/S004.txt
../results/Thesis - BERT Features_cls/S004.csv
File processed!
4/1023 S005.txt
../data/Thesis - Text Transcripts/S005.txt
../results/Thesis - BERT Features_cls/S005.csv
File processed!
5/1023 S006.txt
../data/Thesis - Text Transcripts/S006.txt
../results/Thesis - BERT Features_cls/S006.csv
File processed!
6/1023 S007.txt
../data/Thesis - Text Transcripts/S007.txt
../results/Thesis - BERT Features_cls/S007.csv
File processed!
7/1023 S009.txt
../data/Thesis - Text Transcripts/S009.txt
../results/Thesis - BERT Features_cls/S009.csv
File processed!
8/1023 S011.txt
../data/Thesis - Text Transcripts/S011.txt
../results/Thesis - BERT Features_cls/S011.csv
File processed!
9/1023 S012.txt
../data/

## Generate metadata.csv

In [21]:
def txt_to_csv(file):
    base, ext = os.path.splitext(file)
    return base + ".csv"

metadata_df_new = metadata_df.copy(deep=True)
metadata_df_new['filename'] = metadata_df_new['filename'].apply(lambda x: txt_to_csv(x))
metadata_df_new

Unnamed: 0,filename,record_id,subject_id,age,sex,educ,mmse,dx,dx_binary,dataset,language
0,S002.csv,S002,S002,62.0,F,,30.0,NC,0,ADReSS,en
1,S003.csv,S003,S003,69.0,F,,29.0,NC,0,ADReSS,en
2,S004.csv,S004,S004,71.0,F,,30.0,NC,0,ADReSS,en
3,S005.csv,S005,S005,74.0,F,,30.0,NC,0,ADReSS,en
4,S006.csv,S006,S006,67.0,F,,29.0,NC,0,ADReSS,en
...,...,...,...,...,...,...,...,...,...,...,...
1018,taukdial-168-2.csv,taukdial-168-2,taukdial-168,65.0,M,,29.0,NC,0,TAUKADIAL,zh
1019,taukdial-168-3.csv,taukdial-168-3,taukdial-168,65.0,M,,29.0,NC,0,TAUKADIAL,zh
1020,taukdial-169-1.csv,taukdial-169-1,taukdial-169,81.0,F,,28.0,MCI,1,TAUKADIAL,en
1021,taukdial-169-2.csv,taukdial-169-2,taukdial-169,81.0,F,,28.0,MCI,1,TAUKADIAL,en


In [22]:
metadata_new = os.path.join(output_dir, "metadata.csv")
metadata_df_new.to_csv(metadata_new, index=False)

## Check for completeness

In [23]:
i = 0
for file in metadata_df_new['filename'].values:
    output_file = os.path.join(output_dir, file)
    if os.path.exists(output_file):
        i += 1
    else:
        print("File does not exist:", output_file)
i

1023

Complete.

## Last Hidden State (last_hidden_state)

In [25]:
def get_BERT_feature_extractor(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)
    
    def extract_BERT_features(input_file, output_file):
        with open(input_file, 'r', encoding='utf-8') as f:
            text = f.read()
            inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
        
            with torch.no_grad(): # Disable gradient calculation for inference
                outputs = model(**inputs)
                token_embeddings = outputs.last_hidden_state
                my_numpy_array = token_embeddings[0].numpy()
                my_dataframe = pd.DataFrame(my_numpy_array)
                my_dataframe.to_csv(output_file, index=False)

        return None
        
    return extract_BERT_features

BERT_feature_extractor = get_BERT_feature_extractor(model_name='google-bert/bert-base-multilingual-uncased')

In [26]:
%%time
input_dir = '../data/Thesis - Text Transcripts/'
output_dir = '../results/Thesis - BERT Features_last_hidden_state/'

files = metadata_df['filename'].values
n = len(files)
i = 1
for file in files:
    print(f"{i}/{n} {file}")
    input_file = os.path.join(input_dir, file)
    base, ext = os.path.splitext(file)
    output_file = os.path.join(output_dir, base + '.csv')
    print(input_file)
    print(output_file)

    if os.path.exists(output_file):
        print("Skipping file: already processed.")
        i += 1
        continue

    BERT_feature_extractor(input_file, output_file)
    print("File processed!")
    i += 1

1/1023 S002.txt
../data/Thesis - Text Transcripts/S002.txt
../results/Thesis - BERT Features_last_hidden_state/S002.csv
File processed!
2/1023 S003.txt
../data/Thesis - Text Transcripts/S003.txt
../results/Thesis - BERT Features_last_hidden_state/S003.csv
File processed!
3/1023 S004.txt
../data/Thesis - Text Transcripts/S004.txt
../results/Thesis - BERT Features_last_hidden_state/S004.csv
File processed!
4/1023 S005.txt
../data/Thesis - Text Transcripts/S005.txt
../results/Thesis - BERT Features_last_hidden_state/S005.csv
File processed!
5/1023 S006.txt
../data/Thesis - Text Transcripts/S006.txt
../results/Thesis - BERT Features_last_hidden_state/S006.csv
File processed!
6/1023 S007.txt
../data/Thesis - Text Transcripts/S007.txt
../results/Thesis - BERT Features_last_hidden_state/S007.csv
File processed!
7/1023 S009.txt
../data/Thesis - Text Transcripts/S009.txt
../results/Thesis - BERT Features_last_hidden_state/S009.csv
File processed!
8/1023 S011.txt
../data/Thesis - Text Transcript

In [27]:
metadata_new = os.path.join(output_dir, "metadata.csv")
metadata_df_new.to_csv(metadata_new, index=False)

In [28]:
i = 0
for file in metadata_df_new['filename'].values:
    output_file = os.path.join(output_dir, file)
    if os.path.exists(output_file):
        i += 1
    else:
        print("File does not exist:", output_file)
i

1023