In [1]:
from google.colab import drive
drive.mount('/drive')
!ln -s "/drive/MyDrive/LeafDisease" "/content/LeafDisease"

ValueError: mount failed

In [None]:
# %run "/content/LeafDisease/dataset/datasetup.ipynb"

In [None]:
# from resnet50 import *
# from densenet121 import *
# from efficientnetb0 import *
# from mobilenetv2 import *
# from leafnetv2 import *
# from leafNet import *
%run "/content/LeafDisease/models/densenet121.py"
%run "/content/LeafDisease/models/efficientnetb0.py"
%run "/content/LeafDisease/models/mobilenetv2.py"
%run "/content/LeafDisease/models/leafnetv2.py"
%run "/content/LeafDisease/models/leafnet.py"

Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth
100%|██████████| 30.8M/30.8M [00:00<00:00, 137MB/s]


cpu


In [None]:
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc, f1_score
import pickle
import numpy as np
from google.colab import files

In [None]:
class BinaryClassifierTrainer:
    def __init__(self, device, class_names=["Healthy", "Unhealthy"]):
        self.device = device
        self.class_names = class_names

    def plot_history(self, history, model_name="model"):
        """Plot training/validation metrics and download plot"""
        plt.figure(figsize=(18, 6))

        # Loss plot
        plt.subplot(1, 3, 1)
        plt.plot(history["train_loss"], label="Train Loss")
        plt.plot(history["val_loss"], label="Val Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Loss Curve")
        plt.legend()

        # Accuracy plot
        plt.subplot(1, 3, 2)
        plt.plot(history["train_acc"], label="Train Acc")
        plt.plot(history["val_acc"], label="Val Acc")
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy")
        plt.title("Accuracy Curve")
        plt.legend()

        # F1 Score plot
        plt.subplot(1, 3, 3)
        plt.plot(history["val_f1"], label="Val F1", color='purple')
        plt.xlabel("Epoch")
        plt.ylabel("F1 Score")
        plt.title("F1 Score Curve")
        plt.legend()

        plt.tight_layout()
        plt.show()

        # Save and download
        plot_path = f"{model_name}_training_history.png"
        plt.savefig(plot_path)
        files.download(plot_path)

    def eval_plot(self, y_true, y_pred, model_name="model"):
        """Plot confusion matrix and download plot"""
        conf_matrix = confusion_matrix(y_true, y_pred)
        tn, fp, fn, tp = conf_matrix.ravel()

        print("Confusion Matrix:\n", conf_matrix)
        print(f"TP: {tp}, FP: {fp}, TN: {tn}, FN: {fn}")

        disp = ConfusionMatrixDisplay(confusion_matrix=conf_matrix, display_labels=self.class_names)
        disp.plot(cmap='Blues', values_format='d')
        plt.title(f"Confusion Matrix - {model_name}")
        plt.tight_layout()
        plt.show()

        # Save and download
        plot_path = f"{model_name}_confusion_matrix.png"
        plt.savefig(plot_path)
        files.download(plot_path)

    def plot_roc(self, model, dataloader, model_name="model"):
        """Plot ROC curve for a single model and download"""
        model.eval()
        y_true, y_scores = [], []

        with torch.no_grad():
            for inputs, labels in dataloader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                outputs = model(inputs)
                probs = torch.softmax(outputs, dim=1)[:, 1]
                y_true.extend(labels.cpu().numpy())
                y_scores.extend(probs.detach().cpu().numpy())

        fpr, tpr, _ = roc_curve(y_true, y_scores)
        roc_auc = auc(fpr, tpr)

        plt.figure(figsize=(6,6))
        plt.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
        plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.title(f"ROC Curve - {model_name}")
        plt.legend(loc="lower right")
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        # Save and download
        plot_path = f"{model_name}_roc_curve.png"
        plt.savefig(plot_path)
        files.download(plot_path)

        return roc_auc

    def train(self, model, criterion, optimizer, dataloaders, image_datasets,
         num_epochs=10, patience=3, save_path="history.pkl"):
      """Train model with early stopping and save full evaluation stats"""
      history = {
          "train_loss": [], "train_acc": [],
          "val_loss": [], "val_acc": [], "val_f1": []
      }
      best_loss = float('inf')
      epochs_no_improve = 0

      final_y_true, final_y_pred = [], []
      final_y_scores = []

      for epoch in range(num_epochs):
          print(f"\nEpoch {epoch + 1}/{num_epochs}")

          for phase in ["train", "val"]:
              model.train() if phase == "train" else model.eval()

              running_loss = 0.0
              running_corrects = 0
              y_true, y_pred, y_scores = [], [], []

              for inputs, labels in dataloaders[phase]:
                  inputs = inputs.to(self.device)
                  labels = labels.to(self.device)

                  optimizer.zero_grad()

                  with torch.set_grad_enabled(phase == "train"):
                      outputs = model(inputs)
                      loss = criterion(outputs, labels)
                      _, preds = torch.max(outputs, 1)
                      probs = torch.softmax(outputs, dim=1)[:, 1]

                      if phase == "train":
                          loss.backward()
                          optimizer.step()

                  y_true.extend(labels.cpu().numpy())
                  y_pred.extend(preds.cpu().numpy())
                  y_scores.extend(probs.detach().cpu().numpy())
                  running_loss += loss.item() * inputs.size(0)
                  running_corrects += torch.sum(preds == labels.data)

              epoch_loss = running_loss / len(image_datasets[phase])
              epoch_acc = running_corrects.float() / len(image_datasets[phase])

              history[f"{phase}_loss"].append(epoch_loss)
              history[f"{phase}_acc"].append(epoch_acc.item())

              if phase == "val":
                  epoch_f1 = f1_score(y_true, y_pred)
                  history["val_f1"].append(epoch_f1)

                  final_y_true = y_true
                  final_y_pred = y_pred
                  final_y_scores = y_scores

                  # Early stopping
                  if epoch_loss < best_loss:
                      best_loss = epoch_loss
                      epochs_no_improve = 0
                      torch.save(model.state_dict(), "best_model.pth")
                  else:
                      epochs_no_improve += 1
                      print(f"No improvement for {epochs_no_improve} epoch(s)")

                  if epochs_no_improve >= patience:
                      print("Early stopping triggered")
                      with open(save_path, "wb") as f:
                          pickle.dump({
                              "history": history,
                              "y_true": final_y_true,
                              "y_pred": final_y_pred,
                              "y_scores": final_y_scores,
                              "best_loss": best_loss
                          }, f)
                      return model, history, final_y_true, final_y_pred

              print(f"{phase} Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.4f}",
                    end=" | F1: {epoch_f1:.4f}" if phase == "val" else "")

      # Save at the end
      with open(save_path, "wb") as f:
          pickle.dump({
              "history": history,
              "y_true": final_y_true,
              "y_pred": final_y_pred,
              "y_scores": final_y_scores,
              "best_loss": best_loss
          }, f)

      return model, history, final_y_true, final_y_pred
    def compare_models(self, models_info, dataloader):
        """Compare multiple models using ROC curves"""
        plt.figure(figsize=(8, 6))

        for name, model in models_info:
            model = model.to(self.device)
            y_true, y_scores = self.get_probs_and_labels(model, dataloader)
            fpr, tpr, _ = roc_curve(y_true, y_scores)
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, lw=2, label=f'{name} (AUC = {roc_auc:.2f})')

        plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Model Comparison - ROC Curves')
        plt.legend(loc='lower right')
        plt.grid(True)
        plt.show()

    def get_probs_and_labels(self, model, dataloader):
        """Get predicted probabilities and true labels"""
        model.eval()
        y_true, y_scores = [], []

        with torch.no_grad():
            for inputs, labels in dataloader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)

                outputs = model(inputs)
                probs = torch.softmax(outputs, dim=1)[:, 1]

                y_true.extend(labels.cpu().numpy())
                y_scores.extend(probs.detach().cpu().numpy())

        return y_true, y_scores

    def plot_f1_curves(self, all_histories):
        """Plot F1 score progression for multiple models"""
        plt.figure(figsize=(10, 6))

        for model_name, history in all_histories.items():
            if "val_f1" in history:
                epochs = range(1, len(history["val_f1"]) + 1)
                plt.plot(epochs, history["val_f1"], label=model_name)

        plt.xlabel("Epoch")
        plt.ylabel("F1 Score")
        plt.title("Validation F1 Score Comparison")
        plt.legend()
        plt.grid(True)
        plt.show()