# Kenyan Food Classification Training Pipeline (Trainer Version)
This notebook demonstrates a modular neural network training pipeline using the scripts in the `trainer` folder.

In [1]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim


# Trainer imports
from trainer.trainer import Trainer
from trainer.metrics import AccuracyEstimator
from trainer.configuration import SystemConfig, DatasetConfig, DataloaderConfig, OptimizerConfig, TrainerConfig
from trainer.utils import setup_system, patch_configs

from trainer.data_loader import get_data_loaders
from trainer.model import get_model

from trainer.tensorboard_visualizer import TensorBoardVisualizer



  albumentations.PadIfNeeded(min_height=300, min_width=300, border_mode=0, value=0),
  albumentations.PadIfNeeded(min_height=300, min_width=300, border_mode=0, value=0),
  result = _ensure_odd_values(result, info.field_name)


In [2]:
# Visualizer (TensorBoard)
visualizer = TensorBoardVisualizer()

## Configuration
Set up all configuration objects for the pipeline.

In [3]:


system_config = SystemConfig()
setup_system(system_config)

# Patch configs for device
dataloader_config, trainer_config = patch_configs(epoch_num_to_set=50, 
                                                    batch_size_to_set=16)  # NO! WE NEED TO GET THE DATALOADER AS WELL

optimizer_config = OptimizerConfig()

Using MPS backend for PyTorch


## Import Loaders and Model
Import all the required components to start the training.

In [4]:
# Data loaders

#dataloader_config = DataloaderConfig()

train_loader, val_loader, num_classes = get_data_loaders(
    data_root="./data", 
    batch_size=dataloader_config.batch_size, 
    num_workers=dataloader_config.num_workers, 
    seed=system_config.seed, 
    data_augmentation=True,
    test_size=dataloader_config.test_size,
)

In [8]:
# Model
model = get_model(
    num_classes = num_classes,
    pretrained = True,
    freeze_backbone = True,
    trainable_layers = 1
) 

model = model.to(trainer_config.device)

# Save the model visualization
dummy_input = torch.randn(1, 3, 224, 224).to(trainer_config.device)
visualizer.add_model_graph(model, input_tensor=dummy_input)


In [9]:

# Optimizer, Scheduler
optimizer = optim.AdamW(model.parameters(), 
                        lr=optimizer_config.learning_rate,
                        #momentum=optimizer_config.momentum, 
                        weight_decay=optimizer_config.weight_decay)

scheduler = optim.lr_scheduler.StepLR(optimizer, 
                                      step_size=optimizer_config.scheduler_step_size, 
                                      gamma=optimizer_config.scheduler_gamma)

In [10]:
# Loss and metric
loss_fn = nn.CrossEntropyLoss()
metric_fn = AccuracyEstimator(topk=(1,))

## Training
Run the Trainer pipeline.

In [None]:
trainer = Trainer(
    model=model,
    loader_train=train_loader,
    loader_test=val_loader,
    loss_fn=loss_fn,
    metric_fn=metric_fn,
    optimizer=optimizer,
    lr_scheduler=scheduler,
    device=trainer_config.device,
    model_saving_frequency=trainer_config.model_saving_frequency,
    save_dir=trainer_config.model_dir,
    model_name_prefix="kenyanfood_model",
    data_getter=lambda sample: sample["image"],
    target_getter=lambda sample: torch.tensor(sample["target"]),
    stage_progress=trainer_config.progress_bar,
    visualizer=visualizer,
    get_key_metric=lambda metric: metric["top1"]
)
metrics = trainer.fit(trainer_config.epoch_num)

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  target_getter=lambda sample: torch.tensor(sample["target"]),


  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

  0%|          | 0/327 [00:00<?, ?it/s]

  0%|          | 0/82 [00:00<?, ?it/s]

In [None]:
# Close TensorBoard writer
visualizer.close_tensorboard()

In [None]:
# Visualize accuracy and loss from metrics
import matplotlib.pyplot as plt

# Extract metrics
epochs = metrics['epoch']
train_loss = metrics['train_loss']
val_loss = metrics['test_loss']
val_acc = [m['top1'] if isinstance(m, dict) and 'top1' in m else m for m in metrics['test_metric']]

plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
plt.plot(epochs, train_loss, label='Train Loss')
plt.plot(epochs, val_loss, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss over Epochs')
plt.legend()

plt.subplot(1,2,2)
plt.plot(epochs, val_acc, label='Validation Accuracy (Top-1)')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Validation Accuracy over Epochs')
plt.legend()
plt.tight_layout()
plt.show()