<a href="https://colab.research.google.com/github/garlicxd/Fruit-Classification/blob/automation/Fruit_Classification_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Todo
- [ ] patience for training - stopping after no improvement for several epochs
- [ ] create graphs - store information while training
- [ ] compare different optimizers

In [None]:
#@title Function Declarations
import os
import glob
import shutil
import json
import time
import subprocess
import sys
import itertools
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm.auto import tqdm
from sklearn.metrics import confusion_matrix

# ==========================================
# 1. Configuration Class
# ==========================================
class Config:
    def __init__(self, defaults, params=None, base_experiment_dir=None):
        """
        Initializes Config.
        Args:
            defaults (dict): Dictionary where Key -> (Value, Suffix_String).
                             Suffix_String is used for directory naming.
                             If Suffix_String is "", it is NOT added to naming_map.
            params (dict): Dictionary where Key -> Value (Overrides defaults).
            base_experiment_dir (str): Parent directory for output.
        """
        self.config_values = {}
        self.naming_map = {}

        # 1. Parse Defaults (Separating Values and Naming Suffixes)
        for key, item in defaults.items():
            # Assume item is a tuple of length 2: (Value, "Suffix")
            val, suffix = item
            self.config_values[key] = val

            # Only add to naming_map if suffix is not empty
            if suffix:
                self.naming_map[key] = suffix

        # 2. Update with specific Params (Value overrides)
        if params:
            self.config_values.update(params)

        # 3. Set attributes dynamically
        for key, val in self.config_values.items():
            setattr(self, key, val)

        # 4. Handle Derived Attributes
        if hasattr(self, 'CLASSES_TO_USE'):
            self.NUM_CLASSES = len(self.CLASSES_TO_USE)

        if hasattr(self, 'DEVICE') and isinstance(self.DEVICE, str):
            self.DEVICE = torch.device(self.DEVICE)

        # 5. Automate DIR_NAME creation
        name_parts = []
        # Iterate only through keys present in naming_map (sorted for consistency)
        for key in sorted(self.naming_map.keys()):
            suffix = self.naming_map[key]
            val = self.config_values[key]
            name_parts.append(f"{val}{suffix}")

        self.DIR_NAME = "_".join(name_parts) if name_parts else "default_run"

        # 6. Setup Output Directory
        self.base_experiment_dir = base_experiment_dir if base_experiment_dir else "."
        self.OUTPUT_DIR = os.path.join(self.base_experiment_dir, self.DIR_NAME)

        # Derived Paths for Data
        self.BASE_TRAIN_DIR = os.path.join(self.DATASET_FOLDER_PATH, 'Train_Set_Folder')
        self.BASE_VAL_DIR = os.path.join(self.DATASET_FOLDER_PATH, 'Validation_Set_Folder')
        self.BASE_TEST_DIR = os.path.join(self.DATASET_FOLDER_PATH, 'Test_Set_Folder')

        self.FILTERED_TRAIN_DIR = os.path.join(self.BASE_FILTERED_DIR, 'train')
        self.FILTERED_VAL_DIR = os.path.join(self.BASE_FILTERED_DIR, 'val')
        self.FILTERED_TEST_DIR = os.path.join(self.BASE_FILTERED_DIR, 'test')

    def print_summary(self):
        print(f"\n--- Configuration Summary: {self.DIR_NAME} ---")
        print(f"Output Directory: {self.OUTPUT_DIR}")
        print(f"Device: {self.DEVICE}")
        # Print only keys that determine the experiment name
        for k in self.naming_map:
            print(f"{k}: {getattr(self, k)}")
        print("------------------------------------------\n")

    def save_config(self):
        os.makedirs(self.OUTPUT_DIR, exist_ok=True)
        # Filter for uppercase keys (standard config params)
        config_dict = {k: str(v) for k, v in self.__dict__.items() if k.isupper()}
        path = os.path.join(self.OUTPUT_DIR, 'config.json')
        with open(path, 'w') as f:
            json.dump(config_dict, f, indent=4)

# ==========================================
# 2. Data Preparation
# ==========================================
def setup_dataset(config):
    if not os.path.exists(config.DATASET_FOLDER_PATH):
        print("Dataset not found. Downloading...")
        try:
            subprocess.run(["kaggle", "datasets", "download", "-d", "yudhaislamisulistya/plants-type-datasets"], check=True)
            import zipfile
            with zipfile.ZipFile("plants-type-datasets.zip", 'r') as zip_ref:
                zip_ref.extractall(".")
            print("Dataset downloaded.")
        except Exception as e:
            print(f"Error downloading dataset: {e}")

