In [2]:
"""
Simple W1 scaling plot - much simpler version
"""
import glob
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from pathlib import Path

# Configuration
CSV_GLOB = "comparison/LowRankGaussianMixtureDataset_dim*_*.csv"
DIMS = [4, 8, 16, 32, 64, 128, 256, 512]
NFES = [8, 32, 128]

METHODS = ["CFM", "DFM", "Count Bridge"]
METHOD_MAP = {
    "CFMBridge": "CFM",
    "DiscreteFlowBridge": "DFM", 
    "SkellamBridge": "Count Bridge"
}

COLORS = {"CFM": "#1f77b4", "DFM": "#ff7f0e", "Count Bridge": "#2ca02c"}
ALPHAS = {8: 0.45, 32: 0.75, 128: 1.00}

# Available metrics
METRICS = {
    "mmd_rbf": ("mmd_rbf_mean", "mmd_rbf_std", "MMD"),
    "wasserstein": ("wasserstein_distance_mean", "wasserstein_distance_std", "$W_2$"),
    "energy": ("energy_distance_mean", "energy_distance_std", "EMD")
}

def load_data():
    """Load all CSV files and return data by dimension."""
    files = sorted(glob.glob(CSV_GLOB), key=lambda x: int(re.search(r"dim(\d+)", x).group(1)))
    data = {}
    
    for file in files:
        dim = int(re.search(r"dim(\d+)", file).group(1))
        if dim not in DIMS:
            continue
            
        df = pd.read_csv(file)
        
        # Keep only Skellam Energy for Count Bridge
        is_skellam = df["bridge_type"] == "SkellamBridge"
        df = pd.concat([
            df[~is_skellam], 
            df[is_skellam & (df["model_type"] == "EnergyScoreLoss")]
        ])
        
        df["Method"] = df["bridge_type"].map(METHOD_MAP)
        data[dim] = df[df["Method"].isin(METHODS)]
    
    return data

def build_series(data, metric="mmd_rbf"):
    """Extract plot series from data."""
    mean_col, std_col, _ = METRICS[metric]
    series = {}
    
    for method in METHODS:
        for nfe in NFES:
            x, y, yerr = [], [], []
            
            for dim in sorted(data.keys()):
                df = data[dim]
                method_data = df[df["Method"] == method]
                
                # Filter by NFE
                if "n_steps" in method_data.columns:
                    filtered = method_data[method_data["n_steps"].astype(int) == nfe]
                else:
                    filtered = method_data[method_data["n_sampling_steps_mean"].round().astype(int) == nfe]
                
                if not filtered.empty:
                    row = filtered.iloc[0]
                    x.append(dim)
                    y.append(row[mean_col])
                    yerr.append(row[std_col])
            
            series[(method, nfe)] = (np.array(x), np.array(y), np.array(yerr))
    
    return series

def plot(series, metric="mmd_rbf"):
    """Create the plot."""
    _, _, metric_name = METRICS[metric]
    Path("figs").mkdir(exist_ok=True)
    
    plt.figure(figsize=(7.0, 2.4), dpi=300)
    ax = plt.gca()
    
    # Setup axes
    ax.set_xscale("log", base=2)
    ax.set_xticks([d for d in DIMS if any(len(series[(m, nfe)][0]) > 0 and d in series[(m, nfe)][0] 
                                          for m in METHODS for nfe in NFES)])
    ax.set_xticklabels([str(d) for d in ax.get_xticks()])
    ax.grid(axis="y", linewidth=0.6, alpha=0.35)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.set_xlabel("Dimension $d$")
    ax.set_ylabel(f"{metric_name} " + r"($\downarrow$)")
    
    # Plot lines
    for method in METHODS:
        for nfe in NFES:
            x, y, yerr = series[(method, nfe)]
            if len(x) == 0:
                continue
                
            alpha = ALPHAS[nfe]
            linewidth = 2.0 if nfe == max(NFES) else 1.5
            
            ax.plot(x, y, color=COLORS[method], alpha=alpha, linewidth=linewidth, 
                   marker="o", markersize=3.5)
            ax.fill_between(x, y - yerr, y + yerr, color=COLORS[method], 
                           alpha=0.10 if nfe == max(NFES) else 0.06)
    
    # Legends - place them further apart to avoid overlap
    plt.subplots_adjust(top=0.75)
    
    method_handles = [Line2D([0], [0], color=COLORS[m], linewidth=2.0, marker="o", 
                            markersize=4, label=m) for m in METHODS]
    nfe_handles = [Line2D([0], [0], color="black", alpha=ALPHAS[n], linewidth=2.0, 
                         label=f"NFE={n}") for n in NFES]
    
    # Position legends with more spacing
    leg1 = ax.legend(handles=method_handles, title="Method", ncol=len(METHODS), 
                    frameon=False, loc="upper left", bbox_to_anchor=(0.00, 1.35),
                    handlelength=1.2, columnspacing=0.8, handletextpad=0.5)
    leg2 = ax.legend(handles=nfe_handles, title="Steps", ncol=len(NFES), 
                    frameon=False, loc="upper right", bbox_to_anchor=(1.00, 1.35),
                    handlelength=1.2, columnspacing=0.8, handletextpad=0.5)
    ax.add_artist(leg1)
    
    plt.tight_layout()
    filename = f"figs/scaling_{metric}.png"
    plt.savefig(filename, bbox_inches="tight")
    plt.close()
    print(f"Plot saved: {filename}")

