# Notebook for classification with CNN and Pytorch of the AgroPest-12 dataset

Functions for project specific code can be found in the source folder.

Python files starting with cnn* and optuna.py are relevant for this task

In [None]:
# --- standard and third party library
import os
from pathlib import Path
import random
from datetime import datetime
import sys
import math
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from torchsummary import summary
from torchmetrics.classification import (
    MulticlassF1Score,
    MulticlassPrecision,
    MulticlassRecall,
    MulticlassConfusionMatrix
)

# --- project-specific imports
from source.cnn_model import train_model, FlexibleCNN, get_batch_size
from source.cnn_plotting import (
    plot_random_image_per_class,
    plot_random_predictions,
    plot_filter_weights,
    plot_image,
    plot_feature_maps,
    plot_loss_accuracy,
    find_conv_layers
)

from source.cnn_quality_control import label_histogram, evaluate_classification
from source.cnn_retrieve_images import set_seed, classification_collate_fn, YOLODataset


## Classification with Pytorch

### Settings and preprocessing

In [None]:
# Defining random seed 
seed_number = 42
set_seed(seed_number, deterministic=True)
gen = torch.Generator().manual_seed(seed_number) # for dataloader

In [None]:
"""
Documentation for which classes to use in next section
Available classes in Agropest-12
  [
    "Ants",
    "Bees",
    "Beetles",
    "Caterpillars",
    "Earthworms",
    "Earwigs",
    "Grasshoppers",
    "Moths",
    "Slugs",
    "Snails",
    "Wasps",
    "Weevils",
  ]
"""

In [None]:
# Settings

yaml_path = Path.cwd().parent / "datasets" / "agropest12" / "data.yaml"

# Settings for retreiving images            . Set None for all images or an integer for subset for testing code
subset_classes = None # None for all classes, or a list of classes. Example ['Ants', 'Snails'], can be used for testing code
n_images_train  = None #None for all images. Choose a small number for testing code
n_images_valid =  None # None for all images, or for example 20% of all images int(n_images_train * 0.2). Latter can be applicable for testing code
n_images_test = None # None for all images, or for example 20% of all images int(n_images_train * 0.2). Latter can be applicable for testing code

batch_size = get_batch_size(n_images_train) 


PRETRAINED_MODEL = False   # True - use pretrained model and weights, else custom or Optuna hyperparamter search
if PRETRAINED_MODEL:
    NORMALIZE = 'ImageNet' #'auto' -  load if file for normalized values exists, else compute and save. 'compute' - always compute normalization and save/overwrite file. 'load' - load from file. 'ImageNet' - use ImageNet values
else:
    NORMALIZE = 'load'
NORMALIZE_CACHE_PATH =  Path.cwd().parent / "datasets" / "agropest12" / "cache" / "agropest12_norm.json"

IMG_SIZE= (224,224) # Pretrained must use (224,224) to match ImageNet. (128, 128) used when using Optuna hyperparamter search.
CROP_MODE = 'largest'# 'largest' #'largest'  # 'none' or 'largest  # 'largest' extracts part of image where largest object is located for further processing. Some images have mostly background and small portion of object

DEBUG = False # Print debug info
QUALITY_CONTROL = True # process, print and plot information for quality control
COMPUTE_AVERAGE_IMG_SIZE = True # used in preprocessing and quality control to determine size of cropped images to object.

SAVE_MODELS = True
SHOW_PLOT = True
SAVE_FIGURE = True

OPTUNA = False # to run Optuna hyperparameter search or not
OPTUNA_TRAILS = 50 # Number of trail for different hyperparameteres. Choose a small number for testing code 
OPTUNA_EPOCHS = 30 # epochs to run CNN when using Optuna. Choose a small number for testing code (Optuna trails, epochs and n_images_train must be large enough get variance in Optuna results, for example (10, 10 and 20))
 
N_EPOCHS = 50 # epochs to run CNN (not using Optuna). Choose a small number for testing code


# Paths for saving figures and models
current_dir = Path.cwd()
parent_dir = current_dir.parent

