<a href="https://colab.research.google.com/github/garlicxd/Fruit-Classification/blob/main/Fruit_Classification_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Todo
- [ ] patience for training - stopping after no improvement for several epochs
- [ ] create graphs - store information while training
- [ ] compare different optimizers

In [None]:
# @title Download Dataset

# Define the path for the target dataset folder
dataset_folder_path = "./split_ttv_dataset_type_of_plants"

!pip install kaggle
from google.colab import files
import os
kaggle_dir = os.path.expanduser("~/.kaggle")
kaggle_json_path = os.path.join(kaggle_dir, "kaggle.json")
if not os.path.exists(kaggle_json_path):
    print("kaggle.json not found. Please upload your kaggle.json file:")
    uploaded = files.upload()
    if "kaggle.json" in uploaded:
        !mkdir -p ~/.kaggle
        !cp kaggle.json ~/.kaggle/
        !chmod 600 ~/.kaggle/kaggle.json
        print("Kaggle API key configured successfully!")
    else:
        print("Upload failed or kaggle.json not found in upload.")
else:
    print("Kaggle API key already configured.")

if not os.path.exists(dataset_folder_path):
    print(f"Dataset folder '{dataset_folder_path}' not found. Downloading and unzipping...")
    !kaggle datasets download -d yudhaislamisulistya/plants-type-datasets
    !unzip -q plants-type-datasets.zip
    print("Dataset downloaded and unzipped successfully.")
else:
    print(f"Dataset folder '{dataset_folder_path}' already exists. Skipping download and unzip.")

print("\nClasses in training directory:")
!ls ./split_ttv_dataset_type_of_plants/Train_Set_Folder/

In [None]:
# @title Parameters and Includes

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import os
import glob
import shutil
import json
import time
from tqdm.notebook import tqdm
import numpy as np

BASE_DIR = './split_ttv_dataset_type_of_plants'
BASE_TRAIN_DIR = os.path.join(BASE_DIR, 'Train_Set_Folder')
BASE_VAL_DIR = os.path.join(BASE_DIR, 'Validation_Set_Folder')
BASE_TEST_DIR = os.path.join(BASE_DIR, 'Test_Set_Folder')
SHUFFLE_TRAINING = True

CLASSES_TO_USE = ["aloevera", "banana", "bilimbi", "cantaloupe", "cassava", "coconut", "corn", "cucumber", "curcuma", "eggplant", "galangal", "ginger", "guava", "kale", "longbeans", "mango", "melon", "orange", "paddy", "papaya", "peper chili", "pineapple", "pomelo", "shallot", "soybeans", "spinach", "sweet potatoes", "tobacco", "waterapple", "watermelon"]

MAX_IMAGES_PER_CLASS_TRAIN = 1000

IMG_SIZE = (224, 224)
BATCH_SIZE = 32
LEARNING_RATE = 0.01
EPOCHS = 5
MODEL_SAVE_PATH = './barebones_resnet50.pth'
CLASS_MAP_PATH = './class_mapping.json'

# TOGGLE THIS BETWEEN RUNS
IMG_AUGMENT = True

# Experiment tracking based on IMG_AUGMENT
EXPERIMENT_NAME = "with_augmented_images" if IMG_AUGMENT else "baseline_no_augmented_images"
MODEL_SAVE_PATH = f'./resnet50_{EXPERIMENT_NAME}.pth'
CLASS_MAP_PATH = './class_mapping.json'
HISTORY_SAVE_PATH = f'./history_{EXPERIMENT_NAME}.json'

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
print(f"Classes to train: {', '.join(CLASSES_TO_USE)}")
print(f"Max training images per class: {MAX_IMAGES_PER_CLASS_TRAIN}")
print(f"Include augmented images: {IMG_AUGMENT}")

In [None]:
# @title Filtered Set Creation

base_filtered_dir = './filtered_data'
filtered_train_dir = os.path.join(base_filtered_dir, 'train')
filtered_val_dir = os.path.join(base_filtered_dir, 'val')
filtered_test_dir = os.path.join(base_filtered_dir, 'test')

if os.path.exists(base_filtered_dir):
    print(f"Removing existing filtered data directory: {base_filtered_dir}")
    shutil.rmtree(base_filtered_dir)

class_to_idx = {}
idx_to_class = []

