In [None]:
import sys
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import random_split, DataLoader
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.mps.is_available():
    device = torch.device("mps")

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dataset = torchvision.datasets.ImageFolder(
    root='../images',
    transform=preprocess
)
train_dataset, val_dataset = random_split(dataset, [0.8, 0.2])

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True)

resnet50_model = torchvision.models.resnet50(
    weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1
)
resnet50_model.fc = nn.Identity()
for param in resnet50_model.parameters():
    param.requires_grad = False
resnet50_model.eval()
resnet50_model = resnet50_model.to(device)

fc_model = nn.Sequential(
    nn.Linear(2048, 1024),
    nn.ReLU(),
    nn.Linear(1024, 1)
)
fc_model = fc_model.to(device)

model = nn.Sequential(
    resnet50_model,
    fc_model
)
model = model.to(device)

# TODO: Use a smaller LR (e.g., 0.00025) for the Adam optimizer on fc_model params
optimizer = torch.optim.Adam(fc_model.parameters(), lr=0.001)

loss_fn = nn.BCEWithLogitsLoss()

for epoch in range(10):
    # TODO: Put model in train mode; keep frozen ResNet backbone in eval mode inside training loop

    loss_sum = 0
    # TODO: Track training accuracy: accumulate correct predictions and counts

    for X, y in train_dataloader:
        X = X.to(device)
        y = y.to(device).type(torch.float).reshape(-1, 1)

        outputs = model(X)
        optimizer.zero_grad()
        loss = loss_fn(outputs, y)
        loss_sum += loss.item()
        loss.backward()
        optimizer.step()

        # TODO: Compute predictions = sigmoid(outputs) > 0.5 and update train_accurate / train_sum

    # TODO: Print average training loss and training accuracy for the epoch

    # TODO: Add a validation loop in eval/no_grad() computing avg val loss and val accuracy