# Directory for saving figures
figures_path = parent_dir / "figures"
figures_path.mkdir(parents=True, exist_ok=True)

# Directory for saving models and summaries
models_path = parent_dir / "models"
models_path.mkdir(parents=True, exist_ok=True)


# Current timestamp for files and figures
timestamp = datetime.now().strftime("%Y-%m-%d-%H-%M")


In [None]:
# Computes a normalization for agropest dataset, leave to False to retreive normalization values from disk
# Computation heavy, only use one time and read normalization from disk later.

COMPUTE_NORMALIZATION = False 

if COMPUTE_NORMALIZATION:
    dataset_train_norm = YOLODataset(
        yaml_path=yaml_path,
        subset_classes=None,
        max_images=None, 
        split='train',
        img_size=IMG_SIZE,
        normalize='compute',
        norm_file=NORMALIZE_CACHE_PATH,
        crop_strategy='none',
        debug=True,
        auto_norm_sample=10000
    )

In [None]:
# Load dataset dependent on settings

workers = 2 # number of subprocesses

# Train
dataset_train = YOLODataset(
    yaml_path=yaml_path,
    subset_classes=subset_classes,
    max_images=n_images_train,
    split='train',
    img_size=IMG_SIZE,
    normalize=NORMALIZE,
    norm_file=NORMALIZE_CACHE_PATH,
    crop_strategy=CROP_MODE,
    debug=True,
)
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, collate_fn=classification_collate_fn, generator=gen, num_workers=workers)

# Valid
dataset_valid = YOLODataset(
    yaml_path=yaml_path,
    subset_classes=subset_classes,
    max_images=n_images_valid,
    split='valid',
    img_size=IMG_SIZE,
    normalize=NORMALIZE,
    norm_file=NORMALIZE_CACHE_PATH,
    crop_strategy=CROP_MODE,
    debug=DEBUG,
)
dataloader_valid = DataLoader(dataset_valid, batch_size=batch_size, shuffle=True, collate_fn=classification_collate_fn, generator=gen, num_workers=workers)

# Test
dataset_test = YOLODataset(
    yaml_path=yaml_path,
    subset_classes=subset_classes,
    max_images=n_images_test,
    split='test',
    img_size=IMG_SIZE,
    normalize=NORMALIZE,
    norm_file=NORMALIZE_CACHE_PATH,
    crop_strategy=CROP_MODE,
    debug=DEBUG,
)
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, collate_fn=classification_collate_fn, generator=gen, num_workers=workers)


#### Quality control

In [None]:
# Get average size for cropped image. Information to use for Transform and resize image

if COMPUTE_AVERAGE_IMG_SIZE:
    from statistics import mean

    N = len(dataset_train)
    crop_sizes = []
    final_sizes = []

    for i in range(N):
        s = dataset_train[i]
        crop_sizes.append(s["crop_size"])  # (W,H) before resize
        final_sizes.append((s["image"].shape[2], s["image"].shape[1]))  # (W,H) after transform

    # Compute averages
    avg_crop_w = mean([w for w, h in crop_sizes])
    avg_crop_h = mean([h for w, h in crop_sizes])

    print(f"Average cropped size: (Width: {avg_crop_w:.1f}, Height: {avg_crop_h:.1f})")


In [None]:
# Print information about images in batch as a quality control to see that import of images are OK
if QUALITY_CONTROL:
    from source.cnn_quality_control import validate_dataset
    validate_dataset(dataset_test, dataloader_test)

In [None]:
# Plot one image from each class in subset
if QUALITY_CONTROL: plot_random_image_per_class(dataset_train, verbose=True, save_figure=SAVE_FIGURE, show_plot=SHOW_PLOT, figures_path=figures_path, timestamp=timestamp)


In [None]:
# Print number images in each class for train, validation and train daset
if QUALITY_CONTROL: 
    print()
    print('Training dataset')
    _ = label_histogram(dataset_train)

    print()
    print('Validation dataset')
    _ = label_histogram(dataset_valid)

    print()
    print('Test dataset')
    _ = label_histogram(dataset_test)