print("Creating filtered training set...")
for i, class_name in enumerate(CLASSES_TO_USE):
    class_to_idx[class_name] = i
    idx_to_class.append(class_name)

    source_dir = os.path.join(BASE_TRAIN_DIR, class_name)
    dest_dir = os.path.join(filtered_train_dir, class_name)
    os.makedirs(dest_dir, exist_ok=True)

    all_images = glob.glob(os.path.join(source_dir, '*.*'))
    if IMG_AUGMENT:
        images_to_copy = [img for img in all_images if os.path.basename(img).startswith('aug_')][:MAX_IMAGES_PER_CLASS_TRAIN]
    else:
        images_to_copy = [img for img in all_images if not os.path.basename(img).startswith('aug_')][:MAX_IMAGES_PER_CLASS_TRAIN]


    for img_path in images_to_copy:
        shutil.copy(img_path, dest_dir)
print(f"Filtered training set created at: {filtered_train_dir}")

print("Creating filtered validation set...")
for class_name in CLASSES_TO_USE:
    source_dir = os.path.join(BASE_VAL_DIR, class_name)
    dest_dir = os.path.join(filtered_val_dir, class_name)
    os.makedirs(dest_dir, exist_ok=True)

    all_images = glob.glob(os.path.join(source_dir, '*.*'))
    if IMG_AUGMENT:
        images_to_copy = [img for img in all_images if os.path.basename(img).startswith('aug_')]
    else:
        images_to_copy = [img for img in all_images if not os.path.basename(img).startswith('aug_')]

    for img_path in images_to_copy:
        shutil.copy(img_path, dest_dir)
print(f"Filtered validation set created at: {filtered_val_dir}")

print("Creating filtered test set...")
for class_name in CLASSES_TO_USE:
    source_dir = os.path.join(BASE_TEST_DIR, class_name)
    dest_dir = os.path.join(filtered_test_dir, class_name)
    os.makedirs(dest_dir, exist_ok=True)

    all_images = glob.glob(os.path.join(source_dir, '*.*'))
    if IMG_AUGMENT:
        images_to_copy = [img for img in all_images if os.path.basename(img).startswith('aug_')]
    else:
        images_to_copy = [img for img in all_images if not os.path.basename(img).startswith('aug_')]

    for img_path in images_to_copy:
        shutil.copy(img_path, dest_dir)
print(f"Filtered test set created at: {filtered_test_dir}")

with open(CLASS_MAP_PATH, 'w') as f:
    json.dump(class_to_idx, f)
print(f"Class mapping saved to {CLASS_MAP_PATH}")

NUM_CLASSES = len(CLASSES_TO_USE)

In [None]:
# @title Define Transforms and DataLoaders

imgnet_mean = [0.485, 0.456, 0.406]
imgnet_std = [0.229, 0.224, 0.225]

data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=imgnet_mean, std=imgnet_std)
    ]),
    'val': transforms.Compose([
        transforms.Resize(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=imgnet_mean, std=imgnet_std)
    ]),
    'test': transforms.Compose([
        transforms.Resize(IMG_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(mean=imgnet_mean, std=imgnet_std)
    ]),
}

# Create datasets
image_datasets = {
    'train': datasets.ImageFolder(filtered_train_dir, data_transforms['train']),
    'val': datasets.ImageFolder(filtered_val_dir, data_transforms['val']),
    'test': datasets.ImageFolder(filtered_test_dir, data_transforms['test'])
}

# Create dataloaders
dataloaders = {
    'train': DataLoader(image_datasets['train'], batch_size=BATCH_SIZE, shuffle=SHUFFLE_TRAINING),
    'val': DataLoader(image_datasets['val'], batch_size=BATCH_SIZE, shuffle=False),
    'test': DataLoader(image_datasets['test'], batch_size=BATCH_SIZE, shuffle=False)
}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}
print(f"Total training images: {dataset_sizes['train']}")
print(f"Total validation images: {dataset_sizes['val']}")
print(f"Total test images: {dataset_sizes['test']}")

In [None]:
# @title Define the Model (ResNet50) and Freeze Layers

model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)

for param in model.parameters():
    param.requires_grad = False #freezes

num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, NUM_CLASSES)
model = model.to(DEVICE)

print("Model definition complete. New final layer:")
print(model.fc)

In [None]:
# @title Define Loss and Optimizer

criterion = nn.CrossEntropyLoss()

# Should use Adam for best results but check SGD parameters first
optimizer = optim.SGD(model.fc.parameters(), lr=LEARNING_RATE)
print(f"Using optimizer: SGD with LR={LEARNING_RATE}")

In [None]:
# @title Training Loop (with Validation)

print("Starting training...") # Indicate the start of the training process
start_time = time.time() # Record the start time of training

# To store best model weights
best_model_wts = model.state_dict() # Initialize with the current model's state dictionary
best_val_acc = 0.0 # Initialize best validation accuracy to 0

