# Comprehensive Exploratory Data Analysis: Music Generation Datasets

## Overview

This notebook provides a comprehensive exploratory data analysis of the processed datasets used in the multi-agent reinforcement learning framework for symbolic music generation. The analysis examines three key datasets:

1. **ComMU Bass Dataset**: Curated bass tracks from the ComMU dataset with rich metadata
2. **Bass Loops Dataset**: Collection of bass loops primarily from dance and jazz genres
3. **Combined Dataset**: Merged dataset combining both sources

### Analysis Goals

This EDA focuses on understanding:
- Musical feature distributions and characteristics
- Genre and style patterns across datasets
- Harmonic structures through chord progression analysis
- Temporal and rhythmic characteristics
- Dataset composition and split distributions
- Feature relationships and correlations

### Context

This analysis supports a multi-agent RL system that combines:
- **Perceiving Agent (GHSOM)**: Clusters musical features and discovers structural motifs
- **Generative Agent (DQN with LSTM)**: Selects musical elements to build sequences
- **Human Agent**: Provides feedback for reward shaping and adaptation

## 1. Setup and Data Loading

In [None]:
# Import required libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import ast
import warnings
from collections import Counter
import json

# Visualization settings
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 10
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['legend.fontsize'] = 10

warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(42)

print("Libraries imported successfully")

In [None]:
# Define data paths with fallbacks
candidate_dirs = [
    Path('/workspace/data/subset'),
    Path.cwd() / 'data' / 'subset',
    Path.cwd().parent / 'data' / 'subset',
    Path.cwd().parent.parent / 'data' / 'subset',
]
DATA_DIR = next((p for p in candidate_dirs if (p / 'combined_metadata_clean.csv').exists()), None)
if DATA_DIR is None:
    raise FileNotFoundError(
        "Could not locate combined_metadata_clean.csv in any expected data directory."
    )
DATA_DIR = DATA_DIR.resolve()
print(f"Using data directory: {DATA_DIR}")

# Load datasets
df_combined = pd.read_csv(DATA_DIR / 'combined_metadata_clean.csv')
df_commu = pd.read_csv(DATA_DIR / 'commu' / 'bass' / 'metadata_clean.csv')
df_bass_loops = pd.read_csv(DATA_DIR / 'LM_bass_loops_matched' / 'metadata_clean.csv')

print(f"Combined Dataset: {len(df_combined):,} samples")
print(f"ComMU Bass Dataset: {len(df_commu):,} samples")
print(f"Bass Loops Dataset: {len(df_bass_loops):,} samples")
print(f"\nTotal unique samples: {len(df_combined):,}")

In [None]:
# Display basic information about datasets
print("=" * 80)
print("COMBINED DATASET INFO")
print("=" * 80)
print(df_combined.info())
print("\n" + "=" * 80)
print("SAMPLE DATA")
print("=" * 80)
df_combined.head()

## 2. Dataset Composition Analysis

Understanding the composition and structure of our datasets is crucial for:
- Ensuring balanced training/validation/test splits
- Understanding data source distributions
- Identifying potential biases in the dataset

In [None]:
# Analyze dataset sources
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Source distribution
if 'source' in df_combined.columns:
    source_counts = df_combined['source'].value_counts()
    axes[0].pie(source_counts.values, labels=source_counts.index, autopct='%1.1f%%', startangle=90)
    axes[0].set_title('Dataset Source Distribution', fontsize=14, fontweight='bold')
else:
    axes[0].text(0.5, 0.5, 'Source column not available', ha='center', va='center')
    axes[0].set_title('Dataset Source Distribution', fontsize=14, fontweight='bold')

# Split distribution
split_counts = df_combined['split'].value_counts()
colors = ['#3498db', '#2ecc71', '#e74c3c']
axes[1].bar(split_counts.index, split_counts.values, color=colors)
axes[1].set_title('Train/Validation/Test Split', fontsize=14, fontweight='bold')
axes[1].set_ylabel('Number of Samples')
axes[1].set_xlabel('Split')
for i, v in enumerate(split_counts.values):
    axes[1].text(i, v + 200, f'{v:,}\n({v/len(df_combined)*100:.1f}%)', 
                ha='center', va='bottom', fontweight='bold')

# Split distribution by source
if 'source' in df_combined.columns:
    split_source = pd.crosstab(df_combined['split'], df_combined['source'])
    split_source.plot(kind='bar', stacked=True, ax=axes[2], color=['#9b59b6', '#f39c12'])
    axes[2].set_title('Split Distribution by Source', fontsize=14, fontweight='bold')
    axes[2].set_ylabel('Number of Samples')
    axes[2].set_xlabel('Split')
    axes[2].legend(title='Source')
    axes[2].tick_params(axis='x', rotation=0)