## Generate and train PyTorch CNN model

In [None]:
# Set number of class  to 12 if subset_classes is set to None
if subset_classes == None:
    n_classes = 12
else:
    n_classes = len(subset_classes)

In [None]:
# Load pretrained weights from Resnet

if PRETRAINED_MODEL:
    import torchvision
    from torchvision.models import ResNet18_Weights
    weights = ResNet18_Weights.IMAGENET1K_V1
    model_pretrained = torchvision.models.resnet18(weights=weights)


    # Agropest 12 is a small dataset --> Freeze all layers
    for param in model_pretrained.parameters():
        param.requires_grad = False
        
    num_classes = n_classes  
    # train only fully connected layer for AgroPest-12 and 12 classes
    model_pretrained.fc = nn.Linear(model_pretrained.fc.in_features, num_classes) 


In [None]:
# Retreive number of channels in image and update input size
batch = next(iter(dataloader_test))
images = batch['images']
channels = images.shape[1]
input_dim = (channels,) + IMG_SIZE

### Optuna option to tune hyperparameters (not used with pretrained model and weights in this project)

In [None]:
# Run Optuna hyperparameter search
if OPTUNA and not PRETRAINED_MODEL:
    from source.optuna import create_objective
    import optuna
    objective = create_objective(n_classes, input_dim, dataset_train, dataset_valid, num_epochs=OPTUNA_EPOCHS, device='cpu')

    db_path = models_path / f"{timestamp}_optuna_study.db"
    optuna_path = f"sqlite:///{db_path}"  

    study = optuna.create_study(
        direction="maximize",
        pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=5, interval_steps=1, n_min_trials=5),
        storage=optuna_path,
        load_if_exists=False
    )

    study.optimize(objective, n_trials=OPTUNA_TRAILS, timeout=None, gc_after_trial=True)

    print("Best value (valid_acc):", study.best_value)
    print("Best params:", study.best_params)


In [None]:
if OPTUNA and not PRETRAINED_MODEL:
    from optuna.visualization import plot_optimization_history, plot_param_importances
    plot_optimization_history(study).show()
    plot_param_importances(study).show()


In [None]:
if OPTUNA and not PRETRAINED_MODEL:
    print('Hei')
    # Extract best parameters from Optuna
    best_params = study.best_params

    # Conv layers based on n_conv_layers and filters
    n_conv_layers = best_params['n_conv_layers']
    conv_layers = [
        (best_params[f'conv{i+1}_filters'], 3)  # fixed kernel size
        for i in range(n_conv_layers)
    ]

    # FC layers: optional hidden + fixed output
    fc_layers = [n_classes]
    if best_params['use_hidden_layer']:
        fc_layers.insert(0, best_params['fc_hidden_neurons'])

    # Dropout values
    dropout_fc = best_params['dropout_fc']
    dropout_conv = best_params['dropout_conv']

    activation_choice = nn.ReLU  # fixed activation
    use_batchnorm = True  # always enabled

else:
    # Fallback to predefined custom model architecture
    conv_layers = [(32, 3), (64, 3), (128, 3), (256, 3)]
    fc_layers = [n_classes]
    activation_choice = nn.ReLU
    dropout_fc = 0.3
    dropout_conv = 0.0
    use_batchnorm = True



### Build model and train

In [None]:
# Build model

model_custom = FlexibleCNN(
    input_size=input_dim,
    num_classes=n_classes,
    conv_layers=conv_layers,
    fc_layers=fc_layers,
    activation=activation_choice,
    dropout_fc=dropout_fc,
    dropout_conv=dropout_conv,
    use_batchnorm=use_batchnorm,
    pool_type="max",
    global_pool="avg"
).to('cpu')


In [None]:
# define model in use dependent on pretrained or custom/optuna

if PRETRAINED_MODEL:
    model = model_pretrained
    print('Pretrained model')
else:
    model = model_custom
    print('Custom model')

In [None]:
# Print and save model summary for PyTorch model and parameters
path_summary = Path(models_path / f'{timestamp}_model_summary_pretrained{PRETRAINED_MODEL}.txt')