def create_filtered_subsets(config):
    if os.path.exists(config.BASE_FILTERED_DIR):
        shutil.rmtree(config.BASE_FILTERED_DIR)

    def process_split(src_base, dst_base, is_train=False):
        if not os.path.exists(src_base): return
        os.makedirs(dst_base, exist_ok=True)

        for class_name in config.CLASSES_TO_USE:
            src_dir = os.path.join(src_base, class_name)
            dst_dir = os.path.join(dst_base, class_name)
            os.makedirs(dst_dir, exist_ok=True)

            if not os.path.exists(src_dir): continue

            all_images = glob.glob(os.path.join(src_dir, '*.*'))

            if config.IMG_AUGMENT:
                imgs = all_images
            else:
                imgs = [img for img in all_images if not os.path.basename(img).startswith('aug_')]

            if is_train:
                imgs = imgs[:config.MAX_IMAGES_PER_CLASS_TRAIN]

            for img_path in imgs:
                shutil.copy(img_path, dst_dir)

    print("Creating filtered dataset...")
    process_split(config.BASE_TRAIN_DIR, config.FILTERED_TRAIN_DIR, is_train=True)
    process_split(config.BASE_VAL_DIR, config.FILTERED_VAL_DIR)
    process_split(config.BASE_TEST_DIR, config.FILTERED_TEST_DIR)

def get_dataloaders(config):
    imgnet_mean, imgnet_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

    transforms_common = transforms.Compose([
        transforms.Resize(config.IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(imgnet_mean, imgnet_std)
    ])

    data_transforms = {x: transforms_common for x in ['train', 'val', 'test']}

    image_datasets = {
        x: datasets.ImageFolder(getattr(config, f'FILTERED_{x.upper()}_DIR'), data_transforms[x])
        for x in ['train', 'val', 'test']
    }

    use_pin = (config.DEVICE.type == 'cuda')

    dataloaders = {
        x: DataLoader(image_datasets[x], batch_size=config.BATCH_SIZE,
                      shuffle=(x=='train' and config.SHUFFLE_TRAINING),
                      pin_memory=use_pin, num_workers=2)
        for x in ['train', 'val', 'test']
    }

    sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}
    return dataloaders, sizes

# ==========================================
# 3. Model & Training
# ==========================================
def build_model(config):
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
    for param in model.parameters():
        param.requires_grad = False

    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, config.NUM_CLASSES)
    model = model.to(config.DEVICE)
    return model

def train_model(config, model, dataloaders, dataset_sizes):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.fc.parameters(), lr=config.LEARNING_RATE)

    history = {"loss": [], "accuracy": [], "val_loss": [], "val_accuracy": []}
    best_acc = 0.0
    best_weights = model.state_dict()

    print(f"Starting training for {config.EPOCHS} epochs...")
    for epoch in range(config.EPOCHS):
        model.train()
        running_loss, running_corrects = 0.0, 0

        for inputs, labels in dataloaders['train']:
            inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(torch.max(outputs, 1)[1] == labels.data)

        epoch_loss = running_loss / dataset_sizes['train']
        epoch_acc = running_corrects.double() / dataset_sizes['train']

        model.eval()
        val_loss_running, val_corrects = 0.0, 0
        with torch.no_grad():
            for inputs, labels in dataloaders['val']:
                inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss_running += loss.item() * inputs.size(0)
                val_corrects += torch.sum(torch.max(outputs, 1)[1] == labels.data)

        val_loss = val_loss_running / dataset_sizes['val']
        val_acc = val_corrects.double() / dataset_sizes['val']

        history['loss'].append(epoch_loss)
        history['accuracy'].append(epoch_acc.item())
        history['val_loss'].append(val_loss)
        history['val_accuracy'].append(val_acc.item())

        print(f"Ep {epoch+1}: Train Loss {epoch_loss:.4f} Acc {epoch_acc:.4f} | Val Loss {val_loss:.4f} Acc {val_acc:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            best_weights = model.state_dict()

    model.load_state_dict(best_weights)
    return model, history

def evaluate_and_save(config, model, dataloaders, dataset_sizes, history):
    model.eval()
    test_corrects = 0
    y_true, y_pred = [], []

    with torch.no_grad():
        for inputs, labels in dataloaders['test']:
            inputs, labels = inputs.to(config.DEVICE), labels.to(config.DEVICE)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            test_corrects += torch.sum(preds == labels.data)
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    test_acc = test_corrects.double() / dataset_sizes['test']
    print(f"Test Accuracy: {test_acc:.4f}")

    os.makedirs(config.OUTPUT_DIR, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(config.OUTPUT_DIR, 'model.pth'))

    with open(os.path.join(config.OUTPUT_DIR, 'history.json'), 'w') as f:
        json.dump(history, f)

    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history['accuracy'], label='Train')
    plt.plot(history['val_accuracy'], label='Val')
    plt.title('Accuracy')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(history['loss'], label='Train')
    plt.plot(history['val_loss'], label='Val')
    plt.title('Loss')
    plt.legend()
    plt.savefig(os.path.join(config.OUTPUT_DIR, 'training_curves.png'))
    plt.close()

    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 10))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.colorbar()
    tick_marks = np.arange(config.NUM_CLASSES)
    plt.xticks(tick_marks, config.CLASSES_TO_USE, rotation=90)
    plt.yticks(tick_marks, config.CLASSES_TO_USE)
    plt.tight_layout()
    plt.savefig(os.path.join(config.OUTPUT_DIR, 'confusion_matrix.png'))
    plt.close()