else:
    axes[2].text(0.5, 0.5, 'Source column not available', ha='center', va='center')
    axes[2].set_title('Split Distribution by Source', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

# Print detailed statistics
print("\nDetailed Split Statistics:")
print("=" * 50)
for split in ['train', 'val', 'test']:
    count = len(df_combined[df_combined['split'] == split])
    pct = count / len(df_combined) * 100
    print(f"{split.capitalize():10s}: {count:6,} samples ({pct:5.2f}%)")

## 3. Genre and Style Analysis

Genre distribution is critical for understanding:
- The stylistic diversity of the training data
- Potential genre biases in the model
- Coverage of different musical styles for generalization

In [None]:
# Genre analysis
def parse_genres(df):
    """Parse genre field which may contain multiple genres separated by |."""
    all_genres = []
    for genre in df['genre'].dropna():
        if '|' in str(genre):
            all_genres.extend(str(genre).split('|'))
        else:
            all_genres.append(str(genre))
    return pd.Series(all_genres).value_counts()

genre_counts = parse_genres(df_combined)

fig, axes = plt.subplots(1, 2, figsize=(18, 6))

# Genre distribution bar chart
top_genres = genre_counts.head(15)
axes[0].barh(range(len(top_genres)), top_genres.values, color=sns.color_palette('viridis', len(top_genres)))
axes[0].set_yticks(range(len(top_genres)))
axes[0].set_yticklabels(top_genres.index)
axes[0].set_xlabel('Number of Tracks')
axes[0].set_title('Top 15 Genres in Dataset', fontsize=14, fontweight='bold')
axes[0].invert_yaxis()
for i, v in enumerate(top_genres.values):
    axes[0].text(v + 50, i, f'{v:,}', va='center')

# Genre distribution by source
if 'source' in df_combined.columns:
    genre_source_data = []
    for source in df_combined['source'].unique():
        source_df = df_combined[df_combined['source'] == source]
        source_genres = parse_genres(source_df)
        genre_source_data.append(source_genres.head(10))
    
    # Combine for comparison
    genre_comparison = pd.DataFrame(genre_source_data).T.fillna(0)
    genre_comparison.columns = df_combined['source'].unique()
    genre_comparison.plot(kind='bar', ax=axes[1], color=['#9b59b6', '#f39c12'])
    axes[1].set_title('Top Genres by Source', fontsize=14, fontweight='bold')
    axes[1].set_ylabel('Number of Tracks')
    axes[1].set_xlabel('Genre')
    axes[1].legend(title='Source')
    axes[1].tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

print(f"\nTotal unique genres: {len(genre_counts)}")
print(f"\nTop 10 Genres:")
print("=" * 50)
for genre, count in genre_counts.head(10).items():
    print(f"{genre:20s}: {count:6,} ({count/len(df_combined)*100:5.2f}%)")

## 4. Musical Feature Analysis

### 4.1 Tempo (BPM) Distribution

Tempo is a fundamental characteristic affecting:
- Energy and feel of generated music
- Rhythmic complexity
- Genre characteristics

In [None]:
# BPM analysis
fig, axes = plt.subplots(2, 2, figsize=(18, 12))

# Overall BPM distribution
axes[0, 0].hist(df_combined['bpm'].dropna(), bins=50, color='#3498db', edgecolor='black', alpha=0.7)
axes[0, 0].axvline(df_combined['bpm'].median(), color='red', linestyle='--', linewidth=2, label=f'Median: {df_combined["bpm"].median():.1f}')
axes[0, 0].axvline(df_combined['bpm'].mean(), color='orange', linestyle='--', linewidth=2, label=f'Mean: {df_combined["bpm"].mean():.1f}')
axes[0, 0].set_xlabel('BPM (Beats Per Minute)')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].set_title('BPM Distribution', fontsize=14, fontweight='bold')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# BPM by source
if 'source' in df_combined.columns:
    sources = df_combined['source'].unique()
    bpm_by_source = [df_combined[df_combined['source'] == src]['bpm'].dropna() for src in sources]
    axes[0, 1].violinplot(bpm_by_source, positions=range(len(sources)), showmeans=True, showmedians=True)
    axes[0, 1].set_xticks(range(len(sources)))
    axes[0, 1].set_xticklabels(sources)
    axes[0, 1].set_ylabel('BPM')
    axes[0, 1].set_title('BPM Distribution by Source', fontsize=14, fontweight='bold')
    axes[0, 1].grid(True, alpha=0.3)

# BPM by genre (top genres)
top_genre_list = genre_counts.head(8).index.tolist()
bpm_genre_data = []
genre_labels = []
for genre in top_genre_list:
    mask = df_combined['genre'].str.contains(genre, na=False, case=False)
    genre_bpm = df_combined[mask]['bpm'].dropna()
    if len(genre_bpm) > 0:
        bpm_genre_data.append(genre_bpm)
        genre_labels.append(genre)

bp = axes[1, 0].boxplot(bpm_genre_data, labels=genre_labels, patch_artist=True)
for patch, color in zip(bp['boxes'], sns.color_palette('Set2', len(genre_labels))):
    patch.set_facecolor(color)
