In [15]:
import numpy as np
import os
import pickle
import matplotlib.pyplot as plt
import matplotlib
from collections import defaultdict

import sys
sys.path.insert(0, r'c:\Users\18476\Desktop\Essentials-of-Text-and-Speech-Processing\Project\nejm-brain-to-text')

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['font.family'] = 'sans-serif'


# Load and Inspect Data

In [None]:
# Load pickled data
with open('../model_training/trained_models/baseline_rnn_v2_20251202_220821/checkpoint/val_metrics.pkl', 'rb') as f:
    dat = pickle.load(f)

# Print keys and structure
print("Available keys:", dat.keys())
print("\nData shapes/types:")
for key in dat.keys():
    if isinstance(dat[key], np.ndarray):
        print(f"  {key}: ndarray shape {dat[key].shape}, dtype {dat[key].dtype}")
    elif isinstance(dat[key], list):
        print(f"  {key}: list of length {len(dat[key])}")
        # Check structure of list elements
        if len(dat[key]) > 0:
            print(f"       First element type: {type(dat[key][0])}")
    elif isinstance(dat[key], dict):
        print(f"  {key}: dict with {len(dat[key])} keys")
    else:
        print(f"  {key}: {type(dat[key]).__name__} = {dat[key]}")

# Extract key data
logits = dat['logits']
n_time_steps = dat['n_time_steps']
decoded_seqs = dat['decoded_seqs']
true_seq = dat['true_seq']
phone_seq_lens = dat['phone_seq_lens']
transcription = dat['transcription']
losses = dat['losses']
block_nums = dat['block_nums']
trial_nums = dat['trial_nums']
day_indicies = dat['day_indicies']
day_PERs = dat['day_PERs']
avg_PER = dat['avg_PER']
avg_loss = dat['avg_loss']

# Helper function to flatten inhomogeneous arrays/lists
def flatten_inhomogeneous(arr):
    # Check if it's a list or numpy array
    if isinstance(arr, (list, np.ndarray)):
        # If it's a numpy array, check if it's object type or if we suspect it contains sequences
        # If it's a simple 1D array of numbers, we might not want to flatten unless we know it's grouped
        # But here we assume inputs are potentially grouped structures that need flattening
        
        flat = []
        for item in arr:
            if isinstance(item, list):
                flat.extend(item)
            elif isinstance(item, np.ndarray):
                if item.ndim > 0:
                    flat.extend(item)
                else:
                    flat.append(item.item())
            else:
                flat.append(item)
        return np.array(flat)
    return arr

# Flatten arrays that might contain sequences
day_indicies = flatten_inhomogeneous(day_indicies)
block_nums = flatten_inhomogeneous(block_nums)
trial_nums = flatten_inhomogeneous(trial_nums)
losses = flatten_inhomogeneous(losses)
n_time_steps = flatten_inhomogeneous(n_time_steps)
phone_seq_lens = flatten_inhomogeneous(phone_seq_lens)
decoded_seqs = flatten_inhomogeneous(decoded_seqs)
true_seq = flatten_inhomogeneous(true_seq)
transcription = flatten_inhomogeneous(transcription)

print(f"\nSummary Statistics:")
print(f"  Average PER: {avg_PER:.4f}")
print(f"  Average Loss: {np.mean(losses):.4f}")
print(f"  Total trials: {len(losses)}")
print(f"  Unique days: {len(np.unique(day_indicies))}")
print(f"  Unique blocks: {len(np.unique(block_nums))}")
print(f"  Day indicies shape: {day_indicies.shape if isinstance(day_indicies, np.ndarray) else len(day_indicies)}")

Available keys: dict_keys(['logits', 'n_time_steps', 'decoded_seqs', 'true_seq', 'phone_seq_lens', 'transcription', 'losses', 'block_nums', 'trial_nums', 'day_indicies', 'day_PERs', 'avg_PER', 'avg_loss'])

Data shapes/types:
  logits: list of length 41
       First element type: <class 'numpy.ndarray'>
  n_time_steps: list of length 41
       First element type: <class 'numpy.ndarray'>
  decoded_seqs: list of length 41
       First element type: <class 'list'>
  true_seq: list of length 41
       First element type: <class 'numpy.ndarray'>
  phone_seq_lens: list of length 41
       First element type: <class 'numpy.ndarray'>
  transcription: list of length 41
       First element type: <class 'numpy.ndarray'>
  losses: list of length 82
       First element type: <class 'numpy.ndarray'>
  block_nums: list of length 41
       First element type: <class 'numpy.ndarray'>
  trial_nums: list of length 41
       First element type: <class 'numpy.ndarray'>
  day_indicies: list of length 41
 

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (1426,) + inhomogeneous part.

