# PTSD Prediction Preprocessing Pipeline (Batch Processing)

This notebook implements the preprocessing pipeline with batch processing to avoid memory overload.


In [10]:
# Import necessary libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.preprocessing import StandardScaler
import pickle
import gc  # Garbage collection
import warnings
warnings.filterwarnings('ignore')

# Import additional library for .mat file handling
import scipy.io as sio

In [11]:
# Set paths
BASE_PATH = "../../dataset/edaicwoz_participant"
LABELS_PATH = "../../dataset/edaicwoz_labels"
OUTPUT_PATH = "../../storage/sandeep/processed_data"

# Create output directory
os.makedirs(OUTPUT_PATH, exist_ok=True)

In [12]:
# Load and analyze labels
def load_labels(labels_path):
    train_df = pd.read_csv(f"{labels_path}/train_split.csv")
    dev_df = pd.read_csv(f"{labels_path}/dev_split.csv")
    test_df = pd.read_csv(f"{labels_path}/test_split.csv")
    train_df['split'] = 'train'
    dev_df['split'] = 'dev'
    test_df['split'] = 'test'
    return pd.concat([train_df, dev_df, test_df], ignore_index=True)

# Load labels and analyze distribution
all_labels = load_labels(LABELS_PATH)
print(all_labels.head())

# Verify PTSD distribution by split
for split in ['train', 'dev', 'test']:
    split_df = all_labels[all_labels['split'] == split]
    if 'PCL-C (PTSD)' in split_df.columns:
        ptsd_count = split_df['PCL-C (PTSD)'].sum()
        total = len(split_df)
        print(f"{split}: {ptsd_count}/{total} PTSD positive ({ptsd_count/total*100:.2f}%)")

   Participant_ID   Gender  PHQ_Binary  PHQ_Score  PCL-C (PTSD)  \
0             302     male           0          4             0   
1             303   female           0          0             0   
2             304   female           0          6             0   
3             305     male           0          7             0   
4             307  female            0          4             0   

   PTSD Severity  split  
0           28.0  train  
1           17.0  train  
2           20.0  train  
3           28.0  train  
4           23.0  train  
train: 49/163 PTSD positive (30.06%)
dev: 17/56 PTSD positive (30.36%)
test: 21/56 PTSD positive (37.50%)


In [13]:
# Participant data loading function with memory optimization
def load_participant_data(participant_id, base_path, features_to_load=None):
    """Load features for a participant with memory optimization
    
    Args:
        participant_id: ID of the participant
        base_path: Base path to dataset
        features_to_load: List of feature types to load (None=all)
    """
    p_dir = Path(f"{base_path}/{participant_id}_P/{participant_id}_P")
    features_dir = p_dir / "features"
    feature_data = {}
    
    # Default: load all feature types
    if features_to_load is None:
        features_to_load = ['egemaps', 'mfcc', 'densenet', 'openface', 'resnet', 'vgg']
    
    try:
        # Load audio features conditionally
        if 'egemaps' in features_to_load:
            egemaps_path = features_dir / f"{participant_id}_OpenSMILE2.3.0_egemaps.csv"
            if egemaps_path.exists():
                try:
                    # Try standard CSV loading first
                    feature_data["egemaps"] = pd.read_csv(egemaps_path)
                    # Check if it's a single column CSV with semicolon delimiters
                    if feature_data["egemaps"].shape[1] == 1 and ";" in str(feature_data["egemaps"].iloc[0, 0]):
                        feature_data["egemaps"] = pd.read_csv(egemaps_path, delimiter=';')
                except Exception as e:
                    print(f"Error loading egemaps: {e}")
        
        if 'mfcc' in features_to_load:
            mfcc_path = features_dir / f"{participant_id}_OpenSMILE2.3.0_mfcc.csv"
            if mfcc_path.exists():
                try:
                    feature_data["mfcc"] = pd.read_csv(mfcc_path)
                    if feature_data["mfcc"].shape[1] == 1 and ";" in str(feature_data["mfcc"].iloc[0, 0]):
                        feature_data["mfcc"] = pd.read_csv(mfcc_path, delimiter=';')
                except Exception as e:
                    print(f"Error loading mfcc: {e}")
        
        if 'densenet' in features_to_load:
            densenet_path = features_dir / f"{participant_id}_densenet201.csv"
            if densenet_path.exists():
                feature_data["densenet"] = pd.read_csv(densenet_path)
        
        if 'openface' in features_to_load:
            openface_path = features_dir / f"{participant_id}_OpenFace2.1.0_Pose_gaze_AUs.csv"
            if openface_path.exists():
                feature_data["openface"] = pd.read_csv(openface_path)
        
        if 'resnet' in features_to_load:
            resnet_path = features_dir / f"{participant_id}_CNN_ResNet.mat"
            if resnet_path.exists():
                try:
                    resnet_data = sio.loadmat(resnet_path)
                    if 'features' in resnet_data:
                        feature_data["resnet"] = resnet_data['features']
                    else:
                        feature_keys = [k for k in resnet_data.keys() if not k.startswith('__')]
                        if feature_keys:
                            feature_data["resnet"] = resnet_data[feature_keys[0]]
                except Exception as e:
                    print(f"Error loading ResNet: {e}")
        
        if 'vgg' in features_to_load:
            vgg_path = features_dir / f"{participant_id}_CNN_VGG.mat"
            if vgg_path.exists():
                try:
                    vgg_data = sio.loadmat(vgg_path)
                    if 'features' in vgg_data:
                        feature_data["vgg"] = vgg_data['features']
                    else:
                        feature_keys = [k for k in vgg_data.keys() if not k.startswith('__')]
                        if feature_keys:
                            feature_data["vgg"] = vgg_data[feature_keys[0]]
                except Exception as e:
                    print(f"Error loading VGG: {e}")
        
    except Exception as e:
        print(f"General error loading features for participant {participant_id}: {e}")
    
    return feature_data

