# Schema Analysis Tool - Google Colab Notebook

This notebook provides a complete, reproducible workflow for analyzing schema experiment data.

## Overview

The analysis pipeline consists of:
1. **Load & Merge**: Combine all CSV files from raw data
2. **Clean**: Remove rows with missing group_id
3. **Preprocess**: Rename faces, transform angles
4. **Filter**: Apply angle rules (3-43°) and validate subjects
5. **Balance**: Remove unmatched trials
6. **Analyze**: Calculate D-values and statistics

---

## 1. Setup and Installation

First, we'll clone the repository and install the schema_analysis package.

In [None]:
# Clone the repository
!git clone https://github.com/BackyardBrains/schema-analysis.git

# Install the package
!pip install -q ./schema-analysis

print("✓ Installation complete!")

## 2. Import Required Libraries

Import all necessary libraries for the analysis.

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from schema_analysis.data_loader import load_and_merge_csvs
from schema_analysis import TubeTrials

# Set plotting style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

print("✓ Libraries imported successfully!")

## 3. Upload Your Data

You have two options for providing your data:

### Option A: Upload CSV files directly
Run the cell below to upload your CSV files.

In [None]:
from google.colab import files

# Create data directory
os.makedirs('data/raw', exist_ok=True)

# Upload files
print("Please upload your CSV files:")
uploaded = files.upload()

# Move uploaded files to data/raw directory
for filename in uploaded.keys():
    os.rename(filename, f'data/raw/{filename}')
    print(f"✓ Moved {filename} to data/raw/")

print("\n✓ Upload complete!")

### Option B: Mount Google Drive
If your data is in Google Drive, run this cell instead and update the path.

In [None]:
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Update this path to point to your data folder in Google Drive
DATA_DIR = '/content/drive/MyDrive/your-data-folder'

print(f"✓ Google Drive mounted! Data directory: {DATA_DIR}")

## 4. Configuration

Set the analysis parameters. These are the standard values used in the official analysis.

In [None]:
# Analysis parameters
DATA_DIR = 'data/raw'  # Change this if using Google Drive (Option B above)
MIN_ANGLE = 3
MAX_ANGLE = 43
MAX_INVALID_TRIALS = 2

print("Configuration:")
print(f"  Data directory: {DATA_DIR}")
print(f"  Valid angle range: {MIN_ANGLE}° - {MAX_ANGLE}°")
print(f"  Max invalid trials per subject: {MAX_INVALID_TRIALS}")

## 5. Run the Analysis

This cell runs the complete standardized analysis pipeline.

In [None]:
print("=" * 70)
print("STANDARDIZED ANALYSIS PIPELINE")
print("=" * 70)

# Step 1: Load & Merge Data
print("\n[STEP 1] Loading and merging CSV files...")
merged_df = load_and_merge_csvs(DATA_DIR)
print(f"Loaded {len(merged_df)} total trials")

# Step 2: Clean Data
print("\n[STEP 2] Cleaning data (removing missing group_id)...")
initial_count = len(merged_df)
merged_df = merged_df.dropna(subset=['session_group'])
cleaned_count = len(merged_df)
dropped = initial_count - cleaned_count
print(f"Removed {dropped} rows with missing session_group")
print(f"Remaining: {cleaned_count} trials")

# Step 3-6: Process with TubeTrials
trials = TubeTrials(merged_df)

print(f"\n[STEP 3] Preprocessing (rename faces, transform angles)...")
trials.process_angles()
print(f"✓ Angles processed")

print(f"\n[STEP 4] Filtering trials...")
print(f"  Applying angle rule: {MIN_ANGLE} < end_angle < {MAX_ANGLE}")
trials.mark_valid_angles(min_angle=MIN_ANGLE, max_angle=MAX_ANGLE)

print(f"  Validating subjects (max {MAX_INVALID_TRIALS} invalid trials)...")
trials.mark_valid_subjects(max_invalid_trials=MAX_INVALID_TRIALS)

print(f"\n[STEP 5] Selecting valid trials...")
clean_trials = trials.select(valid_only=True)
print(f"Clean trials: {len(clean_trials)}")

print(f"\n[STEP 6] Balancing (removing unmatched trials)...")
results = clean_trials.calc_d_values()

# Step 7: Validity Report & Statistics
print(f"\n[STEP 7] Attrition Report & Statistics...")
trials.get_validity_stats()
stats_df = clean_trials.calc_stats()

print(f"Valid pairs: {len(results)}")
print(f"Paired trials: {len(results) * 2}")

print("\n" + "=" * 70)
print("✓ ANALYSIS COMPLETE")
print("=" * 70)

## 6. View Results

Display the calculated D-values and statistics.

In [None]:
print("D-Value Pairs (first 20):")
print("=" * 70)
display(results.head(20))

print("\n\nStatistics by Face ID:")
print("=" * 70)
display(stats_df)

## 7. Visualize Results

Create visualizations to better understand the data.

