# D-TRAK Attribution Analysis

Analyze training data attribution for generated audio samples.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
from IPython.display import Audio, display, HTML

%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 6)

## Load Attribution Scores

In [None]:
# Paths
output_dir = Path("/home/xiruij/stable-audio-tools/outputs/dtrak_attribution_20260216_005732")

# Load metadata
with open(output_dir / "scores_query_x_train.memmap.meta.json") as f:
    meta = json.load(f)

print(f"Query samples: {meta['query_num_examples']}")
print(f"Train samples: {meta['train_num_examples']}")
print(f"Score shape: {meta['shape']}")

# Load scores (query x train)
scores = np.memmap(
    output_dir / "scores_query_x_train.memmap",
    dtype=np.float32,
    mode="r",
    shape=tuple(meta['shape'])
)

# Load IDs
with open(output_dir / "query_features.memmap.ids.txt") as f:
    query_ids = [line.strip() for line in f]

with open(output_dir / "train_features.memmap.ids.txt") as f:
    train_ids = [line.strip() for line in f]

print(f"\nLoaded {len(query_ids)} query IDs and {len(train_ids)} train IDs")

## Distribution of Attribution Scores

In [None]:
# Plot score distribution for each query sample
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
axes = axes.flatten()

for i in range(min(10, len(query_ids))):
    ax = axes[i]
    query_scores = scores[i]
    
    ax.hist(query_scores, bins=50, alpha=0.7, edgecolor='black')
    ax.axvline(query_scores.mean(), color='red', linestyle='--', label=f'Mean: {query_scores.mean():.2e}')
    ax.set_title(f'Query {i}: {Path(query_ids[i]).name}', fontsize=10)
    ax.set_xlabel('Attribution Score')
    ax.set_ylabel('Frequency')
    ax.legend(fontsize=8)
    ax.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(output_dir / 'score_distributions.png', dpi=150, bbox_inches='tight')
plt.show()

print("\nScore statistics:")
print(f"Overall mean: {scores.mean():.4e}")
print(f"Overall std: {scores.std():.4e}")
print(f"Overall min: {scores.min():.4e}")
print(f"Overall max: {scores.max():.4e}")

## Top/Middle/Bottom Influential Training Samples

For each query sample, find:
- **Top 5**: highest positive attribution (most influential)
- **Middle 5**: near-zero attribution (neutral)
- **Bottom 5**: most negative attribution (least influential / counterfactual)

In [None]:
def display_audio_samples(query_idx, n_samples=5):
    """
    Display query sample and its top/middle/bottom attributed training samples.
    """
    query_path = query_ids[query_idx]
    query_scores = scores[query_idx]
    
    # Get indices
    sorted_indices = np.argsort(query_scores)
    top_indices = sorted_indices[-n_samples:][::-1]  # Highest scores
    bottom_indices = sorted_indices[:n_samples]       # Lowest scores
    
    # Middle: find indices closest to median
    median_score = np.median(query_scores)
    middle_indices = np.argsort(np.abs(query_scores - median_score))[:n_samples]
    
    # Display
    print("=" * 80)
    print(f"QUERY SAMPLE {query_idx}: {Path(query_path).name}")
    print("=" * 80)
    
    if Path(query_path).exists():
        display(HTML(f"<h3>Query Sample</h3>"))
        display(Audio(query_path, rate=44100))
    else:
        print(f"‚ö†Ô∏è Query file not found: {query_path}")
    
    # Top influential
    display(HTML(f"<h3>üî• Top {n_samples} Most Influential Training Samples</h3>"))
    for rank, idx in enumerate(top_indices, 1):
        train_path = train_ids[idx]
        score = query_scores[idx]
        print(f"\n#{rank} | Score: {score:.4e} | {Path(train_path).name}")
        if Path(train_path).exists():
            display(Audio(train_path, rate=44100))
        else:
            print(f"‚ö†Ô∏è File not found: {train_path}")
    
    # Middle neutral
    display(HTML(f"<h3>‚öñÔ∏è {n_samples} Neutral Training Samples (near median)</h3>"))
    for rank, idx in enumerate(middle_indices, 1):
        train_path = train_ids[idx]
        score = query_scores[idx]
        print(f"\n#{rank} | Score: {score:.4e} | {Path(train_path).name}")
        if Path(train_path).exists():
            display(Audio(train_path, rate=44100))
        else:
            print(f"‚ö†Ô∏è File not found: {train_path}")
    
    # Bottom
    display(HTML(f"<h3>‚ùÑÔ∏è Bottom {n_samples} Least Influential Training Samples</h3>"))
    for rank, idx in enumerate(bottom_indices, 1):
        train_path = train_ids[idx]
        score = query_scores[idx]
        print(f"\n#{rank} | Score: {score:.4e} | {Path(train_path).name}")
        if Path(train_path).exists():
            display(Audio(train_path, rate=44100))
        else:
            print(f"‚ö†Ô∏è File not found: {train_path}")
    
    print("\n" + "=" * 80 + "\n")

## Analysis: Query Sample 0

In [None]:
display_audio_samples(query_idx=0, n_samples=3)

## Analysis: Query Sample 1

In [None]:
display_audio_samples(query_idx=1, n_samples=3)

## Analysis: Query Sample 2

In [None]:
display_audio_samples(query_idx=2, n_samples=3)

## Analyze All Query Samples at Once

In [None]:
# Uncomment to analyze all query samples
# for i in range(len(query_ids)):
#     display_audio_samples(query_idx=i, n_samples=3)

## Export Top Attributions to CSV

In [None]:
import pandas as pd

# Create summary table
results = []
for i in range(len(query_ids)):
    query_scores = scores[i]
    top_idx = np.argmax(query_scores)
    bottom_idx = np.argmin(query_scores)
    
    results.append({
        'query_idx': i,
        'query_file': Path(query_ids[i]).name,
        'top_train_file': Path(train_ids[top_idx]).name,
        'top_score': query_scores[top_idx],
        'bottom_train_file': Path(train_ids[bottom_idx]).name,
        'bottom_score': query_scores[bottom_idx],
        'mean_score': query_scores.mean(),
        'std_score': query_scores.std(),
    })

df = pd.DataFrame(results)
df.to_csv(output_dir / 'attribution_summary.csv', index=False)
print("Saved attribution summary to:", output_dir / 'attribution_summary.csv')
df