<a href="https://colab.research.google.com/github/garlicxd/Fruit-Classification/blob/automation/Fruit_Classification_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Todo
- [ ] patience for training - stopping after no improvement for several epochs
- [ ] create graphs - store information while training
- [ ] compare different optimizers

In [None]:
#@title Function Declarations
import os
import glob
import shutil
import json
import time
import subprocess
import sys
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm.auto import tqdm
from sklearn.metrics import confusion_matrix

# ==========================================
# 1. Configuration
# ==========================================
class Config:
    def __init__(self):
        # Hardware Detection (CUDA / ROCm / CPU)
        self.DEVICE = self._get_device()

        # Dataset Paths
        self.DATASET_FOLDER_PATH = "./split_ttv_dataset_type_of_plants"
        self.BASE_TRAIN_DIR = os.path.join(self.DATASET_FOLDER_PATH, 'Train_Set_Folder')
        self.BASE_VAL_DIR = os.path.join(self.DATASET_FOLDER_PATH, 'Validation_Set_Folder')
        self.BASE_TEST_DIR = os.path.join(self.DATASET_FOLDER_PATH, 'Test_Set_Folder')

        # Filtered Dataset Paths
        self.BASE_FILTERED_DIR = './filtered_data'
        self.FILTERED_TRAIN_DIR = os.path.join(self.BASE_FILTERED_DIR, 'train')
        self.FILTERED_VAL_DIR = os.path.join(self.BASE_FILTERED_DIR, 'val')
        self.FILTERED_TEST_DIR = os.path.join(self.BASE_FILTERED_DIR, 'test')

        # Hyperparameters
        self.IMG_SIZE = (224, 224)
        self.BATCH_SIZE = 32
        self.LEARNING_RATE = 0.01
        self.EPOCHS = 5
        self.MAX_IMAGES_PER_CLASS_TRAIN = 1000
        self.SHUFFLE_TRAINING = True

        # Augmentation Toggle
        self.IMG_AUGMENT = False

        # Experiment Names & Saving
        self.EXPERIMENT_NAME = "with_augmented_images" if self.IMG_AUGMENT else "baseline_no_augmented_images"
        self.MODEL_SAVE_PATH = f'./resnet50_{self.EXPERIMENT_NAME}.pth'
        self.HISTORY_SAVE_PATH = f'./history_{self.EXPERIMENT_NAME}.json'
        self.CLASS_MAP_PATH = './class_mapping.json'

        # Classes
        self.CLASSES_TO_USE = [
            "aloevera", "banana", "bilimbi", "cantaloupe", "cassava", "coconut",
            "corn", "cucumber", "curcuma", "eggplant", "galangal", "ginger",
            "guava", "kale", "longbeans", "mango", "melon", "orange", "paddy",
            "papaya", "peper chili", "pineapple", "pomelo", "shallot", "soybeans",
            "spinach", "sweet potatoes", "tobacco", "waterapple", "watermelon"
        ]
        self.NUM_CLASSES = len(self.CLASSES_TO_USE)

    def _get_device(self):
        """Detects if CUDA (Nvidia) or ROCm (AMD) is available."""
        if torch.cuda.is_available():
            device_name = torch.cuda.get_device_name(0)
            if hasattr(torch.version, 'hip') and torch.version.hip:
                print(f"✅ ROCm (AMD) detected: {device_name}")
            else:
                print(f"✅ CUDA (NVIDIA) detected: {device_name}")
            return torch.device("cuda")
        else:
            print("⚠️ GPU not found. Using CPU.")
            return torch.device("cpu")

    def print_summary(self):
        print(f"Using device: {self.DEVICE}")
        print(f"Experiment: {self.EXPERIMENT_NAME}")
        print(f"Include augmented images: {self.IMG_AUGMENT}")
        print(f"Classes to train ({self.NUM_CLASSES}): {', '.join(self.CLASSES_TO_USE)}")

