In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, ConcatDataset, Subset
import numpy as np
from PIL import Image


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CIFAR_SAMPLES_PER_CLASS = 500  
MIDJOURNEY_SAMPLES_PER_CLASS = 500  

BATCH_SIZE = 64
EPOCHS = 50
PATIENCE = 7  

MIDJOURNEY_BASE = '/home/dhanraj/Documents/Midjourney_Exp2'
TRAIN_SPLIT = 'train'
TEST_SPLIT = 'test'


imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
])


class MidJourneyDataset(torch.utils.data.Dataset):
    def __init__(self, base_dir, split, transform=None, max_samples_per_class=500):
        self.real_dir = os.path.join(base_dir, split, 'REAL')
        self.fake_dir = os.path.join(base_dir, split, 'FAKE')
        self.transform = transform

        self.real_files = [os.path.join(self.real_dir, f) 
                           for f in os.listdir(self.real_dir) 
                           if f.lower().endswith(('.jpg', '.jpeg', '.png'))][:max_samples_per_class]
        self.fake_files = [os.path.join(self.fake_dir, f) 
                           for f in os.listdir(self.fake_dir) 
                           if f.lower().endswith(('.jpg', '.jpeg', '.png'))][:max_samples_per_class]

        self.file_list = self.real_files + self.fake_files
        self.labels = [0]*len(self.real_files) + [1]*len(self.fake_files)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path).convert('RGB')        
        if self.transform:
            img = self.transform(img)
        label = self.labels[idx]
        return img, label



class CIFARRealSubset(torch.utils.data.Dataset):
    def __init__(self, original_dataset, max_samples=500, transform=None):
        self.transform = transform
        self.max_samples = max_samples
        self.data = []
        self.targets = []

        count = 0
        for img, label in original_dataset:
            if count >= max_samples:
                break
            self.data.append(img)
            self.targets.append(0)
            count += 1

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx]
        label = self.targets[idx]
        if self.transform:
            img = self.transform(img)
        return img, label


def get_data_loaders():
    cifar_train = datasets.CIFAR10(root='./data', train=True, download=True)
    cifar_train_subset = CIFARRealSubset(cifar_train, max_samples=CIFAR_SAMPLES_PER_CLASS, transform=train_transform)


    mid_train = MidJourneyDataset(MIDJOURNEY_BASE, TRAIN_SPLIT, transform=train_transform, max_samples_per_class=MIDJOURNEY_SAMPLES_PER_CLASS)

    train_dataset = torch.utils.data.ConcatDataset([cifar_train_subset, mid_train])
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

    # MidJourney test set (real + fake) with test transform
    mid_test = MidJourneyDataset(MIDJOURNEY_BASE, TEST_SPLIT, transform=test_transform, max_samples_per_class=500)
    test_loader = DataLoader(mid_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

    return train_loader, test_loader



def get_model():
    model = models.resnet18(pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 2)  # Binary classification: real vs fake
    return model.to(DEVICE)



def train_one_epoch(model, dataloader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += inputs.size(0)
    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

def evaluate(model, dataloader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += inputs.size(0)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc, np.array(all_preds), np.array(all_labels)


def main_train():
    train_loader, test_loader = get_data_loaders()
    model = get_model()

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.3)

    best_acc = 0.0
    patience_counter = 0

    for epoch in range(EPOCHS):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer)
        val_loss, val_acc, _, _ = evaluate(model, test_loader, criterion)

        print(f"Epoch {epoch+1}/{EPOCHS} - Train Loss: {train_loss:.4f} - Train Acc: {train_acc:.4f} - "
              f"Val Loss: {val_loss:.4f} - Val Acc: {val_acc:.4f}")

        scheduler.step(val_loss)

        if val_acc > best_acc:
            best_acc = val_acc
            patience_counter = 0
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print("Early stopping triggered.")
                break

    print(f"Best Validation Accuracy: {best_acc*100:.2f}%")

    model.load_state_dict(torch.load('best_model.pth'))
    return model, test_loader

if __name__ == '__main__':
    trained_model, test_loader = main_train()

    import sklearn.metrics as metrics

    trained_model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = trained_model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)

    print("Classification Report on MidJourney Test set:")
    print(metrics.classification_report(all_labels, all_preds, digits=4))
    accuracy = metrics.accuracy_score(all_labels, all_preds)
    print(f"Test Accuracy: {accuracy*100:.2f}%")



Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/dhanraj/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████████████████████████████████| 44.7M/44.7M [00:00<00:00, 67.3MB/s]