axes[1, 0].set_xlabel('Genre')
axes[1, 0].set_ylabel('BPM')
axes[1, 0].set_title('BPM Distribution by Genre', fontsize=14, fontweight='bold')
axes[1, 0].tick_params(axis='x', rotation=45)
axes[1, 0].grid(True, alpha=0.3)

# BPM statistics by split
split_stats = df_combined.groupby('split')['bpm'].agg(['mean', 'median', 'std', 'min', 'max'])
x = np.arange(len(split_stats.index))
width = 0.35
axes[1, 1].bar(x - width/2, split_stats['mean'], width, label='Mean', color='#3498db')
axes[1, 1].bar(x + width/2, split_stats['median'], width, label='Median', color='#2ecc71')
axes[1, 1].set_xticks(x)
axes[1, 1].set_xticklabels(split_stats.index)
axes[1, 1].set_ylabel('BPM')
axes[1, 1].set_title('BPM Statistics by Split', fontsize=14, fontweight='bold')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print statistics
print("\nBPM Statistics:")
print("=" * 50)
print(df_combined['bpm'].describe())
print(f"\nBPM Range: {df_combined['bpm'].min():.1f} - {df_combined['bpm'].max():.1f}")

### 4.2 Pitch Range Analysis

Pitch range indicates:
- Melodic complexity and span
- Instrument capabilities
- Musical expressiveness

In [None]:
# Convert pitch_range to numeric if it's categorical
pitch_range_mapping = {
    'low': 1,
    'mid_low': 2,
    'mid': 3,
    'mid_high': 4,
    'high': 5
}

df_combined['pitch_range_numeric'] = df_combined['pitch_range'].map(pitch_range_mapping)

# Also handle numeric pitch ranges if they exist
df_combined['pitch_range_value'] = pd.to_numeric(df_combined['pitch_range'], errors='coerce')

fig, axes = plt.subplots(2, 2, figsize=(18, 12))

# Pitch range distribution (categorical)
if df_combined['pitch_range_numeric'].notna().any():
    pitch_counts = (
        df_combined['pitch_range']
        .fillna('unknown')
        .astype(str)
        .value_counts()
        .sort_index(key=lambda idx: idx.astype(str))
    )
    axes[0, 0].bar(range(len(pitch_counts)), pitch_counts.values, 
                   color=sns.color_palette('coolwarm', len(pitch_counts)))
    axes[0, 0].set_xticks(range(len(pitch_counts)))
    axes[0, 0].set_xticklabels(pitch_counts.index, rotation=45)
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].set_title('Pitch Range Distribution (Categorical)', fontsize=14, fontweight='bold')
    for i, v in enumerate(pitch_counts.values):
        axes[0, 0].text(i, v + 50, f'{v:,}', ha='center', va='bottom')

# Pitch range distribution (numeric)
if df_combined['pitch_range_value'].notna().any():
    axes[0, 1].hist(df_combined['pitch_range_value'].dropna(), bins=30, 
                    color='#9b59b6', edgecolor='black', alpha=0.7)
    axes[0, 1].set_xlabel('Pitch Range (Semitones)')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].set_title('Pitch Range Distribution (Numeric)', fontsize=14, fontweight='bold')
    axes[0, 1].grid(True, alpha=0.3)

# Pitch range by source
if 'source' in df_combined.columns and df_combined['pitch_range_numeric'].notna().any():
    pitch_source = pd.crosstab(df_combined['pitch_range'].fillna('unknown'), df_combined['source'])
    pitch_source.plot(kind='bar', ax=axes[1, 0], color=['#9b59b6', '#f39c12'])
    axes[1, 0].set_title('Pitch Range by Source', fontsize=14, fontweight='bold')
    axes[1, 0].set_ylabel('Number of Tracks')
    axes[1, 0].set_xlabel('Pitch Range')
    axes[1, 0].legend(title='Source')
    axes[1, 0].tick_params(axis='x', rotation=45)

# Pitch range vs BPM scatter
if df_combined['pitch_range_numeric'].notna().any():
    scatter_data = df_combined[df_combined['pitch_range_numeric'].notna()]
    axes[1, 1].scatter(scatter_data['bpm'], scatter_data['pitch_range_numeric'], 
                       alpha=0.5, c=scatter_data['pitch_range_numeric'], 
                       cmap='coolwarm', s=20)
    axes[1, 1].set_xlabel('BPM')
    axes[1, 1].set_ylabel('Pitch Range (Encoded)')
    axes[1, 1].set_title('Pitch Range vs BPM', fontsize=14, fontweight='bold')
    axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Print statistics
print("\nPitch Range Distribution:")
print("=" * 50)
pitch_distribution = (
    df_combined['pitch_range']
    .fillna('unknown')
    .astype(str)
    .value_counts()
    .sort_index(key=lambda idx: idx.astype(str))
)
print(pitch_distribution)
if df_combined['pitch_range_value'].notna().any():
    print("\nNumeric Pitch Range Statistics:")
    print(df_combined['pitch_range_value'].describe())

