In [None]:
from google.colab import drive
drive.mount('/drive')
# !ln -s "/drive/MyDrive/LeafDisease" "/content/LeafDisease"

Mounted at /drive


In [None]:
!ln -s "/drive/MyDrive/LeafDisease" "/content/LeafDisease"

In [None]:
# %run "/content/LeafDisease/dataset/datasetup.ipynb"
!ls

LeafDisease  sample_data


In [3]:
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc, f1_score, classification_report, accuracy_score
import pickle
import numpy as np
from google.colab import files
import pandas as pd
import seaborn as sns


In [4]:
class MultiClassClassifierTrainer:
    def __init__(self, device, class_names):
        self.device = device
        self.class_names = class_names
        self.num_classes = len(class_names)

    def plot_history(self, history, model_name="model"):
        """Plot training/validation metrics"""
        plt.figure(figsize=(18, 6))

        # Loss plot
        plt.subplot(1, 3, 1)
        plt.plot(history["train_loss"], label="Train Loss")
        plt.plot(history["val_loss"], label="Val Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Loss Curve")
        plt.legend()

        # Accuracy plot
        plt.subplot(1, 3, 2)
        plt.plot(history["train_acc"], label="Train Acc")
        plt.plot(history["val_acc"], label="Val Acc")
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.title("Accuracy Curve")
        plt.legend()

        # F1 Score plot (macro average for multi-class)
        plt.subplot(1, 3, 3)
        plt.plot(history["val_f1_macro"], label="Val F1 (Macro)", color='purple')
        plt.xlabel("Epoch")
        plt.ylabel("F1 Score")
        plt.title("F1 Score Curve (Macro Average)")
        plt.legend()

        plt.tight_layout()
        plt.show()

    def eval_plot(self, y_true, y_pred, model_name="model"):
        """Plot confusion matrix for multi-class"""
        conf_matrix = confusion_matrix(y_true, y_pred)

        print("Confusion Matrix:\n", conf_matrix)
        print(f"Overall Accuracy: {accuracy_score(y_true, y_pred):.4f}")
        print(f"Macro F1: {f1_score(y_true, y_pred, average='macro'):.4f}")

        disp = ConfusionMatrixDisplay(confusion_matrix=conf_matrix,
                                    display_labels=self.class_names)
        fig, ax = plt.subplots(figsize=(12, 10))
        disp.plot(ax=ax, cmap='Blues', values_format='d', xticks_rotation=45)
        plt.title(f"Confusion Matrix - {model_name}")
        plt.tight_layout()
        plt.show()

    def plot_roc(self, model, dataloader, model_name="model"):
        """Plot ROC curves for multi-class (one-vs-rest)"""
        model.eval()
        y_true, y_probs = [], []

        with torch.no_grad():
            for inputs, labels in dataloader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                outputs = model(inputs)
                probs = torch.softmax(outputs, dim=1)
                y_true.extend(labels.cpu().numpy())
                y_probs.extend(probs.detach().cpu().numpy())

        y_true = np.array(y_true)
        y_probs = np.array(y_probs)

        # Plot ROC for each class
        plt.figure(figsize=(10, 8))
        for i in range(self.num_classes):
            fpr, tpr, _ = roc_curve(y_true == i, y_probs[:, i])
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, lw=2,
                    label=f'{self.class_names[i]} (AUC = {roc_auc:.2f})')

        plt.plot([0, 1], [0, 1], 'k--', lw=2)
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title(f'Multi-class ROC Curves - {model_name}')
        plt.legend(loc='lower right')
        plt.grid(True)
        plt.tight_layout()
        plt.show()

    def train(self, model, criterion, optimizer, dataloaders, image_datasets,
              num_epochs=10, patience=3, save_path="history.pkl"):
        """Train model with early stopping for multi-class"""
        history = {
            "train_loss": [], "train_acc": [],
            "val_loss": [], "val_acc": [],
            "val_f1_macro": [], "val_f1_weighted": []
        }
        best_loss = float('inf')
        epochs_no_improve = 0

        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch + 1}/{num_epochs}")

            for phase in ["train", "val"]:
                model.train() if phase == "train" else model.eval()

                running_loss = 0.0
                running_corrects = 0
                all_true, all_preds = [], []

                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(self.device)
                    labels = labels.to(self.device)

                    optimizer.zero_grad()

                    with torch.set_grad_enabled(phase == "train"):
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                        _, preds = torch.max(outputs, 1)

                        if phase == "train":
                            loss.backward()
                            optimizer.step()

                    all_true.extend(labels.cpu().numpy())
                    all_preds.extend(preds.cpu().numpy())
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)

                epoch_loss = running_loss / len(image_datasets[phase])
                epoch_acc = running_corrects.float() / len(image_datasets[phase])

                history[f"{phase}_loss"].append(epoch_loss)
                history[f"{phase}_acc"].append(epoch_acc.item())

                if phase == "val":
                    # Multi-class F1 scores
                    f1_macro = f1_score(all_true, all_preds, average='macro')
                    f1_weighted = f1_score(all_true, all_preds, average='weighted')
                    history["val_f1_macro"].append(f1_macro)
                    history["val_f1_weighted"].append(f1_weighted)

                    # Early stopping
                    if epoch_loss < best_loss:
                        best_loss = epoch_loss
                        epochs_no_improve = 0
                        torch.save(model.state_dict(), "best_model.pth")
                        best_true, best_preds = all_true, all_preds
                    else:
                        epochs_no_improve += 1

                    if epochs_no_improve >= patience:
                        print("Early stopping triggered")
                        return model, history, best_true, best_preds

                print(f"{phase} Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f}",
                      end=f" | F1 Macro: {f1_macro:.4f}" if phase == "val" else "")
                print(f" | F1 Weighted: {f1_weighted:.4f}" if phase == "val" else "")

        return model, history, best_true, best_preds

    def plot_classification_report(self, y_true, y_pred, model_name="model"):
        """Plot detailed classification report"""
        report = classification_report(y_true, y_pred,
                                     target_names=self.class_names,
                                     output_dict=True)

        # Convert to DataFrame for better visualization
        report_df = pd.DataFrame(report).transpose()

        plt.figure(figsize=(12, 8))
        sns.heatmap(report_df.iloc[:-3, :-1].astype(float),
                   annot=True, cmap='Blues', fmt='.3f')
        plt.title(f"Classification Report - {model_name}")
        plt.tight_layout()
        plt.show()

        return report_df
    def compare_models(self, models_info, dataloader):
      """Compare multiple models using macro F1 scores"""
      plt.figure(figsize=(10, 6))

      for name, model in models_info:
          model = model.to(self.device)
          model.eval()
          all_true, all_preds = [], []

          with torch.no_grad():
              for inputs, labels in dataloader:
                  inputs = inputs.to(self.device)
                  labels = labels.to(self.device)
                  outputs = model(inputs)
                  _, preds = torch.max(outputs, 1)
                  all_true.extend(labels.cpu().numpy())
                  all_preds.extend(preds.cpu().numpy())

          f1 = f1_score(all_true, all_preds, average='macro')
          plt.bar(name, f1, label=f'{name} (F1 = {f1:.3f})')

      plt.xlabel('Models')
      plt.ylabel('Macro F1 Score')
      plt.title('Model Comparison - Macro F1 Scores')
      plt.legend()
      plt.xticks(rotation=45)
      plt.tight_layout()
      plt.show()

    def plot_f1_curves(self, all_histories):
        """Plot F1 score progression for multiple models"""
        plt.figure(figsize=(12, 6))

        plt.subplot(1, 2, 1)
        for model_name, history in all_histories.items():
            if "val_f1_macro" in history:
                epochs = range(1, len(history["val_f1_macro"]) + 1)
                plt.plot(epochs, history["val_f1_macro"], label=model_name)
        plt.xlabel("Epoch")
        plt.ylabel("F1 Score")
        plt.title("Macro F1 Score Comparison")
        plt.legend()
        plt.grid(True)

        plt.subplot(1, 2, 2)
        for model_name, history in all_histories.items():
            if "val_f1_weighted" in history:
                epochs = range(1, len(history["val_f1_weighted"]) + 1)
                plt.plot(epochs, history["val_f1_weighted"], label=model_name)
        plt.xlabel("Epoch")
        plt.ylabel("F1 Score")
        plt.title("Weighted F1 Score Comparison")
        plt.legend()
        plt.grid(True)

        plt.tight_layout()
        plt.show()