In [14]:
# Optimize normalization function to reduce memory usage
def normalize_features(feature_data):
    if feature_data is None:
        return None
    
    normalized_data = {}
    for key, data in feature_data.items():
        if isinstance(data, pd.DataFrame):
            # Keep only numeric columns
            numeric_cols = data.select_dtypes(include=['number']).columns
            if len(numeric_cols) > 0:
                # Handle case for DataFrame with numeric columns
                numeric_data = data[numeric_cols].fillna(0)
                scaler = StandardScaler()
                normalized_data[key] = scaler.fit_transform(numeric_data)
                print(f"Normalized {key}: shape {normalized_data[key].shape}")
                # Free memory
                del data
                gc.collect()
            else:
                print(f"Warning: No numeric columns found in {key}")
                continue
        elif isinstance(data, np.ndarray):
            # Handle deep feature arrays
            if data.dtype.kind in ['U', 'S', 'O']:
                try:
                    data = np.array(data, dtype=float)
                except ValueError:
                    print(f"Error: Could not convert {key} to numeric. Skipping.")
                    continue
            
            # Use batched normalization for large arrays to save memory
            if data.shape[0] > 10000 and data.shape[1] > 1000:
                print(f"Using batched normalization for large {key} array")
                # Calculate mean and std on sample
                sample_size = min(10000, data.shape[0])
                sample_indices = np.random.choice(data.shape[0], sample_size, replace=False)
                sample = data[sample_indices]
                mean = np.mean(sample, axis=0)
                std = np.std(sample, axis=0)
                std[std == 0] = 1.0  # Avoid division by zero
                
                # Normalize in batches
                batch_size = 5000
                normalized_data[key] = np.zeros_like(data)
                for i in range(0, data.shape[0], batch_size):
                    end = min(i + batch_size, data.shape[0])
                    normalized_data[key][i:end] = (data[i:end] - mean) / std
                    # Force garbage collection after each batch
                    gc.collect()
                
                print(f"Normalized {key}: shape {normalized_data[key].shape}")
                # Free memory
                del data
                gc.collect()
            else:
                # Regular normalization for smaller arrays
                scaler = StandardScaler()
                normalized_data[key] = scaler.fit_transform(data)
                print(f"Normalized {key}: shape {normalized_data[key].shape}")
                # Free memory
                del data
                gc.collect()
        else:
            print(f"Warning: Unsupported data type for {key}. Skipping.")
    
    return normalized_data

