In [3]:
import json
import matplotlib.pyplot as plt
import os
import numpy as np

# --- Configuration ---
OUTPUT_DIR = "thesis_plots_v2" # Changed output dir name to avoid overwriting
FIG_SIZE = (10, 6) # Adjusted figure size for potentially more complex plots
FONT_SIZE = 12
TITLE_FONT_SIZE = 14
LINE_WIDTH = 2
MARKER_SIZE = 6

# Consistent colors for pruning methods
METHOD_COLORS = {
    "BNScale": "royalblue",
    "MagnitudeL2": "darkorange",
    "Random": "forestgreen"
}
METHOD_MARKERS = {
    "BNScale": "o",
    "MagnitudeL2": "s",
    "Random": "^"
}

# Linestyles for models within comparative plots
MODEL_LINESTYLES = {
    "MobileNetV2": "-",  # Solid
    "ResNet18": "--", # Dashed
    "LSTM": ":",   # Dotted
    "MLP": "-."   # Dash-dot
}

# Create output directory if it doesn't exist
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

# Set global plot styles
plt.rcParams.update({'font.size': FONT_SIZE})
plt.style.use('seaborn-v0_8-whitegrid')

# --- Helper Function to Load JSON ---
def load_json_data(filepath):
    with open(filepath, 'r') as f:
        return json.load(f)

# --- Individual Model Plotting Function (from previous response, slightly adapted) ---
def plot_individual_model_metric(
    model_data,
    model_name_display,
    model_name_file,
    metric_key,
    y_label,
    is_higher_better=True,
    output_dir=OUTPUT_DIR,
    log_scale_y=False,
    y_multiplier=1.0,
    y_unit_prefix=""
):
    plt.figure(figsize=FIG_SIZE)

    all_pruning_ratios_str = set()
    # Ensure all methods are checked for available pruning ratios
    for method_name_iter, method_data_iter in model_data.items():
        if isinstance(method_data_iter, dict): # Check if method_data_iter is a dictionary of ratios
            all_pruning_ratios_str.update(method_data_iter.keys())

    pruning_ratios_str_sorted = []
    if "0.0" in all_pruning_ratios_str:
        pruning_ratios_str_sorted.append("0.0")
        all_pruning_ratios_str.remove("0.0")

    pruning_ratios_str_sorted.extend(sorted(list(all_pruning_ratios_str), key=float))
    pruning_ratios_numeric = [float(r) for r in pruning_ratios_str_sorted]

    for method_name, results in model_data.items():
        if method_name not in METHOD_COLORS:
            print(f"Warning: Method '{method_name}' not in METHOD_COLORS for individual plot: {y_label} of {model_name_display}")
            continue

        current_metric_values = []
        current_ratios_numeric = []

        for r_str in pruning_ratios_str_sorted:
            if r_str in results and metric_key in results[r_str]:
                try:
                    value = float(results[r_str][metric_key]) * y_multiplier
                    current_metric_values.append(value)
                    current_ratios_numeric.append(float(r_str))
                except (TypeError, ValueError) as e:
                    print(f"Warning: Could not convert metric '{metric_key}' for {model_name_display}, {method_name} at ratio {r_str}. Error: {e}. Value: {results[r_str].get(metric_key)}")
                    current_metric_values.append(np.nan)
                    current_ratios_numeric.append(float(r_str))
            else:
                current_metric_values.append(np.nan)
                current_ratios_numeric.append(float(r_str))

        plt.plot(
            current_ratios_numeric,
            current_metric_values,
            label=method_name,
            color=METHOD_COLORS[method_name],
            marker=METHOD_MARKERS[method_name],
            linewidth=LINE_WIDTH,
            markersize=MARKER_SIZE
        )

    plt.xlabel("Pruning Ratio")
    plt.ylabel(f"{y_label}{' (' + y_unit_prefix + ')' if y_unit_prefix else ''}")
    plt.title(f"{y_label} vs. Pruning Ratio for {model_name_display}", fontsize=TITLE_FONT_SIZE)
    plt.xticks(pruning_ratios_numeric, labels=[str(r) for r in pruning_ratios_numeric])

    if log_scale_y:
        plt.yscale('log')

    plt.legend(title="Pruning Method")
    plt.tight_layout()

    filename = f"individual_{model_name_file}_{metric_key.lower()}_vs_pruning.png"
    plt.savefig(os.path.join(output_dir, filename), dpi=300)
    print(f"Saved: {filename}")
    plt.close()

