In [4]:
# ---------------------------
# Imports
# ---------------------------
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import wandb

# ---------------------------
# W&B login (make sure you've logged in)
# ---------------------------
wandb.login()

# ---------------------------
# Hyperparameters and options
# ---------------------------
models_to_run = ['alexnet', 'resnet18']  # only AlexNet and ResNet18
batch_sizes = [16, 64]
learning_rates = [0.001, 0.0001]
augment_options = [True, False]
num_epochs = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------------------
# Dataset (example: ImageFolder)
# Replace 'path_to_data' with your dataset path
# ---------------------------
def get_data_loaders(batch_size, augment):
    if augment:
        transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor()
        ])
    else:
        transform = transforms.Compose([transforms.ToTensor()])

    train_dataset = datasets.ImageFolder(train_folder, transform=transform)
    val_dataset = datasets.ImageFolder(val_folder, transform=transforms.ToTensor())

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader

# ---------------------------
# Training function
# ---------------------------
def train_model(model_name, batch_size, lr, augment):
    run_name = f"{model_name}_bs{batch_size}_lr{lr}_aug{augment}"
    wandb.init(project="safety_gear_detection_auto", name=run_name)

    train_loader, val_loader = get_data_loaders(batch_size, augment)

    # Load model
    if model_name == 'alexnet':
        model = models.alexnet(weights=models.AlexNet_Weights.DEFAULT)
    elif model_name == 'resnet18':
        model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
        # Replace final layer for number of classes
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, 2)  # replace 2 with your number of classes

    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        model.train()
        correct_train = 0
        total_train = 0
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        train_acc = correct_train / total_train
        train_loss = running_loss / len(train_loader)

        # Validation
        model.eval()
        correct_val = 0
        total_val = 0
        val_loss = 0.0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        val_acc = correct_val / total_val
        val_loss = val_loss / len(val_loader)

        print(f"{run_name} | Epoch {epoch+1}/{num_epochs} | "
              f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")

        wandb.log({
            "epoch": epoch+1,
            "train_acc": train_acc,
            "train_loss": train_loss,
            "val_acc": val_acc,
            "val_loss": val_loss
        })

    wandb.finish()

# ---------------------------
# Run all experiments
# ---------------------------
for model_name in models_to_run:
    for bs in batch_sizes:
        for lr in learning_rates:
            for aug in augment_options:
                train_model(model_name=model_name, batch_size=bs, lr=lr, augment=aug)




0,1
epoch,▁▅█
train_acc,▁▇█
train_loss,█▅▁
val_acc,█▁▁
val_loss,█▂▁

0,1
epoch,3.0
train_acc,0.55
train_loss,1.20953
val_acc,0.5
val_loss,0.92966


alexnet_bs16_lr0.001_augTrue | Epoch 1/5 | Train Acc: 0.2750 | Val Acc: 0.5000
alexnet_bs16_lr0.001_augTrue | Epoch 2/5 | Train Acc: 0.4250 | Val Acc: 0.5000
alexnet_bs16_lr0.001_augTrue | Epoch 3/5 | Train Acc: 0.4000 | Val Acc: 0.5000
alexnet_bs16_lr0.001_augTrue | Epoch 4/5 | Train Acc: 0.4250 | Val Acc: 0.5000
alexnet_bs16_lr0.001_augTrue | Epoch 5/5 | Train Acc: 0.4500 | Val Acc: 0.5000


0,1
epoch,▁▃▅▆█
train_acc,▁▇▆▇█
train_loss,▅▃█▁▁
val_acc,▁▁▁▁▁
val_loss,▃█▂▁▁

0,1
epoch,5.0
train_acc,0.45
train_loss,1.32087
val_acc,0.5
val_loss,0.77669


alexnet_bs16_lr0.001_augFalse | Epoch 1/5 | Train Acc: 0.3500 | Val Acc: 0.5000
alexnet_bs16_lr0.001_augFalse | Epoch 2/5 | Train Acc: 0.5250 | Val Acc: 0.5000
alexnet_bs16_lr0.001_augFalse | Epoch 3/5 | Train Acc: 0.5750 | Val Acc: 0.5000
alexnet_bs16_lr0.001_augFalse | Epoch 4/5 | Train Acc: 0.3500 | Val Acc: 0.5000
alexnet_bs16_lr0.001_augFalse | Epoch 5/5 | Train Acc: 0.5000 | Val Acc: 0.5000


0,1
epoch,▁▃▅▆█
train_acc,▁▆█▁▆
train_loss,█▅▁▂▅
val_acc,▁▁▁▁▁
val_loss,▇▂▁█▂

0,1
epoch,5.0
train_acc,0.5
train_loss,3.91095
val_acc,0.5
val_loss,1.22096


alexnet_bs16_lr0.0001_augTrue | Epoch 1/5 | Train Acc: 0.0500 | Val Acc: 0.5000
alexnet_bs16_lr0.0001_augTrue | Epoch 2/5 | Train Acc: 0.6000 | Val Acc: 0.7000
alexnet_bs16_lr0.0001_augTrue | Epoch 3/5 | Train Acc: 0.8750 | Val Acc: 0.8000
alexnet_bs16_lr0.0001_augTrue | Epoch 4/5 | Train Acc: 0.9000 | Val Acc: 0.8000
alexnet_bs16_lr0.0001_augTrue | Epoch 5/5 | Train Acc: 0.9750 | Val Acc: 0.9000


0,1
epoch,▁▃▅▆█
train_acc,▁▅▇▇█
train_loss,█▂▁▁▁
val_acc,▁▅▆▆█
val_loss,█▄▄▂▁

0,1
epoch,5.0
train_acc,0.975
train_loss,0.07802
val_acc,0.9
val_loss,0.22488


alexnet_bs16_lr0.0001_augFalse | Epoch 1/5 | Train Acc: 0.1000 | Val Acc: 0.5000
alexnet_bs16_lr0.0001_augFalse | Epoch 2/5 | Train Acc: 0.6750 | Val Acc: 0.8000
alexnet_bs16_lr0.0001_augFalse | Epoch 3/5 | Train Acc: 0.9250 | Val Acc: 0.8000
alexnet_bs16_lr0.0001_augFalse | Epoch 4/5 | Train Acc: 0.9000 | Val Acc: 0.9000
alexnet_bs16_lr0.0001_augFalse | Epoch 5/5 | Train Acc: 1.0000 | Val Acc: 0.9000


0,1
epoch,▁▃▅▆█
train_acc,▁▅▇▇█
train_loss,█▂▁▁▁
val_acc,▁▆▆██
val_loss,█▃▃▁▁

0,1
epoch,5.0
train_acc,1.0
train_loss,0.02145
val_acc,0.9
val_loss,0.3083


alexnet_bs64_lr0.001_augTrue | Epoch 1/5 | Train Acc: 0.0000 | Val Acc: 0.5000
alexnet_bs64_lr0.001_augTrue | Epoch 2/5 | Train Acc: 0.5000 | Val Acc: 0.5000
alexnet_bs64_lr0.001_augTrue | Epoch 3/5 | Train Acc: 0.5000 | Val Acc: 0.4000
alexnet_bs64_lr0.001_augTrue | Epoch 4/5 | Train Acc: 0.4500 | Val Acc: 0.4000
alexnet_bs64_lr0.001_augTrue | Epoch 5/5 | Train Acc: 0.5500 | Val Acc: 0.5000


0,1
epoch,▁▃▅▆█
train_acc,▁▇▇▇█
train_loss,█▂▇▄▁
val_acc,██▁▁█
val_loss,▂█▄▁▄

0,1
epoch,5.0
train_acc,0.55
train_loss,0.78654
val_acc,0.5
val_loss,6.15036


alexnet_bs64_lr0.001_augFalse | Epoch 1/5 | Train Acc: 0.0000 | Val Acc: 0.6000
alexnet_bs64_lr0.001_augFalse | Epoch 2/5 | Train Acc: 0.6750 | Val Acc: 0.5000
alexnet_bs64_lr0.001_augFalse | Epoch 3/5 | Train Acc: 0.5000 | Val Acc: 0.5000
alexnet_bs64_lr0.001_augFalse | Epoch 4/5 | Train Acc: 0.5250 | Val Acc: 0.5000
alexnet_bs64_lr0.001_augFalse | Epoch 5/5 | Train Acc: 0.5000 | Val Acc: 0.5000


0,1
epoch,▁▃▅▆█
train_acc,▁█▆▆▆
train_loss,▇▁█▃▂
val_acc,█▁▁▁▁
val_loss,▁█▃▂▂

0,1
epoch,5.0
train_acc,0.5
train_loss,2.77705
val_acc,0.5
val_loss,2.7468


alexnet_bs64_lr0.0001_augTrue | Epoch 1/5 | Train Acc: 0.0000 | Val Acc: 0.1000
alexnet_bs64_lr0.0001_augTrue | Epoch 2/5 | Train Acc: 0.1750 | Val Acc: 0.5000
alexnet_bs64_lr0.0001_augTrue | Epoch 3/5 | Train Acc: 0.5750 | Val Acc: 0.6000
alexnet_bs64_lr0.0001_augTrue | Epoch 4/5 | Train Acc: 0.7000 | Val Acc: 0.7000
alexnet_bs64_lr0.0001_augTrue | Epoch 5/5 | Train Acc: 0.6750 | Val Acc: 0.7000


0,1
epoch,▁▃▅▆█
train_acc,▁▃▇██
train_loss,█▄▂▁▁
val_acc,▁▆▇██
val_loss,█▃▂▁▁

0,1
epoch,5.0
train_acc,0.675
train_loss,0.79447
val_acc,0.7
val_loss,0.5423


alexnet_bs64_lr0.0001_augFalse | Epoch 1/5 | Train Acc: 0.0000 | Val Acc: 0.1000
alexnet_bs64_lr0.0001_augFalse | Epoch 2/5 | Train Acc: 0.1250 | Val Acc: 0.5000
alexnet_bs64_lr0.0001_augFalse | Epoch 3/5 | Train Acc: 0.6250 | Val Acc: 0.8000
alexnet_bs64_lr0.0001_augFalse | Epoch 4/5 | Train Acc: 0.8750 | Val Acc: 0.8000
alexnet_bs64_lr0.0001_augFalse | Epoch 5/5 | Train Acc: 0.7750 | Val Acc: 0.9000


0,1
epoch,▁▃▅▆█
train_acc,▁▂▆█▇
train_loss,█▄▂▁▁
val_acc,▁▅▇▇█
val_loss,█▃▂▁▁

0,1
epoch,5.0
train_acc,0.775
train_loss,0.52001
val_acc,0.9
val_loss,0.28157


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 83.5MB/s]


