In [1]:
import sys
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from tqdm import tqdm
from torch.utils.data import random_split, DataLoader

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.mps.is_available():
    device = torch.device("mps")

# Data transforms
# Add data augmentation BEFORE the standard resize/crop:
#   - transforms.Resize(512)
#   - transforms.RandomRotation(10)
#   - transforms.RandomVerticalFlip()
#   - transforms.RandomHorizontalFlip()
#   - transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2)
# Then keep the existing resize -> center crop -> to tensor -> normalize.
preprocess = transforms.Compose([
    transforms.Resize(512),
    transforms.RandomRotation(10),
    transforms.RandomVerticalFlip(),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),

    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dataset = torchvision.datasets.ImageFolder(
    root='../images',
    transform=preprocess
)

train_dataset, val_dataset = random_split(dataset, [0.8, 0.2])

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True)

# Base model (frozen ResNet50 backbone) + small classifier head
resnet50_model = torchvision.models.resnet50(
    weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1
)
resnet50_model.fc = nn.Identity()
for param in resnet50_model.parameters():
    param.requires_grad = False
resnet50_model.eval()
resnet50_model = resnet50_model.to(device)

fc_model = nn.Sequential(
    nn.Linear(2048, 1024),
    nn.ReLU(),
    nn.Linear(1024, 1)
)
fc_model = fc_model.to(device)

model = nn.Sequential(
    resnet50_model,
    fc_model
)
model = model.to(device)

optimizer = torch.optim.Adam(fc_model.parameters(), lr=0.00025)
loss_fn = nn.BCEWithLogitsLoss()

# Training loop
# Train for 15 epochs (print a clear epoch header each time).
# After each epoch, save ONLY the fc head weights to "fc_model_{epoch}.pth".
for epoch in range(15):
    print(f"--- EPOCH: {epoch} ---")
    model.train()
    resnet50_model.eval()

    loss_sum = 0
    train_accurate = 0
    train_sum = 0
    for X, y in tqdm(train_dataloader):
        X = X.to(device)
        y = y.to(device).type(torch.float).reshape(-1, 1)

        outputs = model(X)
        optimizer.zero_grad()
        loss = loss_fn(outputs, y)
        loss_sum+=loss.item()
        loss.backward()
        optimizer.step()

        predictions = torch.sigmoid(outputs) > 0.5
        accurate = (predictions == y).sum().item()
        train_accurate+=accurate
        train_sum+=y.size(0)
    print("Training loss: ", loss_sum / len(train_dataloader))
    print("Training accuracy: ", train_accurate / train_sum)

    torch.save(fc_model.state_dict(), f"fc_model_{epoch}.pth")

    model.eval()
    val_loss_sum = 0
    val_accurate = 0
    val_sum = 0
    with torch.no_grad():
        for X, y in tqdm(val_dataloader):
            X = X.to(device)
            y = y.to(device).type(torch.float).reshape(-1, 1)

            outputs = model(X)
            loss = loss_fn(outputs, y)
            val_loss_sum+=loss.item()

            predictions = torch.sigmoid(outputs) > 0.5
            accurate = (predictions == y).sum().item()
            val_accurate+=accurate
            val_sum+=y.size(0)
    print("Validation loss: ", val_loss_sum / len(val_dataloader))
    print("Validation accuracy: ", val_accurate / val_sum)

--- EPOCH: 0 ---


100%|██████████| 47/47 [00:56<00:00,  1.21s/it]


Training loss:  0.340419369333602
Training accuracy:  0.8430976430976431


100%|██████████| 12/12 [00:14<00:00,  1.17s/it]


Validation loss:  0.3005264798800151
Validation accuracy:  0.8814016172506739
--- EPOCH: 1 ---


100%|██████████| 47/47 [00:56<00:00,  1.20s/it]


Training loss:  0.20912580271350575
Training accuracy:  0.907070707070707


100%|██████████| 12/12 [00:14<00:00,  1.17s/it]


Validation loss:  0.2308883424848318
Validation accuracy:  0.9083557951482479
--- EPOCH: 2 ---


100%|██████████| 47/47 [00:56<00:00,  1.21s/it]


Training loss:  0.18361771637771993
Training accuracy:  0.9245791245791246


100%|██████████| 12/12 [00:14<00:00,  1.17s/it]


