In [23]:
import pandas as pd
import torch
from transformers import BertModel, BertTokenizer
import numpy as np

In [24]:
def get_bert_embeddings(input_df, max_length=512):
    """
    Extract BERT embeddings for tokens in the input dataframe.
    Returns a DataFrame with token-level embeddings and metadata.
    """
    # Load pre-trained BERT model and tokenizer
    model = BertModel.from_pretrained('bert-base-uncased')
    model.eval()  # Set to evaluation mode
    
    # Initialize list to store results
    results = []
    
    # Process each row in the input DataFrame
    for _, row in input_df.iterrows():
        subject_id = row['Subject_ID']
        tweet_id = row['Tweet_ID']
        tokens = eval(row['Tokens'])  # Convert string representation to list
        token_ids = eval(row['Indexed_Tokens'])
        segment_ids = eval(row['Segments_IDs'])

        # Split long token sequences into chunks of max_length
        for i in range(0, len(token_ids), max_length):
            chunk_token_ids = token_ids[i:i+max_length]
            chunk_segment_ids = segment_ids[i:i+max_length]
            chunk_tokens = tokens[i:i+max_length]
        
            # Convert inputs to tensors
            tokens_tensor = torch.tensor([chunk_token_ids])
            segments_tensor = torch.tensor([chunk_segment_ids])
            
            # Get BERT embeddings
            with torch.no_grad():
                outputs = model(tokens_tensor, segments_tensor)
                hidden_states = outputs.last_hidden_state.squeeze(0)
                
            # Process each token in the chunk
            for j, (token, embedding) in enumerate(zip(chunk_tokens, hidden_states)):
                if token in ['[CLS]', '[SEP]']:
                    continue
                
                # Convert embedding to numpy and store as list
                token_embedding = embedding.numpy().tolist()
                
                # Store all information
                result = {
                    'Subject_ID': subject_id,
                    'Tweet_ID': tweet_id,
                    'Token': token,
                    'Token_ID': chunk_token_ids[j],
                    'Segment_ID': chunk_segment_ids[j],
                    'Position': i + j,  # Position in the full text
                    'Embedding': token_embedding
                }
                
                results.append(result)
    
    # Create DataFrame from results
    result_df = pd.DataFrame(results)
    
    return result_df

In [25]:
status = "Before"
# Read the input CSV
input_df = pd.read_csv(f'{status}_Tokenized.csv')

# Get embeddings
embeddings_df = get_bert_embeddings(input_df)

# Save to CSV
embeddings_df.to_csv(f'{status}_token_embeddings.csv', index=False)

# Print sample of results
print("\nSample of the resulting DataFrame:")
print(embeddings_df.head())

# Print shape information
print("\nDataFrame shape:", embeddings_df.shape)
print("Number of embedding dimensions:", len([col for col in embeddings_df.columns if 'embedding_' in col]))


Sample of the resulting DataFrame:
   Subject_ID  Tweet_ID        Token  Token_ID  Segment_ID  Position  \
0          74       353         holy      4151           0         1   
1          74       353      fucking      8239           0         2   
2          74       353         shit      4485           0         3   
3          74       353  criticizing     21289           0         4   
4          74       353       israel      3956           0         5   

                                           Embedding  
0  [0.5452250838279724, 1.0058563947677612, 0.125...  
1  [0.46159079670906067, 0.9647833108901978, 0.13...  
2  [0.5526859164237976, 0.8520041108131409, 0.025...  
3  [0.16068339347839355, 0.7583393454551697, -0.0...  
4  [0.26864200830459595, 0.864945113658905, 0.004...  

DataFrame shape: (1692, 7)
Number of embedding dimensions: 0
