<a href="https://colab.research.google.com/github/michael-L-i/CS229-Final-Project/blob/main/Physics_Based_ResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install PyWavelets

Collecting PyWavelets
  Downloading pywavelets-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.0 kB)
Downloading pywavelets-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/4.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m4.5/4.5 MB[0m [31m196.1 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.5/4.5 MB[0m [31m96.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: PyWavelets
Successfully installed PyWavelets-1.8.0


In [None]:
from google.colab import drive
drive.mount('/content/drive')

import copy
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset, random_split
from PIL import Image
import os
from tqdm import tqdm
import logging
import matplotlib.pyplot as plt
import pandas as pd
import sys
import random
import numpy as np
import cv2
import pywt

Mounted at /content/drive


In [None]:
# ========================
# SET RANDOMNESS
# ========================

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# For CUDA convolution determinism
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# ========================
# PARAMETERS AND LOGGING
# ========================

# Set up logging to both file and console
handler = logging.StreamHandler(sys.stdout)
handler.flush = sys.stdout.flush
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('training_log.log'),
        handler
    ]
)

# Base configuration dictionary for easy tuning
base_config = {
    "resnet_version": 50,  # Options: 18, 34, 50, 101, 152
    "batch_size": 32,
    "learning_rate": 0.00005,
    "weight_decay": 1e-4,
    "num_epochs": 5,
    "momentum": 0.9,         # Not used with Adam (but useful for SGD)
    "use_scheduler": True,
    "train_split": 0.8       # This is no longer used for splitting
}

# define num_workers for speedup
n_workers = 15


In [None]:
class WaveletTransform(object):
    def __init__(self, wavelet='haar'):
        self.wavelet = wavelet

    def __call__(self, pil_img):
        # Convert PIL Image to NumPy array
        img = np.array(pil_img)

        # If there's an alpha channel, drop it
        if img.shape[-1] == 4:
            img = img[..., :3]

        # Convert from RGB (PIL) to BGR (OpenCV convention)
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

        # Convert to grayscale
        img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        # Perform single-level 2D Discrete Wavelet Transform
        cA, (cH, cV, cD) = pywt.dwt2(img_gray, self.wavelet)

        # Function to normalize wavelet coefficients to [0, 255]
        def normalize(channel):
            channel = channel.astype(np.float32)
            channel_min, channel_max = channel.min(), channel.max()
            # Avoid divide-by-zero
            channel = (channel - channel_min) / (channel_max - channel_min + 1e-6)
            return (channel * 255).astype(np.uint8)

        # Normalize detail coefficients
        cH = normalize(cH)
        cV = normalize(cV)
        cD = normalize(cD)

        # Stack detail coefficients as 3 channels: (H, V, D) -> (R, G, B)
        wavelet_3ch = np.dstack([cH, cV, cD])

        # Convert NumPy array back to a PIL Image in RGB mode
        wavelet_3ch_pil = Image.fromarray(wavelet_3ch, mode='RGB')

        return wavelet_3ch_pil

In [None]:
# ========================
# DATASET DEFINITIONS
# ========================

class RecaptureDataset(Dataset):
    def __init__(self, root_dir, transform=None, is_test=False):
        """
        For training/validation/testing, expects two subfolders:
        'SingleCaptureImages' (label 0) and 'RecapturedImages' (label 1).
        """
        self.transform = transform
        self.is_test = is_test

        self.single_capture_path = os.path.join(root_dir, 'SingleCaptureImages')
        self.recapture_path = os.path.join(root_dir, 'RecapturedImages')
        self.single_capture_images = []
        self.recapture_images = []
        # Get images from SingleCaptureImages
        for root, _, files in os.walk(self.single_capture_path):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.single_capture_images.append(os.path.join(root, file))
        # Get images from RecapturedImages
        for root, _, files in os.walk(self.recapture_path):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.recapture_images.append(os.path.join(root, file))
        self.all_images = self.single_capture_images + self.recapture_images
        self.labels = ([0] * len(self.single_capture_images)) + ([1] * len(self.recapture_images))

    def __len__(self):
        return len(self.all_images)

    def __getitem__(self, idx):
        img_path = self.all_images[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# ========================
# TRANSFORMATIONS
# ========================

# Transformations for training/validation/testing (includes augmentation for training)
train_transform = transforms.Compose([
    WaveletTransform(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# ========================
# DATASET & DATALOADERS
# ========================

# Specify your dataset root directory
root_dir = "/content/drive/MyDrive/CS229_datasets/processed"

# Create the full dataset
full_dataset = RecaptureDataset(root_dir=root_dir, transform=train_transform, is_test=False)
dataset_size = len(full_dataset)
train_size = int(0.8 * dataset_size)
val_size = int(0.1 * dataset_size)
test_size = dataset_size - train_size - val_size
logging.info(f"Total images: {dataset_size}, Training: {train_size}, Validation: {val_size}, Testing: {test_size}")

# Use a torch.Generator with a fixed seed for reproducible splits
generator = torch.Generator().manual_seed(seed)
train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size], generator=generator)

def get_dataloaders(config, train_dataset, val_dataset, test_dataset):
    train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True,
                              num_workers=n_workers, pin_memory=True, persistent_workers=True,
                              prefetch_factor=3)
    val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False,
                             num_workers=n_workers, pin_memory=True, persistent_workers=True,
                             prefetch_factor=3)
    test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False,
                             num_workers=n_workers, pin_memory=True, persistent_workers=True,
                             prefetch_factor=3)
    return train_loader, val_loader, test_loader

# ========================
# MODEL DEFINITION
# ========================

class RecaptureResNet(nn.Module):
    def __init__(self, resnet_version=18):
        super(RecaptureResNet, self).__init__()
        resnet_models = {
            18: models.resnet18,
            34: models.resnet34,
            50: models.resnet50,
            101: models.resnet101,
            152: models.resnet152,
        }
        # Load the pretrained ResNet model (default 3-channel input)
        self.model = resnet_models[resnet_version](pretrained=True)

        # Modify the fully connected layer for binary classification
        num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(num_features, 1)

    def forward(self, x):
        return self.model(x)

# Set up device (GPU if available)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ========================
# TRAINING & EVALUATION FUNCTIONS
# ========================

def train_and_evaluate(config, train_dataset, val_dataset, test_dataset):
    train_loader, val_loader, test_loader = get_dataloaders(config, train_dataset, val_dataset, test_dataset)
    model = RecaptureResNet(resnet_version=config["resnet_version"]).to(device)

    # Weighted binary cross entropy: Increase weight for class 1
    pos_weight = torch.tensor([3.0]).to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) if config["use_scheduler"] else None

    train_losses = []
    model.train()
    for epoch in range(config["num_epochs"]):
        running_loss = 0.0
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['num_epochs']}", leave=True):
            images, labels = images.to(device), labels.float().to(device)
            labels = labels.view(-1, 1)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        epoch_loss = running_loss / len(train_loader)
        train_losses.append(epoch_loss)
        if scheduler:
            scheduler.step()
        logging.info(f"Epoch {epoch+1}/{config['num_epochs']} - Loss: {epoch_loss:.4f}")
        print(f"Epoch {epoch+1}/{config['num_epochs']} - Loss: {epoch_loss:.4f}")

        # Evaluate on validation set at the end of each epoch
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc="Evaluating on Validation", leave=True):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                predictions = (torch.sigmoid(outputs) > 0.5).int()
                val_correct += (predictions.view(-1) == labels).sum().item()
                val_total += labels.size(0)
        val_accuracy = 100 * val_correct / val_total
        logging.info(f"Epoch {epoch+1}/{config['num_epochs']} - Validation Accuracy: {val_accuracy:.2f}%")
        print(f"Epoch {epoch+1}/{config['num_epochs']} - Validation Accuracy: {val_accuracy:.2f}%")
        model.train()

    # Final evaluation on test set
    model.eval()
    test_correct = 0
    test_total = 0
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating on Test", leave=True):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            predictions = (torch.sigmoid(outputs) > 0.5).int()
            test_correct += (predictions.view(-1) == labels).sum().item()
            test_total += labels.size(0)
    test_accuracy = 100 * test_correct / test_total
    logging.info(f"Test Accuracy: {test_accuracy:.2f}%")
    print(f"Test Accuracy: {test_accuracy:.2f}%")
    return train_losses, val_accuracy, test_accuracy, model