### 4.3 Time Signature Analysis

Time signatures define:
- Rhythmic framework
- Beat patterns
- Musical complexity

In [None]:
# Time signature analysis
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Time signature distribution
time_sig_counts = df_combined['time_signature'].value_counts()
axes[0].pie(time_sig_counts.values, labels=time_sig_counts.index, autopct='%1.1f%%', 
            startangle=90, colors=sns.color_palette('Set3', len(time_sig_counts)))
axes[0].set_title('Time Signature Distribution', fontsize=14, fontweight='bold')

# Time signature by genre
if len(genre_labels) > 0:
    time_sig_genre_data = []
    for genre in genre_labels[:5]:  # Top 5 genres
        mask = df_combined['genre'].str.contains(genre, na=False, case=False)
        genre_time_sigs = df_combined[mask]['time_signature'].value_counts()
        time_sig_genre_data.append(genre_time_sigs)
    
    time_sig_df = pd.DataFrame(time_sig_genre_data, index=genre_labels[:5]).fillna(0).T
    time_sig_df.plot(kind='bar', ax=axes[1], stacked=True)
    axes[1].set_title('Time Signature by Top Genres', fontsize=14, fontweight='bold')
    axes[1].set_ylabel('Number of Tracks')
    axes[1].set_xlabel('Time Signature')
    axes[1].legend(title='Genre', bbox_to_anchor=(1.05, 1), loc='upper left')
    axes[1].tick_params(axis='x', rotation=0)

# Time signature vs BPM
time_sig_bpm = df_combined.groupby('time_signature')['bpm'].agg(['mean', 'std'])
x = np.arange(len(time_sig_bpm.index))
axes[2].bar(x, time_sig_bpm['mean'], yerr=time_sig_bpm['std'], 
            capsize=5, color='#3498db', alpha=0.7)
axes[2].set_xticks(x)
axes[2].set_xticklabels(time_sig_bpm.index)
axes[2].set_ylabel('Average BPM')
axes[2].set_xlabel('Time Signature')
axes[2].set_title('Average BPM by Time Signature', fontsize=14, fontweight='bold')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nTime Signature Distribution:")
print("=" * 50)
for ts, count in time_sig_counts.items():
    print(f"{ts:10s}: {count:6,} ({count/len(df_combined)*100:5.2f}%)")

### 4.4 Number of Measures

The number of measures indicates:
- Sequence length
- Structural complexity
- Learning challenges for the RL agent

In [None]:
# Number of measures analysis
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Measures distribution
axes[0].hist(df_combined['num_measures'].dropna(), bins=20, color='#2ecc71', 
             edgecolor='black', alpha=0.7)
axes[0].axvline(df_combined['num_measures'].median(), color='red', 
                linestyle='--', linewidth=2, label=f'Median: {df_combined["num_measures"].median():.0f}')
axes[0].axvline(df_combined['num_measures'].mean(), color='orange', 
                linestyle='--', linewidth=2, label=f'Mean: {df_combined["num_measures"].mean():.1f}')
axes[0].set_xlabel('Number of Measures')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Number of Measures Distribution', fontsize=14, fontweight='bold')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Measures by source
if 'source' in df_combined.columns:
    sources = df_combined['source'].unique()
    measures_by_source = [df_combined[df_combined['source'] == src]['num_measures'].dropna() 
                          for src in sources]
    bp = axes[1].boxplot(measures_by_source, labels=sources, patch_artist=True)
    for patch, color in zip(bp['boxes'], sns.color_palette('Set2', len(sources))):
        patch.set_facecolor(color)
    axes[1].set_ylabel('Number of Measures')
    axes[1].set_title('Measures Distribution by Source', fontsize=14, fontweight='bold')
    axes[1].grid(True, alpha=0.3)

# Measures vs BPM
axes[2].scatter(df_combined['num_measures'], df_combined['bpm'], 
                alpha=0.3, s=20, c='#e74c3c')
axes[2].set_xlabel('Number of Measures')
axes[2].set_ylabel('BPM')
axes[2].set_title('Number of Measures vs BPM', fontsize=14, fontweight='bold')
axes[2].grid(True, alpha=0.3)

