In [12]:
import os
import random
import time
from pathlib import Path
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models

import matplotlib.pyplot as plt
from torchvision.utils import make_grid

In [1]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

x = torch.randn(10000, 10000, device=device)
print("Computation successful on", device)


Using device: cuda
Computation successful on cuda


In [4]:
# CONFIGURATION

data_root = Path(r"C:/Users/Jimmy/OneDrive/Desktop/test/DS6050_Ai_Detection")  # adjust as needed
train_dir = data_root / "train"
val_dir = data_root / "validation"

batch_size = 32
num_epochs = 5
learning_rate = 1e-4
train_percent = 0.5  # train/validate on 50% of data

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [5]:
# DATA TRANSFORMS

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize for ResNet
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [6]:
# LOAD DATASETS

train_dataset = datasets.ImageFolder(root=str(train_dir), transform=transform)
val_dataset = datasets.ImageFolder(root=str(val_dir), transform=transform)

def subset_dataset(dataset, percent):
    """Return a subset of the dataset based on the given percent (0 < percent <= 1)."""
    if percent >= 1.0:
        return dataset
    subset_size = int(len(dataset) * percent)
    indices = random.sample(range(len(dataset)), subset_size)
    return torch.utils.data.Subset(dataset, indices)

train_dataset = subset_dataset(train_dataset, train_percent)
val_dataset = subset_dataset(val_dataset, train_percent)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

print(f"Training on {len(train_dataset)} images, validating on {len(val_dataset)} images.")

Training on 57600 images, validating on 14170 images.


In [7]:
# MODEL SETUP (ResNet50)

model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)  # binary classification (real vs fake)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to C:\Users\Jimmy/.cache\torch\hub\checkpoints\resnet50-0676ba61.pth
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 97.8M/97.8M [00:02<00:00, 40.2MB/s]


In [13]:
grid_size = 2  
num_images = grid_size * grid_size

# Assuming val_loader is your validation DataLoader
data_iter = iter(val_loader)
images, labels = next(data_iter)

# Subsample fixed indices: half real (label=0), half fake (label=1)
real_indices = (labels == 0).nonzero(as_tuple=True)[0][:num_images//2]
fake_indices = (labels == 1).nonzero(as_tuple=True)[0][:num_images//2]
subset_indices = torch.cat([real_indices, fake_indices])[:num_images]

fixed_images = images[subset_indices].to(device)
fixed_labels = labels[subset_indices].to(device)

In [None]:
# Training loop

def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs):
    best_val_acc = 0.0
    start_time = time.time()

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        print(f"\nüîπ Epoch {epoch+1}/{num_epochs}")

        # Wrap train loader with tqdm
        train_pbar = tqdm(train_loader, desc="Training", unit="batch")
        for images, labels in train_pbar:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            train_acc = 100 * correct / total
            avg_loss = running_loss / (len(train_loader) if len(train_loader) > 0 else 1)

            # Live update progress bar
            train_pbar.set_postfix({
                "Loss": f"{avg_loss:.4f}",
                "Train Acc": f"{train_acc:.2f}%"
            })

        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        val_pbar = tqdm(val_loader, desc="Validating", unit="batch", leave=False)
        with torch.no_grad():
            for images, labels in val_pbar:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = outputs.max(1)
                val_total += labels.size(0)
                val_correct += predicted.eq(labels).sum().item()

            #-------------images
            outputs_grid = model(fixed_images)
            preds_grid = torch.argmax(outputs_grid, dim=1)

        val_acc = 100 * val_correct / val_total

        label_names = ['Real', 'Fake']
        true_names = [label_names[l.item()] for l in fixed_labels]
        pred_names = [label_names[p.item()] for p in preds_grid]

        # Plot the grid
        fig, axs = plt.subplots(3, 1, figsize=(8, 8))

        # --- Row 1: images ---
        grid_img = make_grid(fixed_images.cpu(), nrow=grid_size, normalize=True)
        axs[0].imshow(grid_img.permute(1, 2, 0))
        axs[0].set_title("Validation Samples")
        axs[0].axis("off")

        # --- Row 2: true labels ---
        true_labels_img = torch.zeros_like(fixed_images)
        axs[1].imshow(torch.ones_like(grid_img))  # placeholder
        axs[1].set_title("True Labels")
        axs[1].axis("off")
        for i, lbl in enumerate(true_names):
            axs[1].text((i % grid_size) * 120 + 50, (i // grid_size) * 120 + 60, lbl,
                        ha='center', va='center', fontsize=12, color='black')

        # --- Row 3: predicted labels ---
        axs[2].imshow(torch.ones_like(grid_img))
        axs[2].set_title("Predicted Labels")
        axs[2].axis("off")
        for i, lbl in enumerate(pred_names):
            axs[2].text((i % grid_size) * 120 + 50, (i // grid_size) * 120 + 60, lbl,
                        ha='center', va='center', fontsize=12,
                        color='green' if lbl == true_names[i] else 'red')

        plt.tight_layout()
        plt.show()

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_resnet50_baseline.pth")

        print(f"Epoch {epoch+1}/{num_epochs} "
              f"Loss: {avg_loss:.4f} | Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")

    print(f"\nTraining complete in {(time.time() - start_time)/60:.2f} minutes.")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")

In [15]:
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs)


üîπ Epoch 1/5


Training:  11%|‚ñà‚ñè        | 203/1800 [00:33<04:26,  5.99batch/s, Loss=0.0041, Train Acc=98.58%] 


KeyboardInterrupt: 