# Dictionary to store training history for plotting
history = {
    "loss": [], # List to store training loss per epoch
    "accuracy": [], # List to store training accuracy per epoch
    "val_loss": [], # List to store validation loss per epoch
    "val_accuracy": [] # List to store validation accuracy per epoch
}

# Loop through the specified number of epochs
for epoch in range(EPOCHS):
    print(f"\n--- Epoch {epoch+1}/{EPOCHS} ---") # Print current epoch number

    # Training Phase
    model.train() # Set the model to training mode (enables gradients and dropout)

    running_loss = 0.0 # Initialize running loss for the epoch
    running_corrects = 0 # Initialize running correct predictions for the epoch

    # Create a progress bar for the training dataloader
    progress_bar = tqdm(dataloaders['train'], desc="[Train]")

    # Iterate over data in the training dataloader
    for inputs, labels in progress_bar:
        inputs = inputs.to(DEVICE) # Move inputs to the specified device (CPU/GPU)
        labels = labels.to(DEVICE) # Move labels to the specified device

        optimizer.zero_grad() # Zero the gradients of the optimizer

        outputs = model(inputs) # Forward pass: compute model outputs
        loss = criterion(outputs, labels) # Calculate the loss
        _, preds = torch.max(outputs, 1) # Get the predicted class (index of max logit)

        loss.backward() # Backpropagation: compute gradients
        optimizer.step() # Update model parameters

        batch_loss = loss.item() # Get the scalar value of the loss for the current batch
        running_loss += batch_loss * inputs.size(0) # Accumulate batch loss weighted by batch size
        running_corrects += torch.sum(preds == labels.data) # Accumulate correct predictions

        # Update progress bar with current batch loss
        progress_bar.set_postfix(batch_loss=f"{batch_loss:.4f}")

    epoch_loss = running_loss / dataset_sizes['train'] # Calculate average training loss for the epoch
    epoch_acc = running_corrects.double() / dataset_sizes['train'] # Calculate training accuracy for the epoch


    # Validation Phase
    model.eval() # Set the model to evaluation mode (disables gradients and dropout)

    val_running_loss = 0.0 # Initialize running loss for the validation set
    val_running_corrects = 0 # Initialize running correct predictions for the validation set

    # Disable gradient calculations for current epoch validation
    # Increases running speed
    with torch.no_grad():
        # Create a progress bar for the validation dataloader
        progress_bar_val = tqdm(dataloaders['val'], desc="[Validate]")
        # Iterate over data in the validation dataloader
        for inputs, labels in progress_bar_val:
            inputs = inputs.to(DEVICE) # Move inputs to the specified device
            labels = labels.to(DEVICE) # Move labels to the specified device

            outputs = model(inputs) # Forward pass: compute model outputs
            loss = criterion(outputs, labels) # Calculate the loss
            _, preds = torch.max(outputs, 1) # Get the predicted class

            val_running_loss += loss.item() * inputs.size(0) # Accumulate batch validation loss
            val_running_corrects += torch.sum(preds == labels.data) # Accumulate correct validation predictions

    val_loss = val_running_loss / dataset_sizes['val'] # Calculate average validation loss for the epoch
    val_acc = val_running_corrects.double() / dataset_sizes['val'] # Calculate validation accuracy for the epoch

    # Status update per epoch
    print(f"Epoch Summary:") # Print summary header
    print(f"  Train Loss: {epoch_loss:.4f} | Train Acc: {epoch_acc:.4f}") # Print training loss and accuracy
    print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f}") # Print validation loss and accuracy

    # saving evaluation values for plotting graphs later
    history["loss"].append(epoch_loss) # Append training loss to history
    history["accuracy"].append(epoch_acc.item() if hasattr(epoch_acc, "item") else float(epoch_acc)) # Append training accuracy to history
    history["val_loss"].append(val_loss) # Append validation loss to history
    history["val_accuracy"].append(val_acc.item() if hasattr(val_acc, "item") else float(val_acc)) # Append validation accuracy to history


    # Save the model if it has the best validation accuracy so far
    if val_acc > best_val_acc:
        best_val_acc = val_acc # Update best validation accuracy
        best_model_wts = model.state_dict() # Save the model's state dictionary
        print(f"  -> New best model found! (Val Acc: {best_val_acc:.4f})") # Indicate new best model found

time_elapsed = time.time() - start_time # Calculate total training time
print(f"\nTraining complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s") # Print total training time
print(f"Best Val Acc: {best_val_acc:.4f}") # Print the best validation accuracy achieved