In [15]:
# Create sequences from preprocessed features
def create_sequences(features, seq_length=20, stride=10):
    """Create sequences for LSTM input with memory efficiency"""
    
    # For large feature arrays, process in batches
    if features.shape[0] > 50000:
        print(f"Processing large feature array of shape {features.shape} in batches")
        sequences = []
        batch_size = 10000
        
        for start_idx in range(0, features.shape[0] - seq_length + 1, batch_size):
            end_idx = min(start_idx + batch_size, features.shape[0] - seq_length + 1)
            for i in range(start_idx, end_idx, stride):
                sequences.append(features[i:i+seq_length])
            # Force garbage collection
            gc.collect()
            
        return np.array(sequences)
    else:
        # Regular sequence creation for smaller arrays
        sequences = []
        for i in range(0, features.shape[0] - seq_length + 1, stride):
            sequences.append(features[i:i+seq_length])
        return np.array(sequences)

In [16]:
# Create multimodal sequences with aligned time steps
def create_multimodal_sequences(audio_features, visual_features, seq_length=20, stride=10):
    """Create aligned sequences from audio and visual features"""
    # Determine the minimum length between modalities
    min_length = min(audio_features.shape[0], visual_features.shape[0])
    
    # Initialize sequence lists
    audio_sequences = []
    visual_sequences = []
    
    # For very large feature arrays, process in batches
    if min_length > 50000:
        batch_size = 10000
        for start_idx in range(0, min_length - seq_length + 1, batch_size):
            end_idx = min(start_idx + batch_size, min_length - seq_length + 1)
            for i in range(start_idx, end_idx, stride):
                if i + seq_length <= min_length:
                    audio_sequences.append(audio_features[i:i+seq_length])
                    visual_sequences.append(visual_features[i:i+seq_length])
            # Force garbage collection
            gc.collect()
    else:
        for i in range(0, min_length - seq_length + 1, stride):
            audio_sequences.append(audio_features[i:i+seq_length])
            visual_sequences.append(visual_features[i:i+seq_length])
    
    # Convert to numpy arrays
    audio_array = np.array(audio_sequences)
    visual_array = np.array(visual_sequences)
    
    return audio_array, visual_array

In [17]:
# Extract PTSD label using different approaches
def get_ptsd_label(participant_info):
    """Extract PTSD label, prioritizing direct binary values"""
    # First try to use provided binary classification
    if 'PCL-C (PTSD)' in participant_info:
        return int(participant_info['PCL-C (PTSD)'])
    elif 'PCL-C' in participant_info:
        return int(participant_info['PCL-C'])
    elif 'PTSD' in participant_info:
        return int(participant_info['PTSD'])
    # Fallback to severity threshold
    elif 'PTSD Severity' in participant_info:
        severity = float(participant_info['PTSD Severity'])
        return int(severity >= 35)  # Clinical threshold
    else:
        return 0

In [18]:
# Process single participant with memory optimization
def process_participant(participant_id, base_path, preferred_features=None):
    """Process a participant with memory optimization
    
    Args:
        participant_id: Participant ID
        base_path: Base path to dataset
        preferred_features: Dictionary of preferred features for audio/visual
    """
    # Default preferred features if not specified
    if preferred_features is None:
        preferred_features = {
            'audio': ['egemaps', 'mfcc', 'densenet'],  # Try these in order
            'visual': ['openface', 'resnet', 'vgg']    # Try these in order
        }
    
    print(f"Processing participant {participant_id}...")
    
    # First load minimal feature set to determine availability
    initial_features = load_participant_data(participant_id, base_path, 
                                             features_to_load=preferred_features['audio'][:1] + 
                                                          preferred_features['visual'][:1])
    
    # Check for minimum required features (one audio, one visual)
    has_audio = any(f in initial_features for f in preferred_features['audio'])
    has_visual = any(f in initial_features for f in preferred_features['visual'])
    
    if not has_audio or not has_visual:
        print(f"Participant {participant_id} missing required features")
        return None
    
    # Determine which features to use for this participant
    audio_feature_to_use = None
    for feature in preferred_features['audio']:
        if feature in initial_features:
            audio_feature_to_use = feature
            break
    
    visual_feature_to_use = None
    for feature in preferred_features['visual']:
        if feature in initial_features:
            visual_feature_to_use = feature
            break
    
    # Clear memory of initial features
    del initial_features
    gc.collect()
    
    # Now load only the features we need
    feature_data = load_participant_data(participant_id, base_path, 
                                         features_to_load=[audio_feature_to_use, visual_feature_to_use])
    
    # Normalize the features
    normalized_features = normalize_features(feature_data)
    if normalized_features is None or len(normalized_features) < 2:
        print(f"Failed to normalize features for participant {participant_id}")
        return None
    
    # Extract audio and visual features
    audio_features = normalized_features.get(audio_feature_to_use)
    visual_features = normalized_features.get(visual_feature_to_use)
    
    if audio_features is None or visual_features is None:
        print(f"Missing required features after normalization for participant {participant_id}")
        return None
    
    # Create aligned sequences
    audio_sequences, visual_sequences = create_multimodal_sequences(
        audio_features, visual_features, seq_length=20, stride=10
    )
    
    # Clear memory
    del normalized_features
    gc.collect()
    
    print(f"Created {len(audio_sequences)} sequences for participant {participant_id}")
    
    return {
        'audio_sequences': audio_sequences,
        'visual_sequences': visual_sequences,
        'audio_feature': audio_feature_to_use,
        'visual_feature': visual_feature_to_use
    }

