In [1]:
# AI vs Real Image Classifier - Full Inference & Confusion Matrix

import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
import os
import numpy as np
from train import SimpleCNN
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

ModuleNotFoundError: No module named 'seaborn'

In [None]:
# ------------------------
# Configuration
# ------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = 'models/cnn_model.pth'
test_dir = 'data/test/'  # must have FAKE and REAL subfolders

class_names = ['FAKE', 'REAL']

In [None]:
# ------------------------
# Load Model
# ------------------------
model = SimpleCNN().to(device)
model.load_state_dict(torch.load(model_path))
model.eval()

In [None]:
# ------------------------
# Image Transform
# ------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [None]:
# ------------------------
# Inference Function
# ------------------------
def predict_image(img_path):
    img = Image.open(img_path).convert('RGB')
    img_t = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        pred = model(img_t)
    label_idx = int(pred.item() > 0.5)
    return label_idx, pred.item(), img

In [None]:
# ------------------------
# Loop through test images
# ------------------------
y_true = []
y_pred = []

samples_to_plot = []

for label_idx, label in enumerate(class_names):
    folder = os.path.join(test_dir, label)
    files = os.listdir(folder)
    for f in files:
        img_path = os.path.join(folder, f)
        pred_idx, score, img = predict_image(img_path)
        y_true.append(label_idx)
        y_pred.append(pred_idx)
        # store first 3 of each class for plotting
        if len(samples_to_plot) < 6:
            samples_to_plot.append((img, class_names[pred_idx], score))

In [None]:
# ------------------------
# Confusion Matrix
# ------------------------
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(cm, display_labels=class_names)

plt.figure(figsize=(6,6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.show()

In [None]:
# ------------------------
# Plot sample predictions
# ------------------------
plt.figure(figsize=(12,6))
for i, (img, label, score) in enumerate(samples_to_plot):
    plt.subplot(2,3,i+1)
    plt.imshow(img)
    plt.title(f"{label} ({score:.2f})")
    plt.axis('off')
plt.suptitle("Sample CNN Predictions")
plt.show()