# Add correlation coefficient
corr = df_combined[['num_measures', 'bpm']].corr().iloc[0, 1]
axes[2].text(0.05, 0.95, f'Correlation: {corr:.3f}', 
             transform=axes[2].transAxes, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.show()

print("\nNumber of Measures Statistics:")
print("=" * 50)
print(df_combined['num_measures'].describe())

### 4.5 Velocity Analysis

Velocity (MIDI dynamics) reflects:
- Dynamic range
- Expressive potential
- Performance characteristics

In [None]:
# Velocity analysis
# Filter out rows where velocity data is available
df_velocity = df_combined[(df_combined['min_velocity'].notna()) & 
                          (df_combined['max_velocity'].notna())].copy()

if len(df_velocity) > 0:
    df_velocity['velocity_range'] = df_velocity['max_velocity'] - df_velocity['min_velocity']
    
    fig, axes = plt.subplots(2, 2, figsize=(18, 12))
    
    # Min velocity distribution
    axes[0, 0].hist(df_velocity['min_velocity'], bins=30, color='#3498db', 
                    edgecolor='black', alpha=0.7)
    axes[0, 0].set_xlabel('Minimum Velocity')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].set_title('Minimum Velocity Distribution', fontsize=14, fontweight='bold')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Max velocity distribution
    axes[0, 1].hist(df_velocity['max_velocity'], bins=30, color='#e74c3c', 
                    edgecolor='black', alpha=0.7)
    axes[0, 1].set_xlabel('Maximum Velocity')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].set_title('Maximum Velocity Distribution', fontsize=14, fontweight='bold')
    axes[0, 1].grid(True, alpha=0.3)
    
    # Velocity range distribution
    axes[1, 0].hist(df_velocity['velocity_range'], bins=30, color='#2ecc71', 
                    edgecolor='black', alpha=0.7)
    axes[1, 0].set_xlabel('Velocity Range (Max - Min)')
    axes[1, 0].set_ylabel('Frequency')
    axes[1, 0].set_title('Velocity Range Distribution', fontsize=14, fontweight='bold')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Min vs Max velocity scatter
    axes[1, 1].scatter(df_velocity['min_velocity'], df_velocity['max_velocity'], 
                       alpha=0.5, s=20, c=df_velocity['velocity_range'], 
                       cmap='viridis')
    axes[1, 1].plot([0, 127], [0, 127], 'r--', alpha=0.5, label='y=x')
    axes[1, 1].set_xlabel('Minimum Velocity')
    axes[1, 1].set_ylabel('Maximum Velocity')
    axes[1, 1].set_title('Min vs Max Velocity', fontsize=14, fontweight='bold')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    cbar = plt.colorbar(axes[1, 1].collections[0], ax=axes[1, 1])
    cbar.set_label('Velocity Range')
    
    plt.tight_layout()
    plt.show()
    
    print("\nVelocity Statistics:")
    print("=" * 50)
    print("\nMinimum Velocity:")
    print(df_velocity['min_velocity'].describe())
    print("\nMaximum Velocity:")
    print(df_velocity['max_velocity'].describe())
    print("\nVelocity Range:")
    print(df_velocity['velocity_range'].describe())
else:
    print("\nNo velocity data available in the dataset.")

## 5. Harmonic Analysis: Chord Progressions

Chord progressions are crucial for:
- Understanding harmonic patterns
- Identifying common musical structures
- Training the perceiving agent to recognize motifs

In [None]:
# Chord progression analysis
def parse_chord_progressions(df, sample_size=None):
    """Extract individual chords from chord progression field."""
    all_chords = []
    
    # Sample if dataset is large
    df_sample = df.sample(n=min(sample_size or len(df), len(df)), random_state=42)
    
    for prog in df_sample['chord_progressions'].dropna():
        try:
            # Parse the string representation of list
            if isinstance(prog, str) and prog != '[]':
                chord_list = ast.literal_eval(prog)
                if isinstance(chord_list, list) and len(chord_list) > 0:
                    # Flatten nested lists
                    for item in chord_list:
                        if isinstance(item, list):
                            all_chords.extend(item)
                        else:
                            all_chords.append(item)
        except (ValueError, SyntaxError):
            continue
    
    return pd.Series(all_chords).value_counts()

# Get chord statistics
chord_counts = parse_chord_progressions(df_combined, sample_size=10000)