In [19]:
# Process full dataset in batches to avoid memory issues
def process_full_dataset_batched(base_path, labels_path, output_path):
    """Process the full dataset in batches by split"""
    # Load labels
    all_labels = load_labels(labels_path)
    
    # Define feature preferences
    # Order matters - will try to use first available feature in each list
    preferred_features = {
        'audio': ['egemaps', 'mfcc', 'densenet'],  # In order of preference
        'visual': ['openface', 'resnet', 'vgg']    # In order of preference
    }
    
    # Process each split separately to manage memory
    for split in ['train', 'dev', 'test']:
        print(f"\nProcessing {split} set...")
        
        # Get participant IDs for this split
        split_df = all_labels[all_labels['split'] == split]
        participant_ids = split_df['Participant_ID'].astype(str).tolist()
        
        # Initialize lists to store data
        audio_sequences = []
        visual_sequences = []
        labels = []
        participant_ids_processed = []
        
        # Process participants in small batches
        batch_size = 10  # Process 10 participants at a time
        for batch_start in range(0, len(participant_ids), batch_size):
            batch_end = min(batch_start + batch_size, len(participant_ids))
            batch_ids = participant_ids[batch_start:batch_end]
            
            print(f"Processing batch {batch_start//batch_size + 1}/{(len(participant_ids)-1)//batch_size + 1} in {split} set")
            
            batch_audio = []
            batch_visual = []
            batch_labels = []
            batch_pids = []
            
            for pid in batch_ids:
                # Get PTSD label
                p_info = all_labels[all_labels['Participant_ID'].astype(str) == pid].iloc[0]
                ptsd_label = get_ptsd_label(p_info)
                
                # Process participant
                result = process_participant(pid, base_path, preferred_features)
                if result is None:
                    continue
                
                # Extract sequences
                audio_seqs = result['audio_sequences']
                visual_seqs = result['visual_sequences']
                
                # Add to batch lists
                batch_audio.append(audio_seqs)
                batch_visual.append(visual_seqs)
                batch_labels.extend([ptsd_label] * len(audio_seqs))
                batch_pids.extend([pid] * len(audio_seqs))
                
                # Clear memory
                del result
                gc.collect()
            
            # Combine sequences from this batch
            if batch_audio:
                audio_sequences.extend(np.vstack(batch_audio))
                visual_sequences.extend(np.vstack(batch_visual))
                labels.extend(batch_labels)
                participant_ids_processed.extend(batch_pids)
            
            # Clear batch memory
            del batch_audio, batch_visual, batch_labels, batch_pids
            gc.collect()
            
            # Save intermediate results for this batch
            if audio_sequences:
                print(f"Saving intermediate results for {split} after batch {batch_start//batch_size + 1}")
                
                # Convert to numpy arrays
                audio_array = np.array(audio_sequences)
                visual_array = np.array(visual_sequences)
                labels_array = np.array(labels)
                
                # Create dataset dict
                interim_dataset = {
                    'audio': audio_array,
                    'visual': visual_array,
                    'labels': labels_array,
                    'participant_ids': participant_ids_processed
                }
                
                # Save intermediate results
                os.makedirs(output_path, exist_ok=True)
                with open(f"{output_path}/{split}_batch_{batch_start//batch_size + 1}.pkl", 'wb') as f:
                    pickle.dump(interim_dataset, f)
                
                # Clear memory
                del audio_array, visual_array, labels_array, interim_dataset
                audio_sequences = []
                visual_sequences = []
                labels = []
                participant_ids_processed = []
                gc.collect()
        
        # Combine all batch files for this split
        combine_batch_files(split, output_path)
    
    print("\nDataset processing complete!")
    print_dataset_statistics(output_path)
    
    return True

