In [6]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10
from PIL import Image
from sklearn.metrics import accuracy_score, classification_report
from torch.cuda import amp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_BINS = 64
BATCH_SIZE = 16
EPOCHS = 50
ACCUMULATION_STEPS = 4
MIDJOURNEY_BASE = "/home/dhanraj/Documents/Midjourney_Exp2"  # change to your path

class PixelPredictor(nn.Module):
    def __init__(self, num_bins=64):
        super(PixelPredictor, self).__init__()
        self.num_bins = num_bins
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.relu2 = nn.ReLU()
        self.upsample1 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.relu_upsample1 = nn.ReLU()
        self.final_conv = nn.Conv2d(64, num_bins, 1)

    def forward(self, x):
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        x = self.relu_upsample1(self.bn3(self.upsample1(x)))
        return self.final_conv(x)

# ===== UTILS =====
def rgb_to_grayscale(img):
    weights = torch.tensor([0.2989, 0.5870, 0.1140], device=img.device).view(1, 3, 1, 1)
    return (img * weights).sum(dim=1)

def quantize_targets(gray_img, num_bins=64):
    bins = torch.linspace(0, 256, steps=num_bins + 1, device=gray_img.device)
    target_bins = torch.bucketize((gray_img * 255).long(), bins) - 1
    return target_bins.clamp(0, num_bins - 1)

def downsample(img, levels=3):
    downs = [img]
    for _ in range(levels):
        img = F.avg_pool2d(img, 2, 2)
        downs.append(img)
    return downs

def compute_entropy(probs):
    return -torch.sum(probs * torch.log(probs + 1e-8), dim=1)

def compute_nll(probs, targets):
    return -torch.log(torch.gather(probs, 1, targets.unsqueeze(1)).squeeze(1) + 1e-8)

def extract_features(model, img_tensor, num_bins=64):
    model.eval()
    downs = downsample(img_tensor)
    features = []
    with torch.no_grad():
        for i in range(3):
            high = downs[i]
            low = F.interpolate(downs[i + 1], scale_factor=2, mode='nearest')
            logits = model(low)
            probs = F.softmax(logits, dim=1)
            gray_high = rgb_to_grayscale(high)
            target = quantize_targets(gray_high, num_bins=num_bins)
            entropy = compute_entropy(probs)
            nll = compute_nll(probs, target)
            gap = nll.mean().item() - entropy.mean().item()
            features.append((nll.mean().item(), entropy.mean().item(), gap))
    D0, D1 = features[0][2], features[1][2]
    return {'D0': D0, 'D1': D1, 'delta01': D0 - D1,
            'abs_D0': abs(D0), 'abs_delta01': abs(D0 - D1)}

def classify_image(features, delta01_threshold=0.1):
    return 'synthetic' if abs(features['delta01']) > delta01_threshold else 'real'


transform_train = T.Compose([
    T.Resize((64, 64)),
    T.RandomHorizontalFlip(),
    T.RandomRotation(15),
    T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
    T.ToTensor(),
])

transform_test = T.Compose([
    T.Resize((64, 64)),
    T.ToTensor(),
])

cifar_train = CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_indices = list(range(2000))
train_dataset = Subset(cifar_train, train_indices)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=4, pin_memory=True, persistent_workers=True)

class MidJourneyDataset(torch.utils.data.Dataset):
    def __init__(self, base_dir, split='test', transform=None, max_samples_per_class=200):
        self.real_dir = os.path.join(base_dir, split, 'REAL')
        self.fake_dir = os.path.join(base_dir, split, 'FAKE')
        self.transform = transform
        self.real_files = [os.path.join(self.real_dir, f) for f in os.listdir(self.real_dir)
                           if f.lower().endswith(('.jpg', '.jpeg', '.png'))][:max_samples_per_class]
        self.fake_files = [os.path.join(self.fake_dir, f) for f in os.listdir(self.fake_dir)
                           if f.lower().endswith(('.jpg', '.jpeg', '.png'))][:max_samples_per_class]
        self.files = self.real_files + self.fake_files
        self.labels = [0]*len(self.real_files) + [1]*len(self.fake_files)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert('RGB')
        if self.transform: img = self.transform(img)
        return img, self.labels[idx]