# ==========================================
# 4. Orchestrators
# ==========================================
def get_next_experiment_dir(base_path="."):
    existing_dirs = glob.glob(os.path.join(base_path, "experiment_*"))
    max_num = 0
    for d in existing_dirs:
        try:
            num = int(os.path.basename(d).split('_')[1])
            if num > max_num: max_num = num
        except (IndexError, ValueError):
            pass
    return os.path.join(base_path, f"experiment_{max_num + 1}")

def run_with_config(defaults, params=None, base_experiment_dir=None):
    cfg = Config(defaults=defaults, params=params, base_experiment_dir=base_experiment_dir)
    cfg.print_summary()
    cfg.save_config()

    setup_dataset(cfg)
    create_filtered_subsets(cfg)
    dataloaders, sizes = get_dataloaders(cfg)

    model = build_model(cfg)
    model, history = train_model(cfg, model, dataloaders, sizes)

    evaluate_and_save(cfg, model, dataloaders, sizes, history)

def run_all(defaults, grid, base_experiment_dir=None):
    keys, values = zip(*grid.items())
    combinations = [dict(zip(keys, v)) for v in itertools.product(*values)]

    print(f"Total experiments to run: {len(combinations)}")

    if base_experiment_dir is None:
        base_experiment_dir = get_next_experiment_dir()

    os.makedirs(base_experiment_dir, exist_ok=True)
    print(f"All runs will be saved to: {base_experiment_dir}")

    for i, params in enumerate(combinations):
        print(f"\n>>> Running Experiment {i+1}/{len(combinations)}")
        run_with_config(defaults=defaults, params=params, base_experiment_dir=base_experiment_dir)
        torch.cuda.empty_cache()

In [None]:
#@title Default Config - set folder naming
# Format: "KEY": (Value, "suffix")
# If suffix is "", it is excluded from the naming map.
default_config_tuples = {
    "DEVICE": ("cuda" if torch.cuda.is_available() else "cpu", ""),
    "DATASET_FOLDER_PATH": ("./split_ttv_dataset_type_of_plants", ""),
    "BASE_FILTERED_DIR": ("./filtered_data", ""),
    "IMG_SIZE": ((224, 224), ""),
    "BATCH_SIZE": (32, "bs"),          # Included in dir name: "32bs"
    "LEARNING_RATE": (0.01, "lr"),     # Included in dir name: "0.01lr"
    "EPOCHS": (5, ""),
    "MAX_IMAGES_PER_CLASS_TRAIN": (1000, ""),
    "SHUFFLE_TRAINING": (True, ""),
    "IMG_AUGMENT": (False, "aug"),     # Included in dir name: "Falseaug"
    "CLASSES_TO_USE": ([
        "aloevera", "banana", "bilimbi", "cantaloupe", "cassava", "coconut",
        "corn", "cucumber", "curcuma", "eggplant", "galangal", "ginger",
        "guava", "kale", "longbeans", "mango", "melon", "orange", "paddy",
        "papaya", "peper chili", "pineapple", "pomelo", "shallot", "soybeans",
        "spinach", "sweet potatoes", "tobacco", "waterapple", "watermelon"
    ], "")
}

In [None]:
#@title Solo Run Default Config
run_with_config(
    defaults=default_config_tuples,
    params=None,
    base_experiment_dir="solo_run"
)

In [None]:
#@title Automated Run - add the value options to the grid
grid = {
    "LEARNING_RATE": [0.01, 0.005],
    "BATCH_SIZE": [16, 32],
    "IMG_AUGMENT": [True, False]
}

# 3. Run Automation
run_all(defaults=default_config_tuples, grid=grid)