# --- Comparative Plotting Function ---
def plot_comparative_group(
    models_data_dict, # {"ModelName1": data1_dict, "ModelName2": data2_dict}
    model_names_in_group_display, # list of display names for models
    model_group_name_display,
    model_group_name_file,
    metric_key,
    y_label,
    is_higher_better=True,
    output_dir=OUTPUT_DIR,
    log_scale_y=False,
    y_multiplier=1.0,
    y_unit_prefix=""
):
    plt.figure(figsize=FIG_SIZE)

    # Determine common pruning ratios across all models and methods in the group
    all_pruning_ratios_str = set()
    for model_name_key in models_data_dict.keys():
        model_data_iter = models_data_dict[model_name_key]
        if model_data_iter: # if data was loaded for this model
            for method_data_iter in model_data_iter.values():
                if isinstance(method_data_iter, dict):
                     all_pruning_ratios_str.update(method_data_iter.keys())

    pruning_ratios_str_sorted = []
    if "0.0" in all_pruning_ratios_str:
        pruning_ratios_str_sorted.append("0.0")
        all_pruning_ratios_str.remove("0.0")

    pruning_ratios_str_sorted.extend(sorted(list(all_pruning_ratios_str), key=float))
    pruning_ratios_numeric = [float(r) for r in pruning_ratios_str_sorted]

    for model_name_key, model_data in models_data_dict.items():
        if not model_data: continue # Skip if model data wasn't loaded

        # Get the display name for the legend, default to key if not found
        model_display_name = model_name_key
        for name_obj in model_names_in_group_display:
            if name_obj['key'] == model_name_key:
                model_display_name = name_obj['display']
                break

        for method_name, results in model_data.items():
            if method_name not in METHOD_COLORS:
                print(f"Warning: Method '{method_name}' for model '{model_display_name}' not in METHOD_COLORS for comparative plot. Skipping.")
                continue

            current_metric_values = []
            current_ratios_numeric = [] # Ratios for this specific line, in case of missing points

            for r_str in pruning_ratios_str_sorted:
                if r_str in results and metric_key in results[r_str]:
                    try:
                        value = float(results[r_str][metric_key]) * y_multiplier
                        current_metric_values.append(value)
                        current_ratios_numeric.append(float(r_str))
                    except (TypeError, ValueError) as e:
                        print(f"Warning: Could not convert metric '{metric_key}' for {model_display_name}, {method_name} at ratio {r_str}. Error: {e}. Value: {results[r_str].get(metric_key)}")
                        current_metric_values.append(np.nan)
                        current_ratios_numeric.append(float(r_str))
                else: # Data point missing for this ratio
                    current_metric_values.append(np.nan)
                    current_ratios_numeric.append(float(r_str))

            plt.plot(
                current_ratios_numeric,
                current_metric_values,
                label=f"{model_display_name} - {method_name}",
                color=METHOD_COLORS[method_name],
                linestyle=MODEL_LINESTYLES.get(model_name_key, "-"), # Use model key for linestyle lookup
                marker=METHOD_MARKERS[method_name],
                linewidth=LINE_WIDTH,
                markersize=MARKER_SIZE
            )

    plt.xlabel("Pruning Ratio")
    plt.ylabel(f"{y_label}{' (' + y_unit_prefix + ')' if y_unit_prefix else ''}")
    plt.title(f"Comparison of {y_label} vs. Pruning Ratio for {model_group_name_display}", fontsize=TITLE_FONT_SIZE)
    plt.xticks(pruning_ratios_numeric, labels=[str(r) for r in pruning_ratios_numeric])

    if log_scale_y:
        plt.yscale('log')

    # Place legend outside the plot
    plt.legend(title="Model - Method", bbox_to_anchor=(1.03, 1), loc='upper left', borderaxespad=0.)
    plt.subplots_adjust(right=0.75) # Adjust subplot to make room for legend

    filename = f"comparative_{model_group_name_file}_{metric_key.lower()}_vs_pruning.png"
    plt.savefig(os.path.join(output_dir, filename), dpi=300)
    print(f"Saved: {filename}")
    plt.close()