# Load best model weights
model.load_state_dict(best_model_wts) # Load the state dictionary of the best model

In [None]:
# @title Save Training History for Comparison

import json

# Save history to file for later comparison
history_to_save = {
    'loss': [float(x) for x in history['loss']],
    'accuracy': [float(x) for x in history['accuracy']],
    'val_loss': [float(x) for x in history['val_loss']],
    'val_accuracy': [float(x) for x in history['val_accuracy']],
    'experiment_name': EXPERIMENT_NAME,
    'img_augment': IMG_AUGMENT,
    'epochs': EPOCHS,
    'learning_rate': LEARNING_RATE,
    'batch_size': BATCH_SIZE,
    'best_val_acc': float(best_val_acc)
}

with open(HISTORY_SAVE_PATH, 'w') as f:
    json.dump(history_to_save, f, indent=2)

print(f"\nTraining history saved to {HISTORY_SAVE_PATH}")
print(f"  Experiment: {EXPERIMENT_NAME}")
print(f"  Best Val Acc: {best_val_acc:.4f}")

In [None]:
# @title Save Best Model

print(f"Saving best model state dictionary to {MODEL_SAVE_PATH}...")
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print("Best model saved successfully.")

In [None]:
# @title Load Model

print("--- Loading Model for Inference ---")

print(f"Loading class mapping from {CLASS_MAP_PATH}...")
with open(CLASS_MAP_PATH, 'r') as f:
    loaded_class_to_idx = json.load(f)
LOADED_NUM_CLASSES = len(loaded_class_to_idx)
print(f"Loaded {LOADED_NUM_CLASSES} classes.")

model_to_load = models.resnet50(weights=None)
num_ftrs = model_to_load.fc.in_features
model_to_load.fc = nn.Linear(num_ftrs, LOADED_NUM_CLASSES)

print(f"Loading model weights from {MODEL_SAVE_PATH}...")
model_to_load.load_state_dict(torch.load(MODEL_SAVE_PATH))

model_to_load = model_to_load.to(DEVICE)
model_to_load.eval()

print("Model loaded successfully and set to evaluation mode.")

In [None]:
# @title 3. Final Testing on Test Set

print("--- Running Final Evaluation on Test Set ---")

test_running_loss = 0.0
test_running_corrects = 0

# Get a reverse mapping from index to class name
loaded_idx_to_class = {v: k for k, v in loaded_class_to_idx.items()}

# Set model to evaluation mode and disable gradients
model_to_load.eval()
with torch.no_grad():

    progress_bar_test = tqdm(dataloaders['test'], desc="[Test]")

    for inputs, labels in progress_bar_test:
        inputs = inputs.to(DEVICE)
        labels = labels.to(DEVICE)

        outputs = model_to_load(inputs)
        loss = criterion(outputs, labels)
        _, preds = torch.max(outputs, 1)

        test_running_loss += loss.item() * inputs.size(0)
        test_running_corrects += torch.sum(preds == labels.data)

# Calculate final test statistics
final_test_loss = test_running_loss / dataset_sizes['test']
final_test_acc = test_running_corrects.double() / dataset_sizes['test']

print("\n--- Final Test Results ---")
print(f"  Test Loss: {final_test_loss:.4f}")
print(f"  Test Acc:  {final_test_acc:.4f} ({test_running_corrects.item()}/{dataset_sizes['test']})")

# Save test results to history file
with open(HISTORY_SAVE_PATH, 'r') as f:
    history_data = json.load(f)

history_data['test_loss'] = float(final_test_loss)
history_data['test_accuracy'] = float(final_test_acc)

with open(HISTORY_SAVE_PATH, 'w') as f:
    json.dump(history_data, f, indent=2)

print(f"\nTest results saved to {HISTORY_SAVE_PATH}")

Model evaluation

In [None]:
import matplotlib.pyplot as plt

def show_history(history):
    plt.figure(figsize=(20,6))

    epochs = range(1, len(history['accuracy']) + 1)

    # summarize history for accuracy
    plt.subplot(121)
    plt.plot(epochs, history['accuracy'])
    plt.plot(epochs, history['val_accuracy'])
    plt.title('model accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'validation'], loc='upper left')

    # summarize history for loss
    plt.subplot(122)
    plt.plot(epochs, history['loss'])
    plt.plot(epochs, history['val_loss'])
    plt.title('model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'validation'], loc='upper left')

    plt.show()

show_history(history)


In [None]:
# Redefine imshow to handle normalization
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))

    #undo normalization
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)

    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