def main(metric="mmd_rbf"):
    data = load_data()
    series = build_series(data, metric)
    plot(series, metric)

def plot_all_metrics():
    """Plot all available metrics."""
    data = load_data()
    for metric in METRICS.keys():
        print(f"Plotting {metric}...")
        series = build_series(data, metric)
        plot(series, metric)

# Test with single plot first
plot_all_metrics()


Plotting mmd_rbf...
Plot saved: figs/scaling_mmd_rbf.png
Plotting wasserstein...
Plot saved: figs/scaling_wasserstein.png
Plotting energy...
Plot saved: figs/scaling_energy.png


In [8]:
def create_latex_table(save_to_file=True):
    """Create a comprehensive LaTeX table with dimensions, methods, and NFE in rows, just metrics as columns."""
    data = load_data()
    
    # Define the structure we want: Dimensions, Methods, and NFE in rows, just metrics in columns
    available_dims = sorted([d for d in DIMS if d in data])
    
    # Create tables directory
    tables_dir = Path("tables")
    tables_dir.mkdir(exist_ok=True)
    
    # Create single integrated table
    lines = []
    lines.append("\\begin{table}[h!]")
    lines.append("\\centering")
    lines.append("\\tiny")
    lines.append("\\caption{Performance Comparison Across Dimensions and Methods}")
    lines.append("\\label{tab:scaling_performance}")
    
    # Create column specification: Dim + Method + NFE + Metrics
    ncols = 3 + len(METRICS)
    colspec = "lll" + "c" * len(METRICS)
    lines.append(f"\\begin{{tabular}}{{{colspec}}}")
    lines.append("\\toprule")
    
    # Create header
    header = ["Dim", "Method", "NFE"]
    for metric_key, (_, _, metric_name) in METRICS.items():
        header.append(metric_name)
    
    lines.append(" & ".join(header) + " \\\\")
    lines.append("\\midrule")
    
    # First, find best methods for each dimension and metric (across all NFE values)
    best_methods_per_dim = {}
    for dim in available_dims:
        df = data[dim]
        best_methods_per_dim[dim] = {}
        
        for metric_key, (mean_col, std_col, _) in METRICS.items():
            best_val = float('inf')
            best_method = ""
            
            for method in METHODS:
                method_data = df[df["Method"] == method]
                
                # Find best NFE result for this method at this dimension
                method_vals = []
                for nfe in NFES:
                    if "n_steps" in method_data.columns:
                        filtered = method_data[method_data["n_steps"].astype(int) == nfe]
                    else:
                        filtered = method_data[method_data["n_sampling_steps_mean"].round().astype(int) == nfe]
                    
                    if not filtered.empty:
                        method_vals.append(filtered.iloc[0][mean_col])
                
                if method_vals:
                    min_val = min(method_vals)
                    if min_val < best_val:
                        best_val = min_val
                        best_method = method
            
            best_methods_per_dim[dim][metric_key] = best_method
    
    # Fill in data for each dimension, method, and NFE combination
    for dim_idx, dim in enumerate(available_dims):
        df = data[dim]
        
        for i, method in enumerate(METHODS):
            method_data = df[df["Method"] == method]
            
            for j, nfe in enumerate(NFES):
                row = []
                
                # Add dimension (only on first method/NFE row)
                if i == 0 and j == 0:
                    row.append(f"\\multirow{{{len(METHODS) * len(NFES)}}}{{*}}{{{dim}}}")
                else:
                    row.append("")
                
                # Add method (only on first NFE row for this method)
                if j == 0:
                    row.append(f"\\multirow{{{len(NFES)}}}{{*}}{{{method}}}")
                else:
                    row.append("")
                
                # Add NFE
                row.append(str(nfe))
                
                # Get values for current method
                if "n_steps" in method_data.columns:
                    filtered = method_data[method_data["n_steps"].astype(int) == nfe]
                else:
                    filtered = method_data[method_data["n_sampling_steps_mean"].round().astype(int) == nfe]
                
                # Add metric values
                for metric_key, (mean_col, std_col, _) in METRICS.items():
                    if not filtered.empty:
                        row_data = filtered.iloc[0]
                        mean_val = row_data[mean_col]
                        std_val = row_data[std_col]
                        
                        # Format mean and std values
                        if mean_val < 0.001:
                            formatted_mean = f"{mean_val:.2e}"
                            formatted_std = f"{std_val:.2e}"
                        elif mean_val < 0.01:
                            formatted_mean = f"{mean_val:.4f}"
                            formatted_std = f"{std_val:.4f}"
                        elif mean_val < 1:
                            formatted_mean = f"{mean_val:.3f}"
                            formatted_std = f"{std_val:.3f}"
                        else:
                            formatted_mean = f"{mean_val:.2f}"
                            formatted_std = f"{std_val:.2f}"
                        
                        # Combine mean ± std
                        formatted_val = f"{formatted_mean} ± {formatted_std}"
                        
                        # Bold if this method is the best for this metric at this dimension
                        if method == best_methods_per_dim[dim][metric_key]:
                            formatted_val = f"\\textbf{{{formatted_val}}}"
                        
                        row.append(formatted_val)
                    else:
                        row.append("--")
                
                lines.append(" & ".join(row) + " \\\\")
            
            # Add line after each method within a dimension
            if i < len(METHODS) - 1:
                lines.append("\\cline{2-" + str(3 + len(METRICS)) + "}")
        
        # Add spacing between dimensions
        if dim_idx < len(available_dims) - 1:
            lines.append("\\hline")
    
    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")
    lines.append("")
    lines.append("% Note: Values show mean ± standard deviation. Bold indicates best performance per metric at each dimension/NFE.")
    lines.append("% Metrics: MMD = Maximum Mean Discrepancy, W_2 = 2-Wasserstein Distance, EMD = Earth Mover's Distance")
    
    if save_to_file:
        filename = tables_dir / "scaling_comprehensive.tex"
        with open(filename, 'w') as f:
            f.write('\n'.join(lines))
        print(f"Comprehensive scaling table saved to: {filename}")
    else:
        print("Comprehensive Scaling Table:")
        print("=" * 50)
        for line in lines:
            print(line)

