# Foundational Model Training: MobileNetV4 & MobileNetV5

This notebook trains MobileNetV4 and MobileNetV5 models on the processed dataset to serve as foundational models for downstream agricultural vision tasks.

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
import timm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from pathlib import Path
import time
import copy

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
# Configuration
CONFIG = {
    'data_dir': '../data/release',  # Path to the released dataset
    'models_to_train': [
        'mobilenetv4_conv_medium.e500_r256_in1k', # High performance V4 variant
        'mobilenetv5_base.e500_r256_in1k'           # Base V5 variant
    ],
    'batch_size': 32,
    'num_epochs': 20,
    'learning_rate': 1e-4,
    'image_size': 256,
    'seed': 42,
    'num_workers': 4,
    'checkpoint_dir': '../models/foundational'
}

os.makedirs(CONFIG['checkpoint_dir'], exist_ok=True)

def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(CONFIG['seed'])

In [None]:
# Data Transforms
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(CONFIG['image_size']),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(int(CONFIG['image_size'] * 1.14)),
        transforms.CenterCrop(CONFIG['image_size']),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Data Loading
def get_dataloaders(data_dir):
    image_datasets = {}
    dataloaders = {}
    dataset_sizes = {}
    class_names = []

    for x in ['train', 'val']:
        path = os.path.join(data_dir, x)
        if not os.path.exists(path):
            print(f"Warning: {path} not found. Using random split fallback.")
            full_dataset = ImageFolder(data_dir, transform=data_transforms['train'])
            train_size = int(0.8 * len(full_dataset))
            val_size = len(full_dataset) - train_size
            train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
            
            val_dataset.dataset.transform = data_transforms['val'] 

            image_datasets = {'train': train_dataset, 'val': val_dataset}
            class_names = full_dataset.classes
            break
        else:
            image_datasets[x] = ImageFolder(path, data_transforms[x])
            if x == 'train':
                class_names = image_datasets[x].classes

    for x in ['train', 'val']:
        dataloaders[x] = DataLoader(image_datasets[x], batch_size=CONFIG['batch_size'],
                                     shuffle=True if x == 'train' else False, 
                                     num_workers=CONFIG['num_workers'])
        dataset_sizes[x] = len(image_datasets[x])

    return dataloaders, dataset_sizes, class_names

dataloaders, dataset_sizes, class_names = get_dataloaders(CONFIG['data_dir'])
print(f"Classes: {len(class_names)}")
print(f"Dataset sizes: {dataset_sizes}")

In [None]:
def train_model(model_name, dataloaders, dataset_sizes, num_epochs=25):
    print(f"
Starting training for {model_name}...")
    
    model = timm.create_model(model_name, pretrained=True, num_classes=len(class_names))
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'])
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in tqdm(dataloaders[phase], leave=False):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            history[f'{phase}_loss'].append(epoch_loss)
            history[f'{phase}_acc'].append(epoch_acc.item())

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                save_path = os.path.join(CONFIG['checkpoint_dir'], f"{model_name}_best.pth")
                torch.save(model.state_dict(), save_path)
                print(f"New best model saved to {save_path}")

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    model.load_state_dict(best_model_wts)
    return model, history

In [None]:
# Execute Training
for model_name in CONFIG['models_to_train']:
    print(f"=================================================================")
    print(f"TRAINING MODEL: {model_name}")
    print(f"=================================================================")
    
    clean_name = model_name.split('.')[0]
    
    trained_model, history = train_model(model_name, dataloaders, dataset_sizes, CONFIG['num_epochs'])
    
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.legend()
    plt.title(f'{clean_name} Loss')
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Acc')
    plt.plot(history['val_acc'], label='Val Acc')
    plt.legend()
    plt.title(f'{clean_name} Accuracy')
    
    plt.show()
    
    pd.DataFrame(history).to_csv(os.path.join(CONFIG['checkpoint_dir'], f"{clean_name}_history.csv"), index=False)