# Debug: Inspect day_PERs Structure

In [None]:
# Debug: Check day_PERs structure and flatten if needed
print(f"Type of day_PERs: {type(day_PERs)}")
print(f"day_PERs dtype: {day_PERs.dtype if isinstance(day_PERs, np.ndarray) else 'N/A'}")

# Helper function to flatten inhomogeneous arrays/lists (Redefined here to ensure it's available in this cell)
def flatten_inhomogeneous(arr):
    if isinstance(arr, (list, np.ndarray)):
        flat = []
        for item in arr:
            if isinstance(item, list):
                flat.extend(item)
            elif isinstance(item, np.ndarray):
                if item.ndim > 0:
                    flat.extend(item)
                else:
                    flat.append(item.item())
            else:
                flat.append(item)
        return np.array(flat)
    return arr

# Handle block_nums structure similarly using the helper function
block_nums = flatten_inhomogeneous(block_nums)

if isinstance(day_PERs, np.ndarray):
    print(f"day_PERs shape: {day_PERs.shape}")
    print(f"First 5 elements: {day_PERs[:5]}")
    print(f"Element types: {[type(x) for x in day_PERs[:3]]}")
elif isinstance(day_PERs, dict):
    print(f"day_PERs is a dict with keys: {list(day_PERs.keys())[:5]}")
    for k in list(day_PERs.keys())[:3]:
        v = day_PERs[k]
        print(f"  Key {k}: type={type(v)}, value={v}")

# Ensure losses is flattened and is a numpy array for indexing
losses = flatten_inhomogeneous(losses)

# Safe extraction of daily PERs
unique_days = np.unique(day_indicies)
daily_losses = []
daily_pers = []

for day in unique_days:
    day_mask = day_indicies == day
    daily_losses.append(np.mean(losses[day_mask]))
    
    # Handle different day_PERs formats
    per_val = np.nan
    if isinstance(day_PERs, np.ndarray):
        # If it's an array indexed by day
        if day < len(day_PERs):
            per_val = day_PERs[day]
    elif isinstance(day_PERs, dict):
        per_data = day_PERs.get(day, None)
        if isinstance(per_data, dict) and 'total_edit_distance' in per_data:
             # Calculate PER from stats
             if per_data['total_seq_length'] > 0:
                 per_val = per_data['total_edit_distance'] / per_data['total_seq_length']
             else:
                 per_val = np.nan
        else:
             per_val = per_data if per_data is not None else np.nan
    
    # Extract scalar if needed (in case it's still a list/array)
    if isinstance(per_val, (list, np.ndarray)):
        per_val = np.mean(per_val) if len(per_val) > 0 else np.nan
    
    daily_pers.append(per_val)

daily_losses = np.array(daily_losses)
daily_pers = np.array(daily_pers)

print(f"\nExtracted daily losses: {daily_losses}")
print(f"Extracted daily PERs: {daily_pers}")
print(f"Unique days count: {len(unique_days)}")

Type of day_PERs: <class 'dict'>
day_PERs dtype: N/A
day_PERs is a dict with keys: [0, 1, 2, 3, 4]
  Key 0: type=<class 'dict'>, value={'total_edit_distance': 0, 'total_seq_length': 0}
  Key 1: type=<class 'dict'>, value={'total_edit_distance': 84, 'total_seq_length': 962}
  Key 2: type=<class 'dict'>, value={'total_edit_distance': 103, 'total_seq_length': 1193}


IndexError: boolean index did not match indexed array along axis 0; size of axis is 82 but size of corresponding boolean axis is 1426

# Plot 1: Loss and PER by Day

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

# Plot loss by day
ax1.plot(unique_days, daily_losses, 'b.-', linewidth=2, markersize=8)
ax1.set_xlabel('Day Index')
ax1.set_ylabel('Average Loss', color='b')
ax1.tick_params(axis='y', labelcolor='b')
ax1.grid(axis='y', alpha=0.3)
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax1.set_title('Training Loss by Day')

# Plot PER by day
valid_mask = ~np.isnan(daily_pers)
if np.any(valid_mask):
    ax2.plot(unique_days[valid_mask], daily_pers[valid_mask] * 100, 'r.-', linewidth=2, markersize=8)