if SAVE_MODELS:
    with open(path_summary, "w") as f:
        sys.stdout = f
        summary(model, input_dim, batch_size=batch_size)
        sys.stdout = sys.__stdout__  


# Alternative model summary
#from torchinfo import summary
#summary(model)

In [None]:
# Training

# training based on Optuna or not
learning_rate = best_params['learning_rate'] if OPTUNA else 0.001

if PRETRAINED_MODEL:
    optimizer = torch.optim.Adam(model.fc.parameters(), lr=learning_rate, weight_decay=1e-4)  # only updates weights for fully connected layer
else:
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)

loss_fn = nn.CrossEntropyLoss()
epochs = N_EPOCHS 


history = train_model(
    model=model,
    num_epochs=epochs,
    train_dl=dataloader_train,
    valid_dl=dataloader_valid,
    loss_fn=loss_fn,
    optimizer=optimizer,
    device='cpu',
    verbose=True,
    patience_val=15
)


## Results from classification

In [None]:
# Plotting loss and accuracy by epoch
plot_loss_accuracy(history, figures_path, timestamp, save_figure=SAVE_FIGURE, show_plot=SHOW_PLOT)

In [None]:
# Evaluvate classification, print metric and plot confusion matrix

results = evaluate_classification(
    model=model,
    dataloader=dataloader_test,
    num_classes=n_classes,
    figures_path=figures_path,
    timestamp=timestamp,
    device='cpu',
    class_names=subset_classes,
    save_figure=SAVE_FIGURE,
    show_plot=SHOW_PLOT
)

display(results)

In [None]:
# Write model and summary to file
if SAVE_MODELS:
    path_model = Path(models_path / f'{timestamp}_model_summary_test-img_size{IMG_SIZE}_accuracy{results['accuracy']:.4f}_avgf1_{results['macro_f1']:.4f}_weights.pth')
    torch.save(model.state_dict(), path_model)

In [None]:
# Plot figures of random predictions

n_classes = len(dataset_test.subset_class_names)  # e.g., 2 for ['Ants', 'Snails']
n_cols = min(2, n_classes)
n_rows = math.ceil(n_classes / n_cols)

plot_random_predictions(
    model,
    dataset_test,
    num_samples=n_classes,           
    path=figures_path,
    time=timestamp,
    number_images=n_images_train,
    rows=n_rows, cols=n_cols,
    class_names=dataset_test.subset_class_names,
    device="cpu",
    save_figures=SAVE_FIGURE,
    show_plot=SHOW_PLOT
)


In [None]:
# plot filter weights
plot_filter_weights(model, figures_path, timestamp, n_images_train, rows=2, cols=2, channel=0, save_figures=SAVE_FIGURE, show_plot=SHOW_PLOT) # channel as in RGB


In [None]:
# Plot feature maps for a sample image

# Random sample from test dataset
random_index = np.random.randint(0, len(dataset_test))
sample = dataset_test[random_index]
sample_img = sample["image"]
true_name = sample.get("class_name", "")

plot_image(
    sample_img,
    figures_path,
    timestamp,
    n_images_train,
    title=f'Image sample from test dataset (True: {true_name})',
    save_figures=SAVE_FIGURE,
    show_plot=SHOW_PLOT,
    dataset=dataset_test  
)

# Retreive convolutional layers
conv_layers = find_conv_layers(model)
indices_to_show = [0, len(conv_layers) - 1]  # first and last convolutional layer
for idx, (name, layer) in enumerate(conv_layers):
    if idx in indices_to_show: 
        plot_feature_maps(
            model=model,
            image=sample_img,            
            path=figures_path,
            time=timestamp,
            number_images=n_images_train,
            layers_to_show=[idx],              
            num_maps=4,
            rows=2, cols=2,
            cmap="gray",
            save_figures=SAVE_FIGURE,
            show_plot=SHOW_PLOT,
            device="cpu",
            input_is_tensor=True,        
            preprocess=None
        )