resnet18_bs16_lr0.001_augTrue | Epoch 1/5 | Train Acc: 0.6500 | Val Acc: 0.5000
resnet18_bs16_lr0.001_augTrue | Epoch 2/5 | Train Acc: 0.9250 | Val Acc: 0.5000
resnet18_bs16_lr0.001_augTrue | Epoch 3/5 | Train Acc: 0.9500 | Val Acc: 0.8000
resnet18_bs16_lr0.001_augTrue | Epoch 4/5 | Train Acc: 0.9500 | Val Acc: 0.5000
resnet18_bs16_lr0.001_augTrue | Epoch 5/5 | Train Acc: 0.9500 | Val Acc: 0.5000


0,1
epoch,▁▃▅▆█
train_acc,▁▇███
train_loss,█▄▂▁▃
val_acc,▁▁█▁▁
val_loss,▂▅▁█▇

0,1
epoch,5.0
train_acc,0.95
train_loss,0.21708
val_acc,0.5
val_loss,11.44249


resnet18_bs16_lr0.001_augFalse | Epoch 1/5 | Train Acc: 0.6500 | Val Acc: 1.0000
resnet18_bs16_lr0.001_augFalse | Epoch 2/5 | Train Acc: 0.9500 | Val Acc: 0.8000
resnet18_bs16_lr0.001_augFalse | Epoch 3/5 | Train Acc: 0.9750 | Val Acc: 0.5000
resnet18_bs16_lr0.001_augFalse | Epoch 4/5 | Train Acc: 0.9500 | Val Acc: 0.7000
resnet18_bs16_lr0.001_augFalse | Epoch 5/5 | Train Acc: 0.9750 | Val Acc: 0.8000