# ==========================================
# 2. Data Preparation Functions
# ==========================================
def setup_dataset(config):
    """
    Downloads dataset via Kaggle API.
    CRITICAL: Checks for kaggle.json AND sets environment variables
    BEFORE importing the kaggle library to prevent crashes.
    """
    cwd = os.getcwd()
    kaggle_json_path = os.path.join(cwd, "kaggle.json")
    is_colab = 'google.colab' in sys.modules

    # --- 1. PRE-IMPORT CONFIGURATION (Avoids OSError) ---
    print("Checking for 'kaggle.json'...")

    if os.path.exists(kaggle_json_path):
        print(f"✅ Found 'kaggle.json' in current directory.")

    elif is_colab:
        print("⚠️ 'kaggle.json' not found. Please upload it now:")
        from google.colab import files
        uploaded = files.upload()

        if "kaggle.json" in uploaded:
            print("✅ Upload successful.")
            # Fix permissions (required by Kaggle API)
            os.chmod(kaggle_json_path, 0o600)
        else:
            print("❌ Upload failed or cancelled. Exiting setup.")
            return
    else:
        print(f"❌ 'kaggle.json' not found in {cwd}")
        print("   Please place the file here and run again.")
        return

    # Set the environment variable so Kaggle knows where to look
    os.environ['KAGGLE_CONFIG_DIR'] = cwd

    # --- 2. Install/Import Kaggle ---
    # Now that env var is set, it is safe to import/install
    try:
        import kaggle
    except ImportError:
        print("Installing Kaggle API...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "kaggle"])
        import kaggle
    except OSError as e:
        print(f"⚠️ Warning: Kaggle import had an issue: {e}")
        print("Attempting to proceed with CLI subprocess...")

    # --- 3. Download Dataset ---
    if not os.path.exists(config.DATASET_FOLDER_PATH):
        print(f"Dataset folder '{config.DATASET_FOLDER_PATH}' not found. Downloading...")
        try:
            # Using subprocess is often more robust than the python API for simple downloads
            subprocess.run(["kaggle", "datasets", "download", "-d", "yudhaislamisulistya/plants-type-datasets"], check=True)

            print("Unzipping dataset...")
            import zipfile
            with zipfile.ZipFile("plants-type-datasets.zip", 'r') as zip_ref:
                zip_ref.extractall(".")
            print("Dataset downloaded and unzipped successfully.")
        except subprocess.CalledProcessError as e:
            print(f"❌ Error downloading dataset: {e}")
            print("Check your API token and internet connection.")
    else:
        print("Dataset folder already exists. Skipping download.")

def create_filtered_subsets(config):
    """Creates a filtered subset of the data based on configuration."""
    if os.path.exists(config.BASE_FILTERED_DIR):
        print(f"Removing existing filtered data directory: {config.BASE_FILTERED_DIR}")
        shutil.rmtree(config.BASE_FILTERED_DIR)

    class_to_idx = {name: i for i, name in enumerate(config.CLASSES_TO_USE)}

    # Helper to process directories
    def process_split(src_base, dst_base, is_train=False):
        if not os.path.exists(src_base):
             print(f"⚠️ Warning: Source directory {src_base} does not exist. Skipping.")
             return

        print(f"Creating filtered set at: {dst_base}...")
        for class_name in config.CLASSES_TO_USE:
            src_dir = os.path.join(src_base, class_name)
            dst_dir = os.path.join(dst_base, class_name)
            os.makedirs(dst_dir, exist_ok=True)

            if not os.path.exists(src_dir):
                continue

            all_images = glob.glob(os.path.join(src_dir, '*.*'))

            # Filter based on augmentation setting
            if config.IMG_AUGMENT:
                # Take all images (augmented and original)
                imgs = all_images
            else:
                # Take only original images
                imgs = [img for img in all_images if not os.path.basename(img).startswith('aug_')]

            # Apply limit for training set
            if is_train:
                imgs = imgs[:config.MAX_IMAGES_PER_CLASS_TRAIN]

            for img_path in imgs:
                shutil.copy(img_path, dst_dir)

    process_split(config.BASE_TRAIN_DIR, config.FILTERED_TRAIN_DIR, is_train=True)
    process_split(config.BASE_VAL_DIR, config.FILTERED_VAL_DIR)
    process_split(config.BASE_TEST_DIR, config.FILTERED_TEST_DIR)

    with open(config.CLASS_MAP_PATH, 'w') as f:
        json.dump(class_to_idx, f)
    print(f"Class mapping saved to {config.CLASS_MAP_PATH}")

def get_dataloaders(config):
    """Creates DataLoaders for train, val, and test."""
    imgnet_mean = [0.485, 0.456, 0.406]
    imgnet_std = [0.229, 0.224, 0.225]

    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize(config.IMG_SIZE),
            transforms.ToTensor(),
            transforms.Normalize(mean=imgnet_mean, std=imgnet_std)
        ]),
        'val': transforms.Compose([
            transforms.Resize(config.IMG_SIZE),
            transforms.ToTensor(),
            transforms.Normalize(mean=imgnet_mean, std=imgnet_std)
        ]),
        'test': transforms.Compose([
            transforms.Resize(config.IMG_SIZE),
            transforms.ToTensor(),
            transforms.Normalize(mean=imgnet_mean, std=imgnet_std)
        ]),
    }

    # Ensure directories exist before creating ImageFolder
    if not os.path.exists(config.FILTERED_TRAIN_DIR):
        print("Filtered training directory missing. Cannot create dataloaders.")
        return None, None

    image_datasets = {
        'train': datasets.ImageFolder(config.FILTERED_TRAIN_DIR, data_transforms['train']),
        'val': datasets.ImageFolder(config.FILTERED_VAL_DIR, data_transforms['val']),
        'test': datasets.ImageFolder(config.FILTERED_TEST_DIR, data_transforms['test'])
    }

    # Pin memory helps with GPU transfer speed (CUDA/ROCm)
    use_pin_memory = (config.DEVICE.type == 'cuda')
    # Num workers: 2 is usually safe, set higher locally if you have more cores
    num_workers = 2 if os.cpu_count() is None else min(4, os.cpu_count())

    dataloaders = {
        'train': DataLoader(image_datasets['train'], batch_size=config.BATCH_SIZE, shuffle=config.SHUFFLE_TRAINING, pin_memory=use_pin_memory, num_workers=num_workers),
        'val': DataLoader(image_datasets['val'], batch_size=config.BATCH_SIZE, shuffle=False, pin_memory=use_pin_memory, num_workers=num_workers),
        'test': DataLoader(image_datasets['test'], batch_size=config.BATCH_SIZE, shuffle=False, pin_memory=use_pin_memory, num_workers=num_workers)
    }

    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}

    print(f"Data Loaded: Train({dataset_sizes['train']}), Val({dataset_sizes['val']}), Test({dataset_sizes['test']})")
    return dataloaders, dataset_sizes

