In [1]:
import os
import pickle
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc

In [31]:
class TrainingEvaluator:
    def __init__(self, filepath):
        """
        Loads training history and logs from a .pkl file.
        """
        if not os.path.exists(filepath):
            raise FileNotFoundError(f"No such file: {filepath}")

        with open(filepath, 'rb') as f:
            data = pickle.load(f)

        self.history = data.get('history', {})
        self.gpu_memory_log = data.get('gpu_memory_log', [])
        self.epoch_times = data.get('epoch_times', [])
        self.total_time = data.get('total_training_time_seconds', None)

    def plot_metric(self, metric):
        """
        Plots a given training/validation metric.
        """
        if metric not in self.history:
            print(f"Metric '{metric}' not found in training history.")
            return

        plt.plot(self.history[metric], label=f'Train {metric}')
        val_key = f'val_{metric}'
        if val_key in self.history:
            plt.plot(self.history[val_key], label=f'Validation {metric}')

        plt.title(f"{metric.capitalize()} Over Epochs")
        plt.xlabel("Epoch")
        plt.ylabel(metric.capitalize())
        plt.legend()
        plt.grid(True)
        plt.show()

    def plot_all_metrics(self):
        """
        Plots all available metrics in training history.
        """
        for metric in self.history.keys():
            if not metric.startswith('val_'):  # plot each metric once
                self.plot_metric(metric)

    def plot_gpu_memory_usage(self):
        """
        Plots GPU memory usage per epoch.
        """
        if not self.gpu_memory_log:
            print("No GPU memory log available.")
            return

        epochs = [x["epoch"] for x in self.gpu_memory_log]
        mem_mb = [x["gpu_memory_used_mb"] for x in self.gpu_memory_log]

        plt.plot(epochs, mem_mb, marker='o')
        plt.title("GPU Memory Usage per Epoch (MB)")
        plt.xlabel("Epoch")
        plt.ylabel("Memory (MB)")
        plt.grid(True)
        plt.show()

    def plot_epoch_times(self):
        """
        Plots the duration of each training epoch.
        """
        if not self.epoch_times:
            print("No epoch time data available.")
            return

        epochs = [x["epoch"] for x in self.epoch_times]
        times = [x["time_seconds"] for x in self.epoch_times]

        plt.plot(epochs, times, marker='o')
        plt.title("Epoch Training Time (Seconds)")
        plt.xlabel("Epoch")
        plt.ylabel("Time (s)")
        plt.grid(True)
        plt.show()
        
    def plot_auc_roc(self, y_true, y_scores):
        """
        Plots the AUC-ROC curve given true labels and predicted scores.
        """
        fpr, tpr, _ = roc_curve(y_true, y_scores)
        roc_auc = auc(fpr, tpr)

        plt.figure()
        plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlim([0.0, 1.0])    
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')  # ✅ fixed
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic')
        plt.legend(loc='lower right')
        plt.grid(True)
        plt.show()
            
    def summary(self):
        """
        Prints a basic summary.
        """
        print("=== Training Summary ===")
        print(f"Tracked metrics: {list(self.history.keys())}")
        if self.total_time:
            mins = round(self.total_time / 60, 2)
            print(f"Total training time: {mins} minutes")
        if self.gpu_memory_log:
            print(f"Max GPU memory used: {max(self.gpu_memory_log)} MB")
        if self.epoch_times:
            avg_time = sum(t["time_seconds"] for t in self.epoch_times) / len(self.epoch_times)
            print(f"Average epoch time: {avg_time:.2f} seconds")
        print("========================")
        
    def save_plot(self, plt_obj, filename, directory="plots"):
        os.makedirs(directory, exist_ok=True)
        filepath = os.path.join(directory, filename)
        plt_obj.savefig(filepath)
        print(f"Plot saved: {filepath}")

In [32]:
evaluator = TrainingEvaluator('../training/training_metrics/full_fine_tuning_set_small/full_fine_tuning_set_small.pkl')

TypeError: TrainingEvaluator.plot_auc_roc() missing 2 required positional arguments: 'y_true' and 'y_scores'