In [20]:
# Helper function to combine batch files with extreme memory optimization
def combine_batch_files(split, output_path):
    """Combine batch files in a super memory-efficient way"""
    print(f"Combining batch files for {split}...")
    
    # Find all batch files for this split
    batch_files = sorted([f for f in os.listdir(output_path) 
                         if f.startswith(f"{split}_batch_") and f.endswith(".pkl")])
    
    if not batch_files:
        print(f"No batch files found for {split}")
        return
    
    print(f"Found {len(batch_files)} batch files for {split}")
    
    # Initialize empty arrays for final dataset
    all_audio = None 
    all_visual = None
    all_labels = []
    all_pids = []
    total_sequences = 0
    ptsd_positive = 0
    
    # Create a combined file incrementally
    for i, batch_file in enumerate(batch_files):
        print(f"Processing batch file {i+1}/{len(batch_files)}: {batch_file}")
        
        # Load batch file
        with open(f"{output_path}/{batch_file}", 'rb') as f:
            batch_data = pickle.load(f)
        
        # Get data
        batch_audio = batch_data['audio']
        batch_visual = batch_data['visual']
        batch_labels = batch_data['labels']
        batch_pids = batch_data['participant_ids']
        
        # Track statistics
        total_sequences += len(batch_labels)
        ptsd_positive += sum(batch_labels)
        
        # Combine incrementally
        if all_audio is None:
            # First batch - create initial combined files
            all_audio_file = f"{output_path}/{split}_audio.npy"
            all_visual_file = f"{output_path}/{split}_visual.npy"
            
            # Save initial arrays
            np.save(all_audio_file, batch_audio)
            np.save(all_visual_file, batch_visual)
        else:
            # For subsequent batches, load existing data, append new data, and save back
            # We do this in a very memory efficient way
            
            # Audio sequences
            all_audio_file = f"{output_path}/{split}_audio.npy"
            combined_audio = np.vstack([np.load(all_audio_file, mmap_mode='r'), batch_audio])
            np.save(all_audio_file, combined_audio)
            del combined_audio
            
            # Visual sequences
            all_visual_file = f"{output_path}/{split}_visual.npy"
            combined_visual = np.vstack([np.load(all_visual_file, mmap_mode='r'), batch_visual])
            np.save(all_visual_file, combined_visual)
            del combined_visual
        
        # Extend labels and participant IDs (these are smaller and can be kept in memory)
        all_labels.extend(batch_labels)
        all_pids.extend(batch_pids)
        
        # Delete batch data to free memory
        del batch_data, batch_audio, batch_visual, batch_labels, batch_pids
        gc.collect()
        
        print(f"Current progress: {total_sequences} sequences, {ptsd_positive} PTSD positive")
    
    # Save labels and participant IDs separately
    with open(f"{output_path}/{split}_labels.pkl", 'wb') as f:
        pickle.dump({'labels': all_labels, 'participant_ids': all_pids}, f)
    
    # Save metadata about the dataset dimensions
    audio_shape = (total_sequences, 20, 24)  # Standard shape based on processing
    visual_shape = (total_sequences, 20, 53)  # Standard shape based on processing
    
    with open(f"{output_path}/{split}_metadata.pkl", 'wb') as f:
        pickle.dump({
            'audio_shape': audio_shape,
            'visual_shape': visual_shape,
            'total_sequences': total_sequences,
            'ptsd_positive': ptsd_positive,
            'ptsd_percentage': (ptsd_positive/total_sequences*100) if total_sequences > 0 else 0
        }, f)
    
    print(f"Combined {split} data saved with {total_sequences} sequences")
    print(f"PTSD positive: {ptsd_positive}/{total_sequences} ({ptsd_positive/total_sequences*100:.2f}%)")
    
    # Remove batch files to save disk space
    for batch_file in batch_files:
        try:
            os.remove(f"{output_path}/{batch_file}")
            print(f"Removed {batch_file}")
        except Exception as e:
            print(f"Error removing {batch_file}: {e}")