In [None]:
# Create figure with subplots
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# 1. Distribution of D-values
axes[0, 0].hist(results['d'], bins=30, edgecolor='black', alpha=0.7)
axes[0, 0].axvline(results['d'].mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {results["d"].mean():.2f}')
axes[0, 0].set_xlabel('D-value (degrees)', fontsize=12)
axes[0, 0].set_ylabel('Frequency (D-values)', fontsize=12)
axes[0, 0].set_title('Distribution of D-values (All Faces)', fontsize=14, fontweight='bold')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# 2. D-values by Face ID
face_order = sorted(results['face_id'].unique())
sns.boxplot(data=results, x='face_id', y='d', order=face_order, ax=axes[0, 1])
axes[0, 1].set_xlabel('Face ID', fontsize=12)
axes[0, 1].set_ylabel('D-value (degrees)', fontsize=12)
axes[0, 1].set_title('D-values by Face ID', fontsize=14, fontweight='bold')
axes[0, 1].grid(True, alpha=0.3, axis='y')

# 3. Mean D-value by Face ID with error bars
x_pos = range(len(stats_df))
axes[1, 0].bar(x_pos, stats_df['mean'], yerr=stats_df['sem'], 
               capsize=5, alpha=0.7, edgecolor='black')
axes[1, 0].set_xticks(x_pos)
axes[1, 0].set_xticklabels(stats_df['face_id'])
axes[1, 0].set_xlabel('Face ID', fontsize=12)
axes[1, 0].set_ylabel('Mean D-value (degrees)', fontsize=12)
axes[1, 0].set_title('Mean D-value by Face ID (±SEM)', fontsize=14, fontweight='bold')
axes[1, 0].grid(True, alpha=0.3, axis='y')

# 4. P-values by Face ID
colors = ['green' if p < 0.05 else 'red' for p in stats_df['p_value']]
axes[1, 1].bar(x_pos, stats_df['p_value'], color=colors, alpha=0.7, edgecolor='black')
axes[1, 1].axhline(0.05, color='black', linestyle='--', linewidth=2, label='p = 0.05')
axes[1, 1].set_xticks(x_pos)
axes[1, 1].set_xticklabels(stats_df['face_id'])
axes[1, 1].set_xlabel('Face ID', fontsize=12)
axes[1, 1].set_ylabel('P-value', fontsize=12)
axes[1, 1].set_title('Statistical Significance by Face ID', fontsize=14, fontweight='bold')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\n✓ Visualizations generated!")

## 8. Download Results

Save the results to CSV files and download them.

In [None]:
# Create results directory
os.makedirs('results', exist_ok=True)

# Save results
results_file = 'results/d_values.csv'
stats_file = 'results/statistics.csv'

results.to_csv(results_file, index=False)
stats_df.to_csv(stats_file, index=False)

print("Results saved to:")
print(f"  - {results_file}")
print(f"  - {stats_file}")

# Download files
print("\nDownloading files...")
files.download(results_file)
files.download(stats_file)

print("\n✓ Download complete!")

## 9. Summary Statistics

Display a comprehensive summary of the analysis.

In [None]:
print("=" * 70)
print("ANALYSIS SUMMARY")
print("=" * 70)

print("\nData Processing:")
print(f"  Initial trials: {initial_count}")
print(f"  After cleaning: {cleaned_count}")
print(f"  After filtering: {len(clean_trials)}")
print(f"  Valid pairs: {len(results)}")
print(f"  Paired trials: {len(results) * 2}")

print("\nOverall Statistics:")
print(f"  Mean D-value: {results['d'].mean():.4f}°")
print(f"  Std Dev: {results['d'].std():.4f}°")
print(f"  SEM: {results['d'].sem():.4f}°")
print(f"  Min D-value: {results['d'].min():.4f}°")
print(f"  Max D-value: {results['d'].max():.4f}°")

print("\nFace IDs Analyzed:")
for _, row in stats_df.iterrows():
    sig = "***" if row['p_value'] < 0.001 else "**" if row['p_value'] < 0.01 else "*" if row['p_value'] < 0.05 else "ns"
    print(f"  {row['face_id']}: Mean={row['mean']:.4f}°, SEM={row['sem']:.4f}, p={row['p_value']:.4f} {sig}")

print("\n" + "=" * 70)
print("Legend: *** p<0.001, ** p<0.01, * p<0.05, ns = not significant")
print("=" * 70)

---

## Optional: Verification with Dummy Data

The following cells allow you to verify the analysis pipeline using dummy data with known expected results.

In [None]:
# Download verification script from repository
!wget -q https://raw.githubusercontent.com/BackyardBrains/schema-analysis/main/verification/run_verification.py

# Run verification
!python run_verification.py

print("\n✓ Verification complete! If all tests passed, the analysis pipeline is working correctly.")

---

## Need Help?

- **Repository**: [BackyardBrains/schema-analysis](https://github.com/BackyardBrains/schema-analysis)
- **Documentation**: See the `README.md` and `verification/README.md` in the repository
- **Issues**: Report bugs or ask questions on the GitHub Issues page

---

## Citation

If you use this analysis tool in your research, please cite:

```
Schema Analysis Tool
https://github.com/BackyardBrains/schema-analysis
```