In [None]:
import argparse
import os
import sys
import yaml
import torch
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
from thop import profile
from tqdm import tqdm
import csv

sys.path.append("../")

from networks.unet import UNet
from networks.attention_unet import AttnUNet
import src.data_utils as data_utils
from src.metrics import dice_loss, dice_coefficient

In [None]:
# which run to visualize
run_dir = "../runs/unet_0725_1746"
yaml_path = os.path.join(run_dir, 'summary.yaml')
csv_path = os.path.join(run_dir, 'metrics.csv')

In [None]:
with open(yaml_path, 'r') as f:
    config = yaml.safe_load(f)


with open(csv_path, 'r') as f:
    history = list(csv.DictReader(f))

In [None]:
best_epoch_metrics = max(history, key=lambda x: x['val_dice'])

# make best_epoch_metrics a dictionary with correct types
best_epoch_metrics = {
    'epoch': int(best_epoch_metrics['epoch']),
    'val_dice': float(best_epoch_metrics['val_dice']),
    'train_loss': float(best_epoch_metrics['train_loss']),
    'val_loss': float(best_epoch_metrics['val_loss'])
}

In [None]:
from matplotlib import ticker
metrics_df = pd.read_csv(csv_path)

# 3. Plot and save metrics
plt.style.use('petroff10')

x   = metrics_df['epoch']
y11 = metrics_df['train_loss']
y12 = metrics_df['val_loss']
y2  = metrics_df['val_dice']


fig, ax1 = plt.subplots(figsize=(6, 4))

# Plot Loss
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss', color='brown', fontsize=10)
ax1.plot(x, y11, '-', color='brown', label='Train Loss')
ax1.plot(x, y12, '--', color='brown', label='Val Loss')
ax1.tick_params(axis='y', labelcolor='brown')
ax1.yaxis.set_major_locator(ticker.MaxNLocator(nbins=8, prune=None)) # Keep nbins consistent
ax1.set_ylim(bottom=0)

# Set integer ticks for x-axis
ax1.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
ax1.set_xlim(left=x.min(), right=x.max())

# Right y-axis for Dice
ax2 = ax1.twinx()
ax2.set_ylabel('Dice Coefficient', color='darkblue', fontsize=10)
ax2.plot(x, y2, '-', color='darkblue', label='Val Dice')
ax2.tick_params(axis='y', labelcolor='darkblue')

ax2.yaxis.set_major_locator(ticker.MaxNLocator(nbins=8, prune=None)) # Keep nbins consistent
ax2.set_ylim(bottom=0, top=1)


# Mark best epoch
ax2.axvline(x=best_epoch_metrics['epoch'], color='seagreen', linestyle=':', linewidth=2, label=f"Best Epoch ({best_epoch_metrics['epoch']})")
ax2.scatter(best_epoch_metrics['epoch'], best_epoch_metrics['val_dice'], color='sandybrown', marker='*', s=230, zorder=5, label=f"Best Dice: {best_epoch_metrics['val_dice']:.4f}")


# --- Consolidate Legends ---
# Get handles and labels from both axes
lines1, labels1 = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()

# Combine them
lines = lines1 + lines2
labels = labels1 + labels2

# Set spines color and linewidth
ax2.spines['left'].set_color('brown')
ax2.spines['left'].set_linewidth(1.5)
ax2.spines['right'].set_color('darkblue')
ax2.spines['right'].set_linewidth(1.5)

plt.tick_params(direction="out")

# Create a single legend on ax1 (or ax2, or the figure)
legend = ax1.legend(lines, labels, fontsize=9, loc='upper left', bbox_to_anchor=(0.08, 1), frameon=True, fancybox=False, framealpha=0) # Adjust bbox_to_anchor if needed


# Set the size of the star marker in the legend


plt.title(f"Training Metrics: {config['model_type']} on {config['dataset']}", fontsize=14, pad=10)
fig.tight_layout()
plt.show()
plt.close()