def create_compact_latex_table(save_to_file=True):
    """Create a more compact table showing best results per dimension."""
    data = load_data()
    available_dims = sorted([d for d in DIMS if d in data])
    
    # Create tables directory
    tables_dir = Path("tables")
    tables_dir.mkdir(exist_ok=True)
    
    lines = []
    lines.append("\\begin{table}[h!]")
    lines.append("\\centering") 
    lines.append("\\caption{Best performing method per dimension (lowest values)}")
    lines.append("\\label{tab:best_performance}")
    
    # Simpler table - just show best method for each metric/dimension combo
    lines.append("\\begin{tabular}{lccc}")
    lines.append("\\toprule")
    lines.append("Dimension & MMD & $W_2$ & EMD \\\\")
    lines.append("\\midrule")
    
    for dim in available_dims:
        df = data[dim]
        row = [str(dim)]
        
        for metric_key in METRICS.keys():
            mean_col, std_col, _ = METRICS[metric_key]
            
            best_val = float('inf')
            best_method = ""
            
            # Find best method for this metric at this dimension
            for method in METHODS:
                method_data = df[df["Method"] == method]
                
                if not method_data.empty:
                    # Take best NFE result for this method
                    method_vals = []
                    for nfe in NFES:
                        if "n_steps" in method_data.columns:
                            filtered = method_data[method_data["n_steps"].astype(int) == nfe]
                        else:
                            filtered = method_data[method_data["n_sampling_steps_mean"].round().astype(int) == nfe]
                        
                        if not filtered.empty:
                            method_vals.append(filtered.iloc[0][mean_col])
                    
                    if method_vals:
                        min_val = min(method_vals)
                        if min_val < best_val:
                            best_val = min_val
                            best_method = method
            
            if best_method:
                if best_val < 0.001:
                    result = f"\\textbf{{{best_method}}} ({best_val:.2e})"
                else:
                    result = f"\\textbf{{{best_method}}} ({best_val:.3f})"
                row.append(result)
            else:
                row.append("--")
        
        lines.append(" & ".join(row) + " \\\\")
    
    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\end{table}")
    
    if save_to_file:
        filename = tables_dir / "best_results.tex"
        with open(filename, 'w') as f:
            f.write('\n'.join(lines))
        print(f"Best results table saved to: {filename}")
    else:
        for line in lines:
            print(line)

# Generate and save tables
print("Creating LaTeX tables and saving to files...")
create_latex_table(save_to_file=True)

Creating LaTeX tables and saving to files...
Comprehensive scaling table saved to: tables/scaling_comprehensive.tex
