In [49]:
import pandas as pd
import torch
from transformers import BertModel, BertTokenizer
import os
from tqdm import tqdm
import numpy as np
import ast
from ast import literal_eval

In [50]:
# Load BERT model and tokenizer once to avoid redundant loading
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [51]:
# Function to verify and regenerate tokens
def verify_and_fix_tokens(input_df):
    for index, row in input_df.iterrows():
        tokens = eval(row['Tokens'])  # Convert string to list
        regenerated_ids = tokenizer.convert_tokens_to_ids(tokens)
        
        # Check if Indexed_Tokens match regenerated IDs
        if eval(row['Indexed_Tokens']) != regenerated_ids:
            print(f"Mismatch found at index {index}. Fixing...")
            input_df.at[index, 'Indexed_Tokens'] = str(regenerated_ids)  # Fix Indexed_Tokens
            
    return input_df

In [52]:
# Function to chunk long sequences
def chunk_long_sequences(tokens, token_ids, segment_ids, max_length=512):
    """
    Splits long sequences into chunks of max_length for processing with BERT.
    """
    for i in range(0, len(token_ids), max_length):
        chunk_tokens = tokens[i:i+max_length]
        chunk_token_ids = token_ids[i:i+max_length]
        chunk_segment_ids = segment_ids[i:i+max_length]
        
        # Ensure special tokens are handled correctly
        if i > 0 and chunk_tokens[0] != '[CLS]':
            chunk_tokens = ['[CLS]'] + chunk_tokens
            chunk_token_ids = [101] + chunk_token_ids
            chunk_segment_ids = [segment_ids[i]] + chunk_segment_ids
        
        if i + max_length < len(token_ids) and chunk_tokens[-1] != '[SEP]':
            chunk_tokens = chunk_tokens + ['[SEP]']
            chunk_token_ids = chunk_token_ids + [102]
            chunk_segment_ids = chunk_segment_ids + [segment_ids[i]]
        
        yield chunk_tokens, chunk_token_ids, chunk_segment_ids

In [53]:
def process_and_export_by_subject(input_df, output_dir="output_embeddings", max_length=512):
    """
    Processes text by grouping all tokens for the same Subject_ID.
    Generates BERT embeddings for each subject and exports them individually to CSV files.
    
    Args:
        input_df (pd.DataFrame): Input DataFrame with tokens, token IDs, and segment IDs.
        output_dir (str): Directory to save the individual CSV files.
        max_length (int): Maximum sequence length for BERT processing.
    """
    import os

    # Create the output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    grouped = input_df.groupby('Subject_ID')
    for subject_id, group in grouped:
        print(f"Processing Subject_ID: {subject_id}")
        
        # Concatenate tokens, token IDs, and segment IDs for the same Subject_ID
        all_tokens = sum(group['Tokens'].apply(eval).tolist(), [])
        all_token_ids = sum(group['Indexed_Tokens'].apply(eval).tolist(), [])
        all_segment_ids = sum(group['Segments_IDs'].apply(eval).tolist(), [])
        
        results = []
        
        # Use chunk_long_sequences to split long sequences
        for chunk_tokens, chunk_token_ids, chunk_segment_ids in chunk_long_sequences(
            all_tokens, all_token_ids, all_segment_ids, max_length=max_length
        ):
            # Convert inputs to tensors
            tokens_tensor = torch.tensor([chunk_token_ids])
            segments_tensor = torch.tensor([chunk_segment_ids])
            
            # Get BERT embeddings
            with torch.no_grad():
                outputs = model(tokens_tensor, segments_tensor)
                hidden_states = outputs.last_hidden_state.squeeze(0)
            
            # Process each token in the chunk
            for j, (token, embedding) in enumerate(zip(chunk_tokens, hidden_states)):
                if token in ['[CLS]', '[SEP]']:
                    continue
                
                token_embedding = embedding.numpy().tolist()
                
                # Store all token information
                results.append({
                    'Subject_ID': subject_id,
                    'Token': token,
                    'Token_ID': chunk_token_ids[j],
                    'Segment_ID': chunk_segment_ids[j],
                    'Position': j,  # Position within the chunk
                    'Embedding': token_embedding
                })
        
        # Export results for the current Subject_ID
        subject_df = pd.DataFrame(results)
        output_path = os.path.join(output_dir, f"{subject_id}_embeddings.csv")
        subject_df.to_csv(output_path, index=False)
        print(f"Exported embeddings for Subject_ID {subject_id} to {output_path}")

# This function now incorporates chunk_long_sequences to handle long sequences effectively.

In [54]:
# Load your tokenized dataset
input_df = pd.read_csv("Before_Tokenized.csv")

# Apply the function
input_df = verify_and_fix_tokens(input_df)

# Process and export embeddings
process_and_export_by_subject(input_df, output_dir="subject_embeddings")

Processing Subject_ID: 16
Exported embeddings for Subject_ID 16 to subject_embeddings/16_embeddings.csv
Processing Subject_ID: 17
Exported embeddings for Subject_ID 17 to subject_embeddings/17_embeddings.csv
Processing Subject_ID: 20
Exported embeddings for Subject_ID 20 to subject_embeddings/20_embeddings.csv
Processing Subject_ID: 23
Exported embeddings for Subject_ID 23 to subject_embeddings/23_embeddings.csv
Processing Subject_ID: 26
Exported embeddings for Subject_ID 26 to subject_embeddings/26_embeddings.csv
Processing Subject_ID: 28
Exported embeddings for Subject_ID 28 to subject_embeddings/28_embeddings.csv
Processing Subject_ID: 30
Exported embeddings for Subject_ID 30 to subject_embeddings/30_embeddings.csv
Processing Subject_ID: 32
Exported embeddings for Subject_ID 32 to subject_embeddings/32_embeddings.csv
Processing Subject_ID: 39
Exported embeddings for Subject_ID 39 to subject_embeddings/39_embeddings.csv
Processing Subject_ID: 41
Exported embeddings for Subject_ID 41 