In [None]:
# Import required libraries
from google.colab import drive
import os
import zipfile
import shutil
import random
from transformers import AutoModelForImageClassification, AutoProcessor
from PIL import Image
import torch
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

# Check if GPU is available and set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# # Step 1: Connect to Google Drive
drive.mount('/content/drive')

# Define file paths
zip_path = '/content/drive/MyDrive/cifake_dataset.zip'  # Path to the dataset zip file
unzip_location = '/content/data'  # Destination folder for extracted files

# Step 2: Extract the ZIP file (may take some time)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(unzip_location)

# Step 3: Set up paths for training and validation folders
train_real_src = '/content/train/REAL'
train_fake_src = '/content/train/FAKE'
validation_real_dst = '/content/validation/REAL'
validation_fake_dst = '/content/validation/FAKE'

# Create validation folders if they do not exist
os.makedirs(validation_real_dst, exist_ok=True)
os.makedirs(validation_fake_dst, exist_ok=True)

# Step 4: Transfer images from train/REAL to validation/REAL and from train/FAKE to validation/FAKE
real_images = os.listdir(train_real_src)
for img in tqdm(real_images[:10000], desc="Transferring REAL images to validation"):
    shutil.move(os.path.join(train_real_src, img), validation_real_dst)

fake_images = os.listdir(train_fake_src)
for img in tqdm(fake_images[:10000], desc="Transferring FAKE images to validation"):
    shutil.move(os.path.join(train_fake_src, img), validation_fake_dst)

In [None]:
# Define paths for testing images
real_test_path = '/content/test/REAL'
fake_test_path = '/content/test/FAKE'

# Step 5: Load the image classifier model from Hugging Face
model_identifier = "Organika/sdxl-detector"
classifier_model = AutoModelForImageClassification.from_pretrained(model_identifier).to(device)
preprocess = AutoProcessor.from_pretrained(model_identifier)

In [None]:
# Custom Dataset Class
def collate_fn(batch):
    images, labels = zip(*batch)
    images = torch.cat(images, dim=0)
    return images, labels

class ImageDataset(Dataset):
    def __init__(self, image_folder, preprocess):
        if not os.path.exists(image_folder):
            raise FileNotFoundError(f"Image folder {image_folder} does not exist.")
        self.image_folder = image_folder
        self.image_files = os.listdir(image_folder)
        self.preprocess = preprocess

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_folder, self.image_files[idx])
        img = Image.open(img_path).convert("RGB")
        inputs = self.preprocess(images=img, return_tensors="pt")
        return inputs['pixel_values'], self.image_files[idx]

# Define function to classify images and calculate accuracy
def calculate_accuracy(dataset, expected_label, batch_size=16):
    dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
    correct_count = 0
    total_images = len(dataset)

    for batch in tqdm(dataloader, desc="Calculating accuracy"):
        images, labels = batch
        images = images.to(device)
        with torch.no_grad():
            output_logits = classifier_model(pixel_values=images)
        predicted_idx = torch.argmax(output_logits.logits, dim=-1)
        for idx in predicted_idx:
            predicted_label = classifier_model.config.id2label[idx.item()]
            if predicted_label == expected_label:
                correct_count += 1

    accuracy_percentage = (correct_count / total_images) * 100
    return accuracy_percentage

# Step 6: Calculate detection accuracy for REAL and FAKE images
try:
    real_dataset = ImageDataset(real_test_path, preprocess)
    fake_dataset = ImageDataset(fake_test_path, preprocess)

    accuracy_real = calculate_accuracy(real_dataset, "human")  # Assuming 'human' is the label for real images
    accuracy_fake = calculate_accuracy(fake_dataset, "artificial")  # Assuming 'artificial' is the label for fake images

    # Display results
    print(f"Accuracy for REAL images (predicted as 'human'): {accuracy_real:.2f}%")
    print(f"Accuracy for FAKE images (predicted as 'artificial'): {accuracy_fake:.2f}%")
except FileNotFoundError as e:
    print(e)


Using device: cuda


Transferring REAL images to validation: 100%|██████████| 10000/10000 [00:00<00:00, 28158.03it/s]
Transferring FAKE images to validation: 100%|██████████| 10000/10000 [00:00<00:00, 27118.76it/s]
Calculating accuracy: 100%|██████████| 625/625 [02:29<00:00,  4.19it/s]
Calculating accuracy: 100%|██████████| 625/625 [02:30<00:00,  4.15it/s]

Accuracy for REAL images (predicted as 'human'): 87.14%
Accuracy for FAKE images (predicted as 'artificial'): 27.68%