Epoch 1/50 - Train Loss: 0.3243 - Train Acc: 0.8599 - Val Loss: 0.1805 - Val Acc: 0.9476
Epoch 2/50 - Train Loss: 0.0645 - Train Acc: 0.9778 - Val Loss: 0.1823 - Val Acc: 0.9432
Epoch 3/50 - Train Loss: 0.0287 - Train Acc: 0.9926 - Val Loss: 0.3913 - Val Acc: 0.8428
Epoch 4/50 - Train Loss: 0.0121 - Train Acc: 0.9960 - Val Loss: 0.2444 - Val Acc: 0.9389
Epoch 5/50 - Train Loss: 0.0170 - Train Acc: 0.9966 - Val Loss: 0.1802 - Val Acc: 0.9476
Epoch 6/50 - Train Loss: 0.0080 - Train Acc: 0.9993 - Val Loss: 0.2341 - Val Acc: 0.9258
Epoch 7/50 - Train Loss: 0.0066 - Train Acc: 1.0000 - Val Loss: 0.2384 - Val Acc: 0.9301
Epoch 8/50 - Train Loss: 0.0038 - Train Acc: 1.0000 - Val Loss: 0.1986 - Val Acc: 0.9520
Epoch 9/50 - Train Loss: 0.0029 - Train Acc: 1.0000 - Val Loss: 0.2217 - Val Acc: 0.9520
Epoch 10/50 - Train Loss: 0.0019 - Train Acc: 1.0000 - Val Loss: 0.2117 - Val Acc: 0.9476
Epoch 11/50 - Train Loss: 0.0018 - Train Acc: 1.0000 - Val Loss: 0.2086 - Val Acc: 0.9520
Epoch 12/50 - Train

In [4]:
import torch
from torchvision import transforms, models
from PIL import Image
import os


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "best_model.pth"


imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])


def load_trained_model():
    model = models.resnet18(pretrained=False)
    num_ftrs = model.fc.in_features
    model.fc = torch.nn.Linear(num_ftrs, 2)  # binary classes
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    model.to(DEVICE)
    model.eval()
    return model


def predict_single_image(image_path, model):
    img = Image.open(image_path).convert('RGB')
    img_t = test_transform(img).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        outputs = model(img_t)
        _, pred = torch.max(outputs, 1)
        pred_class = pred.item()

    label_map = {0: "REAL", 1: "FAKE"}
    return label_map[pred_class]

def predict_multiple_images(image_paths, model):
    results = []
    for img_path in image_paths:
        try:
            pred_label = predict_single_image(img_path, model)
            results.append((img_path, pred_label))
        except Exception as e:
            results.append((img_path, f"Error: {e}"))
    return results


if __name__ == "__main__":
    model = load_trained_model()

     test_image_list = [
    '/home/dhanraj/Downloads/lion.jpg',
    '/home/dhanraj/Downloads/srk.jpg',
    '/home/dhanraj/Downloads/person.jpeg',
    '/home/dhanraj/Downloads/mount.jpg',
    '/home/dhanraj/Downloads/girl.png',
    '/home/dhanraj/Downloads/person.jpeg',
    '/home/dhanraj/Downloads/IMG-20250811-WA0004.jpg',
    '/home/dhanraj/Downloads/IMG-20250811-WA0006.jpg',
    '/home/dhanraj/Downloads/kitty.jpeg',
    '/home/dhanraj/Downloads/ponnu.jpg',
    '/home/dhanraj/Downloads/Gina.png',
    '/home/dhanraj/Downloads/girl.jpg',
    '/home/dhanraj/Downloads/new.jpeg',
    '/home/dhanraj/Downloads/Tom.png',
    '/home/dhanraj/Downloads/Allan.png',
]


predictions = predict_multiple_images(test_image_list, model)


for path, pred in predictions:
    print(f"{path} => {pred}")
    


/home/dhanraj/Downloads/lion.jpg => REAL
/home/dhanraj/Downloads/srk.jpg => REAL
/home/dhanraj/Downloads/person.jpeg => REAL
/home/dhanraj/Downloads/mount.jpg => REAL
/home/dhanraj/Downloads/girl.png => REAL
/home/dhanraj/Downloads/person.jpeg => REAL
/home/dhanraj/Downloads/IMG-20250811-WA0004.jpg => REAL
/home/dhanraj/Downloads/IMG-20250811-WA0006.jpg => REAL
/home/dhanraj/Downloads/kitty.jpeg => REAL
/home/dhanraj/Downloads/ponnu.jpg => REAL
/home/dhanraj/Downloads/Gina.png => FAKE
/home/dhanraj/Downloads/girl.jpg => FAKE
/home/dhanraj/Downloads/new.jpeg => REAL
/home/dhanraj/Downloads/Tom.png => REAL
/home/dhanraj/Downloads/Allan.png => REAL