# --- Main Script ---
if __name__ == "__main__":
    # Define models and their data files
    models_config = {
        "MobileNetV2": {
            "file": "mobileNetV2_complete_results.json",
            "display_name": "MobileNetV2",
            "type": "CNN",
            "metrics": {
                "accuracy": {"label": "Accuracy (%)", "higher_better": True},
                "params": {"label": "Parameters", "higher_better": False, "multiplier": 1e-6, "unit_prefix": "M"},
                "macs": {"label": "MACs", "higher_better": False, "multiplier": 1e-6, "unit_prefix": "M"},
                "size_mb": {"label": "Size (MB)", "higher_better": False},
                "loss": {"label": "Loss", "higher_better": False}
            }
        },
        "ResNet18": {
            "file": "resnet18_complete_results.json",
            "display_name": "ResNet-18",
            "type": "CNN",
            "metrics": {
                "accuracy": {"label": "Accuracy (%)", "higher_better": True},
                "params": {"label": "Parameters", "higher_better": False, "multiplier": 1e-6, "unit_prefix": "M"},
                "macs": {"label": "MACs", "higher_better": False, "multiplier": 1e-7, "unit_prefix": "10M"}, # MACs are large
                "size_mb": {"label": "Size (MB)", "higher_better": False},
                "loss": {"label": "Loss", "higher_better": False}
            }
        },
        "LSTM": {
            "file": "lstm_complete_results.json",
            "display_name": "LSTM",
            "type": "TimeSeries",
            "metrics": {
                "mse": {"label": "Mean Squared Error (MSE)", "higher_better": False},
                "mae": {"label": "Mean Absolute Error (MAE)", "higher_better": False},
                "params": {"label": "Parameters", "higher_better": False, "multiplier": 1e-3, "unit_prefix": "K"},
                "macs": {"label": "MACs", "higher_better": False, "multiplier": 1e-6, "unit_prefix": "M"},
                "size_mb": {"label": "Size (MB)", "higher_better": False},
                "loss": {"label": "Loss", "higher_better": False}
            }
        },
        "MLP": {
            "file": "mlp_complete_results.json",
            "display_name": "MLP",
            "type": "TimeSeries", # Assuming used for a time-series like regression task
            "metrics": {
                "mse": {"label": "Mean Squared Error (MSE)", "higher_better": False},
                "mae": {"label": "Mean Absolute Error (MAE)", "higher_better": False},
                "params": {"label": "Parameters", "higher_better": False, "multiplier": 1e-3, "unit_prefix": "K"},
                "macs": {"label": "MACs", "higher_better": False, "multiplier": 1e-3, "unit_prefix": "K"},
                "size_mb": {"label": "Size (MB)", "higher_better": False},
                "loss": {"label": "Loss", "higher_better": False}
            }
        }
    }

    # Load all data
    all_model_data_loaded = {}
    for model_key, config in models_config.items():
        try:
            all_model_data_loaded[model_key] = load_json_data(config["file"])
        except FileNotFoundError:
            print(f"ERROR: Data file not found: {config['file']}. Skipping {config['display_name']}.")
            all_model_data_loaded[model_key] = None # Mark as None if file not found
        except Exception as e:
            print(f"An error occurred while loading {config['file']}: {e}")
            all_model_data_loaded[model_key] = None

    # --- 1. Generate individual plots for each model ---
    print("\n--- Generating Individual Model Plots ---")
    for model_key, config in models_config.items():
        model_data = all_model_data_loaded.get(model_key)
        if model_data is None:
            continue # Skip if data wasn't loaded

        model_name_file = model_key.lower().replace('-', '')
        for metric_key, metric_config in config["metrics"].items():
            plot_individual_model_metric(
                model_data=model_data,
                model_name_display=config["display_name"],
                model_name_file=model_name_file,
                metric_key=metric_key,
                y_label=metric_config["label"],
                is_higher_better=metric_config["higher_better"],
                y_multiplier=metric_config.get("multiplier", 1.0),
                y_unit_prefix=metric_config.get("unit_prefix", "")
            )

    # --- 2. Generate Comparative CNN Plots ---
    print("\n--- Generating Comparative CNN Plots ---")
    cnn_model_keys = [k for k, v in models_config.items() if v["type"] == "CNN"]
    cnn_models_data_for_plot = {k: all_model_data_loaded[k] for k in cnn_model_keys if all_model_data_loaded.get(k)}

    if len(cnn_models_data_for_plot) == len(cnn_model_keys) and cnn_model_keys: # Ensure all specified CNNs loaded
        cnn_display_names_obj = [{'key': k, 'display': models_config[k]['display_name']} for k in cnn_model_keys]
        # Define metrics for CNN comparison (can pick from one of the CNN cfgs)
        ref_cnn_config = models_config[cnn_model_keys[0]]['metrics']
        metrics_for_cnn_comparison = {
            "accuracy": ref_cnn_config["accuracy"],
            "params": ref_cnn_config["params"],
            "macs": ref_cnn_config["macs"], # Will use MobileNetV2's MACs settings here
            "size_mb": ref_cnn_config["size_mb"],
        }
        # Custom MACs config for ResNet18 for better Y-axis scaling in comparative plot
        metrics_for_cnn_comparison["macs"]["multiplier_override"] = {
            "MobileNetV2": 1e-6, "ResNet18": 1e-7 # Example if ResNet18 MACs are an order larger
        }
        metrics_for_cnn_comparison["macs"]["unit_prefix_override"] = {
             "MobileNetV2": "M", "ResNet18": "10M"
        }


        for metric_key, base_metric_config in metrics_for_cnn_comparison.items():
             # For MACs, we want specific scaling for ResNet-18 if defined, otherwise generic
            y_multiplier = base_metric_config.get("multiplier", 1.0)
            y_unit_prefix = base_metric_config.get("unit_prefix", "")
            # No need for multiplier/unit_prefix override logic in comparative plots as it adds complexity.
            # Simpler to pick one consistent scale for the comparative graph.
            # Using MobileNetV2's typical scale for Parameters and MACs for the comparative plot.
            # For resnet18, its MACs are higher, so you might manually adjust if one curve dominates the scale too much.
            # The common y_label, y_multiplier, y_unit_prefix for a given comparative plot are key.

            plot_comparative_group(
                models_data_dict=cnn_models_data_for_plot,
                model_names_in_group_display=cnn_display_names_obj,
                model_group_name_display="CNNs",
                model_group_name_file="cnns",
                metric_key=metric_key,
                y_label=base_metric_config["label"],
                is_higher_better=base_metric_config["higher_better"],
                # Use consistent multiplier for comparison, e.g., from MobileNetV2 or an average scale
                y_multiplier=models_config["MobileNetV2"]["metrics"][metric_key].get("multiplier", 1.0),
                y_unit_prefix=models_config["MobileNetV2"]["metrics"][metric_key].get("unit_prefix", "")
            )
    else:
        print("Skipping comparative CNN plots as data for one or more CNN models is missing.")

    # --- 3. Generate Comparative Time-Series Plots (LSTM & MLP) ---
    print("\n--- Generating Comparative Time-Series Plots ---")
    ts_model_keys = [k for k, v in models_config.items() if v["type"] == "TimeSeries"]
    ts_models_data_for_plot = {k: all_model_data_loaded[k] for k in ts_model_keys if all_model_data_loaded.get(k)}

    if len(ts_models_data_for_plot) == len(ts_model_keys) and ts_model_keys: # Ensure all specified TS models loaded
        ts_display_names_obj = [{'key': k, 'display': models_config[k]['display_name']} for k in ts_model_keys]
        ref_ts_config = models_config[ts_model_keys[0]]['metrics']
        metrics_for_ts_comparison = {
            "mse": ref_ts_config["mse"],
            "mae": ref_ts_config["mae"],
            "params": ref_ts_config["params"],
            "macs": ref_ts_config["macs"],
            "size_mb": ref_ts_config["size_mb"],
        }

        for metric_key, metric_config in metrics_for_ts_comparison.items():
            plot_comparative_group(
                models_data_dict=ts_models_data_for_plot,
                model_names_in_group_display=ts_display_names_obj,
                model_group_name_display="Time-Series Models",
                model_group_name_file="timeseries",
                metric_key=metric_key,
                y_label=metric_config["label"],
                is_higher_better=metric_config["higher_better"],
                y_multiplier=metric_config.get("multiplier", 1.0), # Uses LSTM's settings here by default
                y_unit_prefix=metric_config.get("unit_prefix", "")
            )
    else:
        print("Skipping comparative Time-Series plots as data for one or more models is missing.")

    print(f"\nAll plots saved to '{OUTPUT_DIR}' directory.")


