# Brain-to-Text: Preparing Neural Data for Kumo AI

This notebook takes the core data processing logic from the BCI phoneme decoding project and packages it for use in Google Colab.

The goal is to convert raw neural signals and phoneme annotations into a set of structured tables that are ready to be uploaded to the Kumo AI platform for powerful graph-based analysis.

**Pipeline Steps:**
1.  **Setup:** Install necessary libraries and define the required Python classes.
2.  **Load Data:** Generate a synthetic dataset to demonstrate the pipeline. *(You can replace this step with code to load your own data).*
3.  **Process Neural Signals:** Extract key events (peaks and valleys) from the raw time-series data.
4.  **Prepare Kumo Tables:** Convert the neural events and phoneme annotations into the relational format Kumo requires.
5.  **Save & Download:** Save the final tables as CSV files and download them.


In [None]:
!pip install pydantic==2.11.7 kumoai --upgrade --quiet

# Import libraries
import pandas as pd
import numpy as np
from dataclasses import dataclass
from enum import Enum
import scipy
from typing import Dict, List, Optional, Tuple, Any
import kumoai.experimental.rfm as rfm

  from .autonotebook import tqdm as notebook_tqdm
* 'smart_union' has been removed


In [40]:
# @dataclass
# class NeuralEvent:
#     """Represents a neural signal event (peak or valley)"""
#     timestamp: float
#     amplitude: float
#     channel: int
#     event_type: str  # 'peak' or 'valley'
#     frequency_band: Optional[str] = None
#     duration: Optional[float] = None
#     sharpness: Optional[float] = None

# @dataclass
# class PhonemeEvent:
#     """Represents a phoneme production event"""
#     phoneme: str
#     start_time: float
#     end_time: float
#     features: Dict[str, Any]

# class NeuralSignalProcessor:
#     """Process raw neural signals to extract peaks and valleys"""
#     def __init__(self, sampling_rate: float = 1000.0, peak_prominence: float = 0.5, min_peak_distance: int = 50):
#         self.sampling_rate = sampling_rate
#         self.peak_prominence = peak_prominence
#         self.min_peak_distance = min_peak_distance

#     def extract_peaks_valleys(self, signal_data: np.ndarray) -> Tuple[List[NeuralEvent], List[NeuralEvent]]:
#         timestamps = np.arange(signal_data.shape[1]) / self.sampling_rate
#         peaks = []
#         valleys = []

#         for channel_idx, channel_signal in enumerate(signal_data):
#             normalized = zscore(channel_signal)
#             # Explicitly use scipy.signal.find_peaks
#             peak_indices, peak_properties = scipy.signal.find_peaks(
#                 normalized, prominence=self.peak_prominence, distance=self.min_peak_distance
#             )
#             # Explicitly use scipy.signal.find_peaks
#             valley_indices, valley_properties = scipy.signal.find_peaks(
#                 -normalized, prominence=self.peak_prominence, distance=self.min_peak_distance
#             )

#             for idx, properties in zip(peak_indices, peak_properties['prominences']):
#                 peaks.append(NeuralEvent(
#                     timestamp=timestamps[idx],
#                     amplitude=channel_signal[idx],
#                     channel=channel_idx,
#                     event_type='peak',
#                     sharpness=properties
#                 ))

#             for idx, properties in zip(valley_indices, valley_properties['prominences']):
#                 valleys.append(NeuralEvent(
#                     timestamp=timestamps[idx],
#                     amplitude=channel_signal[idx],
#                     channel=channel_idx,
#                     event_type='valley',
#                     sharpness=properties
#                 ))
#         return peaks, valleys

# class KumoNeuralPhonemeIntegration:
#     """Integration layer for preparing data for Kumo AI"""
#     def prepare_for_kumo(self, neural_data: pd.DataFrame, phoneme_data: pd.DataFrame) -> Dict[str, pd.DataFrame]:
#         # Create event tables
#         neural_events_table = pd.DataFrame({
#             'event_id': range(len(neural_data)),
#             'timestamp': neural_data['timestamp'],
#             'amplitude': neural_data['amplitude'],
#             'channel': neural_data['channel'],
#             'event_type': neural_data['event_type'],
#             'subject_id': neural_data.get('subject_id', 'default')
#         })

#         phoneme_events_table = pd.DataFrame({
#             'phoneme_id': range(len(phoneme_data)),
#             'phoneme': phoneme_data['phoneme'],
#             'start_time': phoneme_data['start_time'],
#             'end_time': phoneme_data['end_time'],
#             'subject_id': phoneme_data.get('subject_id', 'default')
#         })

#         # Create relationship table for potential causal connections
#         causal_relationships = []
#         for _, neural_row in neural_events_table.iterrows():
#             for _, phoneme_row in phoneme_events_table.iterrows():
#                 delay = phoneme_row['start_time'] - neural_row['timestamp']
#                 if 0.05 <= delay <= 0.5:  # Potential causal window
#                     causal_relationships.append({
#                         'neural_event_id': neural_row['event_id'],
#                         'phoneme_id': phoneme_row['phoneme_id'],
#                         'delay': delay
#                     })

#         causality_table = pd.DataFrame(causal_relationships)

#         return {
#             'neural_events': neural_events_table,
#             'phoneme_events': phoneme_events_table,
#             'causal_relationships': causality_table
#         }

# print("Setup complete. All classes are defined.")

In [4]:
# # @title Step 2: Download Real Data from Dryad (Rate Limited)
# # This cell contains the logic from the `download_data.py` script, adapted
# # for Colab/Jupyter with proper rate limiting to avoid IOPub message overflow

# import sys
# import os
# import urllib.request
# import json
# import zipfile
# import time
# from threading import Lock

# # Global variables for rate limiting
# _last_update_time = {}
# _update_lock = Lock()

# def display_progress_bar(block_num, block_size, total_size, message="", filename=""):
#     """Helper function to show a download progress bar with rate limiting."""
#     bytes_downloaded_so_far = block_num * block_size
#     MB_downloaded_so_far = bytes_downloaded_so_far / 1e6
#     MB_total = total_size / 1e6
#     current_time = time.time()
    
#     # Use filename as key for tracking updates per file
#     file_key = filename or "default"
    
#     with _update_lock:
#         # Only update every 2 seconds to avoid rate limiting
#         if (file_key not in _last_update_time or 
#             (current_time - _last_update_time[file_key]) > 2.0):
            
#             # Calculate percentage and speed
#             percentage = (bytes_downloaded_so_far / total_size * 100) if total_size > 0 else 0
            
#             # Show progress with less frequent updates
#             sys.stdout.write(
#                 f"\r{message}: {MB_downloaded_so_far:.1f}/{MB_total:.1f} MB ({percentage:.1f}%)"
#             )
#             sys.stdout.flush()
            
#             _last_update_time[file_key] = current_time

# def download_with_progress(url, filepath, filename):
#     """Download a file with rate-limited progress reporting."""
#     print(f"\nStarting download: {filename}")
#     start_time = time.time()
    
#     def progress_hook(block_num, block_size, total_size):
#         display_progress_bar(block_num, block_size, total_size, 
#                            f"Downloading {filename}", filename)
    
#     try:
#         urllib.request.urlretrieve(url, filepath, reporthook=progress_hook)
        
#         # Final status
#         file_size = os.path.getsize(filepath) / 1e6 if os.path.exists(filepath) else 0
#         duration = time.time() - start_time
#         speed = file_size / duration if duration > 0 else 0
        
#         print(f"\n‚úì {filename} complete: {file_size:.1f} MB in {duration:.1f}s ({speed:.1f} MB/s)")
        
#     except Exception as e:
#         print(f"\n‚úó Error downloading {filename}: {e}")
#         raise

# def download_and_unzip_data():
#     """Downloads and unzips the BCI competition data from Dryad with rate limiting."""
#     DRYAD_DOI = "10.5061/dryad.dncjsxm85"
#     DATA_DIR = "data/"
    
#     # Create the data directory
#     os.makedirs(DATA_DIR, exist_ok=True)
#     data_dirpath = os.path.abspath(DATA_DIR)
#     print(f"Data will be downloaded to: {data_dirpath}")
    
#     # Add delay to avoid hitting API limits
#     time.sleep(1)
    
#     # Get the list of files from the latest version on Dryad
#     DRYAD_ROOT = "https://datadryad.org"
#     urlified_doi = DRYAD_DOI.replace("/", "%2F")
#     versions_url = f"{DRYAD_ROOT}/api/v2/datasets/doi:{urlified_doi}/versions"
    
#     print("Fetching file list from Dryad...")
#     try:
#         with urllib.request.urlopen(versions_url) as response:
#             versions_info = json.loads(response.read().decode())
#     except Exception as e:
#         print(f"Error fetching version info: {e}")
#         return
    
#     time.sleep(1)  # Rate limit API calls
    
#     files_url_path = versions_info["_embedded"]["stash:versions"][-1]["_links"]["stash:files"]["href"]
#     files_url = f"{DRYAD_ROOT}{files_url_path}"
    
#     try:
#         with urllib.request.urlopen(files_url) as response:
#             files_info = json.loads(response.read().decode())
#     except Exception as e:
#         print(f"Error fetching file info: {e}")
#         return
    
#     file_infos = files_info["_embedded"]["stash:files"]
#     print(f"Found {len(file_infos)} files to download.")
    
#     # Download each file into the data directory
#     for i, file_info in enumerate(file_infos, 1):
#         filename = file_info["path"]
        
#         if filename == "README.md":
#             print(f"Skipping {filename}")
#             continue
        
#         print(f"\n[{i}/{len(file_infos)}] Processing: {filename}")
        
#         download_path = file_info["_links"]["stash:download"]["href"]
#         download_url = f"{DRYAD_ROOT}{download_path}"
#         download_to_filepath = os.path.join(data_dirpath, filename)
        
#         # Check if file already exists
#         if os.path.exists(download_to_filepath):
#             file_size = os.path.getsize(download_to_filepath) / 1e6
#             print(f"File already exists ({file_size:.1f} MB). Skipping download.")
#         else:
#             # Download the file with progress
#             download_with_progress(download_url, download_to_filepath, filename)
        
#         # If this file is a zip file, unzip it
#         if file_info["mimeType"] == "application/zip":
#             print(f"Extracting files from {filename}...")
#             try:
#                 with zipfile.ZipFile(download_to_filepath, "r") as zf:
#                     # Get extraction info
#                     file_list = zf.namelist()
#                     print(f"  Extracting {len(file_list)} files...")
                    
#                     # Extract with progress for large archives
#                     extracted_count = 0
#                     for member in file_list:
#                         zf.extract(member, data_dirpath)
#                         extracted_count += 1
                        
#                         # Rate-limited extraction progress
#                         if extracted_count % 100 == 0 or extracted_count == len(file_list):
#                             print(f"  Extracted {extracted_count}/{len(file_list)} files...")
                
#                 print(f"‚úì Extraction complete: {len(file_list)} files")
                
#             except Exception as e:
#                 print(f"‚úó Error extracting {filename}: {e}")
        
#         # Rate limit between files
#         time.sleep(0.5)
    
#     print(f"\nüéâ Download complete! See data files in {data_dirpath}")
    
#     # Show final directory contents
#     try:
#         files = os.listdir(data_dirpath)
#         print(f"\nDownloaded files ({len(files)} total):")
#         for file in sorted(files)[:10]:  # Show first 10 files
#             file_path = os.path.join(data_dirpath, file)
#             if os.path.isfile(file_path):
#                 size_mb = os.path.getsize(file_path) / 1e6
#                 print(f"  üìÑ {file} ({size_mb:.1f} MB)")
#             else:
#                 print(f"  üìÅ {file}/")
        
#         if len(files) > 10:
#             print(f"  ... and {len(files) - 10} more files")
            
#     except Exception as e:
#         print(f"Error listing directory: {e}")

# # Alternative: Simplified version for very restrictive environments
# def download_simple():
#     """Simplified download with minimal output for restrictive Jupyter environments."""
#     DRYAD_DOI = "10.5061/dryad.dncjsxm85"
#     DATA_DIR = "data/"
    
#     os.makedirs(DATA_DIR, exist_ok=True)
#     data_dirpath = os.path.abspath(DATA_DIR)
    
#     print("Starting Dryad download (simplified mode)...")
    
#     # Get file list
#     DRYAD_ROOT = "https://datadryad.org"
#     urlified_doi = DRYAD_DOI.replace("/", "%2F")
#     versions_url = f"{DRYAD_ROOT}/api/v2/datasets/doi:{urlified_doi}/versions"
    
#     with urllib.request.urlopen(versions_url) as response:
#         versions_info = json.loads(response.read().decode())
    
#     files_url_path = versions_info["_embedded"]["stash:versions"][-1]["_links"]["stash:files"]["href"]
#     files_url = f"{DRYAD_ROOT}{files_url_path}"
    
#     with urllib.request.urlopen(files_url) as response:
#         files_info = json.loads(response.read().decode())
    
#     file_infos = files_info["_embedded"]["stash:files"]
    