Validation loss:  0.21019991611440977
Validation accuracy:  0.9191374663072777
--- EPOCH: 3 ---


100%|██████████| 47/47 [00:56<00:00,  1.21s/it]


Training loss:  0.1722581849453297
Training accuracy:  0.9232323232323232


100%|██████████| 12/12 [00:14<00:00,  1.19s/it]


Validation loss:  0.21452444108823934
Validation accuracy:  0.8975741239892183
--- EPOCH: 4 ---


100%|██████████| 47/47 [00:56<00:00,  1.20s/it]


Training loss:  0.17610980388014874
Training accuracy:  0.9313131313131313


100%|██████████| 12/12 [00:14<00:00,  1.18s/it]


Validation loss:  0.19007102058579525
Validation accuracy:  0.9191374663072777
--- EPOCH: 5 ---


100%|██████████| 47/47 [00:57<00:00,  1.23s/it]


Training loss:  0.1572688008718034
Training accuracy:  0.936026936026936


100%|██████████| 12/12 [00:14<00:00,  1.21s/it]


Validation loss:  0.21260770099858442
Validation accuracy:  0.9299191374663073
--- EPOCH: 6 ---


100%|██████████| 47/47 [00:57<00:00,  1.23s/it]


Training loss:  0.13941328366227607
Training accuracy:  0.9447811447811448


100%|██████████| 12/12 [00:14<00:00,  1.20s/it]


Validation loss:  0.23282325702408949
Validation accuracy:  0.9083557951482479
--- EPOCH: 7 ---


100%|██████████| 47/47 [00:57<00:00,  1.22s/it]


Training loss:  0.13687751592790826
Training accuracy:  0.9474747474747475


100%|██████████| 12/12 [00:14<00:00,  1.20s/it]


Validation loss:  0.1750738782187303
Validation accuracy:  0.9433962264150944
--- EPOCH: 8 ---


100%|██████████| 47/47 [00:58<00:00,  1.24s/it]


Training loss:  0.11971397258024266
Training accuracy:  0.9508417508417508


100%|██████████| 12/12 [00:14<00:00,  1.20s/it]


Validation loss:  0.20990366178254286
Validation accuracy:  0.9245283018867925
--- EPOCH: 9 ---


100%|██████████| 47/47 [00:57<00:00,  1.23s/it]


Training loss:  0.13257353435805502
Training accuracy:  0.9454545454545454


100%|██████████| 12/12 [00:14<00:00,  1.25s/it]


Validation loss:  0.21037899578611055
Validation accuracy:  0.9110512129380054
--- EPOCH: 10 ---


100%|██████████| 47/47 [00:58<00:00,  1.24s/it]


Training loss:  0.11939226764630764
Training accuracy:  0.9542087542087542


100%|██████████| 12/12 [00:14<00:00,  1.20s/it]


Validation loss:  0.23941806455453238
Validation accuracy:  0.9191374663072777
--- EPOCH: 11 ---


100%|██████████| 47/47 [00:57<00:00,  1.21s/it]


Training loss:  0.10236803707765772
Training accuracy:  0.958922558922559


100%|██████████| 12/12 [00:14<00:00,  1.17s/it]


Validation loss:  0.19003129191696644
Validation accuracy:  0.9353099730458221
--- EPOCH: 12 ---


100%|██████████| 47/47 [00:56<00:00,  1.21s/it]


Training loss:  0.09466455797565744
Training accuracy:  0.9663299663299664


100%|██████████| 12/12 [00:14<00:00,  1.22s/it]


Validation loss:  0.1854741731658578
Validation accuracy:  0.9299191374663073
--- EPOCH: 13 ---


100%|██████████| 47/47 [00:58<00:00,  1.24s/it]


Training loss:  0.10431738885396973
Training accuracy:  0.9582491582491582


100%|██████████| 12/12 [00:14<00:00,  1.20s/it]


Validation loss:  0.18827213905751705
Validation accuracy:  0.9353099730458221
--- EPOCH: 14 ---


100%|██████████| 47/47 [00:57<00:00,  1.23s/it]


Training loss:  0.09118585668979808
Training accuracy:  0.9683501683501684


100%|██████████| 12/12 [00:14<00:00,  1.20s/it]

Validation loss:  0.22851158306002617
Validation accuracy:  0.9110512129380054



