The code groups frames by video, validates that each video has exactly one transition between classes, and balances class counts by trimming excess frames from the larger class near the transition point, ensuring an equal number of "normal" and "seizure" frames.

In [1]:
import pandas as pd
import numpy as np

#### Load Data

In [2]:
df_time = pd.read_pickle('data/df_time_find_best.pkl')

In [None]:
df_time

#### Balance the data

In [4]:
import pandas as pd
import numpy as np

def validate_sequence(group):
    """Validate that the sequence has exactly one transition between classes"""
    transitions = (group['class'] != group['class'].shift()).sum()
    if transitions != 2:  # One transition creates two changes in the diff
        return False, f"Found {transitions//2} transitions, expected 1"
    
    return True, None

def balance_video_sequence(df_time):
    """Balance normal and seizure frames while preserving temporal order"""
    balanced_data = []
    error_videos = []
    
    for vid_id, group in df_time.groupby('vid_id'):
        try:
            # Basic validation
            if len(group) == 0:
                error_videos.append((vid_id, "Empty group"))
                continue
                
            # Validate sequence
            is_valid, error_msg = validate_sequence(group)
            if not is_valid:
                error_videos.append((vid_id, error_msg))
                continue
            
            # Split by class and preserve original index
            group = group.reset_index(drop=True)  # Reset index for this group
            nml_group = group[group['class'] == 'nml']
            sz_group = group[group['class'] == 'sz']
            
            nml_count = len(nml_group)
            sz_count = len(sz_group)
            
            if nml_count == 0 or sz_count == 0:
                error_videos.append((vid_id, "Missing one class"))
                continue
            
            # Find transition point
            transition_idx = sz_group.index[0]  # First seizure frame
            
            # Balance based on class counts
            if nml_count > sz_count:
                # Take the last normal frames before seizure
                nml_before_transition = nml_group[nml_group.index < transition_idx]
                if len(nml_before_transition) < sz_count:
                    error_videos.append((vid_id, "Not enough normal frames before transition"))
                    continue
                    
                nml_balanced_group = nml_before_transition.tail(sz_count)
                balanced_group = pd.concat([nml_balanced_group, sz_group], axis=0)
                
            elif sz_count > nml_count:
                # Take the first seizure frames after transition
                sz_balanced_group = sz_group.head(nml_count)
                balanced_group = pd.concat([nml_group, sz_balanced_group], axis=0)
                
            else:
                balanced_group = group
            
            # Verify balance
            final_nml_count = len(balanced_group[balanced_group['class'] == 'nml'])
            final_sz_count = len(balanced_group[balanced_group['class'] == 'sz'])
            if final_nml_count != final_sz_count:
                error_videos.append((vid_id, f"Balance failed: nml={final_nml_count}, sz={final_sz_count}"))
                continue
                
            balanced_data.append(balanced_group)
            
        except Exception as e:
            error_videos.append((vid_id, f"Error: {str(e)}"))
            continue
    
    if balanced_data:
        df_time_balanced = pd.concat(balanced_data, axis=0).reset_index(drop=True)
        print(f"Successfully processed {len(balanced_data)} videos")
    else:
        df_time_balanced = pd.DataFrame()
        print("No videos were successfully balanced")
        
    if error_videos:
        print("\nErrors encountered:")
        for vid_id, error in error_videos:
            print(f"VID_ID {vid_id}: {error}")
    
    return df_time_balanced

# Usage
df_time_balanced = balance_video_sequence(df_time)

Successfully processed 72 videos

Errors encountered:
VID_ID 79611U00: Found 0 transitions, expected 1


In [5]:
df_time_balanced.to_pickle('data/df_time_balanced.pkl')

print("DataFrame saved to 'data/df_time_balanced.pkl'")

DataFrame saved to 'data/df_time_balanced.pkl'


#### Analyze Balancing

In [6]:
def analyze_dataframe_differences(df_time, df_time_balanced):
    """
    Analyze and report differences between original and balanced dataframes
    Returns a DataFrame with the analysis results
    """
    # Per-video statistics
    vid_stats = []
    for vid_id in df_time['vid_id'].unique():
        orig_vid = df_time[df_time['vid_id'] == vid_id]
        bal_vid = df_time_balanced[df_time_balanced['vid_id'] == vid_id]
        
        # Skip if video not in balanced df
        if len(bal_vid) == 0:
            continue
            
        orig_nml = len(orig_vid[orig_vid['class'] == 'nml'])
        orig_sz = len(orig_vid[orig_vid['class'] == 'sz'])
        bal_nml = len(bal_vid[bal_vid['class'] == 'nml'])
        bal_sz = len(bal_vid[bal_vid['class'] == 'sz'])
        
        vid_stat = {
            'vid_id': vid_id,
            'original_frames': len(orig_vid),
            'balanced_frames': len(bal_vid),
            'frames_removed': len(orig_vid) - len(bal_vid),
            'original_nml': orig_nml,
            'original_sz': orig_sz,
            'balanced_nml': bal_nml,
            'balanced_sz': bal_sz,
            'reduction_percentage': ((len(orig_vid) - len(bal_vid)) / len(orig_vid) * 100)
        }
        vid_stats.append(vid_stat)
    
    # Create DataFrame from video statistics
    df_stats = pd.DataFrame(vid_stats)
    
    # Add total row
    total_row = {
        'vid_id': 'TOTAL',
        'original_frames': df_stats['original_frames'].sum(),
        'balanced_frames': df_stats['balanced_frames'].sum(),
        'frames_removed': df_stats['frames_removed'].sum(),
        'original_nml': df_stats['original_nml'].sum(),
        'original_sz': df_stats['original_sz'].sum(),
        'balanced_nml': df_stats['balanced_nml'].sum(),
        'balanced_sz': df_stats['balanced_sz'].sum(),
        'reduction_percentage': ((df_stats['frames_removed'].sum() / df_stats['original_frames'].sum()) * 100)
    }
    
    # Append total row to DataFrame
    df_stats = pd.concat([df_stats, pd.DataFrame([total_row])], ignore_index=True)
    
    return df_stats

# Usage
df_differences = analyze_dataframe_differences(df_time, df_time_balanced)

In [None]:
df_differences