#     # Download files with minimal output
#     for file_info in file_infos:
#         filename = file_info["path"]
#         if filename == "README.md":
#             continue
            
#         download_path = file_info["_links"]["stash:download"]["href"]
#         download_url = f"{DRYAD_ROOT}{download_path}"
#         download_to_filepath = os.path.join(data_dirpath, filename)
        
#         print(f"Downloading {filename}...")
#         urllib.request.urlretrieve(download_url, download_to_filepath)
        
#         if file_info["mimeType"] == "application/zip":
#             print(f"Extracting {filename}...")
#             with zipfile.ZipFile(download_to_filepath, "r") as zf:
#                 zf.extractall(data_dirpath)
    
#     print("Download complete!")

# # Run the download function
# print("Choose download method:")
# print("1. Full version with progress bars (recommended)")
# print("2. Simple version (if rate limiting issues persist)")

# # Uncomment the version you want to use:
# download_and_unzip_data()  # Full version
# # download_simple()  # Simple version

Choose download method:
1. Full version with progress bars (recommended)
2. Simple version (if rate limiting issues persist)
Data will be downloaded to: /home/ubuntu/data
Fetching file list from Dryad...
Found 5 files to download.
Skipping README.md

[2/5] Processing: t15_copyTask_neuralData.zip
File already exists (980.3 MB). Skipping download.
Extracting files from t15_copyTask_neuralData.zip...
‚úó Error extracting t15_copyTask_neuralData.zip: File is not a zip file

[3/5] Processing: t15_copyTask.pkl
File already exists (57.8 MB). Skipping download.

[4/5] Processing: t15_personalUse.pkl
File already exists (1.1 MB). Skipping download.

[5/5] Processing: t15_pretrained_rnn_baseline.zip
File already exists (484.9 MB). Skipping download.
Extracting files from t15_pretrained_rnn_baseline.zip...
  Extracting 12 files...
  Extracted 12/12 files...
‚úì Extraction complete: 12 files

üéâ Download complete! See data files in /home/ubuntu/data

Downloaded files (6 total):
  üìÅ __MACOSX/


In [5]:
# @title Load data from alternative formats
import pickle
import numpy as np

# Option 1: Try pickle files
pkl_files = ["data/t15_copyTask.pkl", "data/personalUse.pkl"]

for pkl_file in pkl_files:
    if os.path.exists(pkl_file):
        print(f"\nTrying to load pickle file: {pkl_file}")
        try:
            with open(pkl_file, 'rb') as f:
                data = pickle.load(f)
                print(f"Successfully loaded!")
                print(f"Data type: {type(data)}")

                if isinstance(data, dict):
                    print(f"Dictionary keys: {list(data.keys())}")
                    # Check for neural data
                    for key in data.keys():
                        if 'neural' in key.lower() or 'signal' in key.lower():
                            print(f"  Found potential neural data in key: {key}")
                            print(f"  Shape: {data[key].shape if hasattr(data[key], 'shape') else 'N/A'}")

                # Store for processing
                loaded_data = data
                break

        except Exception as e:
            print(f"Error loading {pkl_file}: {e}")


Trying to load pickle file: data/t15_copyTask.pkl
Successfully loaded!
Data type: <class 'dict'>
Dictionary keys: ['post_implant_day', 'vocab_size', 'cue_sentence', 'cue_sentence_phonemes', 'decoded_logits', 'decoded_phonemes_raw', 'decoded_sentence', 'decoded_sentence_phonemes', 'speech_duration_s']


In [6]:
try:
    # Import from your config file
    from kumo_config import KUMO_API_KEY
    os.environ["KUMO_API_KEY"] = KUMO_API_KEY
    print("‚úì API key loaded from kumo_config.py")
except ImportError:
    print("‚ö†Ô∏è  kumo_config.py not found, will use interactive authentication")

‚úì API key loaded from kumo_config.py


In [7]:
# Extract Neural Signals and Phoneme Data from Pickle
import numpy as np
import pandas as pd

# Extract the decoded logits as neural signals
print("Extracting data from pickle file...")

# Use decoded_logits as neural signal representation
decoded_logits_list = loaded_data['decoded_logits']
print(f"Decoded logits is a list of {len(decoded_logits_list)} arrays.")

# Concatenate the list of arrays into a single NumPy array
if decoded_logits_list:
    decoded_logits = np.concatenate(decoded_logits_list, axis=0)
    print(f"Concatenated decoded logits shape: {decoded_logits.shape}")

    # Convert to neural signals format (channels √ó time)
    # Assuming channels are the second dimension after concatenation, time is the first
    neural_signals = decoded_logits.T
    print(f"Neural signals shape: {neural_signals.shape} (channels √ó time points)")

    # Set sampling rate b KI ased on phoneme rate (assuming decoded_logits time points align with phoneme duration)
    # This might need adjustment based on actual data structure and timing
    SAMPLING_RATE = decoded_logits.shape[0] / loaded_data['speech_duration_s'][0] # Assuming speech_duration_s is a list
    print(f"Estimated sampling rate: {SAMPLING_RATE:.1f} Hz")

else:
    print("Error: decoded_logits list is empty.")
    neural_signals = np.array([]) # Initialize as empty array to prevent further errors
    SAMPLING_RATE = 1000.0 # Default or handle appropriately


# Create phoneme annotations
phonemes = loaded_data['cue_sentence_phonemes']
duration = loaded_data['speech_duration_s']
n_trials = len(phonemes) # Number of trials is the number of sentences/phoneme lists

all_phoneme_annotations = []

# Iterate through each trial to create phoneme annotations
for trial_idx in range(n_trials):
    trial_phonemes = phonemes[trial_idx]
    trial_duration = duration[trial_idx]
    n_phonemes_in_trial = len(trial_phonemes)

    if n_phonemes_in_trial > 0:
      # Create start and end times for phonemes in this trial
      # Distribute phonemes evenly across the trial duration
      start_times = np.linspace(0, trial_duration * 0.9, n_phonemes_in_trial)
      end_times = np.linspace(trial_duration * 0.1, trial_duration, n_phonemes_in_trial)

      trial_annotations = pd.DataFrame({
          'trial_id': loaded_data['post_implant_day'][trial_idx],
          'phoneme_id': [f'{trial_idx}_{i}' for i in range(n_phonemes_in_trial)], # Unique ID per phoneme
          'phoneme': trial_phonemes,
          'start_time': start_times,
          'end_time': end_times,
          'duration': end_times - start_times,
          'sequence_position': range(n_phonemes_in_trial),
          'total_sequence_length': n_phonemes_in_trial,
          'subject_id': 't15' # Assuming subject id is constant for this dataset
      })
      all_phoneme_annotations.append(trial_annotations)

# Concatenate all trial annotations into a single DataFrame
if all_phoneme_annotations:
  phoneme_annotations = pd.concat(all_phoneme_annotations, ignore_index=True)
else:
  phoneme_annotations = pd.DataFrame() # Empty DataFrame if no phonemes found

print(f"\nCreated {len(phoneme_annotations)} phoneme annotations across {n_trials} trials.")
# print(f"Trial: Day {loaded_data['post_implant_day']} post-implant") # This will print a list, which is not very informative
print(f"Vocabulary size: {loaded_data['vocab_size'][0]} words (based on the first trial)") # Assuming vocab size is consistent

Extracting data from pickle file...
Decoded logits is a list of 1718 arrays.
Concatenated decoded logits shape: (346439, 41)
Neural signals shape: (41, 346439) (channels √ó time points)
Estimated sampling rate: 84087.1 Hz

Created 46276 phoneme annotations across 1718 trials.
Vocabulary size: 50 words (based on the first trial)


In [10]:
# @title: Extract Peaks and Valleys from Neural Signals (Enhanced)
from scipy.signal import find_peaks
from scipy.stats import zscore
import pandas as pd
import numpy as np
from tqdm import tqdm
import warnings

print("Processing neural signals into discrete events...")

# Validate inputs first
if 'neural_signals' not in locals():
    print("‚ùå Error: neural_signals not found")
    raise ValueError("neural_signals must be defined first")

if 'SAMPLING_RATE' not in locals():
    print("‚ö†Ô∏è Warning: SAMPLING_RATE not defined, using default 1000.0 Hz")
    SAMPLING_RATE = 1000.0

if 'loaded_data' not in locals():
    print("‚ö†Ô∏è Warning: loaded_data not found, using default trial_id")
    loaded_data = {'post_implant_day': ['default_trial']}

print(f"Input validation:")
print(f"  Neural signals shape: {neural_signals.shape}")
print(f"  Sampling rate: {SAMPLING_RATE} Hz")
print(f"  Channels to process: {neural_signals.shape[0]}")

# Initialize lists for events
neural_events = []
event_id = 0
channels_processed = 0
channels_skipped = 0

# Parameters for peak detection (make these configurable)
PEAK_PROMINENCE = 0.5
MIN_DISTANCE_MS = 10  # Minimum 10ms between peaks
MIN_DISTANCE_SAMPLES = int(SAMPLING_RATE * (MIN_DISTANCE_MS / 1000))

print(f"Peak detection parameters:")
print(f"  Prominence threshold: {PEAK_PROMINENCE}")
print(f"  Minimum distance: {MIN_DISTANCE_MS}ms ({MIN_DISTANCE_SAMPLES} samples)")

# Process each channel with progress bar
for channel_idx in tqdm(range(neural_signals.shape[0]), desc="Processing channels"):
    channel_signal = neural_signals[channel_idx, :]
    
    # Skip if channel is flat or has invalid data
    signal_std = np.std(channel_signal)
    if signal_std < 1e-6:
        channels_skipped += 1
        continue
    
    # Check for NaN or infinite values
    if np.any(np.isnan(channel_signal)) or np.any(np.isinf(channel_signal)):
        print(f"‚ö†Ô∏è Warning: Channel {channel_idx} contains NaN/inf values, skipping")
        channels_skipped += 1
        continue
    
    # Normalize the signal
    try:
        normalized = zscore(channel_signal)
        
        # Handle case where zscore returns NaN (constant signal)
        if np.any(np.isnan(normalized)):
            print(f"‚ö†Ô∏è Warning: Channel {channel_idx} normalization failed, skipping")
            channels_skipped += 1
            continue
            
    except Exception as e:
        print(f"‚ö†Ô∏è Warning: Channel {channel_idx} normalization error: {e}")
        channels_skipped += 1
        continue
    
    # Find peaks with error handling
    try:
        peak_indices, peak_properties = find_peaks(
            normalized,
            prominence=PEAK_PROMINENCE,
            distance=MIN_DISTANCE_SAMPLES
        )
        
        # Find valleys (peaks in inverted signal)
        valley_indices, valley_properties = find_peaks(
            -normalized,
            prominence=PEAK_PROMINENCE,
            distance=MIN_DISTANCE_SAMPLES
        )
        
    except Exception as e:
        print(f"‚ö†Ô∏è Warning: Peak detection failed for channel {channel_idx}: {e}")
        channels_skipped += 1
        continue
    
    # Handle trial_id properly (it might be a list)
    trial_id_value = loaded_data['post_implant_day']
    if isinstance(trial_id_value, list):
        trial_id_value = trial_id_value[0] if trial_id_value else 'unknown'
    
    # Create events for peaks
    for i, idx in enumerate(peak_indices):
        # Add additional peak properties if available
        prominence = peak_properties.get('prominences', [0])[i] if 'prominences' in peak_properties else 0
        
        neural_events.append({
            'event_id': event_id,
            'timestamp': idx / SAMPLING_RATE,
            'amplitude': channel_signal[idx],
            'normalized_amplitude': normalized[idx],
            'prominence': prominence,
            'channel': channel_idx,
            'event_type': 'peak',
            'channel_region': channel_idx // 8,  # Group into regions
            'trial_id': trial_id_value,
            'sample_index': idx
        })
        event_id += 1
    
    # Create events for valleys
    for i, idx in enumerate(valley_indices):
        # Add additional valley properties if available
        prominence = valley_properties.get('prominences', [0])[i] if 'prominences' in valley_properties else 0
        
        neural_events.append({
            'event_id': event_id,
            'timestamp': idx / SAMPLING_RATE,
            'amplitude': channel_signal[idx],
            'normalized_amplitude': -normalized[idx],  # Negative because we inverted for valley detection
            'prominence': prominence,
            'channel': channel_idx,
            'event_type': 'valley',
            'channel_region': channel_idx // 8,
            'trial_id': trial_id_value,
            'sample_index': idx
        })
        event_id += 1
    
    channels_processed += 1