else:
    ax2.plot(unique_days, daily_pers * 100, 'r.-', linewidth=2, markersize=8)
ax2.set_xlabel('Day Index')
ax2.set_ylabel('Phone Error Rate (%)', color='r')
ax2.tick_params(axis='y', labelcolor='r')
ax2.grid(axis='y', alpha=0.3)
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
ax2.set_title('Phone Error Rate by Day')

plt.tight_layout()
plt.show()


# Plot 2: Loss and PER Distribution

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss histogram
axes[0, 0].hist(losses, bins=50, color='blue', alpha=0.7, edgecolor='black')
axes[0, 0].axvline(avg_loss, color='red', linestyle='--', linewidth=2, label=f'Mean: {avg_loss:.4f}')
axes[0, 0].set_xlabel('Loss')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].set_title('Distribution of Losses')
axes[0, 0].legend()
axes[0, 0].grid(alpha=0.3)

# PER histogram
valid_pers = daily_pers[~np.isnan(daily_pers)]
if len(valid_pers) > 0:
    axes[0, 1].hist(valid_pers * 100, bins=30, color='red', alpha=0.7, edgecolor='black')
    axes[0, 1].axvline(avg_PER * 100, color='blue', linestyle='--', linewidth=2, label=f'Mean: {avg_PER*100:.2f}%')
    axes[0, 1].set_xlabel('Phone Error Rate (%)')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].set_title('Distribution of Daily PERs')
    axes[0, 1].legend()
else:
    axes[0, 1].text(0.5, 0.5, 'No valid PER data', ha='center', va='center')
axes[0, 1].grid(alpha=0.3)

# Loss by block
unique_blocks = np.unique(block_nums)
block_losses = [np.mean(losses[block_nums == b]) for b in unique_blocks]
axes[1, 0].bar(unique_blocks, block_losses, color='steelblue', alpha=0.7, edgecolor='black')
axes[1, 0].set_xlabel('Block Number')
axes[1, 0].set_ylabel('Average Loss')
axes[1, 0].set_title('Average Loss by Block')
axes[1, 0].grid(axis='y', alpha=0.3)

# Number of trials per day
trials_per_day = defaultdict(int)
for day in day_indicies:
    trials_per_day[int(day)] += 1
days_sorted = sorted(trials_per_day.keys())
trial_counts = [trials_per_day[d] for d in days_sorted]
axes[1, 1].bar(days_sorted, trial_counts, color='coral', alpha=0.7, edgecolor='black')
axes[1, 1].set_xlabel('Day Index')
axes[1, 1].set_ylabel('Number of Trials')
axes[1, 1].set_title('Trials per Day')
axes[1, 1].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()


# Plot 3: Detailed Performance Analysis

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(14, 10))

# Cumulative loss
cumsum_losses = np.cumsum(losses)
axes[0].plot(cumsum_losses, 'b-', linewidth=1.5)
axes[0].fill_between(range(len(cumsum_losses)), cumsum_losses, alpha=0.3)
axes[0].set_ylabel('Cumulative Loss')
axes[0].set_title('Cumulative Loss Over All Trials')
axes[0].grid(axis='y', alpha=0.3)
axes[0].spines['top'].set_visible(False)
axes[0].spines['right'].set_visible(False)

# Loss by trial (with rolling average)
window = 20
if len(losses) >= window:
    rolling_avg = np.convolve(losses, np.ones(window)/window, mode='valid')
    axes[1].plot(losses, 'lightblue', alpha=0.5, label='Loss per trial')
    axes[1].plot(range(window-1, len(losses)), rolling_avg, 'b-', linewidth=2, label=f'Rolling avg (window={window})')
    axes[1].set_ylabel('Loss')
    axes[1].set_title('Loss per Trial with Rolling Average')
    axes[1].legend()
    axes[1].grid(alpha=0.3)
    axes[1].spines['top'].set_visible(False)
    axes[1].spines['right'].set_visible(False)
else:
    axes[1].plot(losses, 'b.-', linewidth=1.5)
    axes[1].set_ylabel('Loss')
    axes[1].set_title('Loss per Trial')
    axes[1].grid(alpha=0.3)