0,1
epoch,▁▃▅▆█
train_acc,▁▇█▇█
train_loss,█▄▁▂▁
val_acc,█▅▁▄▅
val_loss,▁▂█▄▃

0,1
epoch,5.0
train_acc,0.975
train_loss,0.0731
val_acc,0.8
val_loss,3.92444


resnet18_bs16_lr0.0001_augTrue | Epoch 1/5 | Train Acc: 0.6750 | Val Acc: 0.7000
resnet18_bs16_lr0.0001_augTrue | Epoch 2/5 | Train Acc: 0.9250 | Val Acc: 0.8000
resnet18_bs16_lr0.0001_augTrue | Epoch 3/5 | Train Acc: 0.9500 | Val Acc: 0.9000
resnet18_bs16_lr0.0001_augTrue | Epoch 4/5 | Train Acc: 1.0000 | Val Acc: 0.9000
resnet18_bs16_lr0.0001_augTrue | Epoch 5/5 | Train Acc: 1.0000 | Val Acc: 0.9000


0,1
epoch,▁▃▅▆█
train_acc,▁▆▇██
train_loss,█▄▂▂▁
val_acc,▁▅███
val_loss,█▄▂▁▁

0,1
epoch,5.0
train_acc,1.0
train_loss,0.01972
val_acc,0.9
val_loss,0.18716