In [None]:
# Generic function to display predictions for a few images
def visualize_model(model, num_images=10):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['train']):
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            print("preds: ", len(preds))

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(2, num_images//2, images_so_far)
                ax.axis('off')
                ax.set_title('pred/true: {}/{}'.format(CLASSES_TO_USE[preds[j]],
                                                       CLASSES_TO_USE[labels[j]]))
                print(CLASSES_TO_USE[preds[j]] == CLASSES_TO_USE[labels[j]])
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

In [None]:
visualize_model(model)

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(y_true, y_pred, class_names, normalize=False, title=None, cmap=plt.cm.Blues):
    if title is None:
        title = 'Normalized confusion matrix' if normalize else 'Confusion matrix, without normalization'

    cm = confusion_matrix(y_true, y_pred, labels=np.arange(len(class_names)))

    if normalize:
        with np.errstate(all='ignore'):
            cm = cm.astype(float) / cm.sum(axis=1, keepdims=True)
            cm = np.nan_to_num(cm)  # handle rows with zero support

    width = min(24, 0.6 * len(class_names))
    fig, ax = plt.subplots(figsize=(width, 10))
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap, aspect='auto')
    ax.figure.colorbar(im, ax=ax)

    ax.set(
        xticks=np.arange(len(class_names)),
        yticks=np.arange(len(class_names)),
        xticklabels=class_names,
        yticklabels=class_names,
        title=title,
        ylabel='True label',
        xlabel='Predicted label'
    )
    ax.tick_params(axis='x', labelrotation=45, pad=12)
    ax.tick_params(axis='y', pad=6)

    # give labels extra margins so they don’t clip
    fig.subplots_adjust(bottom=0.32, left=0.28)

    # annotations (smaller font helps clutter)
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.0 if cm.size else 0
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                    ha="center", va="center",
                    fontsize=8,
                    color="white" if cm[i, j] > thresh else "black")

    fig.tight_layout()
    return ax

# --- Collect predictions over the FULL dataset, not just the last batch ---
model.eval()
y_true_all = []
y_pred_all = []

with torch.no_grad():
    for inputs, y_true in dataloaders['train']:
        inputs = inputs.to(DEVICE)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        y_true_all.append(y_true.cpu().numpy())
        y_pred_all.append(preds.cpu().numpy())

y_true_all = np.concatenate(y_true_all)
y_pred_all = np.concatenate(y_pred_all)

print(len(y_true_all))
print(len(y_pred_all))
# class_names must be a list of names ordered by the class index used by the model
# If you have a torchvision ImageFolder:
#   idx_to_class = {v:k for k, v in dataset.class_to_idx.items()}
#   class_names = [idx_to_class[i] for i in range(len(idx_to_class))]
class_names = CLASSES_TO_USE  # ensure this is a list of strings in index order

np.set_printoptions(precision=2)

# Plot
plot_confusion_matrix(y_true_all, y_pred_all, class_names, normalize=False,
                      title='Confusion matrix, without normalization')
plt.show()

plot_confusion_matrix(y_true_all, y_pred_all, class_names, normalize=True,
                      title='Normalized confusion matrix')
plt.show()


In [None]:
# @title Data Distribution Analysis

import matplotlib.pyplot as plt
import numpy as np

# Count samples per class for each split
def count_class_samples(dataset_dir):
    """Count number of samples per class in a directory."""
    class_counts = {}
    for class_name in CLASSES_TO_USE:
        class_dir = os.path.join(dataset_dir, class_name)
        if os.path.exists(class_dir):
            num_images = len(glob.glob(os.path.join(class_dir, '*.*')))
            class_counts[class_name] = num_images
        else:
            class_counts[class_name] = 0
    return class_counts

# Get counts for all splits
train_counts = count_class_samples(filtered_train_dir)
val_counts = count_class_samples(filtered_val_dir)
test_counts = count_class_samples(filtered_test_dir)

# Create visualization
fig, axes = plt.subplots(3, 1, figsize=(16, 12))

# Training set distribution
axes[0].bar(CLASSES_TO_USE, train_counts.values(), color='steelblue')
axes[0].set_title('Training Set Distribution', fontsize=14, fontweight='bold')
axes[0].set_ylabel('Number of Samples')
axes[0].tick_params(axis='x', rotation=45)
axes[0].grid(axis='y', alpha=0.3)

# Validation set distribution
axes[1].bar(CLASSES_TO_USE, val_counts.values(), color='orange')
axes[1].set_title('Validation Set Distribution', fontsize=14, fontweight='bold')
axes[1].set_ylabel('Number of Samples')
axes[1].tick_params(axis='x', rotation=45)
axes[1].grid(axis='y', alpha=0.3)