In [None]:
# ========================
# SAVE & LOAD MODEL
# ========================

tuned_config = {
    "resnet_version": 50,  # Options: 18, 34, 50, 101, 152
    "batch_size": 128,
    "learning_rate": 5e-5,
    "weight_decay": 1e-4,
    "num_epochs": 15,
    "use_scheduler": True,
    "train_split": 0.8       # Not used anymore for splitting
}

# Train the model with the tuned configuration using the 8:1:1 split
_, _, _, model = train_and_evaluate(tuned_config, train_dataset, val_dataset, test_dataset)

# Save the model state (only weights)
model_path = f'recapture_resnet{tuned_config["resnet_version"]}.pth'
torch.save(model.state_dict(), model_path)
logging.info(f"Model saved as {model_path}")
print(f"Model saved as {model_path}")

Epoch 1/15: 100%|██████████| 15/15 [02:32<00:00, 10.16s/it]


Epoch 1/15 - Loss: 0.6480


Evaluating on Validation: 100%|██████████| 2/2 [01:47<00:00, 53.92s/it] 


Epoch 1/15 - Validation Accuracy: 40.17%


Epoch 2/15: 100%|██████████| 15/15 [01:14<00:00,  4.94s/it]


Epoch 2/15 - Loss: 0.0692


Evaluating on Validation: 100%|██████████| 2/2 [00:33<00:00, 16.77s/it]