if len(chord_counts) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(18, 12))
    
    # Top chords
    top_chords = chord_counts.head(20)
    axes[0, 0].barh(range(len(top_chords)), top_chords.values, 
                    color=sns.color_palette('viridis', len(top_chords)))
    axes[0, 0].set_yticks(range(len(top_chords)))
    axes[0, 0].set_yticklabels(top_chords.index)
    axes[0, 0].set_xlabel('Frequency')
    axes[0, 0].set_title('Top 20 Most Common Chords', fontsize=14, fontweight='bold')
    axes[0, 0].invert_yaxis()
    for i, v in enumerate(top_chords.values):
        axes[0, 0].text(v + 10, i, f'{v:,}', va='center')
    
    # Chord frequency distribution
    axes[0, 1].hist(chord_counts.values, bins=50, color='#9b59b6', 
                    edgecolor='black', alpha=0.7, log=True)
    axes[0, 1].set_xlabel('Chord Frequency')
    axes[0, 1].set_ylabel('Number of Chords (log scale)')
    axes[0, 1].set_title('Chord Frequency Distribution', fontsize=14, fontweight='bold')
    axes[0, 1].grid(True, alpha=0.3)
    
    # Chord type analysis (major, minor, 7th, etc.)
    chord_types = {
        'Major': sum(1 for c in chord_counts.index if 'm' not in c.lower() and '7' not in c),
        'Minor': sum(1 for c in chord_counts.index if 'm' in c.lower() and '7' not in c),
        'Seventh': sum(1 for c in chord_counts.index if '7' in c),
        'Diminished': sum(1 for c in chord_counts.index if 'dim' in c.lower()),
        'Augmented': sum(1 for c in chord_counts.index if '+' in c or 'aug' in c.lower()),
        'Suspended': sum(1 for c in chord_counts.index if 'sus' in c.lower()),
    }
    
    axes[1, 0].pie(chord_types.values(), labels=chord_types.keys(), autopct='%1.1f%%', 
                   startangle=90, colors=sns.color_palette('Set2', len(chord_types)))
    axes[1, 0].set_title('Chord Type Distribution', fontsize=14, fontweight='bold')
    
    # Root note distribution
    root_notes = []
    for chord in chord_counts.index:
        # Extract root note (first letter(s) before modifiers)
        root = chord.split('m')[0].split('7')[0].split('sus')[0].split('dim')[0].split('+')[0]
        root_notes.append(root)
    
    root_counts = pd.Series(root_notes).value_counts().head(12)
    axes[1, 1].bar(range(len(root_counts)), root_counts.values, 
                   color=sns.color_palette('husl', len(root_counts)))
    axes[1, 1].set_xticks(range(len(root_counts)))
    axes[1, 1].set_xticklabels(root_counts.index, rotation=45)
    axes[1, 1].set_ylabel('Frequency')
    axes[1, 1].set_xlabel('Root Note')
    axes[1, 1].set_title('Root Note Distribution', fontsize=14, fontweight='bold')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("\nChord Progression Statistics:")
    print("=" * 50)
    print(f"Total unique chords: {len(chord_counts):,}")
    print(f"Total chord occurrences: {chord_counts.sum():,}")
    print(f"\nTop 10 Most Common Chords:")
    for i, (chord, count) in enumerate(chord_counts.head(10).items(), 1):
        print(f"{i:2d}. {chord:15s}: {count:6,}")
else:
    print("\nNo chord progression data available or could not parse chord progressions.")

## 6. Instrument and Audio Key Analysis

In [None]:
# Instrument and key analysis
fig, axes = plt.subplots(2, 2, figsize=(18, 12))

# Instrument distribution
inst_counts = df_combined['inst'].value_counts().head(15)
axes[0, 0].barh(range(len(inst_counts)), inst_counts.values, 
                color=sns.color_palette('tab20', len(inst_counts)))
axes[0, 0].set_yticks(range(len(inst_counts)))
axes[0, 0].set_yticklabels(inst_counts.index)
axes[0, 0].set_xlabel('Frequency')
axes[0, 0].set_title('Top 15 Instruments', fontsize=14, fontweight='bold')
axes[0, 0].invert_yaxis()
for i, v in enumerate(inst_counts.values):
    axes[0, 0].text(v + 50, i, f'{v:,}', va='center')

# Audio key distribution
key_counts = df_combined['audio_key'].value_counts().head(15)
axes[0, 1].bar(range(len(key_counts)), key_counts.values, 
               color=sns.color_palette('rainbow', len(key_counts)))
axes[0, 1].set_xticks(range(len(key_counts)))
axes[0, 1].set_xticklabels(key_counts.index, rotation=45, ha='right')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('Top 15 Audio Keys', fontsize=14, fontweight='bold')
for i, v in enumerate(key_counts.values):
    axes[0, 1].text(i, v + 50, f'{v:,}', ha='center', va='bottom', rotation=0)

# Track role distribution
role_counts = df_combined['track_role'].value_counts()
axes[1, 0].pie(role_counts.values, labels=role_counts.index, autopct='%1.1f%%', 
               startangle=90, colors=sns.color_palette('pastel', len(role_counts)))
axes[1, 0].set_title('Track Role Distribution', fontsize=14, fontweight='bold')

# Sample rhythm distribution
rhythm_counts = df_combined['sample_rhythm'].value_counts().head(10)
axes[1, 1].bar(range(len(rhythm_counts)), rhythm_counts.values, 
               color=sns.color_palette('muted', len(rhythm_counts)))
