# Visualization Module

> Plotting functions for validation results and confusion matrices.

This module provides:
- `plot_confusion_matrix()`: Heatmap visualization of confusion matrices
- `plot_metrics_comparison()`: Bar charts comparing model versions
- `plot_class_distribution()`: Training data distribution

In [None]:
#| default_exp visualization

In [None]:
#| export
from __future__ import annotations
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from typing import Optional, List, Dict, Any

from openness_classifier.validation import ValidationResult, ClassificationMetrics

In [None]:
#| export
# Default style settings for publication-quality figures
FIGURE_STYLE = {
    'figure.figsize': (8, 6),
    'font.size': 12,
    'axes.titlesize': 14,
    'axes.labelsize': 12,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
}

CATEGORY_ORDER = ['open', 'mostly_open', 'mostly_closed', 'closed']
CATEGORY_LABELS = ['Open', 'Mostly\nOpen', 'Mostly\nClosed', 'Closed']

## Confusion Matrix

In [None]:
#| export
def plot_confusion_matrix(
    confusion_matrix: np.ndarray,
    title: str = "Confusion Matrix",
    labels: Optional[List[str]] = None,
    normalize: bool = False,
    cmap: str = "Blues",
    figsize: tuple = (8, 6),
    save_path: Optional[str | Path] = None,
) -> plt.Figure:
    """Plot confusion matrix as heatmap.
    
    Args:
        confusion_matrix: NxN confusion matrix array
        title: Plot title
        labels: Axis labels (default: openness categories)
        normalize: If True, normalize by row (true label)
        cmap: Colormap name
        figsize: Figure size
        save_path: Optional path to save figure
        
    Returns:
        matplotlib Figure object
    """
    if labels is None:
        labels = CATEGORY_LABELS
    
    # Normalize if requested
    if normalize:
        cm = confusion_matrix.astype('float') / confusion_matrix.sum(axis=1, keepdims=True)
        cm = np.nan_to_num(cm)  # Handle division by zero
        fmt = '.2f'
    else:
        cm = confusion_matrix
        fmt = 'd'
    
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Plot heatmap
    sns.heatmap(
        cm,
        annot=True,
        fmt=fmt,
        cmap=cmap,
        xticklabels=labels,
        yticklabels=labels,
        ax=ax,
        cbar_kws={'label': 'Proportion' if normalize else 'Count'},
    )
    
    ax.set_xlabel('Predicted Label')
    ax.set_ylabel('True Label')
    ax.set_title(title)
    
    plt.tight_layout()
    
    if save_path:
        fig.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return fig

In [None]:
#| export
def plot_validation_results(
    result: ValidationResult,
    save_dir: Optional[str | Path] = None,
) -> Dict[str, plt.Figure]:
    """Plot all visualizations for a validation result.
    
    Creates confusion matrices for data and code classifications.
    
    Args:
        result: ValidationResult to visualize
        save_dir: Optional directory to save figures
        
    Returns:
        Dictionary of {name: figure} for all created plots
    """
    figures = {}
    
    if save_dir:
        save_dir = Path(save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)
    
    # Data confusion matrix
    if 'data' in result.confusion_matrices:
        fig = plot_confusion_matrix(
            result.confusion_matrices['data'],
            title='Data Availability Classification',
            save_path=save_dir / 'confusion_matrix_data.png' if save_dir else None,
        )
        figures['data_confusion'] = fig
        
        # Normalized version
        fig_norm = plot_confusion_matrix(
            result.confusion_matrices['data'],
            title='Data Availability Classification (Normalized)',
            normalize=True,
            save_path=save_dir / 'confusion_matrix_data_normalized.png' if save_dir else None,
        )
        figures['data_confusion_normalized'] = fig_norm
    
    # Code confusion matrix
    if 'code' in result.confusion_matrices:
        fig = plot_confusion_matrix(
            result.confusion_matrices['code'],
            title='Code Availability Classification',
            save_path=save_dir / 'confusion_matrix_code.png' if save_dir else None,
        )
        figures['code_confusion'] = fig
    
    return figures

## Metrics Comparison

In [None]:
#| export
def plot_metrics_comparison(
    metrics_list: List[ClassificationMetrics],
    labels: List[str],
    title: str = "Model Comparison",
    figsize: tuple = (10, 6),
    save_path: Optional[str | Path] = None,
) -> plt.Figure:
    """Plot bar chart comparing metrics across models/runs.
    
    Args:
        metrics_list: List of ClassificationMetrics to compare
        labels: Labels for each metrics set
        title: Plot title
        figsize: Figure size
        save_path: Optional path to save figure
        
    Returns:
        matplotlib Figure object
    """
    fig, ax = plt.subplots(figsize=figsize)
    
    x = np.arange(len(labels))
    width = 0.2
    
    # Extract metrics
    accuracies = [m.accuracy for m in metrics_list]
    kappas = [m.cohens_kappa for m in metrics_list]
    macro_f1s = [m.macro_f1 for m in metrics_list]
    
    # Plot bars
    ax.bar(x - width, accuracies, width, label='Accuracy', color='#2ecc71')
    ax.bar(x, kappas, width, label="Cohen's Kappa", color='#3498db')
    ax.bar(x + width, macro_f1s, width, label='Macro F1', color='#9b59b6')
    
    ax.set_xlabel('Model')
    ax.set_ylabel('Score')
    ax.set_title(title)
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.legend()
    ax.set_ylim(0, 1)
    ax.axhline(y=0.8, color='gray', linestyle='--', alpha=0.5, label='Target (80%)')
    
    plt.tight_layout()
    
    if save_path:
        fig.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return fig

## Class Distribution

In [None]:
#| export
def plot_class_distribution(
    distribution: Dict[str, int],
    title: str = "Training Data Distribution",
    figsize: tuple = (8, 5),
    save_path: Optional[str | Path] = None,
) -> plt.Figure:
    """Plot bar chart of class distribution.
    
    Args:
        distribution: Dict mapping category names to counts
        title: Plot title
        figsize: Figure size
        save_path: Optional path to save figure
        
    Returns:
        matplotlib Figure object
    """
    fig, ax = plt.subplots(figsize=figsize)
    
    # Order categories
    categories = [c for c in CATEGORY_ORDER if c in distribution]
    counts = [distribution[c] for c in categories]
    labels = [CATEGORY_LABELS[CATEGORY_ORDER.index(c)] for c in categories]
    
    colors = ['#27ae60', '#2ecc71', '#e74c3c', '#c0392b']
    
    bars = ax.bar(labels, counts, color=colors[:len(categories)])
    
    # Add count labels on bars
    for bar, count in zip(bars, counts):
        ax.text(
            bar.get_x() + bar.get_width()/2,
            bar.get_height() + 0.5,
            str(count),
            ha='center',
            va='bottom',
            fontsize=10,
        )
    
    ax.set_xlabel('Openness Category')
    ax.set_ylabel('Count')
    ax.set_title(title)
    
    plt.tight_layout()
    
    if save_path:
        fig.savefig(save_path, dpi=300, bbox_inches='tight')
    
    return fig

In [None]:
# Test confusion matrix plotting
cm = np.array([
    [10, 2, 0, 0],
    [1, 8, 1, 0],
    [0, 2, 6, 2],
    [0, 0, 1, 9],
])

fig = plot_confusion_matrix(cm, title="Test Confusion Matrix")
plt.show()
print("Visualization module ready!")

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()