# Case Study Analysis

Amazon Product Classification - DATA304 Final Project

**Objective:** Detailed analysis of individual predictions, error patterns, and model behavior

## Setup

In [None]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import pickle
from collections import Counter
import networkx as nx

# Project imports
from src.data_preprocessing import DataLoader

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

# Directories
predictions_dir = Path('../results/predictions')
fig_dir = Path('../results/images/case_study')
fig_dir.mkdir(parents=True, exist_ok=True)

print(f"{'='*70}")
print(f"              CASE STUDY ANALYSIS")
print(f"{'='*70}")
print(f"✓ Images will be saved to: {fig_dir}\n")

In [None]:
# Load data
data_loader = DataLoader(data_dir='../data/raw/Amazon_products')
data_loader.load_all()

print(f"✓ Loaded {data_loader.num_classes} classes")
print(f"✓ Test samples: {len(data_loader.test_corpus)}")

# Load class names
class_names = {}
with open('../data/raw/Amazon_products/classes.txt', 'r', encoding='utf-8') as f:
    for line in f:
        if '\t' in line:
            class_id, class_name = line.strip().split('\t', 1)
            class_names[int(class_id)] = class_name

print(f"✓ Loaded {len(class_names)} class names")

## 1. Load Predictions

In [None]:
# Find and load latest prediction file
MODEL_NAME = 'baseline'  # Change to analyze different models

pred_files = list(predictions_dir.glob(f'{MODEL_NAME}_*.pkl'))

if pred_files:
    # Load most recent
    pred_file = sorted(pred_files)[-1]
    print(f"Loading: {pred_file.name}\n")
    
    with open(pred_file, 'rb') as f:
        results = pickle.load(f)
    
    pids = results['pids']
    predictions = results['predictions']
    probabilities = results['probabilities']
    
    print(f"✓ Loaded predictions for {len(pids)} samples")
    print(f"✓ Model: {results.get('model_name', 'unknown')}")
    print(f"✓ Threshold: {results.get('threshold', 0.5)}")
else:
    print(f"⚠️  No prediction files found for {MODEL_NAME}")
    print(f"   Run predict.py first!")
    pids, predictions, probabilities = None, None, None

## 2. Prediction Confidence Distribution