# Test set distribution
axes[2].bar(CLASSES_TO_USE, test_counts.values(), color='green')
axes[2].set_title('Test Set Distribution', fontsize=14, fontweight='bold')
axes[2].set_ylabel('Number of Samples')
axes[2].set_xlabel('Class Name')
axes[2].tick_params(axis='x', rotation=45)
axes[2].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

# Print statistics
print("\n=== Dataset Distribution Statistics ===\n")

print("TRAINING SET:")
total_train = sum(train_counts.values())
for class_name, count in sorted(train_counts.items(), key=lambda x: x[1], reverse=True):
    percentage = 100 * count / total_train if total_train > 0 else 0
    print(f"  {class_name:20s}: {count:4d} samples ({percentage:5.2f}%)")
print(f"  {'TOTAL':20s}: {total_train:4d} samples")

print("\nVALIDATION SET:")
total_val = sum(val_counts.values())
for class_name, count in sorted(val_counts.items(), key=lambda x: x[1], reverse=True):
    percentage = 100 * count / total_val if total_val > 0 else 0
    print(f"  {class_name:20s}: {count:4d} samples ({percentage:5.2f}%)")
print(f"  {'TOTAL':20s}: {total_val:4d} samples")

print("\nTEST SET:")
total_test = sum(test_counts.values())
for class_name, count in sorted(test_counts.items(), key=lambda x: x[1], reverse=True):
    percentage = 100 * count / total_test if total_test > 0 else 0
    print(f"  {class_name:20s}: {count:4d} samples ({percentage:5.2f}%)")
print(f"  {'TOTAL':20s}: {total_test:4d} samples")

# Check for imbalance
train_counts_list = list(train_counts.values())
if train_counts_list:
    max_count = max(train_counts_list)
    min_count = min(train_counts_list)
    imbalance_ratio = max_count / min_count if min_count > 0 else float('inf')

    print(f"\n=== Imbalance Analysis ===")
    print(f"Training set imbalance ratio: {imbalance_ratio:.2f}:1")

    if imbalance_ratio > 3:
        print("WARNING: Significant class imbalance detected!")
        print("Consider using class weights or data augmentation.")
    elif imbalance_ratio > 1.5:
        print("Moderate class imbalance detected.")
        print("Monitor per-class performance in confusion matrix.")
    else:
        print("Dataset is reasonably balanced.")

In [None]:
# @title Compare Baseline vs Augmented Results

import json
import matplotlib.pyplot as plt
import os

# Try to load both history files
baseline_path = './history_baseline_no_augmented_images.json'
augmented_path = './history_with_augmented_images.json'

try:
    with open(baseline_path, 'r') as f:
        history_baseline = json.load(f)
    baseline_exists = True
    print(f"✓ Loaded baseline results")
except FileNotFoundError:
    baseline_exists = False
    print(f"✗ Baseline results not found at {baseline_path}")

try:
    with open(augmented_path, 'r') as f:
        history_augmented = json.load(f)
    augmented_exists = True
    print(f"✓ Loaded augmented results")
except FileNotFoundError:
    augmented_exists = False
    print(f"✗ Augmented results not found at {augmented_path}")

if not (baseline_exists and augmented_exists):
    print("\n⚠️  You need to run training twice:")
    print("   1. First run with IMG_AUGMENT = False")
    print("   2. Second run with IMG_AUGMENT = True")
    print("   Then run this cell again to compare.")