resnet18_bs16_lr0.0001_augFalse | Epoch 1/5 | Train Acc: 0.7000 | Val Acc: 0.9000
resnet18_bs16_lr0.0001_augFalse | Epoch 2/5 | Train Acc: 1.0000 | Val Acc: 0.9000
resnet18_bs16_lr0.0001_augFalse | Epoch 3/5 | Train Acc: 1.0000 | Val Acc: 0.9000
resnet18_bs16_lr0.0001_augFalse | Epoch 4/5 | Train Acc: 1.0000 | Val Acc: 0.9000
resnet18_bs16_lr0.0001_augFalse | Epoch 5/5 | Train Acc: 1.0000 | Val Acc: 0.9000


0,1
epoch,▁▃▅▆█
train_acc,▁████
train_loss,█▂▁▁▁
val_acc,▁▁▁▁▁
val_loss,█▇▅▄▁

0,1
epoch,5.0
train_acc,1.0
train_loss,0.01282
val_acc,0.9
val_loss,0.2998


resnet18_bs64_lr0.001_augTrue | Epoch 1/5 | Train Acc: 0.4750 | Val Acc: 0.7000
resnet18_bs64_lr0.001_augTrue | Epoch 2/5 | Train Acc: 0.9750 | Val Acc: 0.9000
resnet18_bs64_lr0.001_augTrue | Epoch 3/5 | Train Acc: 1.0000 | Val Acc: 0.8000
resnet18_bs64_lr0.001_augTrue | Epoch 4/5 | Train Acc: 1.0000 | Val Acc: 0.8000
resnet18_bs64_lr0.001_augTrue | Epoch 5/5 | Train Acc: 1.0000 | Val Acc: 0.7000


0,1
epoch,▁▃▅▆█
train_acc,▁████
train_loss,█▁▁▁▁
val_acc,▁█▅▅▁
val_loss,▂▁▃▆█

0,1
epoch,5.0
train_acc,1.0
train_loss,0.00431
val_acc,0.7
val_loss,2.41779


resnet18_bs64_lr0.001_augFalse | Epoch 1/5 | Train Acc: 0.4750 | Val Acc: 0.6000
resnet18_bs64_lr0.001_augFalse | Epoch 2/5 | Train Acc: 1.0000 | Val Acc: 0.8000
resnet18_bs64_lr0.001_augFalse | Epoch 3/5 | Train Acc: 1.0000 | Val Acc: 1.0000
resnet18_bs64_lr0.001_augFalse | Epoch 4/5 | Train Acc: 1.0000 | Val Acc: 1.0000
resnet18_bs64_lr0.001_augFalse | Epoch 5/5 | Train Acc: 1.0000 | Val Acc: 0.7000


0,1
epoch,▁▃▅▆█
train_acc,▁████
train_loss,█▁▁▁▁
val_acc,▁▅██▃
val_loss,█▅▁▂▇

0,1
epoch,5.0
train_acc,1.0
train_loss,0.00048
val_acc,0.7
val_loss,0.54426


resnet18_bs64_lr0.0001_augTrue | Epoch 1/5 | Train Acc: 0.5250 | Val Acc: 0.5000
resnet18_bs64_lr0.0001_augTrue | Epoch 2/5 | Train Acc: 0.9000 | Val Acc: 0.5000
resnet18_bs64_lr0.0001_augTrue | Epoch 3/5 | Train Acc: 1.0000 | Val Acc: 0.5000
resnet18_bs64_lr0.0001_augTrue | Epoch 4/5 | Train Acc: 1.0000 | Val Acc: 0.5000
resnet18_bs64_lr0.0001_augTrue | Epoch 5/5 | Train Acc: 1.0000 | Val Acc: 0.6000


0,1
epoch,▁▃▅▆█
train_acc,▁▇███
train_loss,█▅▂▂▁
val_acc,▁▁▁▁█
val_loss,▁▁▆█▇

0,1
epoch,5.0
train_acc,1.0
train_loss,0.06665
val_acc,0.6
val_loss,0.57458


resnet18_bs64_lr0.0001_augFalse | Epoch 1/5 | Train Acc: 0.5500 | Val Acc: 0.7000
resnet18_bs64_lr0.0001_augFalse | Epoch 2/5 | Train Acc: 1.0000 | Val Acc: 0.7000
resnet18_bs64_lr0.0001_augFalse | Epoch 3/5 | Train Acc: 1.0000 | Val Acc: 0.8000
resnet18_bs64_lr0.0001_augFalse | Epoch 4/5 | Train Acc: 1.0000 | Val Acc: 0.8000
resnet18_bs64_lr0.0001_augFalse | Epoch 5/5 | Train Acc: 1.0000 | Val Acc: 0.8000


0,1
epoch,▁▃▅▆█
train_acc,▁████
train_loss,█▃▂▁▁
val_acc,▁▁███
val_loss,█▅▃▂▁

0,1
epoch,5.0
train_acc,1.0
train_loss,0.01523
val_acc,0.8
val_loss,0.36791