In [None]:
if predictions is not None:
    # Extract max confidence per sample
    max_confidences = []
    avg_confidences = []
    
    for i, pred_classes in enumerate(predictions):
        if len(pred_classes) > 0:
            # Get probabilities for predicted classes
            class_probs = [probabilities[i][c] for c in pred_classes]
            max_confidences.append(max(class_probs))
            avg_confidences.append(np.mean(class_probs))
        else:
            max_confidences.append(0.0)
            avg_confidences.append(0.0)
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Max confidence distribution
    axes[0].hist(max_confidences, bins=50, alpha=0.8, edgecolor='black', color='#2E86AB')
    axes[0].axvline(np.mean(max_confidences), color='red', linestyle='--', linewidth=2,
                    label=f'Mean: {np.mean(max_confidences):.3f}')
    axes[0].set_xlabel('Max Confidence Score', fontsize=13, fontweight='bold')
    axes[0].set_ylabel('Frequency', fontsize=13, fontweight='bold')
    axes[0].set_title(f'{MODEL_NAME.upper()} - Max Confidence Distribution', 
                     fontsize=15, fontweight='bold')
    axes[0].legend(fontsize=11)
    axes[0].grid(alpha=0.3)
    
    # Average confidence distribution
    axes[1].hist(avg_confidences, bins=50, alpha=0.8, edgecolor='black', color='#F18F01')
    axes[1].axvline(np.mean(avg_confidences), color='red', linestyle='--', linewidth=2,
                    label=f'Mean: {np.mean(avg_confidences):.3f}')
    axes[1].set_xlabel('Average Confidence Score', fontsize=13, fontweight='bold')
    axes[1].set_ylabel('Frequency', fontsize=13, fontweight='bold')
    axes[1].set_title(f'{MODEL_NAME.upper()} - Average Confidence Distribution', 
                     fontsize=15, fontweight='bold')
    axes[1].legend(fontsize=11)
    axes[1].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(fig_dir / 'confidence_distribution.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Saved: {fig_dir / 'confidence_distribution.png'}")
    print(f"\nConfidence Statistics:")
    print(f"  Max confidence - Mean: {np.mean(max_confidences):.3f}, Std: {np.std(max_confidences):.3f}")
    print(f"  Avg confidence - Mean: {np.mean(avg_confidences):.3f}, Std: {np.std(avg_confidences):.3f}")
else:
    print("⚠️  No predictions loaded")

## 3. Example Predictions - High Confidence Cases

In [None]:
if predictions is not None:
    # Find high confidence examples
    high_conf_indices = np.argsort(max_confidences)[-5:][::-1]  # Top 5
    
    print("="*80)
    print("           HIGH CONFIDENCE PREDICTIONS (Top 5)")
    print("="*80 + "\n")
    
    for rank, idx in enumerate(high_conf_indices, 1):
        pid = pids[idx]
        text = data_loader.test_corpus[pid]
        pred_classes = predictions[idx]
        
        print(f"[Example {rank}] Document ID: {pid}")
        print(f"Max Confidence: {max_confidences[idx]:.4f}")
        print(f"\nText (first 200 chars):")
        print(f"  {text[:200]}...")
        print(f"\nPredicted Classes ({len(pred_classes)}):")
        
        for cls_id in pred_classes[:5]:  # Show top 5 classes
            conf = probabilities[idx][cls_id]
            cls_name = class_names.get(cls_id, f"Unknown-{cls_id}")
            print(f"  - Class {cls_id}: {cls_name} (confidence: {conf:.4f})")
        
        if len(pred_classes) > 5:
            print(f"  ... and {len(pred_classes)-5} more classes")
        
        print("\n" + "-"*80 + "\n")
else:
    print("⚠️  No predictions loaded")

## 4. Example Predictions - Low Confidence Cases

In [None]:
if predictions is not None:
    # Find low confidence examples (but with predictions)
    valid_indices = [i for i, p in enumerate(predictions) if len(p) > 0]
    valid_confidences = [max_confidences[i] for i in valid_indices]
    low_conf_indices = [valid_indices[i] for i in np.argsort(valid_confidences)[:5]]  # Bottom 5
    
    print("="*80)
    print("           LOW CONFIDENCE PREDICTIONS (Bottom 5)")
    print("="*80 + "\n")
    
    for rank, idx in enumerate(low_conf_indices, 1):
        pid = pids[idx]
        text = data_loader.test_corpus[pid]
        pred_classes = predictions[idx]
        
        print(f"[Example {rank}] Document ID: {pid}")
        print(f"Max Confidence: {max_confidences[idx]:.4f}")
        print(f"\nText (first 200 chars):")
        print(f"  {text[:200]}...")
        print(f"\nPredicted Classes ({len(pred_classes)}):")
        
        for cls_id in pred_classes[:5]:
            conf = probabilities[idx][cls_id]
            cls_name = class_names.get(cls_id, f"Unknown-{cls_id}")
            print(f"  - Class {cls_id}: {cls_name} (confidence: {conf:.4f})")
        
        if len(pred_classes) > 5:
            print(f"  ... and {len(pred_classes)-5} more classes")
        
        print("\n" + "-"*80 + "\n")
else:
    print("⚠️  No predictions loaded")

## 5. Visualization - Example Predictions

In [None]:
if predictions is not None:
    # Visualize top 10 predictions with confidence bars
    sample_indices = high_conf_indices[:3].tolist() + low_conf_indices[:2].tolist()
    
    fig, axes = plt.subplots(len(sample_indices), 1, figsize=(14, 4*len(sample_indices)))
    if len(sample_indices) == 1:
        axes = [axes]
    
    for ax_idx, idx in enumerate(sample_indices):
        pid = pids[idx]
        pred_classes = predictions[idx][:10]  # Top 10 classes
        confidences = [probabilities[idx][c] for c in pred_classes]
        labels = [f"C{c}: {class_names.get(c, 'Unknown')[:30]}" for c in pred_classes]
        
        # Create horizontal bar chart
        y_pos = np.arange(len(labels))
        colors = plt.cm.RdYlGn(np.array(confidences))
        
        axes[ax_idx].barh(y_pos, confidences, color=colors, edgecolor='black', alpha=0.8)
        axes[ax_idx].set_yticks(y_pos)
        axes[ax_idx].set_yticklabels(labels, fontsize=9)
        axes[ax_idx].set_xlabel('Confidence Score', fontsize=11, fontweight='bold')
        axes[ax_idx].set_title(f'Doc {pid} - Text: "{data_loader.test_corpus[pid][:60]}..."', 
                              fontsize=12, fontweight='bold')
        axes[ax_idx].set_xlim(0, 1.0)
        axes[ax_idx].grid(axis='x', alpha=0.3)
        axes[ax_idx].invert_yaxis()
        
        # Add confidence values
        for i, (conf, label) in enumerate(zip(confidences, labels)):
            axes[ax_idx].text(conf + 0.02, i, f'{conf:.3f}', 
                             va='center', fontsize=9, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(fig_dir / 'example_predictions.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Saved: {fig_dir / 'example_predictions.png'}")
else:
    print("⚠️  No predictions loaded")

## 6. Hierarchy Consistency Analysis

In [None]:
if predictions is not None:
    # Build hierarchy graph
    G = nx.DiGraph()
    for parent, child in data_loader.hierarchy:
        G.add_edge(parent, child)
    
    # Check for hierarchy violations
    violations = []
    total_pairs = 0
    
    for idx, pred_classes in enumerate(predictions):
        for child in pred_classes:
            # Check if all ancestors are also predicted
            if child in G:
                ancestors = set()
                for node in nx.ancestors(G, child):
                    ancestors.add(node)
                
                for ancestor in ancestors:
                    total_pairs += 1
                    if ancestor not in pred_classes:
                        violations.append({
                            'doc_id': pids[idx],
                            'child': child,
                            'missing_ancestor': ancestor
                        })
    
    violation_rate = len(violations) / total_pairs if total_pairs > 0 else 0
    
    print(f"Hierarchy Consistency Analysis:")
    print(f"  Total parent-child pairs checked: {total_pairs}")
    print(f"  Hierarchy violations: {len(violations)}")
    print(f"  Violation rate: {violation_rate:.2%}")
    
    if violations:
        print(f"\nExample violations (first 5):")
        for i, v in enumerate(violations[:5], 1):
            child_name = class_names.get(v['child'], f"Unknown-{v['child']}")
            parent_name = class_names.get(v['missing_ancestor'], f"Unknown-{v['missing_ancestor']}")
            print(f"  {i}. Doc {v['doc_id']}: Predicted '{child_name}' but missing parent '{parent_name}'")
    
    # Visualize violation rate
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    
    categories = ['Consistent', 'Violations']
    values = [total_pairs - len(violations), len(violations)]
    colors = ['#2E86AB', '#E76F51']
    
    bars = ax.bar(categories, values, color=colors, alpha=0.8, edgecolor='black', width=0.5)
    ax.set_ylabel('Number of Predictions', fontsize=13, fontweight='bold')
    ax.set_title(f'{MODEL_NAME.upper()} - Hierarchy Consistency', fontsize=15, fontweight='bold')
    ax.grid(axis='y', alpha=0.3)
    
    # Add value labels and percentages
    for bar, val in zip(bars, values):
        height = bar.get_height()
        pct = val / total_pairs * 100 if total_pairs > 0 else 0
        ax.text(bar.get_x() + bar.get_width()/2., height,
               f'{val}\n({pct:.1f}%)', ha='center', va='bottom', 
               fontsize=12, fontweight='bold')
    
    plt.tight_layout()
    plt.savefig(fig_dir / 'hierarchy_violations.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"\n✓ Saved: {fig_dir / 'hierarchy_violations.png'}")
else:
    print("⚠️  No predictions loaded")

## 7. Error Analysis - Multi-Label Statistics

In [None]:
if predictions is not None:
    # Analyze prediction patterns
    labels_per_sample = [len(p) for p in predictions]
    no_prediction_count = sum(1 for p in predictions if len(p) == 0)
    
    # Group by number of labels
    label_groups = {
        '0 labels': sum(1 for l in labels_per_sample if l == 0),
        '1 label': sum(1 for l in labels_per_sample if l == 1),
        '2-3 labels': sum(1 for l in labels_per_sample if 2 <= l <= 3),
        '4-5 labels': sum(1 for l in labels_per_sample if 4 <= l <= 5),
        '6-10 labels': sum(1 for l in labels_per_sample if 6 <= l <= 10),
        '11+ labels': sum(1 for l in labels_per_sample if l > 10)
    }
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Bar chart of label groups
    groups = list(label_groups.keys())
    counts = list(label_groups.values())
    colors = plt.cm.viridis(np.linspace(0, 1, len(groups)))
    
    bars = axes[0].bar(range(len(groups)), counts, color=colors, alpha=0.8, edgecolor='black')
    axes[0].set_xticks(range(len(groups)))
    axes[0].set_xticklabels(groups, rotation=15, ha='right')
    axes[0].set_ylabel('Number of Samples', fontsize=13, fontweight='bold')
    axes[0].set_title(f'{MODEL_NAME.upper()} - Prediction Pattern Distribution', 
                     fontsize=15, fontweight='bold')
    axes[0].grid(axis='y', alpha=0.3)
    
    # Add value labels
    for bar, count in zip(bars, counts):
        height = bar.get_height()
        pct = count / len(predictions) * 100
        axes[0].text(bar.get_x() + bar.get_width()/2., height,
                    f'{count}\n({pct:.1f}%)', ha='center', va='bottom', 
                    fontsize=10, fontweight='bold')
    
    # Confidence vs number of labels
    labels_with_conf = [(l, max_confidences[i]) for i, l in enumerate(labels_per_sample) if l > 0]
    if labels_with_conf:
        label_counts, confs = zip(*labels_with_conf)
        
        axes[1].scatter(label_counts, confs, alpha=0.5, s=20, color='#2E86AB')
        axes[1].set_xlabel('Number of Predicted Labels', fontsize=13, fontweight='bold')
        axes[1].set_ylabel('Max Confidence', fontsize=13, fontweight='bold')
        axes[1].set_title('Confidence vs Number of Labels', fontsize=15, fontweight='bold')
        axes[1].grid(alpha=0.3)
        
        # Add trend line
        z = np.polyfit(label_counts, confs, 1)
        p = np.poly1d(z)
        x_line = np.linspace(min(label_counts), max(label_counts), 100)
        axes[1].plot(x_line, p(x_line), "r--", linewidth=2, alpha=0.8, label='Trend')
        axes[1].legend()
    
    plt.tight_layout()
    plt.savefig(fig_dir / 'error_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Saved: {fig_dir / 'error_analysis.png'}")
    print(f"\nMulti-Label Statistics:")
    print(f"  Samples with no predictions: {no_prediction_count} ({no_prediction_count/len(predictions)*100:.1f}%)")
    print(f"  Average labels per sample: {np.mean(labels_per_sample):.2f}")
    print(f"  Median: {np.median(labels_per_sample):.0f}")
    print(f"  Max: {max(labels_per_sample)}")
else:
    print("⚠️  No predictions loaded")

## Summary

✅ All case study visualizations saved to: `results/images/case_study/`

**Generated Files:**
- `confidence_distribution.png` - Max and average confidence score distributions
- `example_predictions.png` - Detailed visualization of high/low confidence examples
- `hierarchy_violations.png` - Hierarchy consistency analysis
- `error_analysis.png` - Multi-label prediction patterns and confidence trends

---

### Key Insights:

**Confidence Patterns:**
- High confidence predictions: Model is certain about specific classes
- Low confidence predictions: Ambiguous cases or borderline classifications
- Average confidence gives overall model certainty

**Prediction Quality:**
- Examples show how model interprets text features
- Class names reveal if predictions make semantic sense
- Multiple labels indicate product category complexity

**Hierarchy Consistency:**
- Measures if child classes have parent classes predicted
- Violations indicate potential model improvements needed
- Useful for hierarchical loss function evaluation

**Error Patterns:**
- Samples with no predictions: Threshold too strict or unclear text
- Samples with many predictions: General products or threshold too loose
- Confidence vs label count correlation reveals model behavior

---

### Recommendations:

1. **High Confidence Cases**: Use as training examples for future iterations
2. **Low Confidence Cases**: Review for data quality or ambiguity issues
3. **Hierarchy Violations**: Consider hierarchical loss constraints
4. **No Predictions**: Adjust threshold or improve silver labeling

---

### Next Steps:
1. Analyze specific failure cases manually
2. Adjust confidence threshold based on precision-recall trade-off
3. Improve hierarchy constraint enforcement
4. Collect human annotations for ambiguous cases