else:
    # ==================== PLOTTING ====================
    fig = plt.figure(figsize=(20, 12))

    epochs_base = range(1, len(history_baseline['accuracy']) + 1)
    epochs_aug = range(1, len(history_augmented['accuracy']) + 1)

    # 1. Training Accuracy Comparison
    ax1 = plt.subplot(2, 3, 1)
    ax1.plot(epochs_base, history_baseline['accuracy'], 'b-o', label='Baseline', linewidth=2)
    ax1.plot(epochs_aug, history_augmented['accuracy'], 'r-s', label='With Augmented Images', linewidth=2)
    ax1.set_title('Training Accuracy Comparison', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # 2. Validation Accuracy Comparison
    ax2 = plt.subplot(2, 3, 2)
    ax2.plot(epochs_base, history_baseline['val_accuracy'], 'b-o', label='Baseline', linewidth=2)
    ax2.plot(epochs_aug, history_augmented['val_accuracy'], 'r-s', label='With Augmented Images', linewidth=2)
    ax2.set_title('Validation Accuracy Comparison', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3)

    # 3. Overfitting Gap (Train - Val Accuracy)
    ax3 = plt.subplot(2, 3, 3)
    gap_baseline = [t - v for t, v in zip(history_baseline['accuracy'], history_baseline['val_accuracy'])]
    gap_augmented = [t - v for t, v in zip(history_augmented['accuracy'], history_augmented['val_accuracy'])]
    ax3.plot(epochs_base, gap_baseline, 'b-o', label='Baseline Gap', linewidth=2)
    ax3.plot(epochs_aug, gap_augmented, 'r-s', label='Augmented Gap', linewidth=2)
    ax3.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    ax3.set_title('Overfitting Gap (Train - Val Acc)', fontsize=14, fontweight='bold')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Gap')
    ax3.legend()
    ax3.grid(True, alpha=0.3)

    # 4. Training Loss Comparison
    ax4 = plt.subplot(2, 3, 4)
    ax4.plot(epochs_base, history_baseline['loss'], 'b-o', label='Baseline', linewidth=2)
    ax4.plot(epochs_aug, history_augmented['loss'], 'r-s', label='With Augmented Images', linewidth=2)
    ax4.set_title('Training Loss Comparison', fontsize=14, fontweight='bold')
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Loss')
    ax4.legend()
    ax4.grid(True, alpha=0.3)

    # 5. Validation Loss Comparison
    ax5 = plt.subplot(2, 3, 5)
    ax5.plot(epochs_base, history_baseline['val_loss'], 'b-o', label='Baseline', linewidth=2)
    ax5.plot(epochs_aug, history_augmented['val_loss'], 'r-s', label='With Augmented Images', linewidth=2)
    ax5.set_title('Validation Loss Comparison', fontsize=14, fontweight='bold')
    ax5.set_xlabel('Epoch')
    ax5.set_ylabel('Loss')
    ax5.legend()
    ax5.grid(True, alpha=0.3)

    # 6. Summary Bar Chart
    ax6 = plt.subplot(2, 3, 6)
    metrics = ['Train Acc', 'Val Acc', 'Test Acc', 'Train-Val Gap']
    baseline_vals = [
        history_baseline['accuracy'][-1],
        history_baseline['val_accuracy'][-1],
        history_baseline.get('test_accuracy', 0),
        gap_baseline[-1]
    ]
    augmented_vals = [
        history_augmented['accuracy'][-1],
        history_augmented['val_accuracy'][-1],
        history_augmented.get('test_accuracy', 0),
        gap_augmented[-1]
    ]

    x = np.arange(len(metrics))
    width = 0.35
    ax6.bar(x - width/2, baseline_vals, width, label='Baseline', color='steelblue')
    ax6.bar(x + width/2, augmented_vals, width, label='With Augmented Images', color='coral')
    ax6.set_ylabel('Value')
    ax6.set_title('Final Metrics Comparison', fontsize=14, fontweight='bold')
    ax6.set_xticks(x)
    ax6.set_xticklabels(metrics, rotation=45, ha='right')
    ax6.legend()
    ax6.grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    plt.show()

    # ==================== NUMERICAL COMPARISON ====================
    print("\n" + "="*70)
    print("                    DETAILED COMPARISON REPORT")
    print("="*70)

    print("\nBASELINE (No Augmented Images):")
    print(f"  Final Train Acc:     {history_baseline['accuracy'][-1]:.4f}")
    print(f"  Final Val Acc:       {history_baseline['val_accuracy'][-1]:.4f}")
    print(f"  Best Val Acc:        {history_baseline['best_val_acc']:.4f}")
    if 'test_accuracy' in history_baseline:
        print(f"  Test Acc:            {history_baseline['test_accuracy']:.4f}")
    print(f"  Train-Val Gap:       {gap_baseline[-1]:.4f}")
    print(f"  Final Train Loss:    {history_baseline['loss'][-1]:.4f}")
    print(f"  Final Val Loss:      {history_baseline['val_loss'][-1]:.4f}")

    print("\nWITH AUGMENTED IMAGES:")
    print(f"  Final Train Acc:     {history_augmented['accuracy'][-1]:.4f}")
    print(f"  Final Val Acc:       {history_augmented['val_accuracy'][-1]:.4f}")
    print(f"  Best Val Acc:        {history_augmented['best_val_acc']:.4f}")
    if 'test_accuracy' in history_augmented:
        print(f"  Test Acc:            {history_augmented['test_accuracy']:.4f}")
    print(f"  Train-Val Gap:       {gap_augmented[-1]:.4f}")
    print(f"  Final Train Loss:    {history_augmented['loss'][-1]:.4f}")
    print(f"  Final Val Loss:      {history_augmented['val_loss'][-1]:.4f}")

    print("\nIMPROVEMENTS (Augmented - Baseline):")
    val_improvement = history_augmented['best_val_acc'] - history_baseline['best_val_acc']
    gap_reduction = gap_baseline[-1] - gap_augmented[-1]

    print(f"  Best Val Acc Change:     {val_improvement:+.4f} ({val_improvement*100:+.2f}%)")
    print(f"  Overfitting Reduction:   {gap_reduction:+.4f} ({gap_reduction*100:+.2f}%)")

    if 'test_accuracy' in history_baseline and 'test_accuracy' in history_augmented:
        test_improvement = history_augmented['test_accuracy'] - history_baseline['test_accuracy']
        print(f"  Test Acc Change:         {test_improvement:+.4f} ({test_improvement*100:+.2f}%)")

    print("\nCONCLUSIONS:")

    if val_improvement > 0.01:  # More than 1% improvement
        print(f"  ✓ Using augmented images IMPROVED validation accuracy by {val_improvement*100:.2f}%")
    elif val_improvement < -0.01:
        print(f"  ✗ Using augmented images DECREASED validation accuracy by {abs(val_improvement)*100:.2f}%")
    else:
        print(f"  ≈ Using augmented images had MINIMAL IMPACT on validation accuracy")

    if gap_reduction > 0.02:  # Gap reduced by more than 2%
        print(f"  ✓ Augmented images REDUCED overfitting (gap decreased by {gap_reduction*100:.2f}%)")
    elif gap_reduction < -0.02:
        print(f"  ✗ Augmented images INCREASED overfitting (gap increased by {abs(gap_reduction)*100:.2f}%)")
    else:
        print(f"  ≈ Augmented images had MINIMAL IMPACT on overfitting")

    # Training set size comparison
    print(f"\nDATASET SIZES:")
    print(f"  Baseline likely used:    ~{len(CLASSES_TO_USE) * MAX_IMAGES_PER_CLASS_TRAIN} images")
    print(f"  Augmented likely used:   More images (original + aug_ prefixed)")

    print("\n" + "="*70)

In [None]:
# @title Dataset Size Comparison

import matplotlib.pyplot as plt

# Count images with and without augmentation
def count_images_by_type(base_dir, classes):
    """Count original vs augmented images."""
    original_count = 0
    augmented_count = 0

    for class_name in classes:
        class_dir = os.path.join(base_dir, class_name)
        if os.path.exists(class_dir):
            all_images = glob.glob(os.path.join(class_dir, '*.*'))
            for img in all_images:
                if os.path.basename(img).startswith('aug_'):
                    augmented_count += 1
                else:
                    original_count += 1

    return original_count, augmented_count

train_orig, train_aug = count_images_by_type(BASE_TRAIN_DIR, CLASSES_TO_USE)
val_orig, val_aug = count_images_by_type(BASE_VAL_DIR, CLASSES_TO_USE)
test_orig, test_aug = count_images_by_type(BASE_TEST_DIR, CLASSES_TO_USE)

# Visualize
fig, ax = plt.subplots(1, 3, figsize=(15, 5))

splits = ['Train', 'Validation', 'Test']
original = [train_orig, val_orig, test_orig]
augmented = [train_aug, val_aug, test_aug]

for i, (split, orig, aug) in enumerate(zip(splits, original, augmented)):
    ax[i].bar(['Original', 'Augmented', 'Total'],
              [orig, aug, orig + aug],
              color=['steelblue', 'coral', 'green'])
    ax[i].set_title(f'{split} Set')
    ax[i].set_ylabel('Number of Images')
    ax[i].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

print("Dataset Composition:")
print(f"\nTraining Set:")
print(f"  Original images:   {train_orig}")
print(f"  Augmented images:  {train_aug}")
print(f"  Total:             {train_orig + train_aug}")
print(f"  Augmentation ratio: {train_aug/train_orig:.2f}x")

print(f"\nValidation Set:")
print(f"  Original images:   {val_orig}")
print(f"  Augmented images:  {val_aug}")
print(f"  Total:             {val_orig + val_aug}")

print(f"\nTest Set:")
print(f"  Original images:   {test_orig}")
print(f"  Augmented images:  {test_aug}")
print(f"  Total:             {test_orig + test_aug}")

# Note: Limited by MAX_IMAGES_PER_CLASS_TRAIN
print(f"\nNote: Training is limited to {MAX_IMAGES_PER_CLASS_TRAIN} images per class")
print(f"  With IMG_AUGMENT=False: Uses only original images")
print(f"  With IMG_AUGMENT=True:  Uses only augmented images")