axes[1, 1].set_xticks(range(len(rhythm_counts)))
axes[1, 1].set_xticklabels(rhythm_counts.index, rotation=45, ha='right')
axes[1, 1].set_ylabel('Frequency')
axes[1, 1].set_title('Top 10 Sample Rhythms', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

print("\nInstrument Statistics:")
print("=" * 50)
print(f"Total unique instruments: {df_combined['inst'].nunique()}")
print(f"\nTop 10 Instruments:")
for inst, count in inst_counts.head(10).items():
    print(f"{inst:30s}: {count:6,} ({count/len(df_combined)*100:5.2f}%)")

print("\n" + "=" * 50)
print("Audio Key Statistics:")
print("=" * 50)
print(f"Total unique keys: {df_combined['audio_key'].nunique()}")

## 7. Feature Correlation Analysis

Understanding feature correlations helps:
- Identify redundant features
- Understand feature relationships
- Guide feature engineering for the RL agent

In [None]:
# Feature correlation analysis
# Select numeric features for correlation
numeric_features = ['bpm', 'num_measures']

# Add pitch_range_numeric if available
if 'pitch_range_numeric' in df_combined.columns and df_combined['pitch_range_numeric'].notna().any():
    numeric_features.append('pitch_range_numeric')

# Add velocity features if available
if 'min_velocity' in df_combined.columns and df_combined['min_velocity'].notna().any():
    numeric_features.extend(['min_velocity', 'max_velocity'])

# Compute correlation matrix
corr_matrix = df_combined[numeric_features].corr()

fig, axes = plt.subplots(1, 2, figsize=(18, 6))

# Correlation heatmap
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', center=0, 
            square=True, linewidths=1, ax=axes[0], 
            cbar_kws={"shrink": 0.8}, fmt='.3f')
axes[0].set_title('Feature Correlation Matrix', fontsize=14, fontweight='bold')

# Feature pairplot (sample for performance)
clean_numeric = df_combined[numeric_features].dropna()
sample_size = min(1000, len(clean_numeric))
sample_df = clean_numeric.sample(n=sample_size, random_state=42) if sample_size > 0 else clean_numeric
axes[1].axis('off')
axes[1].text(0.5, 0.5, 'See detailed pairplot below', 
             ha='center', va='center', fontsize=12, style='italic')

plt.tight_layout()
plt.show()