# Create DataFrame and sort by timestamp
if neural_events:
    neural_events_df = pd.DataFrame(neural_events).sort_values('timestamp').reset_index(drop=True)
    
    # Additional analysis
    print(f"\n‚úÖ Neural event extraction completed!")
    print(f"üìä Summary statistics:")
    print(f"  Channels processed: {channels_processed}")
    print(f"  Channels skipped: {channels_skipped}")
    print(f"  Total events extracted: {len(neural_events_df)}")
    print(f"  Peaks: {sum(neural_events_df['event_type'] == 'peak')}")
    print(f"  Valleys: {sum(neural_events_df['event_type'] == 'valley')}")
    print(f"  Time range: {neural_events_df['timestamp'].min():.3f}s - {neural_events_df['timestamp'].max():.3f}s")
    print(f"  Average events per channel: {len(neural_events_df) / channels_processed:.1f}")
    
    # Channel-wise analysis
    events_per_channel = neural_events_df.groupby('channel').size()
    print(f"  Most active channel: {events_per_channel.idxmax()} ({events_per_channel.max()} events)")
    print(f"  Least active channel: {events_per_channel.idxmin()} ({events_per_channel.min()} events)")
    
    # Amplitude analysis
    print(f"üìà Amplitude statistics:")
    print(f"  Peak amplitudes: {neural_events_df[neural_events_df['event_type'] == 'peak']['amplitude'].describe()}")
    print(f"  Valley amplitudes: {neural_events_df[neural_events_df['event_type'] == 'valley']['amplitude'].describe()}")
    
else:
    print("‚ùå No neural events extracted! Check your signal processing parameters.")
    neural_events_df = pd.DataFrame()

# --- ADVANCED VERSION WITH ADAPTIVE PARAMETERS ---
def extract_neural_events_adaptive(neural_signals, sampling_rate, 
                                 adaptive_prominence=True,
                                 min_events_per_channel=10,
                                 max_events_per_channel=1000):
    """
    Enhanced version with adaptive parameters for different channel characteristics
    """
    print("üß† Adaptive neural event extraction...")
    
    all_events = []
    event_id = 0
    
    for channel_idx in tqdm(range(neural_signals.shape[0]), desc="Adaptive processing"):
        channel_signal = neural_signals[channel_idx, :]
        
        if np.std(channel_signal) < 1e-6:
            continue
        
        normalized = zscore(channel_signal)
        
        # Adaptive prominence based on signal characteristics
        if adaptive_prominence:
            signal_variance = np.var(normalized)
            prominence = max(0.3, min(1.0, signal_variance * 0.5))
        else:
            prominence = 0.5
        
        # Try different distance parameters if not enough events
        distances = [int(sampling_rate * 0.005), int(sampling_rate * 0.01), int(sampling_rate * 0.02)]
        
        for distance in distances:
            peak_indices, _ = find_peaks(normalized, prominence=prominence, distance=distance)
            valley_indices, _ = find_peaks(-normalized, prominence=prominence, distance=distance)
            
            total_events = len(peak_indices) + len(valley_indices)
            
            if min_events_per_channel <= total_events <= max_events_per_channel:
                break
        
        # Create events with adaptive parameters
        for idx in peak_indices:
            all_events.append({
                'event_id': event_id,
                'timestamp': idx / sampling_rate,
                'amplitude': channel_signal[idx],
                'channel': channel_idx,
                'event_type': 'peak',
                'prominence_used': prominence,
                'distance_used': distance
            })
            event_id += 1
        
        for idx in valley_indices:
            all_events.append({
                'event_id': event_id,
                'timestamp': idx / sampling_rate,
                'amplitude': channel_signal[idx],
                'channel': channel_idx,
                'event_type': 'valley',
                'prominence_used': prominence,
                'distance_used': distance
            })
            event_id += 1
    
    return pd.DataFrame(all_events).sort_values('timestamp').reset_index(drop=True)

# --- QUALITY CONTROL VERSION ---
def extract_neural_events_with_qc(neural_signals, sampling_rate,
                                signal_quality_threshold=0.1,
                                artifact_threshold=5.0):
    """
    Version with signal quality control and artifact rejection
    """
    print("üîç Neural event extraction with quality control...")
    
    events = []
    event_id = 0
    qc_stats = {'good_channels': 0, 'noisy_channels': 0, 'flat_channels': 0}
    
    for channel_idx in tqdm(range(neural_signals.shape[0]), desc="QC processing"):
        channel_signal = neural_signals[channel_idx, :]
        
        # Signal quality assessment
        signal_std = np.std(channel_signal)
        signal_range = np.ptp(channel_signal)  # Peak-to-peak
        
        # Skip flat channels
        if signal_std < 1e-6:
            qc_stats['flat_channels'] += 1
            continue
        
        # Skip very noisy channels
        normalized = zscore(channel_signal)
        if np.any(np.abs(normalized) > artifact_threshold):
            qc_stats['noisy_channels'] += 1
            continue
        
        qc_stats['good_channels'] += 1
        
        # Proceed with peak detection
        peak_indices, peak_props = find_peaks(
            normalized, 
            prominence=0.5, 
            distance=int(sampling_rate * 0.01)
        )
        
        valley_indices, valley_props = find_peaks(
            -normalized, 
            prominence=0.5, 
            distance=int(sampling_rate * 0.01)
        )
        
        # Add quality metrics to events
        for i, idx in enumerate(peak_indices):
            events.append({
                'event_id': event_id,
                'timestamp': idx / sampling_rate,
                'amplitude': channel_signal[idx],
                'channel': channel_idx,
                'event_type': 'peak',
                'signal_quality': signal_std,
                'prominence': peak_props.get('prominences', [0])[i] if 'prominences' in peak_props else 0
            })
            event_id += 1
        
        for i, idx in enumerate(valley_indices):
            events.append({
                'event_id': event_id,
                'timestamp': idx / sampling_rate,
                'amplitude': channel_signal[idx],
                'channel': channel_idx,
                'event_type': 'valley',
                'signal_quality': signal_std,
                'prominence': valley_props.get('prominences', [0])[i] if 'prominences' in valley_props else 0
            })
            event_id += 1
    
    print(f"üìä Quality control results:")
    print(f"  Good channels: {qc_stats['good_channels']}")
    print(f"  Noisy channels (excluded): {qc_stats['noisy_channels']}")
    print(f"  Flat channels (excluded): {qc_stats['flat_channels']}")
    
    return pd.DataFrame(events).sort_values('timestamp').reset_index(drop=True)

# Example usage of advanced versions:
# neural_events_df = extract_neural_events_adaptive(neural_signals, SAMPLING_RATE)
# neural_events_df = extract_neural_events_with_qc(neural_signals, SAMPLING_RATE)

print(f"\nüéØ Neural event extraction complete and ready for causal analysis!")

Processing neural signals into discrete events...
Input validation:
  Neural signals shape: (41, 346439)
  Sampling rate: 84087.1359223301 Hz
  Channels to process: 41
Peak detection parameters:
  Prominence threshold: 0.5
  Minimum distance: 10ms (840 samples)


Processing channels: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 41/41 [00:01<00:00, 33.74it/s]



‚úÖ Neural event extraction completed!
üìä Summary statistics:
  Channels processed: 41
  Channels skipped: 0
  Total events extracted: 25123
  Peaks: 12594
  Valleys: 12529
  Time range: 0.000s - 4.120s
  Average events per channel: 612.8
  Most active channel: 17 (630 events)
  Least active channel: 3 (599 events)
üìà Amplitude statistics:
  Peak amplitudes: count    12594.000000
mean        23.822777
std          9.594965
min        -10.531623
25%         16.162013
50%         26.001776
75%         30.278698
max         60.857002
Name: amplitude, dtype: float64
  Valley amplitudes: count    12529.000000
mean       -14.346450
std          3.787599
min        -30.202309
25%        -16.828918
50%        -14.386750
75%        -12.071593
max         17.986731
Name: amplitude, dtype: float64

üéØ Neural event extraction complete and ready for causal analysis!


In [15]:
# # Fixed Kumo Graph Creation with Proper Data Type Handling
# import pandas as pd
# import numpy as np
# import kumoai.experimental.rfm as rfm

# def prepare_kumo_tables_fixed(neural_events_df, phoneme_annotations, causal_df):
#     """
#     Prepare tables for Kumo AI with proper data type handling
#     """
#     print("üìä Preparing tables for Kumo AI (with data type fixes)...")
    
#     # 1. Neural Events Table - Clean copy with proper data types
#     neural_table = neural_events_df.copy()
    
#     # Convert all categorical columns to strings to avoid categorical issues
#     for col in neural_table.columns:
#         if neural_table[col].dtype.name == 'category':
#             neural_table[col] = neural_table[col].astype(str)
    
#     # Ensure numeric columns are proper numeric types
#     numeric_cols = ['event_id', 'timestamp', 'channel', 'amplitude']
#     for col in numeric_cols:
#         if col in neural_table.columns:
#             neural_table[col] = pd.to_numeric(neural_table[col], errors='coerce')
    
#     # Ensure string columns are strings
#     string_cols = ['event_type', 'trial_id']
#     for col in string_cols:
#         if col in neural_table.columns:
#             neural_table[col] = neural_table[col].astype(str)
    
#     # Add required features with safe data types
#     neural_table['timestamp_ms'] = neural_table['timestamp'] * 1000
#     neural_table['channel_group'] = neural_table['channel'] // 4
    
#     # Safe amplitude normalization
#     amp_mean = neural_table['amplitude'].mean()
#     amp_std = neural_table['amplitude'].std()
#     if amp_std > 0:
#         neural_table['amplitude_normalized'] = (neural_table['amplitude'] - amp_mean) / amp_std
#     else:
#         neural_table['amplitude_normalized'] = 0.0
    
#     print(f"‚úì Neural events table: {len(neural_table)} events")
    
#     # 2. Phoneme Events Table - Clean copy with proper data types
#     phoneme_table = phoneme_annotations.copy()
    
#     # Convert categorical columns to strings
#     for col in phoneme_table.columns:
#         if phoneme_table[col].dtype.name == 'category':
#             phoneme_table[col] = phoneme_table[col].astype(str)
    
#     # Ensure required columns exist and have correct types
#     required_cols = ['phoneme_id', 'phoneme', 'start_time', 'end_time']
#     for col in required_cols:
#         if col not in phoneme_table.columns:
#             if col == 'phoneme_id':
#                 phoneme_table['phoneme_id'] = range(len(phoneme_table))
#             elif col == 'end_time' and 'start_time' in phoneme_table.columns:
#                 phoneme_table['end_time'] = phoneme_table['start_time'] + 0.1  # Default 100ms duration
#             else:
#                 print(f"‚ö†Ô∏è Warning: Missing required column {col}")
    
#     # Ensure numeric columns
#     numeric_cols = ['start_time', 'end_time']
#     for col in numeric_cols:
#         if col in phoneme_table.columns:
#             phoneme_table[col] = pd.to_numeric(phoneme_table[col], errors='coerce')
    
#     # Ensure string columns
#     phoneme_table['phoneme'] = phoneme_table['phoneme'].astype(str)
    
#     # Add features with safe data types
#     phoneme_table['start_time_ms'] = phoneme_table['start_time'] * 1000
#     if 'end_time' in phoneme_table.columns:
#         phoneme_table['duration_ms'] = (phoneme_table['end_time'] - phoneme_table['start_time']) * 1000
#     else:
#         phoneme_table['duration_ms'] = 100.0  # Default 100ms
    
#     # Safe phoneme category extraction
#     phoneme_table['phoneme_category'] = phoneme_table['phoneme'].str[:1].fillna('UNK')
    
#     print(f"‚úì Phoneme events table: {len(phoneme_table)} phonemes")
    
#     # 3. Relationships Table - Clean copy with proper data types
#     relationships_table = causal_df.copy()
    
#     # Convert categorical columns to strings
#     for col in relationships_table.columns:
#         if relationships_table[col].dtype.name == 'category':
#             relationships_table[col] = relationships_table[col].astype(str)
    
#     # Ensure required edge columns exist
#     if 'neural_event_id' not in relationships_table.columns:
#         print("‚ùå Error: Missing neural_event_id in causal relationships")
#         return None
#     if 'phoneme_id' not in relationships_table.columns:
#         print("‚ùå Error: Missing phoneme_id in causal relationships")
#         return None
    
#     # Ensure numeric types for edge columns
#     relationships_table['neural_event_id'] = pd.to_numeric(relationships_table['neural_event_id'], errors='coerce')
#     relationships_table['phoneme_id'] = pd.to_numeric(relationships_table['phoneme_id'], errors='coerce')
#     relationships_table['delay_ms'] = pd.to_numeric(relationships_table['delay_ms'], errors='coerce')
#     relationships_table['strength'] = pd.to_numeric(relationships_table['strength'], errors='coerce')
    
#     # Safe categorical feature creation
#     try:
#         relationships_table['strength_category'] = pd.cut(
#             relationships_table['strength'], 
#             bins=[0, 0.7, 0.85, 1.0], 
#             labels=['weak', 'medium', 'strong'],
#             include_lowest=True
#         ).astype(str)
#     except Exception as e:
#         print(f"‚ö†Ô∏è Warning: Could not create strength categories: {e}")
#         relationships_table['strength_category'] = 'medium'  # Default value
    
#     # Ensure string columns
#     string_cols = ['event_type', 'phoneme']
#     for col in string_cols:
#         if col in relationships_table.columns:
#             relationships_table[col] = relationships_table[col].astype(str)
    
