<a href="https://colab.research.google.com/github/dhamu2908/DeepLearningAssignment2/blob/main/DL_Assign2_train_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import torchvision
from torchvision import transforms, models, datasets
import argparse
import time
import os

# Configuration constants
IMAGE_SIZE = 224
NORMALIZATION_MEAN = [0.485, 0.456, 0.406]
NORMALIZATION_STD = [0.229, 0.224, 0.225]

def initialize_pretrained(base_model, trainable_layers):
    """Create and configure a pretrained model with specified trainable layers"""
    model_architectures = {
        'googlenet': models.googlenet,
        'inception': models.inception_v3,
        'resnet': models.resnet50
    }

    creator = model_architectures.get(base_model.lower())
    if not creator:
        raise ValueError(f"Unsupported architecture: {base_model}")

    network = creator(weights="DEFAULT")

    # Freeze all parameters initially
    for param in network.parameters():
        param.requires_grad = False

    # Unfreeze specified layers
    if trainable_layers > 0:
        layers = list(network.children())[-trainable_layers:]
        for layer in layers:
            for param in layer.parameters():
                param.requires_grad = True

    # Modify final layer
    in_features = network.fc.in_features
    network.fc = nn.Linear(in_features, 10)

    return network

def prepare_data_loaders(train_path, test_path, batch_size):
    """Create data loaders with augmentation and normalization"""
    data_transforms = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(NORMALIZATION_MEAN, NORMALIZATION_STD)
    ])

    full_dataset = datasets.ImageFolder(train_path, data_transforms)
    test_data = datasets.ImageFolder(test_path, data_transforms)

    # Split dataset
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_set, val_set = random_split(full_dataset, [train_size, val_size])

    return (
        DataLoader(train_set, batch_size=batch_size, shuffle=True),
        DataLoader(val_set, batch_size=batch_size),
        DataLoader(test_data, batch_size=batch_size)
    )

def compute_accuracy(model, data_loader, device):
    """Calculate model accuracy on given dataset"""
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct / total

def train_model(model, train_loader, val_loader, epochs, device):
    """Main training procedure with progress tracking"""
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)

    best_acc = 0.0
    model = model.to(device)

    for epoch in range(epochs):
        start_time = time.time()
        model.train()
        running_loss = 0.0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)

        # Calculate metrics
        epoch_loss = running_loss / len(train_loader.dataset)
        val_acc = compute_accuracy(model, val_loader, device)
        train_acc = compute_accuracy(model, train_loader, device)
        epoch_time = time.time() - start_time

        # Display progress
        print(f"Epoch {epoch+1}/{epochs} | "
              f"Time: {epoch_time:.1f}s | "
              f"Loss: {epoch_loss:.4f} | "
              f"Train Acc: {train_acc:.4f} | "
              f"Val Acc: {val_acc:.4f}")

        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')

    return model

def main_execution(args):
    """Orchestrate the complete workflow"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Prepare data
    train_loader, val_loader, test_loader = prepare_data_loaders(
        args.train_dataset_path,
        args.test_dataset_path,
        args.batch_size
    )

    # Initialize model
    network = initialize_pretrained(
        base_model='googlenet',
        trainable_layers=args.unfreezed_layers_from_end
    ).to(device)

    # Train model
    trained_model = train_model(
        network,
        train_loader,
        val_loader,
        args.epochs,
        device
    )

    # Final evaluation
    test_acc = compute_accuracy(trained_model, test_loader, device)
    print(f"\nFinal Test Accuracy: {test_acc:.4f}")

if __name__ == "__main__":
    # Configure command line interface
    parser = argparse.ArgumentParser()
    parser.add_argument("-ptrn", "--train_dataset_path",
                      type=str, default="inaturalist_12K/train")
    parser.add_argument("-ptst", "--test_dataset_path",
                      type=str, default="inaturalist_12K/val")
    parser.add_argument("-ep", "--epochs", type=int, default=15)
    parser.add_argument("-bs", "--batch_size", type=int, default=256)
    parser.add_argument("-ul", "--unfreezed_layers_from_end",
                      type=int, default=0)

    cli_args = parser.parse_args()
    main_execution(cli_args)