In [21]:
# Helper function to print dataset statistics
def print_dataset_statistics(output_path):
    """Print statistics about the processed dataset"""
    print("\nDataset Statistics:")
    
    # Load and analyze each split
    for split in ['train', 'dev', 'test']:
        split_file = f"{output_path}/{split}_data.pkl"
        if not os.path.exists(split_file):
            print(f"{split} data file not found")
            continue
        
        # Load split data
        with open(split_file, 'rb') as f:
            data = pickle.load(f)
        
        # Print statistics
        print(f"\n{split.upper()} SET:")
        print(f"  Audio sequences: {data['audio'].shape}")
        print(f"  Visual sequences: {data['visual'].shape}")
        print(f"  PTSD positive: {sum(data['labels'])}/{len(data['labels'])} "
              f"({sum(data['labels'])/len(data['labels'])*100:.2f}%)")
        print(f"  Unique participants: {len(set(data['participant_ids']))}")
        
        # Free memory
        del data
        gc.collect()

In [22]:
# Run the batch processing
process_full_dataset_batched(BASE_PATH, LABELS_PATH, OUTPUT_PATH)


Processing train set...
Processing batch 1/17 in train set
Processing participant 302...
Normalized egemaps: shape (75876, 24)
Normalized openface: shape (22766, 53)
Created 2275 sequences for participant 302
Processing participant 303...
Normalized egemaps: shape (98526, 24)
Normalized openface: shape (29565, 53)
Created 2955 sequences for participant 303
Processing participant 304...
Normalized egemaps: shape (79256, 24)
Normalized openface: shape (23780, 53)
Created 2377 sequences for participant 304
Processing participant 305...
Normalized egemaps: shape (170396, 24)
Normalized openface: shape (51122, 53)
Created 5111 sequences for participant 305
Processing participant 307...
Normalized egemaps: shape (123872, 24)
Normalized openface: shape (37167, 53)
Created 3715 sequences for participant 307
Processing participant 308...
Normalized egemaps: shape (86756, 24)
Normalized openface: shape (26031, 53)
Created 2602 sequences for participant 308
Processing participant 309...
Normaliz

True

In [23]:
def process_with_gpu_memory_constraints(train_audio, train_visual, train_labels, dev_audio, dev_visual, dev_labels, batch_size=32):
    """
    Process data in smaller batches to avoid GPU memory issues
    
    Args:
        train_audio: Training audio features
        train_visual: Training visual features
        train_labels: Training labels
        dev_audio: Validation audio features
        dev_visual: Validation visual features
        dev_labels: Validation labels
        batch_size: Batch size for training
    
    Returns:
        model_inputs: Dictionary with all necessary data for model training
    """
    import tensorflow as tf
    import numpy as np
    
    print("Processing data with GPU memory constraints...")
    
    # 1. Use CPU for initial tensor conversions
    with tf.device('/CPU:0'):
        # Make sure labels match the size of features
        num_train_samples = min(len(train_audio), len(train_visual))
        train_labels_subset = train_labels[:num_train_samples]
        
        num_dev_samples = min(len(dev_audio), len(dev_visual))
        dev_labels_subset = dev_labels[:num_dev_samples]
        
        # Create sample weights for class imbalance
        from sklearn.utils import class_weight
        classes = np.unique(train_labels_subset)
        class_weights = class_weight.compute_class_weight(
            class_weight='balanced',
            classes=classes,
            y=train_labels_subset
        )
        class_weight_dict = {i: w for i, w in enumerate(class_weights)}
        
        # Create dataset objects to handle memory efficiently
        train_dataset = tf.data.Dataset.from_tensor_slices(
            ((train_audio[:num_train_samples], train_visual[:num_train_samples]), 
             train_labels_subset)
        ).batch(batch_size)
        
        val_dataset = tf.data.Dataset.from_tensor_slices(
            ((dev_audio[:num_dev_samples], dev_visual[:num_dev_samples]), 
             dev_labels_subset)
        ).batch(batch_size)
    
    # 2. Return the TensorFlow datasets and related info
    return {
        'train_dataset': train_dataset,
        'val_dataset': val_dataset,
        'class_weight_dict': class_weight_dict,
        'train_samples': num_train_samples,
        'val_samples': num_dev_samples,
        'batch_size': batch_size
    }

## Next Steps

Once the preprocessing is complete with this memory-optimized approach, we can move on to model development:

1. Build a two-branch LSTM architecture with cross-modal attention
2. Implement class balancing to handle imbalanced data
3. Train the model with proper validation
4. Evaluate on test set and compare with baselines