--- Generating Individual Model Plots ---
Saved: individual_mobilenetv2_accuracy_vs_pruning.png
Saved: individual_mobilenetv2_params_vs_pruning.png
Saved: individual_mobilenetv2_macs_vs_pruning.png
Saved: individual_mobilenetv2_size_mb_vs_pruning.png
Saved: individual_mobilenetv2_loss_vs_pruning.png
Saved: individual_resnet18_accuracy_vs_pruning.png
Saved: individual_resnet18_params_vs_pruning.png
Saved: individual_resnet18_macs_vs_pruning.png
Saved: individual_resnet18_size_mb_vs_pruning.png
Saved: individual_resnet18_loss_vs_pruning.png
Saved: individual_lstm_mse_vs_pruning.png
Saved: individual_lstm_mae_vs_pruning.png
Saved: individual_lstm_params_vs_pruning.png
Saved: individual_lstm_macs_vs_pruning.png
Saved: individual_lstm_size_mb_vs_pruning.png
Saved: individual_lstm_loss_vs_pruning.png
Saved: individual_mlp_mse_vs_pruning.png
Saved: individual_mlp_mae_vs_pruning.png
Saved: individual_mlp_params_vs_pruning.png
Saved: individual_mlp_macs_vs_pruning.png
Saved: individual_mlp_siz