#     print(f"‚úì Causal relationships table: {len(relationships_table)} edges")
    
#     # 4. Data validation
#     print(f"üîç Validating data types...")
    
#     # Check for any remaining categorical columns
#     for table_name, table in [('neural_events', neural_table), 
#                              ('phoneme_events', phoneme_table), 
#                              ('relationships', relationships_table)]:
#         categorical_cols = [col for col in table.columns if table[col].dtype.name == 'category']
#         if categorical_cols:
#             print(f"‚ö†Ô∏è Warning: {table_name} still has categorical columns: {categorical_cols}")
    
#     # Check for NaN values in key columns
#     key_checks = [
#         (neural_table, 'event_id', 'neural_events'),
#         (phoneme_table, 'phoneme_id', 'phoneme_events'),
#         (relationships_table, 'neural_event_id', 'relationships'),
#         (relationships_table, 'phoneme_id', 'relationships')
#     ]
    
#     for table, col, table_name in key_checks:
#         if col in table.columns:
#             nan_count = table[col].isna().sum()
#             if nan_count > 0:
#                 print(f"‚ö†Ô∏è Warning: {table_name}.{col} has {nan_count} NaN values")
    
#     return {
#         'neural_events': neural_table,
#         'phoneme_events': phoneme_table,
#         'neural_phoneme_edges': relationships_table
#     }

# def create_kumo_graph_safe(tables):
#     """
#     Create Kumo AI graph with enhanced error handling
#     """
#     print("üîó Creating Kumo AI graph (safe mode)...")
    
#     try:
#         # Display table info before creation
#         print("üìã Table information:")
#         for table_name, table in tables.items():
#             print(f"  {table_name}: {len(table)} rows, {len(table.columns)} columns")
#             print(f"    Data types: {dict(table.dtypes)}")
            
#             # Check for any remaining issues
#             problematic_cols = []
#             for col in table.columns:
#                 if table[col].dtype.name == 'category':
#                     problematic_cols.append(f"{col} (categorical)")
#                 elif table[col].dtype == 'object':
#                     # Check if object column contains mixed types
#                     try:
#                         sample_values = table[col].dropna().head().tolist()
#                         value_types = [type(v).__name__ for v in sample_values]
#                         if len(set(value_types)) > 1:
#                             problematic_cols.append(f"{col} (mixed types: {value_types})")
#                     except:
#                         pass
            
#             if problematic_cols:
#                 print(f"    ‚ö†Ô∏è Potential issues: {problematic_cols}")
        
#         # Create the graph
#         print("üî® Creating graph...")
#         graph = rfm.LocalGraph.from_data(tables)
        
#         # Configure the graph
#         print("‚öôÔ∏è Configuring graph...")
        
#         # Set temporal columns
#         if 'neural_events' in tables:
#             graph['neural_events'].time_column = 'timestamp'
#             graph['neural_events'].primary_key = 'event_id'
        
#         if 'phoneme_events' in tables:
#             graph['phoneme_events'].time_column = 'start_time'
#             graph['phoneme_events'].primary_key = 'phoneme_id'
        
#         # Configure edge table
#         if 'neural_phoneme_edges' in tables:
#             graph['neural_phoneme_edges'].source_column = 'neural_event_id'
#             graph['neural_phoneme_edges'].target_column = 'phoneme_id'
#             graph['neural_phoneme_edges'].source_table = 'neural_events'
#             graph['neural_phoneme_edges'].target_table = 'phoneme_events'
        
#         print("‚úÖ Graph created successfully!")
        
#         # Display graph statistics
#         print(f"üìà Graph Statistics:")
#         print(f"  Tables: {len(graph.tables)}")
#         for table_name, table in graph.tables.items():
#             print(f"  - {table_name}: {len(table)} rows")
        
#         # Try to visualize (this might also fail, so wrap in try-catch)
#         try:
#             print("üé® Visualizing graph schema...")
#             graph.visualize(show_columns=True)
#         except Exception as viz_error:
#             print(f"‚ö†Ô∏è Could not visualize graph: {viz_error}")
        
#         return graph
        
#     except Exception as e:
#         print(f"‚ùå Error creating graph: {e}")
#         print(f"üîç Error details: {type(e).__name__}: {str(e)}")
        
#         # Additional debugging
#         if "categorical" in str(e).lower():
#             print("üí° This appears to be a categorical data issue.")
#             print("   Try running the data cleaning steps again.")
        
#         return None

# def create_simple_prediction_queries(graph):
#     """
#     Create simplified prediction queries that are less likely to fail
#     """
#     print("üß† Creating simplified prediction models...")
    
#     try:
#         # Simple Neural ‚Üí Phoneme prediction
#         neural_to_phoneme_query = {
#             'target_table': 'phoneme_events',
#             'target_column': 'phoneme',
#             'feature_tables': ['neural_events'],
#             'features': ['neural_events.event_type', 'neural_events.channel'],
#             'time_column': 'start_time',
#             'training_window': '2s'
#         }
        
#         # Simple Phoneme ‚Üí Neural prediction
#         phoneme_to_neural_query = {
#             'target_table': 'neural_events',
#             'target_column': 'event_type',
#             'feature_tables': ['phoneme_events'],
#             'features': ['phoneme_events.phoneme'],
#             'time_column': 'timestamp',
#             'training_window': '1s'
#         }
        
#         print("‚úì Simplified prediction queries created")
#         return neural_to_phoneme_query, phoneme_to_neural_query
        
#     except Exception as e:
#         print(f"‚ùå Error creating prediction queries: {e}")
#         return None, None

# def run_fixed_kumo_pipeline(neural_events_df, phoneme_annotations, causal_df):
#     """
#     Run the Kumo pipeline with enhanced error handling
#     """
#     print("üöÄ Running FIXED Kumo AI pipeline...")
#     print("=" * 60)
    
#     # Step 1: Prepare data with proper type handling
#     tables = prepare_kumo_tables_fixed(neural_events_df, phoneme_annotations, causal_df)
#     if not tables:
#         print("‚ùå Data preparation failed")
#         return None
    
#     # Step 2: Create graph with safe mode
#     graph = create_kumo_graph_safe(tables)
#     if not graph:
#         print("‚ùå Graph creation failed")
#         return None
    
#     # Step 3: Try simple predictions
#     try:
#         print("üéØ Attempting basic graph queries...")
        
#         # Test basic queries to ensure graph works
#         neural_count = graph.query("SELECT COUNT(*) as count FROM neural_events")
#         phoneme_count = graph.query("SELECT COUNT(*) as count FROM phoneme_events")
        
#         print(f"‚úì Graph queries working:")
#         print(f"  Neural events: {neural_count['count'].iloc[0]}")
#         print(f"  Phoneme events: {phoneme_count['count'].iloc[0]}")
        
#         # Try to create simplified models
#         neural_to_phoneme_query, phoneme_to_neural_query = create_simple_prediction_queries(graph)
        
#         result = {
#             'graph': graph,
#             'tables': tables,
#             'neural_to_phoneme_query': neural_to_phoneme_query,
#             'phoneme_to_neural_query': phoneme_to_neural_query,
#             'status': 'graph_created'
#         }
        
#         print("‚úÖ Fixed pipeline completed successfully!")
#         print("üìä Graph is ready for predictions")
        
#         return result
        
#     except Exception as e:
#         print(f"‚ùå Error in pipeline execution: {e}")
#         return None

# # ============================================================================
# # RUN THE FIXED PIPELINE
# # ============================================================================

# print("üîß RUNNING FIXED KUMO PIPELINE")
# print("=" * 50)

# if 'causal_df' in globals() and len(causal_df) > 0:
#     print(f"‚úÖ Found causal_df with {len(causal_df)} relationships")
    
#     # Run the fixed pipeline
#     fixed_results = run_fixed_kumo_pipeline(neural_events_df, phoneme_annotations, causal_df)
    
#     if fixed_results and fixed_results['graph']:
#         print(f"\nüéâ SUCCESS! Graph created successfully!")
#         print(f"üìä You now have a working Kumo AI graph")
#         print(f"üîç Try some basic queries:")
        
#         # Example queries
#         try:
#             graph = fixed_results['graph']
            
#             # Basic statistics
#             print(f"\nüìà Basic Graph Statistics:")
            
#             # Neural events by type
#             event_types = graph.query("""
#                 SELECT event_type, COUNT(*) as count 
#                 FROM neural_events 
#                 GROUP BY event_type
#             """)
#             print(f"Event types: {dict(zip(event_types['event_type'], event_types['count']))}")
            
#             # Top phonemes
#             top_phonemes = graph.query("""
#                 SELECT phoneme, COUNT(*) as count 
#                 FROM phoneme_events 
#                 GROUP BY phoneme 
#                 ORDER BY count DESC 
#                 LIMIT 5
#             """)
#             print(f"Top phonemes: {dict(zip(top_phonemes['phoneme'], top_phonemes['count']))}")
            
#             # Causal relationship stats
#             causal_stats = graph.query("""
#                 SELECT 
#                     COUNT(*) as total_edges,
#                     AVG(delay_ms) as avg_delay,
#                     MIN(delay_ms) as min_delay,
#                     MAX(delay_ms) as max_delay
#                 FROM neural_phoneme_edges
#             """)
#             print(f"Causal relationships: {causal_stats.iloc[0].to_dict()}")
            
#         except Exception as e:
#             print(f"‚ö†Ô∏è Could not run example queries: {e}")
        
#         print(f"\nüéØ Next steps:")
#         print(f"1. Try training simple models on this graph")
#         print(f"2. Create prediction queries")
#         print(f"3. Build real-time prediction system")
        
#     else:
#         print(f"\n‚ùå Fixed pipeline also failed")
#         print(f"üí° The dataset might need more preprocessing")
        
# else:
#     print(f"‚ùå causal_df not found or empty")
#     print(f"Run the causal relationship creation code first")

