In [1]:
# train_resnet_model.py
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
import pandas as pd
from tqdm import tqdm
from sklearn.utils.class_weight import compute_class_weight
import numpy as np

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 40
BATCH_SIZE = 32
IMG_SIZE = 224

TRAIN_CSV = "./data/train.csv"
VAL_CSV = "./data/val.csv"
TEST_CSV = "./data/test.csv"
TRAIN_DIR = "./data/train"
VAL_DIR = "./data/val"
TEST_DIR = "./data/test"
MODEL_SAVE_PATH = "./resnet_based_models"
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

class PhotoDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.data.iloc[idx, 0])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        label = self.data.iloc[idx, 1]
        return image, label

# Data augmentation + normalization
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

train_dataset = PhotoDataset(TRAIN_CSV, TRAIN_DIR, transform)
val_dataset = PhotoDataset(VAL_CSV, VAL_DIR, transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Compute class weights
y_train = pd.read_csv(TRAIN_CSV)['label'].values
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)

# Load pretrained ResNet18 and modify classifier
resnet = models.resnet18(pretrained=True)
resnet.fc = nn.Sequential(
    nn.Linear(resnet.fc.in_features, 128),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(128, 1)  # Binary classification
)
model = resnet.to(DEVICE)

criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights_tensor[1])
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training loop
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        inputs = inputs.to(DEVICE)
        labels = labels.float().unsqueeze(1).to(DEVICE)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}: Loss = {avg_loss:.4f}")

    # Save model
    if (epoch + 1) % 5 == 0:
        save_path = os.path.join(MODEL_SAVE_PATH, f"resnet_model_epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), save_path)
        print(f"✅ Model saved at {save_path}")


Epoch 1/40: 100%|██████████| 125/125 [05:19<00:00,  2.55s/it]


Epoch 1: Loss = 0.8218


Epoch 2/40: 100%|██████████| 125/125 [05:08<00:00,  2.47s/it]


Epoch 2: Loss = 0.6288


Epoch 3/40: 100%|██████████| 125/125 [05:15<00:00,  2.52s/it]


Epoch 3: Loss = 0.6042


Epoch 4/40: 100%|██████████| 125/125 [05:15<00:00,  2.52s/it]


Epoch 4: Loss = 0.5314


Epoch 5/40: 100%|██████████| 125/125 [05:11<00:00,  2.49s/it]


Epoch 5: Loss = 0.4684
✅ Model saved at ./resnet_based_models/resnet_model_epoch_5.pth


Epoch 6/40: 100%|██████████| 125/125 [05:18<00:00,  2.55s/it]


Epoch 6: Loss = 0.4889


Epoch 7/40: 100%|██████████| 125/125 [05:03<00:00,  2.42s/it]


Epoch 7: Loss = 0.3914


Epoch 8/40: 100%|██████████| 125/125 [05:03<00:00,  2.43s/it]


Epoch 8: Loss = 0.3886


Epoch 9/40: 100%|██████████| 125/125 [03:47<00:00,  1.82s/it]


Epoch 9: Loss = 0.3747


Epoch 10/40: 100%|██████████| 125/125 [03:23<00:00,  1.63s/it]


Epoch 10: Loss = 0.2985
✅ Model saved at ./resnet_based_models/resnet_model_epoch_10.pth


Epoch 11/40: 100%|██████████| 125/125 [03:27<00:00,  1.66s/it]


Epoch 11: Loss = 0.3264


Epoch 12/40: 100%|██████████| 125/125 [03:21<00:00,  1.61s/it]


Epoch 12: Loss = 0.2932


Epoch 13/40: 100%|██████████| 125/125 [03:22<00:00,  1.62s/it]


Epoch 13: Loss = 0.2789


Epoch 14/40: 100%|██████████| 125/125 [03:24<00:00,  1.63s/it]


Epoch 14: Loss = 0.3115


Epoch 15/40: 100%|██████████| 125/125 [03:21<00:00,  1.62s/it]


Epoch 15: Loss = 0.2531
✅ Model saved at ./resnet_based_models/resnet_model_epoch_15.pth


Epoch 16/40: 100%|██████████| 125/125 [03:28<00:00,  1.66s/it]


Epoch 16: Loss = 0.2367


Epoch 17/40: 100%|██████████| 125/125 [03:21<00:00,  1.61s/it]


Epoch 17: Loss = 0.2117


Epoch 18/40: 100%|██████████| 125/125 [03:18<00:00,  1.59s/it]


Epoch 18: Loss = 0.1684


Epoch 19/40: 100%|██████████| 125/125 [03:23<00:00,  1.62s/it]


Epoch 19: Loss = 0.1851


Epoch 20/40: 100%|██████████| 125/125 [03:23<00:00,  1.63s/it]


Epoch 20: Loss = 0.1744
✅ Model saved at ./resnet_based_models/resnet_model_epoch_20.pth


Epoch 21/40: 100%|██████████| 125/125 [03:26<00:00,  1.65s/it]


Epoch 21: Loss = 0.2557


Epoch 22/40: 100%|██████████| 125/125 [03:29<00:00,  1.68s/it]


Epoch 22: Loss = 0.1999


Epoch 23/40: 100%|██████████| 125/125 [03:31<00:00,  1.69s/it]


Epoch 23: Loss = 0.1607


Epoch 24/40: 100%|██████████| 125/125 [03:35<00:00,  1.72s/it]


Epoch 24: Loss = 0.1598


Epoch 25/40: 100%|██████████| 125/125 [03:43<00:00,  1.79s/it]


Epoch 25: Loss = 0.1643
✅ Model saved at ./resnet_based_models/resnet_model_epoch_25.pth


Epoch 26/40: 100%|██████████| 125/125 [03:35<00:00,  1.72s/it]


Epoch 26: Loss = 0.1698


Epoch 27/40: 100%|██████████| 125/125 [03:39<00:00,  1.76s/it]


Epoch 27: Loss = 0.1388


Epoch 28/40: 100%|██████████| 125/125 [03:49<00:00,  1.84s/it]


Epoch 28: Loss = 0.1253


Epoch 29/40: 100%|██████████| 125/125 [03:41<00:00,  1.77s/it]


Epoch 29: Loss = 0.1570


Epoch 30/40: 100%|██████████| 125/125 [03:42<00:00,  1.78s/it]


Epoch 30: Loss = 0.1151
✅ Model saved at ./resnet_based_models/resnet_model_epoch_30.pth


Epoch 31/40: 100%|██████████| 125/125 [03:21<00:00,  1.62s/it]


Epoch 31: Loss = 0.1858


Epoch 32/40: 100%|██████████| 125/125 [03:19<00:00,  1.59s/it]


Epoch 32: Loss = 0.1499


Epoch 33/40: 100%|██████████| 125/125 [03:19<00:00,  1.60s/it]


Epoch 33: Loss = 0.0660


Epoch 34/40: 100%|██████████| 125/125 [03:22<00:00,  1.62s/it]


Epoch 34: Loss = 0.1362


Epoch 35/40: 100%|██████████| 125/125 [03:18<00:00,  1.59s/it]


Epoch 35: Loss = 0.0905
✅ Model saved at ./resnet_based_models/resnet_model_epoch_35.pth


Epoch 36/40: 100%|██████████| 125/125 [03:17<00:00,  1.58s/it]


Epoch 36: Loss = 0.0748


Epoch 37/40: 100%|██████████| 125/125 [03:19<00:00,  1.60s/it]


Epoch 37: Loss = 0.1338


Epoch 38/40: 100%|██████████| 125/125 [03:22<00:00,  1.62s/it]


Epoch 38: Loss = 0.1053


Epoch 39/40: 100%|██████████| 125/125 [03:36<00:00,  1.73s/it]


Epoch 39: Loss = 0.0473


Epoch 40/40: 100%|██████████| 125/125 [03:31<00:00,  1.69s/it]

Epoch 40: Loss = 0.1310
✅ Model saved at ./resnet_based_models/resnet_model_epoch_40.pth





## Oversampling based approach

In [4]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
import pandas as pd
from tqdm import tqdm
from sklearn.utils.class_weight import compute_class_weight
import numpy as np

# Constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 40
BATCH_SIZE = 32
IMG_SIZE = 224

# Paths
TRAIN_CSV = "./data/train.csv"
VAL_CSV = "./data/val.csv"
TRAIN_DIR = "./data/train"
VAL_DIR = "./data/val"
MODEL_SAVE_PATH = "./resnet_based_models_oversampling"
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

# Load train & val CSVs
train_df = pd.read_csv(TRAIN_CSV)
val_df = pd.read_csv(VAL_CSV)

# Get all label==1 from val and add to train
val_label_1 = val_df[val_df["label"] == 1].copy()
val_label_1["source"] = "val"
train_df["source"] = "train"
combined_df = pd.concat([train_df, val_label_1], ignore_index=True)

# Now balance the dataset (upsample minority)
label_counts = combined_df["label"].value_counts()
print(f"Before balancing:\n{label_counts}")

min_class = combined_df[combined_df["label"] == 1]
maj_class = combined_df[combined_df["label"] == 0]

if len(min_class) < len(maj_class):
    # Upsample minority
    min_class_upsampled = min_class.sample(len(maj_class), replace=True, random_state=42)
    balanced_df = pd.concat([maj_class, min_class_upsampled], ignore_index=True)
    print(f"✅ Upsampled label=1 to match label=0: {len(min_class_upsampled)} samples.")
else:
    # Upsample majority (rare case)
    maj_class_upsampled = maj_class.sample(len(min_class), replace=True, random_state=42)
    balanced_df = pd.concat([min_class, maj_class_upsampled], ignore_index=True)

# Shuffle dataset
balanced_df = balanced_df.sample(frac=1, random_state=42).reset_index(drop=True)

# Save balanced dataset for tracking
balanced_df.to_csv("train_balanced.csv", index=False)

# Dataset class
class PhotoDataset(Dataset):
    def __init__(self, dataframe, img_dirs, transform=None):
        self.data = dataframe
        self.img_dirs = img_dirs
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img_dir = self.img_dirs[row['source']]
        img_path = os.path.join(img_dir, row['path'])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        label = row['label']
        return image, label

# Transforms
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# Datasets
train_dataset = PhotoDataset(balanced_df, {"train": TRAIN_DIR, "val": VAL_DIR}, train_transform)
val_dataset = PhotoDataset(val_df, {"val": VAL_DIR}, val_transform)

# Loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Class weights (should be 1:1 now, but still good practice)
y_train = balanced_df['label'].values
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(DEVICE)

# Model
resnet = models.resnet18(pretrained=True)
resnet.fc = nn.Sequential(
    nn.Linear(resnet.fc.in_features, 128),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(128, 1)
)
model = resnet.to(DEVICE)

# Loss & Optimizer
criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights_tensor[1])
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training loop
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        inputs = inputs.to(DEVICE)
        labels = labels.float().unsqueeze(1).to(DEVICE)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"📘 Epoch {epoch+1}: Loss = {avg_loss:.4f}")

    # Save model every 5 epochs
    if (epoch + 1) % 2 == 0:
        save_path = os.path.join(MODEL_SAVE_PATH, f"resnet_model_epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), save_path)
        print(f"✅ Model saved at {save_path}")




Before balancing:
label
0    3800
1     216
Name: count, dtype: int64
✅ Upsampled label=1 to match label=0: 3800 samples.


Epoch 1/40: 100%|██████████| 238/238 [06:27<00:00,  1.63s/it]


📘 Epoch 1: Loss = 0.3794


Epoch 2/40: 100%|██████████| 238/238 [06:28<00:00,  1.63s/it]


📘 Epoch 2: Loss = 0.1876
✅ Model saved at ./resnet_based_models_oversampling/resnet_model_epoch_2.pth


Epoch 3/40: 100%|██████████| 238/238 [06:28<00:00,  1.63s/it]


📘 Epoch 3: Loss = 0.1164


Epoch 4/40: 100%|██████████| 238/238 [06:23<00:00,  1.61s/it]


📘 Epoch 4: Loss = 0.0826
✅ Model saved at ./resnet_based_models_oversampling/resnet_model_epoch_4.pth


Epoch 5/40: 100%|██████████| 238/238 [06:31<00:00,  1.64s/it]


📘 Epoch 5: Loss = 0.0619


Epoch 6/40: 100%|██████████| 238/238 [06:29<00:00,  1.64s/it]


📘 Epoch 6: Loss = 0.0502
✅ Model saved at ./resnet_based_models_oversampling/resnet_model_epoch_6.pth


Epoch 7/40: 100%|██████████| 238/238 [06:44<00:00,  1.70s/it]


📘 Epoch 7: Loss = 0.0444


Epoch 8/40: 100%|██████████| 238/238 [06:48<00:00,  1.72s/it]


📘 Epoch 8: Loss = 0.0429
✅ Model saved at ./resnet_based_models_oversampling/resnet_model_epoch_8.pth


Epoch 9/40: 100%|██████████| 238/238 [06:52<00:00,  1.73s/it]


📘 Epoch 9: Loss = 0.0246


Epoch 10/40: 100%|██████████| 238/238 [07:05<00:00,  1.79s/it]


📘 Epoch 10: Loss = 0.0324
✅ Model saved at ./resnet_based_models_oversampling/resnet_model_epoch_10.pth


Epoch 11/40: 100%|██████████| 238/238 [07:11<00:00,  1.81s/it]


📘 Epoch 11: Loss = 0.0218


Epoch 12/40: 100%|██████████| 238/238 [06:34<00:00,  1.66s/it]


📘 Epoch 12: Loss = 0.0188
✅ Model saved at ./resnet_based_models_oversampling/resnet_model_epoch_12.pth


Epoch 13/40: 100%|██████████| 238/238 [06:52<00:00,  1.73s/it]


📘 Epoch 13: Loss = 0.0242


Epoch 14/40:  72%|███████▏  | 171/238 [04:49<01:53,  1.69s/it]


KeyboardInterrupt: 