# Phone sequence length distribution
axes[2].hist(phone_seq_lens, bins=30, color='green', alpha=0.7, edgecolor='black')
axes[2].axvline(np.mean(phone_seq_lens), color='red', linestyle='--', linewidth=2, label=f'Mean: {np.mean(phone_seq_lens):.1f}')
axes[2].set_xlabel('Phone Sequence Length')
axes[2].set_ylabel('Frequency')
axes[2].set_title('Distribution of Phone Sequence Lengths')
axes[2].legend()
axes[2].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nPhone sequence length statistics:")
print(f"  Min: {np.min(phone_seq_lens)}")
print(f"  Max: {np.max(phone_seq_lens)}")
print(f"  Mean: {np.mean(phone_seq_lens):.2f}")
print(f"  Median: {np.median(phone_seq_lens):.2f}")


# Plot 4: Sequence Comparison Analysis

In [None]:
# Compare decoded vs true sequences
print("Sequence Comparison Samples:")
print("=" * 80)

num_samples = min(5, len(decoded_seqs))
for i in range(num_samples):
    print(f"\nTrial {i+1}:")
    print(f"  Loss: {losses[i]:.4f}")
    print(f"  True seq length: {len(true_seq[i]) if isinstance(true_seq[i], (list, np.ndarray)) else 'N/A'}")
    print(f"  Decoded seq length: {len(decoded_seqs[i]) if isinstance(decoded_seqs[i], (list, np.ndarray)) else 'N/A'}")
    print(f"  Phone seq length: {phone_seq_lens[i]}")
    print(f"  Block: {block_nums[i]}, Trial: {trial_nums[i]}, Day: {day_indicies[i]}")
    if transcription and i < len(transcription):
        print(f"  Transcription: {transcription[i]}")

# Plot time steps distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# n_time_steps histogram
axes[0].hist(n_time_steps, bins=30, color='purple', alpha=0.7, edgecolor='black')
axes[0].axvline(np.mean(n_time_steps), color='red', linestyle='--', linewidth=2, label=f'Mean: {np.mean(n_time_steps):.1f}')
axes[0].set_xlabel('Time Steps')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Distribution of Time Steps')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Loss vs time steps
axes[1].scatter(n_time_steps, losses, alpha=0.5, s=30)
axes[1].set_xlabel('Time Steps')
axes[1].set_ylabel('Loss')
axes[1].set_title('Loss vs Time Steps')
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nTime steps statistics:")
print(f"  Min: {np.min(n_time_steps)}")
print(f"  Max: {np.max(n_time_steps)}")
print(f"  Mean: {np.mean(n_time_steps):.2f}")
print(f"  Median: {np.median(n_time_steps):.2f}")


# Summary Statistics

In [None]:
print("\n" + "="*80)
print("COMPREHENSIVE DATA ANALYSIS SUMMARY")
print("="*80)

print("\nðŸ“Š LOSS STATISTICS:")
print(f"  Average Loss: {avg_loss:.6f}")
print(f"  Min Loss: {np.min(losses):.6f}")
print(f"  Max Loss: {np.max(losses):.6f}")
print(f"  Std Dev: {np.std(losses):.6f}")
print(f"  Median: {np.median(losses):.6f}")

print("\nðŸ“Š PER STATISTICS:")
print(f"  Average PER: {avg_PER:.6f}")
if len(valid_pers) > 0:
    print(f"  Min PER: {np.min(valid_pers):.6f}")
    print(f"  Max PER: {np.max(valid_pers):.6f}")
    print(f"  Std Dev: {np.std(valid_pers):.6f}")
    print(f"  Median: {np.median(valid_pers):.6f}")
else:
    print(f"  No valid PER data")

print("\nðŸ“Š TRIAL INFORMATION:")
print(f"  Total Trials: {len(losses)}")
print(f"  Unique Days: {len(np.unique(day_indicies))}")
print(f"  Unique Blocks: {len(np.unique(block_nums))}")
print(f"  Days in data: {sorted(np.unique(day_indicies).astype(int).tolist())}")

print("\nðŸ“Š SEQUENCE INFORMATION:")
print(f"  Total sequences: {len(phone_seq_lens)}")
print(f"  Avg phone seq length: {np.mean(phone_seq_lens):.2f}")
print(f"  Avg time steps: {np.mean(n_time_steps):.2f}")

print("\nðŸ“Š TRIALS PER DAY:")
for day in sorted(trials_per_day.keys()):
    print(f"  Day {day}: {trials_per_day[day]} trials")

print("\n" + "="*80)