# ==========================================
# 3. Model & Training Functions
# ==========================================
def build_model(config):
    """Initializes ResNet50 with frozen layers and custom head."""
    print("Building model...")
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

    for param in model.parameters():
        param.requires_grad = False

    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, config.NUM_CLASSES)
    model = model.to(config.DEVICE)
    return model

def train_model(config, model, dataloaders, dataset_sizes):
    """Runs the training loop."""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.fc.parameters(), lr=config.LEARNING_RATE)

    print("Starting training...")
    start_time = time.time()

    best_model_wts = model.state_dict()
    best_val_acc = 0.0

    history = {"loss": [], "accuracy": [], "val_loss": [], "val_accuracy": []}

    for epoch in range(config.EPOCHS):
        print(f"\n--- Epoch {epoch+1}/{config.EPOCHS} ---")

        # --- Training Phase ---
        model.train()
        running_loss = 0.0
        running_corrects = 0

        progress_bar = tqdm(dataloaders['train'], desc="[Train]")
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            progress_bar.set_postfix(batch_loss=f"{loss.item():.4f}")

        epoch_loss = running_loss / dataset_sizes['train']
        epoch_acc = running_corrects.double() / dataset_sizes['train']

        # --- Validation Phase ---
        model.eval()
        val_running_loss = 0.0
        val_running_corrects = 0

        with torch.no_grad():
            progress_bar_val = tqdm(dataloaders['val'], desc="[Validate]")
            for inputs, labels in progress_bar_val:
                inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)

                val_running_loss += loss.item() * inputs.size(0)
                val_running_corrects += torch.sum(preds == labels.data)

        val_loss = val_running_loss / dataset_sizes['val']
        val_acc = val_running_corrects.double() / dataset_sizes['val']

        print(f"  Train Loss: {epoch_loss:.4f} | Train Acc: {epoch_acc:.4f}")
        print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f}")

        history["loss"].append(epoch_loss)
        history["accuracy"].append(epoch_acc.item())
        history["val_loss"].append(val_loss)
        history["val_accuracy"].append(val_acc.item())

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_wts = model.state_dict()
            print(f"  -> New best model found!")

    time_elapsed = time.time() - start_time
    print(f"\nTraining complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")
    print(f"Best Val Acc: {best_val_acc:.4f}")

    model.load_state_dict(best_model_wts)
    return model, history, best_val_acc

