In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from pathlib import Path

root = Path("/Users/Chandraprakash.Patra/Downloads/codes/photo_of_photo/version_3.0/dataset")

train_t = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    # Color jitter injects controlled randomness into brightness and contrast. Photo-of-photo images often carry screen glare, uneven lighting, exposure shifts, and moiré contrast distortions. A static model overfits exact capture conditions. Mild jitter forces the network to learn features that survive these variations. It reduces dependence on accidental lighting cues and stabilizes the classifier.
    transforms.ToTensor(),
])

val_t = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_ds = datasets.ImageFolder(root / "train", transform=train_t)
val_ds   = datasets.ImageFolder(root / "val",   transform=val_t)
test_ds  = datasets.ImageFolder(root / "test",  transform=val_t)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_ds,   batch_size=16, shuffle=False, num_workers=2)

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = 'mps'
device = 'cpu'

model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 2)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

def run_epoch(loader, train_mode):
    if train_mode:
        model.train()
    else:
        model.eval()

    total = 0
    correct = 0
    loss_sum = 0

    with torch.set_grad_enabled(train_mode):
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)

            if train_mode:
                optimizer.zero_grad()

            out = model(imgs)
            loss = criterion(out, labels)

            if train_mode:
                loss.backward()
                optimizer.step()

            loss_sum += loss.item() * imgs.size(0)
            _, pred = out.max(1)
            total += labels.size(0)
            correct += pred.eq(labels).sum().item()

    return loss_sum / total, correct / total

for epoch in range(4):
    train_loss, train_acc = run_epoch(train_loader, True)
    val_loss, val_acc = run_epoch(val_loader, False)
    print(f"- epoch {epoch+1}: train_acc={train_acc:.3f} val_acc={val_acc:.3f}")

torch.save(model.state_dict(), "resnet18_live_photo.pth")
print("saved")

- epoch 1: train_acc=0.880 val_acc=0.969
- epoch 2: train_acc=0.959 val_acc=0.953
- epoch 3: train_acc=0.979 val_acc=0.969
- epoch 4: train_acc=0.981 val_acc=0.969


# Observation 
##### 1st run for observation
- epoch 1: train_acc=0.874 val_acc=0.938
- epoch 2: train_acc=0.979 val_acc=0.953
- epoch 3: train_acc=0.984 val_acc=0.969
- epoch 4: train_acc=0.986 val_acc=0.969
- epoch 5: train_acc=0.988 val_acc=0.953
- epoch 6: train_acc=0.994 val_acc=0.969
- epoch 7: train_acc=0.996 val_acc=0.969
- epoch 8: train_acc=0.981 val_acc=0.953
- epoch 9: train_acc=0.988 val_acc=0.969
- epoch 10: train_acc=0.983 val_acc=0.969
- epoch 11: train_acc=0.986 val_acc=0.969
- epoch 12: train_acc=0.998 val_acc=0.953
- epoch 13: train_acc=0.996 val_acc=0.953
- epoch 14: train_acc=0.992 val_acc=0.969
- epoch 15: train_acc=1.000 val_acc=0.969

# Choosen epoch for 
##### Training saturated early. Validation accuracy stabilised at 0.953–0.969 from epoch 2 onward. Later epochs only inflate train accuracy without shifting validation. That is the signature of a converged classifier with mild overfit drift.

##### Right stopping point: epoch 3 or 4.

- epoch 1: train_acc=0.880 val_acc=0.969
- epoch 2: train_acc=0.959 val_acc=0.953
- epoch 3: train_acc=0.979 val_acc=0.969
- epoch 4: train_acc=0.981 val_acc=0.969

