In [1]:
import sys
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import random_split, DataLoader

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.mps.is_available():
    device = torch.device("mps")

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dataset = torchvision.datasets.ImageFolder(
    root='../images',
    transform=preprocess
)
train_dataset, val_dataset = random_split(dataset, [0.8, 0.2])

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True)

resnet50_model = torchvision.models.resnet50(
    weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1
)
resnet50_model.fc = nn.Identity()
for param in resnet50_model.parameters():
    param.requires_grad = False
resnet50_model.eval()
resnet50_model = resnet50_model.to(device)

fc_model = nn.Sequential(
    nn.Linear(2048, 1024),
    nn.ReLU(),
    nn.Linear(1024, 1)
)
fc_model = fc_model.to(device)

model = nn.Sequential(
    resnet50_model,
    fc_model
)
model = model.to(device)

# Use a smaller LR (e.g., 0.00025) for the Adam optimizer on fc_model params
optimizer = torch.optim.Adam(fc_model.parameters(), lr=0.00025)

loss_fn = nn.BCEWithLogitsLoss()

for epoch in range(10):
    # Put model in train mode; keep frozen ResNet backbone in eval mode inside training loop
    model.train()
    resnet50_model.eval()

    loss_sum = 0
    # Track training accuracy: accumulate correct predictions and counts
    train_accurate, train_sum = 0, 0

    for X, y in train_dataloader:
        X = X.to(device)
        y = y.to(device).type(torch.float).reshape(-1, 1)

        outputs = model(X)
        optimizer.zero_grad()
        loss = loss_fn(outputs, y)
        loss_sum += loss.item()
        loss.backward()
        optimizer.step()

        # Compute predictions = sigmoid(outputs) > 0.5 and update train_accurate / train_sum
        predictions = torch.sigmoid(outputs) > 0.5
        train_accurate += (predictions == y).sum().item()
        train_sum += y.size(0)

    # Print average training loss and training accuracy for the epoch
    print("Training loss:", loss_sum / len(train_dataloader))
    print("Training accuracy:", train_accurate / train_sum)

    # Add a validation loop in eval/no_grad() computing avg val loss and val accuracy
    model.eval()
    with torch.no_grad():
        val_loss_sum, val_accurate, val_sum = 0, 0, 0
        for X, y in val_dataloader:
            X = X.to(device)
            y = y.to(device).type(torch.float).reshape(-1, 1)

            outputs = model(X)
            loss = loss_fn(outputs, y)
            val_loss_sum += loss.item()

            preds = torch.sigmoid(outputs) > 0.5
            val_accurate += (preds == y).sum().item()
            val_sum += y.size(0)

    print("Validation loss:", val_loss_sum / len(val_dataloader))
    print("Validation accuracy:", val_accurate / val_sum)

Training loss: 0.3367737937480845
Training accuracy: 0.8505050505050505
Validation loss: 0.22291619951526323
Validation accuracy: 0.9110512129380054
Training loss: 0.22949915553661102
Training accuracy: 0.9050505050505051
Validation loss: 0.18017095451553664
Validation accuracy: 0.9164420485175202
Training loss: 0.16701401365881272
Training accuracy: 0.934006734006734
Validation loss: 0.17790906938413778
Validation accuracy: 0.921832884097035
Training loss: 0.14810873647319509
Training accuracy: 0.9488215488215488
Validation loss: 0.16170590557157993
Validation accuracy: 0.9326145552560647
Training loss: 0.13247888780971792
Training accuracy: 0.9542087542087542
Validation loss: 0.21936001380284628
Validation accuracy: 0.9083557951482479
Training loss: 0.12458149914411788
Training accuracy: 0.9542087542087542
Validation loss: 0.16808504021416107
Validation accuracy: 0.9299191374663073
Training loss: 0.0997519130085377
Training accuracy: 0.9676767676767677
Validation loss: 0.146174860186