In [None]:
import matplotlib.pyplot as plt  # ensure this is imported at the top

# Only run these 3 configurations
configurations = [
    ('SGD', 0.01, 0.5),
    ('Adam', 0.01, 0.5),
    ('AdamW', 0.005, 0.9)
]

with open(report_path, "w") as report_file:
    for opt_name, lr, lrd in configurations:
        print(f"\n=== Training with {opt_name}, lr={lr}, lr_decay={lrd} ===")
        model = FruitCNN().to(device)
        criterion = nn.CrossEntropyLoss()

        if opt_name == 'SGD':
            optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
        elif opt_name == 'Adam':
            optimizer = optim.Adam(model.parameters(), lr=lr)
        elif opt_name == 'AdamW':
            optimizer = optim.AdamW(model.parameters(), lr=lr)
        else:
            raise ValueError(f"Unknown optimizer: {opt_name}")

        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=lrd)

        train_acc, val_acc, train_loss, val_loss = [], [], [], []

        for epoch in range(epochs):
            model.train()
            correct, total, running_loss = 0, 0, 0.0
            for images, labels in train_loader:
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)
            train_acc.append(correct / total)
            train_loss.append(running_loss)

            model.eval()
            val_correct, val_total, val_running_loss = 0, 0, 0.0
            with torch.no_grad():
                for val_images, val_labels in val_loader:
                    val_images, val_labels = val_images.to(device), val_labels.to(device)
                    val_outputs = model(val_images)
                    val_loss_batch = criterion(val_outputs, val_labels)
                    val_running_loss += val_loss_batch.item()
                    _, val_preds = torch.max(val_outputs, 1)
                    val_total += val_labels.size(0)
                    val_correct += (val_preds == val_labels).sum().item()
            val_acc.append(val_correct / val_total)
            val_loss.append(val_running_loss)
            scheduler.step()

            print(f"Epoch {epoch+1}, Train Acc: {train_acc[-1]:.4f}, Val Acc: {val_acc[-1]:.4f}")

        # Test + Classification Report
        model.eval()
        y_true, y_pred = [], []
        with torch.no_grad():
            for images, labels in test_loader:
                images = images.to(device)
                outputs = model(images)
                _, preds = torch.max(outputs, 1)
                y_pred.extend(preds.cpu().numpy())
                y_true.extend(labels.numpy())

        report = classification_report(y_true, y_pred, target_names=train_data.classes)
        print("\nClassification Report:")
        print(report)

        report_file.write(
            f"=== Optimizer: {opt_name}, LR: {lr}, Decay: {lrd} ===\n{report}\n\n"
        )

        all_results.append({
            'optimizer': opt_name,
            'lr': lr,
            'lr_decay': lrd,
            'final_train_acc': train_acc[-1],
            'final_val_acc': val_acc[-1],
            'final_train_loss': train_loss[-1],
            'final_val_loss': val_loss[-1]
        })

        # Plot Accuracy
        plt.figure()
        plt.plot(train_acc, label='Train Acc')
        plt.plot(val_acc, label='Val Acc')
        plt.title(f'{opt_name} lr={lr}, decay={lrd} - Accuracy')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        # Plot Loss
        plt.figure()
        plt.plot(train_loss, label='Train Loss')
        plt.plot(val_loss, label='Val Loss')
        plt.title(f'{opt_name} lr={lr}, decay={lrd} - Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        del model
        torch.cuda.empty_cache()