üîß RUNNING FIXED KUMO PIPELINE
‚úÖ Found causal_df with 5009609 relationships
üöÄ Running FIXED Kumo AI pipeline...
üìä Preparing tables for Kumo AI (with data type fixes)...
‚úì Neural events table: 25123 events
‚úì Phoneme events table: 46276 phonemes
‚úì Causal relationships table: 5009609 edges
üîç Validating data types...
üîó Creating Kumo AI graph (safe mode)...
üìã Table information:
  neural_events: 25123 rows, 13 columns
    Data types: {'event_id': dtype('int64'), 'timestamp': dtype('float64'), 'amplitude': dtype('float32'), 'normalized_amplitude': dtype('float32'), 'prominence': dtype('float64'), 'channel': dtype('int64'), 'event_type': dtype('O'), 'channel_region': dtype('int64'), 'trial_id': dtype('O'), 'sample_index': dtype('int64'), 'timestamp_ms': dtype('float64'), 'channel_group': dtype('int64'), 'amplitude_normalized': dtype('float32')}
  phoneme_events: 46276 rows, 12 columns
    Data types: {'trial_id': dtype('int64'), 'phoneme_id': dtype('O'), 'phoneme': dty

In [17]:
!pip install graphviz

Defaulting to user installation because normal site-packages is not writeable


In [39]:
# #!/usr/bin/env python3
# """
# kumo_graph_final_fix.py - Final comprehensive fix for Kumo AI graph creation
# """

# import pandas as pd
# import numpy as np
# import kumoai.experimental.rfm as rfm

# def fix_phoneme_id_issue(causal_df, phoneme_annotations):
#     """
#     Fix the phoneme_id NaN issue by ensuring proper ID mapping
#     """
#     print("üîß Fixing phoneme_id mapping issue...")
    
#     # Check if phoneme_id in causal_df is string format
#     causal_df_fixed = causal_df.copy()
    
#     # If phoneme_id is string format like "0_1", "1_2", etc., we need to map to proper IDs
#     if 'phoneme_id' in causal_df_fixed.columns:
#         # Create a mapping from string phoneme_ids to numeric IDs
#         unique_phoneme_ids = phoneme_annotations['phoneme_id'].unique()
        
#         # Create mapping dictionary
#         phoneme_id_mapping = {}
#         for i, pid in enumerate(unique_phoneme_ids):
#             phoneme_id_mapping[pid] = i
        
#         # Map the causal_df phoneme_id to the new numeric IDs
#         causal_df_fixed['phoneme_id_numeric'] = causal_df_fixed['phoneme_id'].map(phoneme_id_mapping)
        
#         # If mapping failed, create sequential mapping
#         if causal_df_fixed['phoneme_id_numeric'].isna().all():
#             print("  Creating sequential phoneme ID mapping...")
#             # Get unique phoneme_ids from causal_df
#             unique_causal_phoneme_ids = causal_df_fixed['phoneme_id'].unique()
#             # Create sequential mapping
#             id_mapping = {pid: i for i, pid in enumerate(unique_causal_phoneme_ids)}
#             causal_df_fixed['phoneme_id_numeric'] = causal_df_fixed['phoneme_id'].map(id_mapping)
        
#         # Replace the old phoneme_id column
#         causal_df_fixed['phoneme_id'] = causal_df_fixed['phoneme_id_numeric']
#         causal_df_fixed = causal_df_fixed.drop('phoneme_id_numeric', axis=1)
        
#         # Ensure phoneme_annotations has matching numeric IDs
#         phoneme_annotations_fixed = phoneme_annotations.copy()
#         phoneme_annotations_fixed['phoneme_id'] = range(len(phoneme_annotations_fixed))
        
#         print(f"  ‚úì Fixed phoneme_id mapping")
#         print(f"  ‚úì Causal relationships: {len(causal_df_fixed)} (NaN count: {causal_df_fixed['phoneme_id'].isna().sum()})")
#         print(f"  ‚úì Phoneme annotations: {len(phoneme_annotations_fixed)}")
        
#         return causal_df_fixed, phoneme_annotations_fixed
    
#     return causal_df, phoneme_annotations

# def prepare_kumo_tables_fixed(neural_events_df, phoneme_annotations, causal_df):
#     """
#     Prepare tables for Kumo AI with proper data type handling and all fixes
#     """
#     print("üìä Preparing tables for Kumo AI (with comprehensive fixes)...")
    
#     # First fix the phoneme_id issue
#     causal_df_fixed, phoneme_annotations_fixed = fix_phoneme_id_issue(causal_df, phoneme_annotations)
    
#     # 1. Neural Events Table - Clean copy with proper data types
#     neural_table = neural_events_df.copy()
    
#     # Convert all categorical columns to strings to avoid categorical issues
#     for col in neural_table.columns:
#         if neural_table[col].dtype.name == 'category':
#             neural_table[col] = neural_table[col].astype(str)
    
#     # Ensure numeric columns are proper numeric types
#     numeric_cols = ['event_id', 'timestamp', 'channel', 'amplitude']
#     for col in numeric_cols:
#         if col in neural_table.columns:
#             neural_table[col] = pd.to_numeric(neural_table[col], errors='coerce')
    
#     # Ensure string columns are strings
#     string_cols = ['event_type', 'trial_id']
#     for col in string_cols:
#         if col in neural_table.columns:
#             neural_table[col] = neural_table[col].astype(str)
    
#     # Add required features with safe data types
#     neural_table['timestamp_ms'] = neural_table['timestamp'] * 1000
#     neural_table['channel_group'] = neural_table['channel'] // 4
    
#     # Safe amplitude normalization
#     amp_mean = neural_table['amplitude'].mean()
#     amp_std = neural_table['amplitude'].std()
#     if amp_std > 0:
#         neural_table['amplitude_normalized'] = (neural_table['amplitude'] - amp_mean) / amp_std
#     else:
#         neural_table['amplitude_normalized'] = 0.0
    
#     print(f"‚úì Neural events table: {len(neural_table)} events")
    
#     # 2. Phoneme Events Table - Clean copy with proper data types
#     phoneme_table = phoneme_annotations_fixed.copy()
    
#     # Convert categorical columns to strings
#     for col in phoneme_table.columns:
#         if phoneme_table[col].dtype.name == 'category':
#             phoneme_table[col] = phoneme_table[col].astype(str)
    
#     # Ensure required columns exist and have correct types
#     required_cols = ['phoneme_id', 'phoneme', 'start_time', 'end_time']
#     for col in required_cols:
#         if col not in phoneme_table.columns:
#             if col == 'phoneme_id':
#                 phoneme_table['phoneme_id'] = range(len(phoneme_table))
#             elif col == 'end_time' and 'start_time' in phoneme_table.columns:
#                 phoneme_table['end_time'] = phoneme_table['start_time'] + 0.1  # Default 100ms duration
#             else:
#                 print(f"‚ö†Ô∏è Warning: Missing required column {col}")
    
#     # Ensure numeric columns
#     numeric_cols = ['phoneme_id', 'start_time', 'end_time']
#     for col in numeric_cols:
#         if col in phoneme_table.columns:
#             phoneme_table[col] = pd.to_numeric(phoneme_table[col], errors='coerce')
    
#     # Ensure string columns
#     phoneme_table['phoneme'] = phoneme_table['phoneme'].astype(str)
    
#     # Add features with safe data types
#     phoneme_table['start_time_ms'] = phoneme_table['start_time'] * 1000
#     if 'end_time' in phoneme_table.columns:
#         phoneme_table['duration_ms'] = (phoneme_table['end_time'] - phoneme_table['start_time']) * 1000
#     else:
#         phoneme_table['duration_ms'] = 100.0  # Default 100ms
    
#     # Safe phoneme category extraction
#     phoneme_table['phoneme_category'] = phoneme_table['phoneme'].str[:1].fillna('UNK')
    
#     print(f"‚úì Phoneme events table: {len(phoneme_table)} phonemes")
    
#     # 3. Relationships Table - Clean copy with proper data types
#     relationships_table = causal_df_fixed.copy()
    
#     # Convert categorical columns to strings
#     for col in relationships_table.columns:
#         if relationships_table[col].dtype.name == 'category':
#             relationships_table[col] = relationships_table[col].astype(str)
    
#     # Ensure required edge columns exist
#     if 'neural_event_id' not in relationships_table.columns:
#         print("‚ùå Error: Missing neural_event_id in causal relationships")
#         return None
#     if 'phoneme_id' not in relationships_table.columns:
#         print("‚ùå Error: Missing phoneme_id in causal relationships")
#         return None
    
#     # Ensure numeric types for edge columns
#     relationships_table['neural_event_id'] = pd.to_numeric(relationships_table['neural_event_id'], errors='coerce')
#     relationships_table['phoneme_id'] = pd.to_numeric(relationships_table['phoneme_id'], errors='coerce')
#     relationships_table['delay_ms'] = pd.to_numeric(relationships_table['delay_ms'], errors='coerce')
#     relationships_table['strength'] = pd.to_numeric(relationships_table['strength'], errors='coerce')
    
#     # Remove any rows with NaN in key columns
#     before_count = len(relationships_table)
#     relationships_table = relationships_table.dropna(subset=['neural_event_id', 'phoneme_id'])
#     after_count = len(relationships_table)
    
#     if before_count != after_count:
#         print(f"  Removed {before_count - after_count} rows with NaN IDs")
    
#     # Safe categorical feature creation
#     try:
#         relationships_table['strength_category'] = pd.cut(
#             relationships_table['strength'], 
#             bins=[0, 0.7, 0.85, 1.0], 
#             labels=['weak', 'medium', 'strong'],
#             include_lowest=True
#         ).astype(str)
#     except Exception as e:
#         print(f"‚ö†Ô∏è Warning: Could not create strength categories: {e}")
#         relationships_table['strength_category'] = 'medium'  # Default value
    
#     # Ensure string columns
#     string_cols = ['event_type', 'phoneme']
#     for col in string_cols:
#         if col in relationships_table.columns:
#             relationships_table[col] = relationships_table[col].astype(str)
    
#     print(f"‚úì Causal relationships table: {len(relationships_table)} edges")
    
#     # 4. Final data validation
#     print(f"üîç Final validation...")
    
#     # Check for any remaining categorical columns
#     for table_name, table in [('neural_events', neural_table), 
#                              ('phoneme_events', phoneme_table), 
#                              ('relationships', relationships_table)]:
#         categorical_cols = [col for col in table.columns if table[col].dtype.name == 'category']
#         if categorical_cols:
#             print(f"‚ö†Ô∏è Warning: {table_name} still has categorical columns: {categorical_cols}")
    
#     # Check for NaN values in key columns
#     key_checks = [
#         (neural_table, 'event_id', 'neural_events'),
#         (phoneme_table, 'phoneme_id', 'phoneme_events'),
#         (relationships_table, 'neural_event_id', 'relationships'),
#         (relationships_table, 'phoneme_id', 'relationships')
#     ]
    
#     validation_passed = True
#     for table, col, table_name in key_checks:
#         if col in table.columns:
#             nan_count = table[col].isna().sum()
#             if nan_count > 0:
#                 print(f"‚ö†Ô∏è Warning: {table_name}.{col} has {nan_count} NaN values")
#                 validation_passed = False
    
#     if validation_passed:
#         print("‚úÖ All critical columns are clean")
    
#     return {
#         'neural_events': neural_table,
#         'phoneme_events': phoneme_table,
#         'neural_phoneme_edges': relationships_table
#     }

# def check_kumo_capabilities(graph):
#     """Check which methods are available in this Kumo version"""
    
#     print("üîç Checking Kumo AI capabilities...")
    
#     capabilities = {
#         'query': hasattr(graph, 'query'),
#         'visualize': hasattr(graph, 'visualize'),
#         'train': hasattr(graph, 'train'),
#         'tables': hasattr(graph, 'tables'),
#         'get_table': hasattr(graph, 'get_table'),
#         'create_prediction': hasattr(graph, 'create_prediction'),
#         'fit': hasattr(graph, 'fit')
#     }
    
#     print("üìã Available methods:")
#     for method, available in capabilities.items():
#         status = "‚úÖ" if available else "‚ùå"
#         print(f"  {status} {method}")
    
#     return capabilities

# def get_table_data(graph, table_name):
#     """Get table data using available methods"""
    
#     try:
#         # Method 1: Direct table access
#         if hasattr(graph, 'tables') and table_name in graph.tables:
#             table = graph.tables[table_name]
            
#             # Convert LocalTable to DataFrame if possible
#             if hasattr(table, 'to_pandas'):
#                 return table.to_pandas()
#             elif hasattr(table, 'data'):
#                 return table.data
#             elif hasattr(table, 'df'):
#                 return table.df
#             else:
#                 print(f"‚ö†Ô∏è Unknown table format for {table_name}")
#                 return None
                
#     except Exception as e:
#         print(f"‚ùå Error accessing table {table_name}: {e}")
#         return None

# def create_kumo_graph_safe(tables):
#     """
#     Create Kumo AI graph with enhanced error handling and compatibility checks
#     """
#     print("üîó Creating Kumo AI graph (safe mode with compatibility)...")
    
#     try:
#         # Display basic table info
#         print("üìã Table summary:")
#         for table_name, table in tables.items():
#             print(f"  {table_name}: {len(table)} rows")
        
#         # Create the graph
#         print("üî® Creating graph...")
#         graph = rfm.LocalGraph.from_data(tables)
        
#         # Configure the graph
#         print("‚öôÔ∏è Configuring graph...")
        
#         # Set temporal columns and primary keys
#         if 'neural_events' in tables:
#             graph['neural_events'].time_column = 'timestamp'
#             graph['neural_events'].primary_key = 'event_id'
        
#         if 'phoneme_events' in tables:
#             graph['phoneme_events'].time_column = 'start_time'
#             graph['phoneme_events'].primary_key = 'phoneme_id'
        
#         # Configure edge table
#         if 'neural_phoneme_edges' in tables:
#             graph['neural_phoneme_edges'].source_column = 'neural_event_id'
#             graph['neural_phoneme_edges'].target_column = 'phoneme_id'
#             graph['neural_phoneme_edges'].source_table = 'neural_events'
#             graph['neural_phoneme_edges'].target_table = 'phoneme_events'
        
#         print("‚úÖ Graph created and configured successfully!")
        
#         # Check capabilities
#         capabilities = check_kumo_capabilities(graph)
        
#         # Test basic functionality without using len() on LocalTable
#         print("üß™ Testing graph functionality...")
        
#         try:
#             # Test queries if available
#             if capabilities.get('query'):
#                 neural_count_result = graph.query("SELECT COUNT(*) as count FROM neural_events")
#                 phoneme_count_result = graph.query("SELECT COUNT(*) as count FROM phoneme_events")
#                 edges_count_result = graph.query("SELECT COUNT(*) as count FROM neural_phoneme_edges")
                
#                 neural_count = neural_count_result['count'].iloc[0]
#                 phoneme_count = phoneme_count_result['count'].iloc[0]
#                 edges_count = edges_count_result['count'].iloc[0]
                
#                 print(f"‚úÖ Graph queries successful:")
#                 print(f"  Neural events: {neural_count}")
#                 print(f"  Phoneme events: {phoneme_count}")
#                 print(f"  Edges: {edges_count}")
#             else:
#                 print("‚ö†Ô∏è Query method not available, using table access instead")
#                 # Try to get basic stats through table access
#                 for table_name in ['neural_events', 'phoneme_events', 'neural_phoneme_edges']:
#                     table_data = get_table_data(graph, table_name)
#                     if table_data is not None:
#                         print(f"  {table_name}: {len(table_data)} rows")
                
#         except Exception as query_error:
#             print(f"‚ö†Ô∏è Query test failed: {query_error}")
        
#         # Try visualization (optional)
#         if capabilities.get('visualize'):
#             try:
#                 print("üé® Attempting visualization...")
#                 graph.visualize(show_columns=True)
#                 print("‚úÖ Visualization successful")
#             except Exception as viz_error:
#                 print(f"‚ö†Ô∏è Visualization failed: {viz_error}")
#         else:
#             print("‚ö†Ô∏è Visualization not available in this Kumo version")
        
#         return graph
        
#     except Exception as e:
#         print(f"‚ùå Error creating graph: {e}")
#         print(f"üîç Error details: {type(e).__name__}: {str(e)}")
        
#         # Additional debugging
#         if "categorical" in str(e).lower():
#             print("üí° This appears to be a categorical data issue.")
#             print("   Try running the data cleaning steps again.")
        
#         return None

# def create_prediction_framework(graph):
#     """
#     Create prediction framework compatible with available Kumo methods
#     """
#     print("üß† Creating prediction framework...")
    
#     capabilities = check_kumo_capabilities(graph)
    
#     try:
#         # Method 1: Try modern RFM approach
#         if capabilities.get('train'):
#             print("üöÄ Using RFM training approach...")
            
#             # Create prediction queries
#             neural_to_phoneme_query = rfm.PredictionQuery(
#                 target_table="phoneme_events",
#                 target_column="phoneme",
#                 feature_tables=["neural_events"],
#                 time_column="start_time",
#                 features=["neural_events.event_type", "neural_events.channel"]
#             )
            
#             phoneme_to_neural_query = rfm.PredictionQuery(
#                 target_table="neural_events", 
#                 target_column="event_type",
#                 feature_tables=["phoneme_events"],
#                 time_column="timestamp",
#                 features=["phoneme_events.phoneme"]
#             )
            
#             return {
#                 'neural_to_phoneme_query': neural_to_phoneme_query,
#                 'phoneme_to_neural_query': phoneme_to_neural_query,
#                 'method': 'rfm_modern'
#             }
        
#         # Method 2: Try simple dictionary-based queries
#         else:
#             print("üöÄ Using simple prediction queries...")
            
#             neural_to_phoneme_query = {
#                 'target_table': 'phoneme_events',
#                 'target_column': 'phoneme',
#                 'feature_tables': ['neural_events'],
#                 'features': ['neural_events.event_type', 'neural_events.channel'],
#                 'time_column': 'start_time',
#                 'training_window': '2s'
#             }
            
#             phoneme_to_neural_query = {
#                 'target_table': 'neural_events',
#                 'target_column': 'event_type',
#                 'feature_tables': ['phoneme_events'],
#                 'features': ['phoneme_events.phoneme'],
#                 'time_column': 'timestamp',
#                 'training_window': '1s'
#             }
            
#             return {
#                 'neural_to_phoneme_query': neural_to_phoneme_query,
#                 'phoneme_to_neural_query': phoneme_to_neural_query,
#                 'method': 'simple_dict'
#             }
        
#     except Exception as e:
#         print(f"‚ùå Error creating prediction framework: {e}")
#         return None

# def run_fixed_kumo_pipeline(neural_events_df, phoneme_annotations, causal_df):
#     """
#     Run the Kumo pipeline with all fixes and compatibility handling
#     """
#     print("üöÄ Running COMPREHENSIVE FIXED Kumo AI pipeline...")
#     print("=" * 60)
    
#     # Step 1: Prepare data with all fixes
#     tables = prepare_kumo_tables_fixed(neural_events_df, phoneme_annotations, causal_df)
#     if not tables:
#         print("‚ùå Data preparation failed")
#         return None
    
#     # Step 2: Create graph with safe mode and compatibility
#     graph = create_kumo_graph_safe(tables)
#     if not graph:
#         print("‚ùå Graph creation failed")
#         return None
    
#     # Step 3: Create prediction framework
#     prediction_framework = create_prediction_framework(graph)
    
#     # Step 4: Analyze data using compatible methods
#     try:
#         print("üìä Analyzing graph data...")
        
#         analysis_results = {}
#         table_names = ['neural_events', 'phoneme_events', 'neural_phoneme_edges']
        
#         for table_name in table_names:
#             table_data = get_table_data(graph, table_name)
#             if table_data is not None:
#                 analysis_results[table_name] = table_data
#                 print(f"  ‚úÖ {table_name}: {len(table_data)} rows")
        
#         # Basic statistics if we have data
#         if 'neural_events' in analysis_results:
#             neural_df = analysis_results['neural_events']
#             if 'event_type' in neural_df.columns:
#                 event_counts = neural_df['event_type'].value_counts()
#                 print(f"  üìà Neural event types: {dict(event_counts)}")
        
#         if 'phoneme_events' in analysis_results:
#             phoneme_df = analysis_results['phoneme_events']
#             if 'phoneme' in phoneme_df.columns:
#                 top_phonemes = phoneme_df['phoneme'].value_counts().head(5)
#                 print(f"  üìà Top 5 phonemes: {list(top_phonemes.index)}")
        
#         if 'neural_phoneme_edges' in analysis_results:
#             edges_df = analysis_results['neural_phoneme_edges']
#             if 'delay_ms' in edges_df.columns:
#                 avg_delay = edges_df['delay_ms'].mean()
#                 print(f"  üìà Average causal delay: {avg_delay:.1f}ms")
        
#     except Exception as analysis_error:
#         print(f"‚ö†Ô∏è Analysis failed: {analysis_error}")
#         analysis_results = {}
    
#     result = {
#         'graph': graph,
#         'tables': tables,
#         'prediction_framework': prediction_framework,
#         'analysis_results': analysis_results,
#         'status': 'success'
#     }
    
#     print("‚úÖ Comprehensive fixed pipeline completed successfully!")
#     print("üìä Graph is ready for neural-phoneme predictions")
    
#     return result

# # ============================================================================
# # MAIN EXECUTION
# # ============================================================================

# def main():
#     print("üîß RUNNING COMPREHENSIVE FIXED KUMO PIPELINE")
#     print("=" * 50)
    
#     # Check for required variables
#     required_vars = ['neural_events_df', 'phoneme_annotations', 'causal_df']
#     missing_vars = []
    
#     for var in required_vars:
#         if var not in globals():
#             missing_vars.append(var)
    
#     if missing_vars:
#         print(f"‚ùå Missing required variables: {missing_vars}")
#         print("Make sure you have run the previous steps to create these DataFrames")
#         return None
    
#     # Check data availability
#     print(f"‚úÖ Found all required data:")
#     print(f"  neural_events_df: {len(neural_events_df)} events")
#     print(f"  phoneme_annotations: {len(phoneme_annotations)} phonemes") 
#     print(f"  causal_df: {len(causal_df)} relationships")
    
#     # Run the comprehensive fixed pipeline
#     comprehensive_results = run_fixed_kumo_pipeline(neural_events_df, phoneme_annotations, causal_df)
    
#     if comprehensive_results and comprehensive_results.get('graph'):
#         print(f"\nüéâ COMPREHENSIVE SUCCESS!")
#         print(f"‚úÖ Kumo AI graph created and working")
#         print(f"‚úÖ Compatible with your Kumo version")
#         print(f"‚úÖ Ready for neural-phoneme predictions")
        
#         print(f"\nüéØ What you can do now:")
#         print(f"1. Access table data: graph.tables['table_name']")
#         print(f"2. Train prediction models (if available)")
#         print(f"3. Build custom neural-phoneme classifiers")
#         print(f"4. Create real-time BCI applications")
        
#         return comprehensive_results
#     else:
#         print(f"\n‚ùå Comprehensive pipeline failed")
#         return None

# # Install graphviz instructions
# def show_graphviz_install():
#     """Show graphviz installation instructions"""
#     print("\nüì¶ GRAPHVIZ INSTALLATION (if needed):")
#     print("For Lambda Labs: sudo apt-get update && sudo apt-get install -y graphviz")
#     print("Then: pip install graphviz")

# # Run if this file is executed directly
# if __name__ == "__main__":
#     show_graphviz_install()
#     main()
# else:
#     # Run automatically when imported
#     if 'neural_events_df' in globals() and 'phoneme_annotations' in globals() and 'causal_df' in globals():
#         print("üöÄ Auto-running comprehensive fixed Kumo pipeline...")
#         comprehensive_results = main()
#         if comprehensive_results:
#             print("‚úÖ Comprehensive pipeline completed! Results stored in 'comprehensive_results' variable")

In [38]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import networkx as nx
from datetime import datetime

# Set style for better looking plots - multiple options:

# Option 1: Modern seaborn style (recommended)
try:
    plt.style.use('seaborn-v0_8')
except OSError:
    # Fallback styles that work with current seaborn/matplotlib
    plt.style.use('default')
    sns.set_theme(style="whitegrid")  # Modern seaborn approach

# Option 2: Classic seaborn styles (if available)
# plt.style.use('seaborn')  # Generic seaborn
# plt.style.use('ggplot')   # ggplot2-inspired
# plt.style.use('bmh')      # Bayesian Methods for Hackers style

# Option 3: Pure matplotlib styles (always work)
# plt.style.use('default')
# plt.style.use('classic') 
# plt.style.use('fivethirtyeight')

# Set color palette
sns.set_palette("husl")

def setup_visualization_directories():
    """Create organized directory structure for visualizations"""
    
    base_dir = Path("graphs")
    
    directories = {
        'base': base_dir,
        'neural_events': base_dir / "neural_events",
        'phoneme_analysis': base_dir / "phoneme_analysis", 
        'causal_relationships': base_dir / "causal_relationships",
        'temporal_analysis': base_dir / "temporal_analysis",
        'channel_analysis': base_dir / "channel_analysis",
        'interactive': base_dir / "interactive",
        'network_graphs': base_dir / "network_graphs",
        'summary_reports': base_dir / "summary_reports",
        'kumo_graphs': base_dir / "kumo_graphs"
    }
    
    print("Creating visualization directory structure...")
    for name, path in directories.items():
        path.mkdir(parents=True, exist_ok=True)
        print(f"  ‚úì {path}")
    
    return directories

def save_neural_events_visualizations(neural_events_df, dirs):
    """Create and save neural events visualizations"""
    
    print("Creating neural events visualizations...")
    
    # 1. Event type distribution
    plt.figure(figsize=(10, 6))
    event_counts = neural_events_df['event_type'].value_counts()
    plt.pie(event_counts.values, labels=event_counts.index, autopct='%1.1f%%')
    plt.title('Neural Event Type Distribution')
    plt.savefig(dirs['neural_events'] / 'event_type_distribution.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 2. Events over time
    plt.figure(figsize=(15, 8))
    plt.subplot(2, 1, 1)
    neural_events_df.groupby(neural_events_df['timestamp'].round(1)).size().plot()
    plt.title('Neural Events Over Time')
    plt.xlabel('Time (seconds)')
    plt.ylabel('Number of Events')
    
    # 3. Channel activity heatmap
    plt.subplot(2, 1, 2)
    channel_activity = neural_events_df.groupby(['channel', neural_events_df['timestamp'].round(1)]).size().unstack(fill_value=0)
    sns.heatmap(channel_activity.iloc[:, ::10], cmap='viridis', cbar_kws={'label': 'Event Count'})
    plt.title('Channel Activity Heatmap')
    plt.xlabel('Time (seconds)')
    plt.ylabel('Channel')
    
    plt.tight_layout()
    plt.savefig(dirs['neural_events'] / 'neural_activity_timeline.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 4. Amplitude distribution by event type
    plt.figure(figsize=(12, 6))
    neural_events_df.boxplot(column='amplitude', by='event_type', ax=plt.gca())
    plt.title('Amplitude Distribution by Event Type')
    plt.suptitle('')  # Remove default title
    plt.savefig(dirs['neural_events'] / 'amplitude_by_event_type.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 5. Channel region analysis
    if 'channel_region' in neural_events_df.columns:
        plt.figure(figsize=(10, 6))
        region_counts = neural_events_df['channel_region'].value_counts().sort_index()
        plt.bar(region_counts.index, region_counts.values)
        plt.title('Events by Channel Region')
        plt.xlabel('Channel Region')
        plt.ylabel('Number of Events')
        plt.savefig(dirs['channel_analysis'] / 'events_by_region.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    print(f"   Saved 5 neural events visualizations")

def save_phoneme_analysis_visualizations(phoneme_annotations, dirs):
    """Create and save phoneme analysis visualizations"""
    
    print("Creating phoneme analysis visualizations...")
    
    # 1. Phoneme frequency distribution
    plt.figure(figsize=(15, 8))
    top_phonemes = phoneme_annotations['phoneme'].value_counts().head(20)
    plt.bar(range(len(top_phonemes)), top_phonemes.values)
    plt.xticks(range(len(top_phonemes)), top_phonemes.index, rotation=45)
    plt.title('Top 20 Most Frequent Phonemes')
    plt.xlabel('Phoneme')
    plt.ylabel('Frequency')
    plt.tight_layout()
    plt.savefig(dirs['phoneme_analysis'] / 'phoneme_frequency.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 2. Phoneme duration analysis
    if 'duration' in phoneme_annotations.columns:
        plt.figure(figsize=(12, 6))
        plt.hist(phoneme_annotations['duration'], bins=50, alpha=0.7, edgecolor='black')
        plt.title('Phoneme Duration Distribution')
        plt.xlabel('Duration (seconds)')
        plt.ylabel('Frequency')
        plt.axvline(phoneme_annotations['duration'].mean(), color='red', linestyle='--', label=f"Mean: {phoneme_annotations['duration'].mean():.3f}s")
        plt.legend()
        plt.savefig(dirs['phoneme_analysis'] / 'phoneme_duration_distribution.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    # 3. Phoneme timeline
    plt.figure(figsize=(15, 8))
    sample_phonemes = phoneme_annotations.head(100)  # Sample for readability
    colors = plt.cm.Set3(np.linspace(0, 1, len(sample_phonemes['phoneme'].unique())))
    phoneme_colors = dict(zip(sample_phonemes['phoneme'].unique(), colors))
    
    for idx, row in sample_phonemes.iterrows():
        plt.barh(idx, row['duration'] if 'duration' in row else 0.1, 
                left=row['start_time'], color=phoneme_colors[row['phoneme']], alpha=0.7)
    
    plt.title('Phoneme Timeline (First 100 phonemes)')
    plt.xlabel('Time (seconds)')
    plt.ylabel('Phoneme Index')
    plt.savefig(dirs['phoneme_analysis'] / 'phoneme_timeline.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 4. Sequence position analysis
    if 'sequence_position' in phoneme_annotations.columns:
        plt.figure(figsize=(12, 6))
        position_counts = phoneme_annotations['sequence_position'].value_counts().sort_index()
        plt.plot(position_counts.index, position_counts.values, marker='o')
        plt.title('Phoneme Count by Sequence Position')
        plt.xlabel('Position in Sequence')
        plt.ylabel('Number of Phonemes')
        plt.savefig(dirs['phoneme_analysis'] / 'sequence_position_analysis.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    print(f"  ‚úì Saved 4 phoneme analysis visualizations")

def save_causal_relationships_visualizations(causal_df, dirs):
    """Create and save causal relationships visualizations"""
    
    print("Creating causal relationships visualizations...")
    
    # 1. Delay distribution
    plt.figure(figsize=(12, 6))
    plt.hist(causal_df['delay_ms'], bins=50, alpha=0.7, edgecolor='black')
    plt.title('Causal Delay Distribution')
    plt.xlabel('Delay (milliseconds)')
    plt.ylabel('Number of Relationships')
    plt.axvline(causal_df['delay_ms'].mean(), color='red', linestyle='--', 
                label=f"Mean: {causal_df['delay_ms'].mean():.1f}ms")
    plt.legend()
    plt.savefig(dirs['causal_relationships'] / 'delay_distribution.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 2. Strength vs Delay scatter plot
    plt.figure(figsize=(10, 8))
    plt.scatter(causal_df['delay_ms'], causal_df['strength'], alpha=0.1, s=1)
    plt.xlabel('Delay (ms)')
    plt.ylabel('Relationship Strength')
    plt.title('Relationship Strength vs Delay')
    plt.savefig(dirs['causal_relationships'] / 'strength_vs_delay.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 3. Top phoneme-channel combinations
    plt.figure(figsize=(15, 10))
    top_combinations = causal_df.groupby(['phoneme', 'channel']).size().sort_values(ascending=False).head(20)
    
    # Create heatmap data
    phonemes = [combo[0] for combo in top_combinations.index]
    channels = [combo[1] for combo in top_combinations.index]
    
    # Create a matrix for heatmap
    unique_phonemes = list(set(phonemes))
    unique_channels = list(set(channels))
    
    matrix = np.zeros((len(unique_phonemes), len(unique_channels)))
    for (phoneme, channel), count in top_combinations.items():
        i = unique_phonemes.index(phoneme)
        j = unique_channels.index(channel)
        matrix[i, j] = count
    
    sns.heatmap(matrix, xticklabels=unique_channels, yticklabels=unique_phonemes, 
                annot=True, fmt='.0f', cmap='viridis')
    plt.title('Top Phoneme-Channel Relationships')
    plt.xlabel('Channel')
    plt.ylabel('Phoneme')
    plt.tight_layout()
    plt.savefig(dirs['causal_relationships'] / 'phoneme_channel_heatmap.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 4. Event type analysis
    if 'event_type' in causal_df.columns:
        plt.figure(figsize=(12, 6))
        event_type_counts = causal_df['event_type'].value_counts()
        plt.pie(event_type_counts.values, labels=event_type_counts.index, autopct='%1.1f%%')
        plt.title('Causal Relationships by Neural Event Type')
        plt.savefig(dirs['causal_relationships'] / 'relationships_by_event_type.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    print(f"  ‚úì Saved 4 causal relationships visualizations")

def save_temporal_analysis_visualizations(neural_events_df, phoneme_annotations, causal_df, dirs):
    """Create and save temporal analysis visualizations"""
    
    print("‚è∞ Creating temporal analysis visualizations...")
    
    # 1. Neural events and phonemes timeline
    plt.figure(figsize=(15, 10))
    
    # Plot neural events
    plt.subplot(3, 1, 1)
    neural_timeline = neural_events_df.groupby(neural_events_df['timestamp'].round(1)).size()
    plt.plot(neural_timeline.index, neural_timeline.values, label='Neural Events', alpha=0.7)
    plt.title('Neural Events Timeline')
    plt.ylabel('Event Count')
    plt.legend()
    
    # Plot phonemes
    plt.subplot(3, 1, 2)
    phoneme_timeline = phoneme_annotations.groupby(phoneme_annotations['start_time'].round(1)).size()
    plt.plot(phoneme_timeline.index, phoneme_timeline.values, label='Phonemes', color='orange', alpha=0.7)
    plt.title('Phoneme Timeline')
    plt.ylabel('Phoneme Count')
    plt.legend()
    
    # Plot relationship density
    plt.subplot(3, 1, 3)
    # Calculate relationship density over time (simplified approach)
    try:
        # Create a simple time mapping for causal relationships
        causal_times = []
        for _, row in causal_df.iterrows():
            # Use a sample of relationships to avoid memory issues
            if len(causal_times) < 10000:  # Limit for performance
                # Approximate time based on delay
                approx_time = row['delay_ms'] / 1000.0  # Convert to seconds
                causal_times.append(approx_time)
        
        if causal_times:
            causal_series = pd.Series(causal_times)
            relationship_timeline = causal_series.groupby(causal_series.round(1)).size()
            plt.plot(relationship_timeline.index, relationship_timeline.values, 
                    label='Causal Relationships', color='green', alpha=0.7)
        
    except Exception as e:
        print(f"  ‚ö†Ô∏è Skipping causal timeline due to: {e}")
        plt.plot([0, 1], [0, 0], label='Causal Relationships (unavailable)', color='green', alpha=0.7)
    
    plt.title('Causal Relationships Timeline')
    plt.xlabel('Time (seconds)')
    plt.ylabel('Relationship Count')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(dirs['temporal_analysis'] / 'temporal_overview.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 2. Delay patterns over time
    plt.figure(figsize=(12, 8))
    
    try:
        # Sample data for performance
        sample_causal = causal_df.sample(min(10000, len(causal_df)))
        time_bins = pd.cut(sample_causal['delay_ms'], bins=20)
        delay_by_time = sample_causal.groupby(time_bins)['delay_ms'].mean()
        
        plt.plot(range(len(delay_by_time)), delay_by_time.values, marker='o')
        plt.title('Average Causal Delay Distribution (Binned)')
        plt.xlabel('Delay Bin')
        plt.ylabel('Average Delay (ms)')
        plt.xticks(range(0, len(delay_by_time), 2), rotation=45)
        
    except Exception as e:
        print(f"  ‚ö†Ô∏è Simplified delay analysis due to: {e}")
        plt.hist(causal_df['delay_ms'].sample(min(10000, len(causal_df))), bins=20)
        plt.title('Causal Delay Distribution (Sample)')
        plt.xlabel('Delay (ms)')
        plt.ylabel('Frequency')
    
    plt.tight_layout()
    plt.savefig(dirs['temporal_analysis'] / 'delay_patterns_over_time.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"  ‚úì Saved 2 temporal analysis visualizations")

def save_channel_analysis_visualizations(neural_events_df, causal_df, dirs):
    """Create and save channel analysis visualizations"""
    
    print("üì° Creating channel analysis visualizations...")
    
    # 1. Channel activity overview
    plt.figure(figsize=(15, 8))
    channel_counts = neural_events_df['channel'].value_counts().sort_index()
    plt.bar(channel_counts.index, channel_counts.values)
    plt.title('Neural Activity by Channel')
    plt.xlabel('Channel')
    plt.ylabel('Number of Events')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(dirs['channel_analysis'] / 'channel_activity_overview.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 2. Channel involvement in causal relationships
    plt.figure(figsize=(15, 8))
    channel_relationships = causal_df['channel'].value_counts().sort_index()
    plt.bar(channel_relationships.index, channel_relationships.values, color='orange')
    plt.title('Channel Involvement in Causal Relationships')
    plt.xlabel('Channel')
    plt.ylabel('Number of Relationships')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(dirs['channel_analysis'] / 'channel_relationships.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # 3. Channel regions analysis (if available)
    if 'channel_region' in neural_events_df.columns:
        plt.figure(figsize=(12, 8))
        
        # Region activity
        plt.subplot(2, 1, 1)
        region_activity = neural_events_df['channel_region'].value_counts().sort_index()
        plt.bar(region_activity.index, region_activity.values)
        plt.title('Neural Activity by Brain Region')
        plt.xlabel('Region')
        plt.ylabel('Number of Events')
        
        # Region relationship involvement
        plt.subplot(2, 1, 2)
        # Map channels to regions for causal_df
        channel_to_region = neural_events_df.groupby('channel')['channel_region'].first()
        causal_df['region'] = causal_df['channel'].map(channel_to_region)
        region_relationships = causal_df['region'].value_counts().sort_index()
        plt.bar(region_relationships.index, region_relationships.values, color='orange')
        plt.title('Causal Relationships by Brain Region')
        plt.xlabel('Region')
        plt.ylabel('Number of Relationships')
        
        plt.tight_layout()
        plt.savefig(dirs['channel_analysis'] / 'brain_region_analysis.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    print(f"  ‚úì Saved 3 channel analysis visualizations")

def save_interactive_visualizations(neural_events_df, phoneme_annotations, causal_df, dirs):
    """Create and save interactive plotly visualizations"""
    
    print("üîÑ Creating interactive visualizations...")
    
    # 1. Interactive neural events timeline
    fig = px.scatter(neural_events_df.sample(1000), x='timestamp', y='channel', 
                     color='event_type', hover_data=['amplitude'],
                     title='Interactive Neural Events Timeline (Sample)')
    fig.write_html(dirs['interactive'] / 'neural_events_timeline.html')
    
    # 2. Interactive phoneme frequency
    phoneme_counts = phoneme_annotations['phoneme'].value_counts().head(20)
    fig = px.bar(x=phoneme_counts.index, y=phoneme_counts.values,
                 title='Interactive Phoneme Frequency Distribution')
    fig.update_xaxes(title='Phoneme')
    fig.update_yaxes(title='Frequency')
    fig.write_html(dirs['interactive'] / 'phoneme_frequency.html')
    
    # 3. Interactive 3D scatter of causal relationships
    sample_causal = causal_df.sample(min(5000, len(causal_df)))
    fig = px.scatter_3d(sample_causal, x='delay_ms', y='strength', z='channel',
                        color='phoneme', title='3D Causal Relationships (Sample)')
    fig.write_html(dirs['interactive'] / 'causal_relationships_3d.html')
    
    # 4. Interactive heatmap
    pivot_data = causal_df.groupby(['phoneme', 'channel']).size().reset_index(name='count')
    top_phonemes = pivot_data.groupby('phoneme')['count'].sum().nlargest(10).index
    top_channels = pivot_data.groupby('channel')['count'].sum().nlargest(20).index
    
    filtered_data = pivot_data[pivot_data['phoneme'].isin(top_phonemes) & 
                              pivot_data['channel'].isin(top_channels)]
    
    heatmap_matrix = filtered_data.pivot(index='phoneme', columns='channel', values='count').fillna(0)
    
    fig = px.imshow(heatmap_matrix, 
                    title='Interactive Phoneme-Channel Relationship Heatmap',
                    color_continuous_scale='viridis')
    fig.write_html(dirs['interactive'] / 'phoneme_channel_heatmap.html')
    
    print(f"  ‚úì Saved 4 interactive visualizations")

def save_network_graph_visualizations(causal_df, phoneme_annotations, dirs):
    """Create and save network graph visualizations"""
    
    print("üï∏Ô∏è Creating network graph visualizations...")
    
    # Sample data for network visualization (full dataset would be too large)
    sample_size = min(1000, len(causal_df))
    sample_causal = causal_df.sample(sample_size)
    
    # Create network graph
    G = nx.Graph()
    
    # Add nodes for phonemes and neural events
    for phoneme in sample_causal['phoneme'].unique():
        G.add_node(f"P_{phoneme}", node_type='phoneme', label=phoneme)
    
    for channel in sample_causal['channel'].unique():
        G.add_node(f"C_{channel}", node_type='channel', label=f"Ch{channel}")
    
    # Add edges for relationships
    for _, row in sample_causal.iterrows():
        G.add_edge(f"P_{row['phoneme']}", f"C_{row['channel']}", 
                   weight=row['strength'], delay=row['delay_ms'])
    
    # Create visualization
    plt.figure(figsize=(20, 20))
    pos = nx.spring_layout(G, k=1, iterations=50)
    
    # Draw phoneme nodes
    phoneme_nodes = [n for n in G.nodes() if n.startswith('P_')]
    channel_nodes = [n for n in G.nodes() if n.startswith('C_')]
    
    nx.draw_networkx_nodes(G, pos, nodelist=phoneme_nodes, node_color='lightblue', 
                          node_size=300, label='Phonemes')
    nx.draw_networkx_nodes(G, pos, nodelist=channel_nodes, node_color='lightcoral', 
                          node_size=200, label='Channels')
    
    # Draw edges
    nx.draw_networkx_edges(G, pos, alpha=0.3, width=0.5)
    
    # Draw labels
    labels = {node: G.nodes[node]['label'] for node in G.nodes()}
    nx.draw_networkx_labels(G, pos, labels, font_size=8)
    
    plt.title('Neural-Phoneme Network Graph (Sample)', size=16)
    plt.legend()
    plt.axis('off')
    plt.savefig(dirs['network_graphs'] / 'neural_phoneme_network.png', dpi=300, bbox_inches='tight')
    plt.close()
    
    # Save network statistics
    with open(dirs['network_graphs'] / 'network_stats.txt', 'w') as f:
        f.write(f"Network Statistics (Sample of {sample_size} relationships):\n")
        f.write(f"Nodes: {G.number_of_nodes()}\n")
        f.write(f"Edges: {G.number_of_edges()}\n")
        f.write(f"Average degree: {sum(dict(G.degree()).values()) / G.number_of_nodes():.2f}\n")
        f.write(f"Network density: {nx.density(G):.4f}\n")
        
        if nx.is_connected(G):
            f.write(f"Average path length: {nx.average_shortest_path_length(G):.2f}\n")
            f.write(f"Diameter: {nx.diameter(G)}\n")
    
    print(f"  ‚úì Saved network graph and statistics")

def save_kumo_graph_visualizations(graph, dirs):
    """Save Kumo AI graph visualizations if possible"""
    
    print("üìä Creating Kumo graph visualizations...")
    
    try:
        # Try to save Kumo's built-in visualization
        if hasattr(graph, 'visualize'):
            # This might not work with all Kumo versions, but let's try
            graph.visualize(save_path=str(dirs['kumo_graphs'] / 'kumo_schema.png'))
            print("  ‚úì Saved Kumo schema visualization")
        else:
            print("  ‚ö†Ô∏è Kumo visualization not available")
            
        # Save graph summary
        with open(dirs['kumo_graphs'] / 'kumo_graph_summary.txt', 'w') as f:
            f.write("Kumo AI Graph Summary\n")
            f.write("=" * 30 + "\n")
            f.write(f"Created: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
            
            if hasattr(graph, 'tables'):
                f.write("Tables:\n")
                for table_name in graph.tables.keys():
                    f.write(f"  - {table_name}\n")
            
            f.write("\nGraph Configuration:\n")
            f.write("  - Temporal columns configured\n")
            f.write("  - Primary keys set\n")
            f.write("  - Edge relationships defined\n")
        
        print("  ‚úì Saved Kumo graph summary")
        
    except Exception as e:
        print(f"  ‚ö†Ô∏è Could not save Kumo visualizations: {e}")

def create_summary_report(neural_events_df, phoneme_annotations, causal_df, dirs):
    """Create a comprehensive summary report"""
    
    print("üìã Creating summary report...")
    
    # Create HTML summary report
    html_content = f"""
    <!DOCTYPE html>
    <html>
    <head>
        <title>Neural-Phoneme Analysis Summary</title>
        <style>
            body {{ font-family: Arial, sans-serif; margin: 40px; }}
            .header {{ background-color: #f0f0f0; padding: 20px; border-radius: 10px; }}
            .section {{ margin: 20px 0; }}
            .stat {{ background-color: #e8f4f8; padding: 10px; margin: 5px 0; border-radius: 5px; }}
            .visualization-grid {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); gap: 20px; }}
            .viz-card {{ border: 1px solid #ddd; padding: 15px; border-radius: 10px; }}
        </style>
    </head>
    <body>
        <div class="header">
            <h1>üß† Neural-Phoneme BCI Analysis Summary</h1>
            <p>Generated on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
        </div>
        
        <div class="section">
            <h2>üìä Dataset Overview</h2>
            <div class="stat">Neural Events: {len(neural_events_df):,}</div>
            <div class="stat">Phoneme Annotations: {len(phoneme_annotations):,}</div>
            <div class="stat">Causal Relationships: {len(causal_df):,}</div>
            <div class="stat">Unique Channels: {neural_events_df['channel'].nunique()}</div>
            <div class="stat">Unique Phonemes: {phoneme_annotations['phoneme'].nunique()}</div>
            <div class="stat">Sparsity: {(len(causal_df) / (len(neural_events_df) * len(phoneme_annotations))) * 100:.2f}%</div>
        </div>
        
        <div class="section">
            <h2>‚è∞ Temporal Statistics</h2>
            <div class="stat">Neural Recording Duration: {neural_events_df['timestamp'].max() - neural_events_df['timestamp'].min():.2f} seconds</div>
            <div class="stat">Average Causal Delay: {causal_df['delay_ms'].mean():.1f} ¬± {causal_df['delay_ms'].std():.1f} ms</div>
            <div class="stat">Delay Range: {causal_df['delay_ms'].min():.1f} - {causal_df['delay_ms'].max():.1f} ms</div>
        </div>
        
        <div class="section">
            <h2>üéØ Key Findings</h2>
            <div class="stat">Most Active Channel: {neural_events_df['channel'].value_counts().index[0]} ({neural_events_df['channel'].value_counts().iloc[0]:,} events)</div>
            <div class="stat">Most Common Phoneme: {phoneme_annotations['phoneme'].value_counts().index[0]} ({phoneme_annotations['phoneme'].value_counts().iloc[0]:,} occurrences)</div>
            <div class="stat">Peak Relationship Strength: {causal_df['strength'].max():.3f}</div>
        </div>
        
        <div class="section">
            <h2>üìÅ Visualization Categories</h2>
            <div class="visualization-grid">
                <div class="viz-card">
                    <h3>üß† Neural Events</h3>
                    <p>Event distributions, timeline analysis, amplitude patterns</p>
                </div>
                <div class="viz-card">
                    <h3>üó£Ô∏è Phoneme Analysis</h3>
                    <p>Frequency distributions, duration analysis, sequence patterns</p>
                </div>
                <div class="viz-card">
                    <h3>üîó Causal Relationships</h3>
                    <p>Delay distributions, strength analysis, phoneme-channel mappings</p>
                </div>
                <div class="viz-card">
                    <h3>‚è∞ Temporal Analysis</h3>
                    <p>Timeline correlations, delay patterns, temporal dynamics</p>
                </div>
                <div class="viz-card">
                    <h3>üì° Channel Analysis</h3>
                    <p>Channel activity, brain region analysis, spatial patterns</p>
                </div>
                <div class="viz-card">
                    <h3>üîÑ Interactive</h3>
                    <p>Plotly visualizations, 3D scatter plots, interactive heatmaps</p>
                </div>
            </div>
        </div>
    </body>
    </html>
    """
    
    with open(dirs['summary_reports'] / 'analysis_summary.html', 'w') as f:
        f.write(html_content)
    
    # Create text summary
    with open(dirs['summary_reports'] / 'analysis_summary.txt', 'w') as f:
        f.write("NEURAL-PHONEME BCI ANALYSIS SUMMARY\n")
        f.write("=" * 50 + "\n")
        f.write(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
        
        f.write("DATASET OVERVIEW:\n")
        f.write(f"  Neural Events: {len(neural_events_df):,}\n")
        f.write(f"  Phoneme Annotations: {len(phoneme_annotations):,}\n")
        f.write(f"  Causal Relationships: {len(causal_df):,}\n")
        f.write(f"  Unique Channels: {neural_events_df['channel'].nunique()}\n")
        f.write(f"  Unique Phonemes: {phoneme_annotations['phoneme'].nunique()}\n")
        f.write(f"  Sparsity: {(len(causal_df) / (len(neural_events_df) * len(phoneme_annotations))) * 100:.2f}%\n\n")
        
        f.write("KEY STATISTICS:\n")
        f.write(f"  Average Causal Delay: {causal_df['delay_ms'].mean():.1f} ¬± {causal_df['delay_ms'].std():.1f} ms\n")
        f.write(f"  Most Active Channel: {neural_events_df['channel'].value_counts().index[0]}\n")
        f.write(f"  Most Common Phoneme: {phoneme_annotations['phoneme'].value_counts().index[0]}\n")
        f.write(f"  Peak Relationship Strength: {causal_df['strength'].max():.3f}\n")
    
    print("  ‚úì Saved HTML and text summary reports")

def main_visualization_pipeline(neural_events_df, phoneme_annotations, causal_df, graph=None):
    """Main function to execute the complete visualization pipeline"""
    
    print("üé® STARTING COMPREHENSIVE VISUALIZATION PIPELINE")
    print("=" * 60)
    
    # Setup directories
    dirs = setup_visualization_directories()
    
    # Create all visualizations
    save_neural_events_visualizations(neural_events_df, dirs)
    save_phoneme_analysis_visualizations(phoneme_annotations, dirs)
    save_causal_relationships_visualizations(causal_df, dirs)
    save_temporal_analysis_visualizations(neural_events_df, phoneme_annotations, causal_df, dirs)
    save_channel_analysis_visualizations(neural_events_df, causal_df, dirs)
    save_interactive_visualizations(neural_events_df, phoneme_annotations, causal_df, dirs)
    save_network_graph_visualizations(causal_df, phoneme_annotations, dirs)
    
    # Save Kumo graph visualizations if available
    if graph is not None:
        save_kumo_graph_visualizations(graph, dirs)
    
    # Create summary report
    create_summary_report(neural_events_df, phoneme_annotations, causal_df, dirs)
    
    print("\nüéâ VISUALIZATION PIPELINE COMPLETE!")
    print(f"üìÅ All visualizations saved to: {dirs['base']}")
    print(f"üìã Summary report: {dirs['summary_reports'] / 'analysis_summary.html'}")
    
    return dirs

# Usage instructions
def show_usage():
    """Show usage instructions"""
    
    print("\nüìñ USAGE INSTRUCTIONS:")
    print("=" * 30)
    print("1. Make sure you have the required packages:")
    print("   pip install matplotlib seaborn plotly networkx")
    print()
    print("2. Run the visualization pipeline:")
    print("   dirs = main_visualization_pipeline(neural_events_df, phoneme_annotations, causal_df, graph)")
    print()
    print("3. View your visualizations in the graphs/ directory")
    print("4. Open graphs/summary_reports/analysis_summary.html in a browser")

if __name__ == "__main__":
    show_usage()
    
    # Check if required data is available
    required_vars = ['neural_events_df', 'phoneme_annotations', 'causal_df']
    if all(var in globals() for var in required_vars):
        print("\n‚úÖ Found all required data - running visualization pipeline...")
        graph = globals().get('comprehensive_results', {}).get('graph', None)
        dirs = main_visualization_pipeline(neural_events_df, phoneme_annotations, causal_df, graph)
    else:
        print(f"\n‚ö†Ô∏è Missing data - make sure you have: {required_vars}")


üìñ USAGE INSTRUCTIONS:
1. Make sure you have the required packages:
   pip install matplotlib seaborn plotly networkx

2. Run the visualization pipeline:
   dirs = main_visualization_pipeline(neural_events_df, phoneme_annotations, causal_df, graph)

3. View your visualizations in the graphs/ directory
4. Open graphs/summary_reports/analysis_summary.html in a browser

‚úÖ Found all required data - running visualization pipeline...
üé® STARTING COMPREHENSIVE VISUALIZATION PIPELINE
üìÅ Creating visualization directory structure...
  ‚úì graphs
  ‚úì graphs/neural_events
  ‚úì graphs/phoneme_analysis
  ‚úì graphs/causal_relationships
  ‚úì graphs/temporal_analysis
  ‚úì graphs/channel_analysis
  ‚úì graphs/interactive
  ‚úì graphs/network_graphs
  ‚úì graphs/summary_reports
  ‚úì graphs/kumo_graphs
üß† Creating neural events visualizations...
  ‚úì Saved 5 neural events visualizations
üó£Ô∏è Creating phoneme analysis visualizations...
  ‚úì Saved 4 phoneme analysis visualizations
ü

In [24]:
!pip install seaborn matplotlib plotly

Defaulting to user installation because normal site-packages is not writeable