# ==========================================
# 4. Evaluation & Visualization Functions
# ==========================================
def evaluate_on_test(config, model, dataloaders, dataset_sizes):
    """Evaluates the model on the test set."""
    print("--- Running Final Evaluation on Test Set ---")
    criterion = nn.CrossEntropyLoss()
    model.eval()

    test_loss = 0.0
    test_corrects = 0

    with torch.no_grad():
        for inputs, labels in tqdm(dataloaders['test'], desc="[Test]"):
            inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            _, preds = torch.max(outputs, 1)

            test_loss += loss.item() * inputs.size(0)
            test_corrects += torch.sum(preds == labels.data)

    final_loss = test_loss / dataset_sizes['test']
    final_acc = test_corrects.double() / dataset_sizes['test']

    print(f"\nTest Results -> Loss: {final_loss:.4f} | Acc: {final_acc:.4f}")
    return float(final_loss), float(final_acc)

def save_history(config, history, best_val_acc, test_loss, test_acc):
    """Saves training history to JSON."""
    history_data = {
        'loss': [float(x) for x in history['loss']],
        'accuracy': [float(x) for x in history['accuracy']],
        'val_loss': [float(x) for x in history['val_loss']],
        'val_accuracy': [float(x) for x in history['val_accuracy']],
        'experiment_name': config.EXPERIMENT_NAME,
        'img_augment': config.IMG_AUGMENT,
        'epochs': config.EPOCHS,
        'best_val_acc': float(best_val_acc),
        'test_loss': test_loss,
        'test_accuracy': test_acc
    }

    with open(config.HISTORY_SAVE_PATH, 'w') as f:
        json.dump(history_data, f, indent=2)
    print(f"History saved to {config.HISTORY_SAVE_PATH}")

def plot_training_curves(history):
    """Plots accuracy and loss curves."""
    epochs = range(1, len(history['accuracy']) + 1)
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, history['accuracy'], label='Train')
    plt.plot(epochs, history['val_accuracy'], label='Val')
    plt.title('Accuracy')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, history['loss'], label='Train')
    plt.plot(epochs, history['val_loss'], label='Val')
    plt.title('Loss')
    plt.legend()

    plt.show()

def plot_confusion_matrix_vis(config, model, dataloaders):
    """Generates and plots the confusion matrix."""
    model.eval()
    y_true_all = []
    y_pred_all = []

    print("Generating Confusion Matrix...")
    with torch.no_grad():
        for inputs, labels in dataloaders['test']:
            inputs = inputs.to(config.DEVICE)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            y_true_all.extend(labels.cpu().numpy())
            y_pred_all.extend(preds.cpu().numpy())

    cm = confusion_matrix(y_true_all, y_pred_all)

    # Normalize
    with np.errstate(all='ignore'):
        cm_norm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
        cm_norm = np.nan_to_num(cm_norm)

    # Plot
    fig, ax = plt.subplots(figsize=(14, 12))
    im = ax.imshow(cm_norm, interpolation='nearest', cmap=plt.cm.Blues)
    ax.figure.colorbar(im, ax=ax)

    ax.set(xticks=np.arange(config.NUM_CLASSES),
           yticks=np.arange(config.NUM_CLASSES),
           xticklabels=config.CLASSES_TO_USE,
           yticklabels=config.CLASSES_TO_USE,
           title='Normalized Confusion Matrix',
           ylabel='True label',
           xlabel='Predicted label')

    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    plt.tight_layout()
    plt.show()

