In [None]:
%autoreload 2

In [None]:
import pickle
import matplotlib.pyplot as plt
import numpy as np

models = {
    'nuinsseg-mouse': 'models/nuinsseg_mouse_model_name/history.pkl',
    'nuinsseg-human': 'models/nuinsseg_human_model_name/history.pkl', 
    'nuinsseg-human-mouse': 'models/nuinsseg_human_mouse_model_name/history.pkl'
}

histories = {}
for model_name, path in models.items():
    try:
        with open(path, 'rb') as f:
            histories[model_name] = pickle.load(f)
    except FileNotFoundError:
        continue

colors = {'nuinsseg-mouse': '#1f77b4',
          'nuinsseg-human': '#ff7f0e',
          'nuinsseg-human-mouse': '#2ca02c'}

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Plot Training Loss
ax1.set_title('Training Loss Comparison', fontsize=14, fontweight='bold')
for model_name, history in histories.items():
    if 'loss' in history:
        epochs = range(1, len(history['loss']) + 1)
        ax1.plot(epochs, history['loss'], 
                label=model_name, 
                color=colors.get(model_name, 'black'),
                linewidth=2, alpha=0.8)

ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Training Loss', fontsize=12) 
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)
ax1.set_yscale('log')

# Plot Validation Loss
ax2.set_title('Validation Loss Comparison', fontsize=14, fontweight='bold')
for model_name, history in histories.items():
    if 'val_loss' in history:
        epochs = range(1, len(history['val_loss']) + 1)
        ax2.plot(epochs, history['val_loss'], 
                label=model_name,
                color=colors.get(model_name, 'black'),
                linewidth=2, alpha=0.8)

ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Validation Loss', fontsize=12)
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)
ax2.set_yscale('log')

plt.tight_layout()
plt.show()

In [None]:
import json
import matplotlib.pyplot as plt

eval_files = {
    'nuinsseg-mouse': {
        'val': 'models/nuinsseg_mouse_model_name/validation_evaluation_results.json',
        'test': 'models/nuinsseg_mouse_model_name/test_evaluation_results.json'
    },
    'nuinsseg-human': {
        'val': 'models/nuinsseg_human_model_name/validation_evaluation_results.json', 
        'test': 'models/nuinsseg_human_model_name/test_evaluation_results.json'
    },
    'nuinsseg-human-mouse': {
        'val': 'models/nuinsseg_human_mouse_model_name/validation_evaluation_results.json',
        'test': 'models/nuinsseg_human_mouse_model_name/test_evaluation_results.json'
    }
}

eval_results = {}
for model_name, paths in eval_files.items():
    eval_results[model_name] = {}
    for split, path in paths.items():
        try:
            with open(path, 'r') as f:
                eval_results[model_name][split] = json.load(f)
            print(f"✓ Loaded {model_name} {split} results")
        except FileNotFoundError:
            print(f"✗ Could not find {path}")
            continue

colors = {'nuinsseg-mouse': '#1f77b4',
          'nuinsseg-human': '#ff7f0e',
          'nuinsseg-human-mouse': '#2ca02c'}

line_styles = {'val': '-', 'test': '--'}
markers = {'val': 'o', 'test': 's'}

metrics = ['precision', 'recall', 'f1', 'accuracy']
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.flatten()

for i, metric in enumerate(metrics):
    ax = axes[i]
    ax.set_title(f'{metric.title()} Comparison Across Models', fontsize=14, fontweight='bold')
    
    # Plot each model and split combination
    for model_name, results in eval_results.items():
        for split in ['val', 'test']:
            if split in results and results[split]:
                iou_thresholds = results[split]['iou_thresholds'] 
                stats = results[split]['stats']
                values = [stat[metric] for stat in stats]
                
                label = f"{model_name} ({split.capitalize()})"
                ax.plot(iou_thresholds, values, 
                       color=colors.get(model_name, 'black'),
                       linestyle=line_styles[split],
                       marker=markers[split],
                       markersize=4,
                       linewidth=2,
                       alpha=0.8,
                       label=label)
    
    ax.set_xlabel('IoU Threshold', fontsize=12)
    ax.set_ylabel(metric.title(), fontsize=12)
    ax.legend(fontsize=9, loc='best')
    ax.grid(True, alpha=0.3)
    ax.set_ylim([0, 1])

plt.tight_layout()
plt.show()