Epoch 2/15 - Validation Accuracy: 86.75%


Epoch 3/15: 100%|██████████| 15/15 [01:12<00:00,  4.83s/it]


Epoch 3/15 - Loss: 0.0219


Evaluating on Validation: 100%|██████████| 2/2 [00:34<00:00, 17.06s/it]


Epoch 3/15 - Validation Accuracy: 97.86%


Epoch 4/15: 100%|██████████| 15/15 [01:12<00:00,  4.81s/it]


Epoch 4/15 - Loss: 0.0095


Evaluating on Validation: 100%|██████████| 2/2 [00:33<00:00, 16.88s/it]


Epoch 4/15 - Validation Accuracy: 99.57%


Epoch 5/15: 100%|██████████| 15/15 [01:13<00:00,  4.90s/it]


Epoch 5/15 - Loss: 0.0047


Evaluating on Validation: 100%|██████████| 2/2 [00:34<00:00, 17.03s/it]


Epoch 5/15 - Validation Accuracy: 97.86%


Epoch 6/15: 100%|██████████| 15/15 [01:15<00:00,  5.00s/it]


Epoch 6/15 - Loss: 0.0040


Evaluating on Validation: 100%|██████████| 2/2 [00:34<00:00, 17.21s/it]


Epoch 6/15 - Validation Accuracy: 97.86%


Epoch 7/15: 100%|██████████| 15/15 [01:13<00:00,  4.91s/it]


Epoch 7/15 - Loss: 0.0029


Evaluating on Validation: 100%|██████████| 2/2 [00:33<00:00, 16.77s/it]


Epoch 7/15 - Validation Accuracy: 97.86%


Epoch 8/15: 100%|██████████| 15/15 [01:12<00:00,  4.81s/it]


Epoch 8/15 - Loss: 0.0033


Evaluating on Validation: 100%|██████████| 2/2 [00:33<00:00, 16.63s/it]


Epoch 8/15 - Validation Accuracy: 97.86%


Epoch 9/15: 100%|██████████| 15/15 [01:12<00:00,  4.84s/it]


Epoch 9/15 - Loss: 0.0030


Evaluating on Validation: 100%|██████████| 2/2 [00:33<00:00, 16.89s/it]


Epoch 9/15 - Validation Accuracy: 97.86%


Epoch 10/15: 100%|██████████| 15/15 [01:13<00:00,  4.92s/it]


Epoch 10/15 - Loss: 0.0027


Evaluating on Validation: 100%|██████████| 2/2 [00:33<00:00, 16.79s/it]


Epoch 10/15 - Validation Accuracy: 98.29%


Epoch 11/15: 100%|██████████| 15/15 [01:11<00:00,  4.79s/it]


Epoch 11/15 - Loss: 0.0032


Evaluating on Validation: 100%|██████████| 2/2 [00:33<00:00, 16.74s/it]


Epoch 11/15 - Validation Accuracy: 97.86%


Epoch 12/15: 100%|██████████| 15/15 [01:15<00:00,  5.01s/it]


Epoch 12/15 - Loss: 0.0024


Evaluating on Validation: 100%|██████████| 2/2 [00:33<00:00, 16.66s/it]


Epoch 12/15 - Validation Accuracy: 98.29%


Epoch 13/15: 100%|██████████| 15/15 [01:13<00:00,  4.87s/it]


Epoch 13/15 - Loss: 0.0029


Evaluating on Validation: 100%|██████████| 2/2 [00:33<00:00, 16.73s/it]


Epoch 13/15 - Validation Accuracy: 97.86%


Epoch 14/15: 100%|██████████| 15/15 [01:13<00:00,  4.88s/it]


Epoch 14/15 - Loss: 0.0026


Evaluating on Validation: 100%|██████████| 2/2 [00:33<00:00, 16.76s/it]


Epoch 14/15 - Validation Accuracy: 98.29%


Epoch 15/15: 100%|██████████| 15/15 [01:13<00:00,  4.89s/it]


Epoch 15/15 - Loss: 0.0025


Evaluating on Validation: 100%|██████████| 2/2 [00:33<00:00, 16.73s/it]


Epoch 15/15 - Validation Accuracy: 98.29%


Evaluating on Test: 100%|██████████| 2/2 [01:43<00:00, 51.85s/it] 


Test Accuracy: 97.86%
Model saved as recapture_resnet50.pth