def compare_experiments():
    """Compares baseline and augmented history files if both exist."""
    base_path = './history_baseline_no_augmented_images.json'
    aug_path = './history_with_augmented_images.json'

    if os.path.exists(base_path) and os.path.exists(aug_path):
        print("\n--- Comparative Analysis ---")
        try:
            with open(base_path) as f: h_base = json.load(f)
            with open(aug_path) as f: h_aug = json.load(f)

            print(f"Baseline Test Acc:  {h_base.get('test_accuracy', 0):.4f}")
            print(f"Augmented Test Acc: {h_aug.get('test_accuracy', 0):.4f}")

            plt.figure(figsize=(6, 4))
            plt.bar(['Baseline', 'Augmented'],
                    [h_base.get('test_accuracy', 0), h_aug.get('test_accuracy', 0)],
                    color=['blue', 'orange'])
            plt.title('Test Accuracy Comparison')
            plt.show()
        except Exception as e:
            print(f"Could not load comparisons: {e}")
    else:
        print("\nSkipping comparison (run both experiment types to enable).")

# ==========================================
# 5. Main Execution
# ==========================================
def main():
    # 1. Initialize Configuration
    cfg = Config()
    cfg.print_summary()

    # 2. Setup Data
    setup_dataset(cfg)
    create_filtered_subsets(cfg)
    dataloaders, dataset_sizes = get_dataloaders(cfg)

    if dataloaders is None:
        print("❌ Error: Data not available. Exiting.")
        return

    # 3. Initialize Model
    model = build_model(cfg)

    # 4. Train
    model, history, best_val_acc = train_model(cfg, model, dataloaders, dataset_sizes)

    # 5. Save Model
    torch.save(model.state_dict(), cfg.MODEL_SAVE_PATH)
    print(f"Best model saved to {cfg.MODEL_SAVE_PATH}")

    # 6. Final Evaluation
    test_loss, test_acc = evaluate_on_test(cfg, model, dataloaders, dataset_sizes)

    # 7. Save History
    save_history(cfg, history, best_val_acc, test_loss, test_acc)

    # 8. Visualization
    plot_training_curves(history)
    plot_confusion_matrix_vis(cfg, model, dataloaders)

    # 9. Comparison (Optional)
    compare_experiments()

# if __name__ == "__main__":
    # main()

### Initialize Configuration

This step sets up all the parameters for the experiment, including dataset paths, hyperparameters, and experiment names. It then prints a summary of the configuration.

In [None]:
cfg = Config()
cfg.print_summary()

### Setup Data

This sequence of steps downloads the dataset (if not present), creates filtered subsets based on the configuration (e.g., with or without augmented images, and a limited number of training images per class), and then prepares the data loaders for training, validation, and testing.

In [None]:
setup_dataset(cfg)
create_filtered_subsets(cfg)
dataloaders, dataset_sizes = get_dataloaders(cfg)

if dataloaders is None:
    print("❌ Error: Data not available. Exiting.")
else:
    print("Data setup complete.")

### Initialize Model

This step builds the ResNet50 model, freezes its convolutional layers, and adds a new classification head suitable for the number of classes in the dataset. The model is then moved to the specified device (GPU/CPU).

In [None]:
model = build_model(cfg)

### Train Model

This initiates the training loop for the model using the defined criterion (loss function) and optimizer. It tracks training and validation loss/accuracy across epochs and saves the best model weights based on validation accuracy.

In [None]:
model, history, best_val_acc = train_model(cfg, model, dataloaders, dataset_sizes)

### Save Model

The weights of the best performing model (based on validation accuracy) are saved to a file for later use.

In [None]:
torch.save(model.state_dict(), cfg.MODEL_SAVE_PATH)
print(f"Best model saved to {cfg.MODEL_SAVE_PATH}")

### Final Evaluation

The trained model is evaluated on the unseen test set to determine its generalization performance, providing final loss and accuracy metrics.

In [None]:
test_loss, test_acc = evaluate_on_test(cfg, model, dataloaders, dataset_sizes)

### Save History

All training and evaluation metrics, including losses, accuracies, and configuration details, are saved to a JSON file for record-keeping and future analysis.

In [None]:
save_history(cfg, history, best_val_acc, test_loss, test_acc)

### Visualization

This step generates and displays plots for the training and validation accuracy/loss curves, and a normalized confusion matrix to visualize model performance per class.

In [None]:
plot_training_curves(history)
plot_confusion_matrix_vis(cfg, model, dataloaders)

### Compare Experiments (Optional)

If multiple experiment histories are saved (e.g., baseline vs. augmented), this step loads them and provides a comparative analysis of their test accuracies.

In [None]:
compare_experiments()