# Create pairplot for detailed view
if len(clean_numeric) > 0:
    print("\nGenerating detailed pairplot (this may take a moment)...\n")
    pairplot_sample = clean_numeric.sample(
        n=min(500, len(clean_numeric)), random_state=42
    )
    sns.pairplot(pairplot_sample, diag_kind='kde', plot_kws={'alpha': 0.6})
    plt.suptitle('Feature Pairplot (Sample)', y=1.01, fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
else:
    print("\nInsufficient numeric data for pairplot.")

print("\nFeature Correlation Matrix:")
print("=" * 50)
print(corr_matrix)

## 8. Data Quality Assessment

In [None]:
# Data quality assessment
print("Data Quality Report")
print("=" * 80)

# Missing values analysis
missing_values = df_combined.isnull().sum()
missing_pct = (missing_values / len(df_combined)) * 100
missing_df = pd.DataFrame({
    'Missing Count': missing_values,
    'Percentage': missing_pct
}).sort_values('Missing Count', ascending=False)

print("\nMissing Values:")
print("-" * 80)
print(missing_df[missing_df['Missing Count'] > 0])

# Visualize missing values
fig, axes = plt.subplots(1, 2, figsize=(18, 6))

# Missing values bar chart
missing_features = missing_df[missing_df['Missing Count'] > 0]
if len(missing_features) > 0:
    axes[0].barh(range(len(missing_features)), missing_features['Percentage'].values, 
                 color='#e74c3c')
    axes[0].set_yticks(range(len(missing_features)))
    axes[0].set_yticklabels(missing_features.index)
    axes[0].set_xlabel('Percentage Missing (%)')
    axes[0].set_title('Missing Values by Feature', fontsize=14, fontweight='bold')
    axes[0].invert_yaxis()
    for i, v in enumerate(missing_features['Percentage'].values):
        axes[0].text(v + 0.5, i, f'{v:.1f}%', va='center')
else:
    axes[0].text(0.5, 0.5, 'No missing values found!', 
                 ha='center', va='center', fontsize=14, style='italic')
    axes[0].set_title('Missing Values by Feature', fontsize=14, fontweight='bold')

# Data completeness by split
completeness_by_split = []
for split in ['train', 'val', 'test']:
    split_df = df_combined[df_combined['split'] == split]
    completeness = (1 - split_df.isnull().sum() / len(split_df)) * 100
    completeness_by_split.append(completeness.mean())

axes[1].bar(['Train', 'Val', 'Test'], completeness_by_split, 
            color=['#3498db', '#2ecc71', '#e74c3c'])
axes[1].set_ylabel('Completeness (%)')
axes[1].set_title('Data Completeness by Split', fontsize=14, fontweight='bold')
axes[1].set_ylim([0, 105])
for i, v in enumerate(completeness_by_split):
    axes[1].text(i, v + 1, f'{v:.1f}%', ha='center', va='bottom', fontweight='bold')
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

# Duplicate analysis
print("\n" + "=" * 80)
print("Duplicate Analysis:")
print("-" * 80)
duplicates = df_combined.duplicated(subset=['id']).sum()
print(f"Duplicate IDs: {duplicates}")
print(f"Unique samples: {df_combined['id'].nunique():,}")

# Data type consistency
print("\n" + "=" * 80)
print("Data Types:")
print("-" * 80)
print(df_combined.dtypes)

## 9. Summary Statistics and Key Findings

In [None]:
# Generate comprehensive summary
print("=" * 80)
print("COMPREHENSIVE DATASET SUMMARY")
print("=" * 80)

summary = {
    'Dataset Size': {
        'Total Samples': len(df_combined),
        'Training Samples': len(df_combined[df_combined['split'] == 'train']),
        'Validation Samples': len(df_combined[df_combined['split'] == 'val']),
        'Test Samples': len(df_combined[df_combined['split'] == 'test']),
    },
    'Musical Characteristics': {
        'BPM Range': f"{df_combined['bpm'].min():.1f} - {df_combined['bpm'].max():.1f}",
        'Average BPM': f"{df_combined['bpm'].mean():.1f}",
        'Unique Genres': df_combined['genre'].nunique(),
        'Unique Instruments': df_combined['inst'].nunique(),
        'Most Common Time Signature': df_combined['time_signature'].mode()[0] if len(df_combined) > 0 else 'N/A',
    },
    'Complexity Metrics': {
        'Average Measures': f"{df_combined['num_measures'].mean():.1f}",
        'Measures Range': f"{df_combined['num_measures'].min():.0f} - {df_combined['num_measures'].max():.0f}",
    }
}

if 'source' in df_combined.columns:
    summary['Data Sources'] = dict(df_combined['source'].value_counts())

if len(chord_counts) > 0:
    summary['Harmonic Characteristics'] = {
        'Unique Chords': len(chord_counts),
        'Most Common Chord': chord_counts.index[0],
    }

# Print summary
for category, metrics in summary.items():
    print(f"\n{category}:")
    print("-" * 80)
    for metric, value in metrics.items():
        print(f"  {metric:30s}: {value}")

print("\n" + "=" * 80)

## 10. Key Insights and Recommendations

### Key Findings:

1. **Dataset Composition**
   - The combined dataset provides a diverse collection of bass-focused musical material
   - Train/validation/test splits are appropriately distributed for RL training

2. **Musical Diversity**
   - Wide BPM range supporting various tempo-based generation tasks
   - Multiple genres represented, with cinematic and dance/jazz as primary categories
   - Rich harmonic content with diverse chord progressions

3. **Feature Characteristics**
   - Predominantly 4/4 time signature (standard for most genres)
   - Consistent measure lengths suitable for pattern learning
   - Bass-focused pitch ranges appropriate for the generative task

4. **Data Quality**
   - Minimal missing values in critical features
   - Consistent data formatting across both sources
   - No significant duplicate issues

### Recommendations for Multi-Agent RL Training:

1. **GHSOM Pre-training**
   - Use BPM, pitch range, and chord progression features for clustering
   - Consider genre-specific clustering to capture style-specific motifs
   - Pay attention to tempo-based groupings for rhythmic pattern discovery

2. **DQN Agent Training**
   - Balance rewards across tempo ranges to avoid BPM bias
   - Use chord progression patterns as structural rewards
   - Consider time signature as a constraint in action space

3. **Human-in-the-Loop Adaptation**
   - Focus feedback collection on genre-specific preferences
   - Use BPM ranges to guide style-appropriate generation
   - Leverage chord progression familiarity for quality assessment

4. **Feature Engineering**
   - Consider normalizing BPM to tempo categories (slow, medium, fast)
   - Extract chord transition patterns for sequence modeling
   - Encode time signatures for structural constraints

### Potential Challenges:

1. **Data Imbalance**
   - Some genres are underrepresented
   - Consider data augmentation or weighted sampling

2. **Missing Metadata**
   - Velocity information incomplete for some samples
   - Audio key often unknown in bass_loops dataset
   - May need imputation or fallback strategies

3. **Complexity Range**
   - Wide variation in measures and structure
   - May need difficulty-based curriculum learning

### Next Steps:

1. Perform GHSOM pre-training on selected features
2. Extract and analyze discovered motifs
3. Design reward functions based on feature distributions
4. Implement data augmentation for underrepresented categories
5. Set up monitoring for feature distribution shifts during training

## Conclusion

This comprehensive EDA has provided deep insights into the musical characteristics, structure, and quality of our training datasets. The analysis reveals a rich, diverse collection of bass-focused musical material suitable for training a multi-agent RL system for symbolic music generation.

The datasets exhibit:
- Strong genre diversity
- Wide tempo and dynamic ranges
- Rich harmonic content
- Appropriate structural complexity

These characteristics make them well-suited for training perceiving agents (GHSOM) to discover musical patterns and generative agents (DQN+LSTM) to create coherent musical sequences with human-in-the-loop guidance.