mid_test_dataset = MidJourneyDataset(MIDJOURNEY_BASE, split='test', transform=transform_test)
mid_test_loader = DataLoader(mid_test_dataset, batch_size=BATCH_SIZE, shuffle=False,
                             num_workers=4, pin_memory=True, persistent_workers=True)

def train_model_with_accumulation(model, train_loader, device, epochs=50, accumulation_steps=4):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs//2)
    criterion = nn.CrossEntropyLoss()
    scaler = amp.GradScaler()

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        optimizer.zero_grad()
        for batch_idx, (imgs, _) in enumerate(train_loader):
            imgs = imgs.to(device)
            low_res = F.avg_pool2d(imgs, 2, 2)
            gray_imgs = rgb_to_grayscale(imgs)
            targets = quantize_targets(gray_imgs, num_bins=model.num_bins)

            with amp.autocast():
                logits = model(low_res)
                loss = criterion(logits, targets) / accumulation_steps
            
            scaler.scale(loss).backward()

            if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(train_loader):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

            total_loss += loss.item() * accumulation_steps

        scheduler.step()
        if epoch % 5 == 0 or epoch <= 5:
            print(f"Epoch {epoch}/{epochs}, Loss={total_loss/len(train_loader):.6f}, LR={scheduler.get_last_lr()[0]:.6f}")


def save_model(model, path="pixel_predictor_accum.pth"):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

def load_model(path="pixel_predictor_accum.pth"):
    model = PixelPredictor(num_bins=NUM_BINS)
    model.load_state_dict(torch.load(path, map_location=device))
    model.to(device).eval()
    print(f"Model loaded from {path}")
    return model

if __name__ == "__main__":
    print("=== TRAINING ===")
    model = PixelPredictor(num_bins=NUM_BINS)
    train_model_with_accumulation(model, train_loader, device,
                                  epochs=EPOCHS, accumulation_steps=ACCUMULATION_STEPS)
    save_model(model)

    print("\n=== TESTING ON MIDJOURNEY ===")
    y_true, y_pred = [], []
    for imgs, labels in mid_test_loader:
        imgs = imgs.to(device)
        for i in range(imgs.size(0)):
            feats = extract_features(model, imgs[i].unsqueeze(0), num_bins=NUM_BINS)
            pred = classify_image(feats, delta01_threshold=0.1)
            y_pred.append(pred)
            y_true.append('real' if labels[i] == 0 else 'synthetic')

    acc = accuracy_score(y_true, y_pred)
    print(f"MidJourney Test Accuracy: {acc*100:.2f}%")
    print(classification_report(y_true, y_pred, digits=4))


=== TRAINING ===


  scaler = amp.GradScaler()
  with amp.autocast():


Epoch 1/50, Loss=3.013132, LR=0.009961
Epoch 2/50, Loss=2.422997, LR=0.009843
Epoch 3/50, Loss=2.421217, LR=0.009649
Epoch 4/50, Loss=2.382593, LR=0.009382
Epoch 5/50, Loss=2.396114, LR=0.009045
Epoch 10/50, Loss=2.266308, LR=0.006545
Epoch 15/50, Loss=2.187725, LR=0.003455
Epoch 20/50, Loss=2.275485, LR=0.000955
Epoch 25/50, Loss=2.171755, LR=0.000000
Epoch 30/50, Loss=2.148370, LR=0.000955
Epoch 35/50, Loss=2.186818, LR=0.003455
Epoch 40/50, Loss=2.243596, LR=0.006545
Epoch 45/50, Loss=2.173289, LR=0.009045
Epoch 50/50, Loss=2.245404, LR=0.010000
Model saved to pixel_predictor_accum.pth

=== TESTING ON MIDJOURNEY ===
MidJourney Test Accuracy: 52.84%
              precision    recall  f1-score   support

        real     0.5862    0.1504    0.2394       113
   synthetic     0.5200    0.8966    0.6582       116

    accuracy                         0.5284       229
   macro avg     0.5531    0.5235    0.4488       229
weighted avg     0.5527    0.5284    0.4516       229

