In [2]:
import os
import pandas as pd
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModel, logging
from tqdm import tqdm
import warnings
import gc

# Suppress warnings
logging.set_verbosity_error()
warnings.filterwarnings("ignore")

# Setup
RAW_DATA_PATH = '../../data/raw'
SAVE_BASE_PATH = '../../data/interim/clinical'
BATCH_SIZE = 8  # Reduced for CPU processing
MAX_LENGTH = 512

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

MODEL_OPTIONS = {
    'biobert': 'dmis-lab/biobert-base-cased-v1.1',
}

# Choose your model
MODEL_NAME = MODEL_OPTIONS['biobert']

print(f"Loading model: {MODEL_NAME}")
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModel.from_pretrained(MODEL_NAME)
    model.to(device)
    model.eval()
    print(f"✅ Successfully loaded {MODEL_NAME}")
except Exception as e:
    print(f"❌ Failed to load {MODEL_NAME}: {e}")
    # Fallback to BERT base
    MODEL_NAME = MODEL_OPTIONS['bert_base']
    print(f"Falling back to {MODEL_NAME}")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModel.from_pretrained(MODEL_NAME)
    model.to(device)
    model.eval()

def get_bert_embeddings(texts, batch_size=BATCH_SIZE):
    """Extract BERT embeddings from texts with proper memory management"""
    if not texts:
        return np.array([]).reshape(0, 768)
    
    embeddings = []
    
    for i in tqdm(range(0, len(texts), batch_size), desc="Processing batches"):
        batch = texts[i:i + batch_size]
        
        # Handle empty strings
        batch = [text if text and text.strip() else "[EMPTY]" for text in batch]
        
        try:
            inputs = tokenizer(
                batch, 
                return_tensors='pt', 
                truncation=True, 
                padding=True, 
                max_length=MAX_LENGTH
            )
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = model(**inputs)
                cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
                embeddings.append(cls_embeddings)
            
            # Clear memory
            del inputs, outputs
            if device.type == 'cuda':
                torch.cuda.empty_cache()
            gc.collect()
            
        except Exception as e:
            print(f"Error processing batch {i//batch_size + 1}: {e}")
            # Create zero embeddings for failed batch
            batch_size_actual = len(batch)
            embedding_dim = 768  # Standard BERT embedding dimension
            zero_embeddings = np.zeros((batch_size_actual, embedding_dim))
            embeddings.append(zero_embeddings)
    
    return np.vstack(embeddings)

def process_user_data(user_path, save_dir, user_name):
    """Process a single user's data"""
    transcript_file = next(
        (f for f in os.listdir(user_path) if f.endswith('_Transcript.csv')), 
        None
    )
    
    if transcript_file is None:
        print(f"⚠️  No transcript file found for {user_name}")
        return False
    
    csv_path = os.path.join(user_path, transcript_file)
    
    try:
        # Check file size
        file_size_mb = os.path.getsize(csv_path) / (1024 * 1024)
        print(f"Processing {user_name} ({file_size_mb:.1f}MB)")
        
        df = pd.read_csv(csv_path)
        
        if 'Text' not in df.columns:
            print(f"⚠️  No 'Text' column found for {user_name}")
            return False
        
        # Get texts and handle NaN values
        texts = df['Text'].fillna('').astype(str).tolist()
        
        if not texts:
            print(f"⚠️  No text data found for {user_name}")
            return False
        
        print(f"Extracting features for {len(texts)} texts...")
        features = get_bert_embeddings(texts)
        
        if features.size == 0:
            print(f"⚠️  No features extracted for {user_name}")
            return False
        
        # Create DataFrame with features
        feature_prefix = MODEL_NAME.split('/')[-1].replace('-', '_')
        df_features = pd.DataFrame(
            features, 
            columns=[f'{feature_prefix}_{i}' for i in range(features.shape[1])]
        )
        
        # Add Start_Time and End_Time if they exist in the original data
        if 'Start_Time' in df.columns:
            df_features['Start_Time'] = df['Start_Time'].reset_index(drop=True)
        else:
            print(f"⚠️  No 'Start_Time' column found for {user_name}")
            
        if 'End_Time' in df.columns:
            df_features['End_Time'] = df['End_Time'].reset_index(drop=True)
        else:
            print(f"⚠️  No 'End_Time' column found for {user_name}")
        
        # Save to parquet
        os.makedirs(save_dir, exist_ok=True)
        output_path = os.path.join(save_dir, f'{feature_prefix}_features.parquet')
        df_features.to_parquet(output_path, index=False)
        
        print(f"✅ Saved {feature_prefix} features for {user_name} ({features.shape[0]} samples)")
        return True
        
    except Exception as e:
        print(f"❌ Error processing {user_name}: {e}")
        return False

# Main processing loop
def main():
    if not os.path.exists(RAW_DATA_PATH):
        print(f"❌ Raw data path does not exist: {RAW_DATA_PATH}")
        return
    
    users = [u for u in os.listdir(RAW_DATA_PATH) 
             if os.path.isdir(os.path.join(RAW_DATA_PATH, u))]
    
    if not users:
        print("❌ No user directories found")
        return
    
    print(f"Found {len(users)} users to process")
    print(f"Using model: {MODEL_NAME}")
    
    successful = 0
    failed = 0
    
    for user in users:
        user_path = os.path.join(RAW_DATA_PATH, user, 'text')
        
        if not os.path.isdir(user_path):
            print(f"⚠️  Text directory not found for {user}")
            failed += 1
            continue
        
        save_dir = os.path.join(SAVE_BASE_PATH, user)
        
        if process_user_data(user_path, save_dir, user):
            successful += 1
        else:
            failed += 1
    
    print(f"\n📊 Processing complete:")
    print(f"✅ Successful: {successful}")
    print(f"❌ Failed: {failed}")
    print(f"📁 Total users: {len(users)}")

if __name__ == "__main__":
    main()

Using device: cpu


KeyError: 'clinicalbert'