In [6]:
import os
from PIL import Image
from torchvision.datasets.vision import VisionDataset
from torchvision import transforms

import re

def extract_object_class(filename):
    match = re.search(r"\((\d+)\)", filename)
    if match:
        return int(match.group(1))
    else:
        return 0  # fallback or special value if no match

class CIFAKEDataset(VisionDataset):
    def __init__(self, root, split='train', transform=None, target_transform=None):
        assert split in ['train', 'test'], "split must be 'train' or 'test'"
        super().__init__(root, transform=transform, target_transform=target_transform)
        
        self.data_dir = os.path.join(root, split)
        self.classes = ['REAL', 'FAKE']
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        
        self.samples = self._make_dataset()

    def _make_dataset(self):
        samples = []
        for class_name in self.classes:
            class_dir = os.path.join(self.data_dir, class_name)
            if not os.path.isdir(class_dir):
                continue
            for fname in os.listdir(class_dir):
                if fname.endswith(".jpg") or fname.endswith(".png"):
                    path = os.path.join(class_dir, fname)
                    label = self.class_to_idx[class_name]
                    samples.append((path, label))
        return samples

    def __getitem__(self, index):
        path, target = self.samples[index]
        filename = os.path.basename(path)
        object_class = extract_object_class(filename)

        image = Image.open(path).convert("RGB")

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            target = self.target_transform(target)

        return image, float(target), object_class
    
    def __len__(self):
        return len(self.samples)

In [7]:
from torch.utils.data import DataLoader
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((32, 32)),  # same as CIFAR
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # CIFAR-like normalization
])

train_dataset = CIFAKEDataset(root="./cifake", split='train', transform=transform)
test_dataset = CIFAKEDataset(root="./cifake", split='test', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


for images, labels, objclass in train_loader:
    print(images.shape, labels.shape)
    print(labels)
    print(objclass)
    # Here you can add your training code
    break  # Remove this break to iterate through the entire dataset    

torch.Size([64, 3, 32, 32]) torch.Size([64])
tensor([0., 0., 1., 0., 0., 0., 1., 0., 1., 1., 0., 0., 1., 0., 0., 0., 0., 1.,
        1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0.,
        0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 0., 1., 1., 1., 1., 1.,
        0., 1., 0., 1., 1., 0., 0., 0., 1., 0.], dtype=torch.float64)
tensor([ 5,  2,  2,  0,  5,  9,  4,  8,  0,  5,  9,  6,  8,  8,  8,  4,  7,  3,
         2,  8,  6,  3,  4,  2,  6,  5,  0,  7,  9,  9,  9,  0,  0,  4,  9,  7,
         6,  5,  0, 10,  9,  4,  2,  7,  6,  4,  3,  3, 10, 10,  8,  0,  9,  4,
         9,  5, 10,  9,  5,  5,  5,  7,  5,  0])


In [8]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader

class CIFAKEDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=64, num_workers=4, model_type="resnet"):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        image_size = 224 if model_type == "vit" else 32
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))
        ])

    def setup(self, stage=None):
        self.train_dataset = CIFAKEDataset(self.data_dir, split='train', transform=self.transform)
        self.test_dataset = CIFAKEDataset(self.data_dir, split='test', transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18, vit_b_16
from collections import defaultdict



def get_model(backbone="resnet"):
    if backbone == "resnet":
        model = resnet18(pretrained=False)
        model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        model.maxpool = nn.Identity()
        in_features = model.fc.in_features
        model.fc = nn.Identity()
    elif backbone == "vit":
        model = vit_b_16(pretrained=False)
        in_features = model.heads.head.in_features
        model.heads.head = nn.Identity()
    else:
        raise ValueError(f"Unsupported backbone: {backbone}")
    
    return model, in_features

class LitCIFAKEClassifier(pl.LightningModule):
    def __init__(self, architecture="vit", learning_rate=1e-3, use_logits=True):
        super().__init__()
        self.save_hyperparameters()
        self.use_logits = use_logits
        self.val_stats = defaultdict(lambda: {"real": {"correct": 0, "total": 0},
                                              "fake": {"correct": 0, "total": 0}})

        self.model, in_features = get_model(backbone=architecture)


        if use_logits:
            self.classifier = nn.Linear(in_features, 1)
            self.loss_fn = nn.BCEWithLogitsLoss()
        else:
            self.classifier = nn.Linear(in_features, 2)
            self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        features = self.model(x)
        return self.classifier(features)

    def training_step(self, batch, batch_idx):
        x, y, _ = batch
        logits = self(x)

        if self.use_logits:
            y = y.float().unsqueeze(1)
            loss = self.loss_fn(logits, y)
            preds = (torch.sigmoid(logits) > 0.5).int()
            acc = (preds == y.int()).float().mean()
        else:
            loss = self.loss_fn(logits, y)
            preds = torch.argmax(logits, dim=1)
            acc = (preds == y).float().mean()

        self.log("train_loss", loss)
        self.log("train_acc", acc)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y, obj_cls = batch
        logits = self(x)

        if self.use_logits:
            y_float = y.float().unsqueeze(1)
            loss = self.loss_fn(logits, y_float)
            preds = (torch.sigmoid(logits) > 0.5).int().squeeze(1)
        else:
            loss = self.loss_fn(logits, y)
            preds = torch.argmax(logits, dim=1)

        for pred, label, cls_id in zip(preds.cpu(), y.cpu(), obj_cls.cpu()):
            label_str = "real" if label == 0 else "fake"
            self.val_stats[cls_id.item()][label_str]["total"] += 1
            if pred.item() == label.item():
                self.val_stats[cls_id.item()][label_str]["correct"] += 1

        return {"val_loss": loss}   
    def on_validation_epoch_end(self):
        real_correct, real_total = 0, 0
        fake_correct, fake_total = 0, 0

        print("\nPer-class accuracy (val):")
        for cls_id, stats in sorted(self.val_stats.items()):
            real = stats["real"]
            fake = stats["fake"]

            real_acc = real["correct"] / real["total"] if real["total"] > 0 else 0.0
            fake_acc = fake["correct"] / fake["total"] if fake["total"] > 0 else 0.0

            real_correct += real["correct"]
            real_total += real["total"]
            fake_correct += fake["correct"]
            fake_total += fake["total"]

            print(f"Class {cls_id}: REAL acc = {real_acc:.3f}, FAKE acc = {fake_acc:.3f}")

        overall_real_acc = real_correct / real_total if real_total > 0 else 0.0
        overall_fake_acc = fake_correct / fake_total if fake_total > 0 else 0.0
        print(f"\nOverall REAL accuracy: {overall_real_acc:.3f}")
        print(f"Overall FAKE accuracy: {overall_fake_acc:.3f}")
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

In [10]:
from pytorch_lightning import Trainer


dm = CIFAKEDataModule(data_dir="./cifake", batch_size=64, model_type="resnet")
model = LitCIFAKEClassifier(architecture="resnet", learning_rate=1e-3, use_logits=True)

trainer = Trainer(max_epochs=10, accelerator="auto")
trainer.fit(model, datamodule=dm)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/computri/anaconda3/envs/deepfakes/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch

Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:00<00:00,  5.75it/s]
Per-class accuracy (val):
Class 0: REAL acc = 1.000, FAKE acc = 0.000
Class 2: REAL acc = 1.000, FAKE acc = 0.000
Class 3: REAL acc = 1.000, FAKE acc = 0.000
Class 4: REAL acc = 1.000, FAKE acc = 0.000
Class 5: REAL acc = 1.000, FAKE acc = 0.000
Class 6: REAL acc = 1.000, FAKE acc = 0.000
Class 7: REAL acc = 1.000, FAKE acc = 0.000
Class 8: REAL acc = 1.000, FAKE acc = 0.000
Class 9: REAL acc = 1.000, FAKE acc = 0.000
Class 10: REAL acc = 1.000, FAKE acc = 0.000

Overall REAL accuracy: 1.000
Overall FAKE accuracy: 0.000
Epoch 0: 100%|██████████| 1563/1563 [00:23<00:00, 67.16it/s, v_num=8]      
Per-class accuracy (val):
Class 0: REAL acc = 0.760, FAKE acc = 0.992
Class 2: REAL acc = 0.786, FAKE acc = 0.992
Class 3: REAL acc = 0.828, FAKE acc = 0.928
Class 4: REAL acc = 0.898, FAKE acc = 0.984
Class 5: REAL acc = 0.926, FAKE acc = 0.984
Class 6: REAL acc = 0.891, FAKE acc = 0.977
Class 7: REAL acc = 0.926, FAKE 

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 1563/1563 [00:26<00:00, 58.36it/s, v_num=8]


In [14]:
import torch
from diffusers import StableDiffusionXLPipeline
from pathlib import Path
from PIL import Image
from tqdm import tqdm

classes = [
    "airplane", "car", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]
# Config
prompts = [f"A realistic {cl}" for cl in classes]
output_dir = Path("generated_sdxl")
output_dir.mkdir(exist_ok=True)
seed = 42
num_images_per_prompt = 1
height = 1024
width = 1024
guidance_scale = 7.5
num_inference_steps = 30
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load SDXL pipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    use_safetensors=True,
)
pipe.to(device)

# Set generator seed
# generator = torch.Generator(device).manual_seed(seed)

# Generate images
for prompt in tqdm(prompts, desc="Generating images"):
    for i in range(num_images_per_prompt):
        image = pipe(
            prompt=prompt,
            height=height,
            width=width,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            # generator=generator
        ).images[0]

        # small_image = image.resize((32, 32), resample=Image.BICUBIC)
        
        # Save image
        filename = f"{prompt.replace(' ', '_')[:50]}_{i+1}.jpg"
        # small_image.save(output_dir / filename)
        image.save(output_dir / filename)

Loading pipeline components...: 100%|██████████| 7/7 [00:01<00:00,  3.76it/s]
100%|██████████| 30/30 [00:07<00:00,  4.09it/s]<?, ?it/s]
100%|██████████| 30/30 [00:07<00:00,  4.06it/s]<01:11,  7.94s/it]
100%|██████████| 30/30 [00:07<00:00,  4.06it/s]<01:03,  7.96s/it]
100%|██████████| 30/30 [00:07<00:00,  4.05it/s]<00:55,  7.98s/it]
100%|██████████| 30/30 [00:07<00:00,  4.05it/s]<00:47,  7.98s/it]
100%|██████████| 30/30 [00:07<00:00,  4.04it/s]<00:39,  7.99s/it]
100%|██████████| 30/30 [00:07<00:00,  4.03it/s]<00:32,  8.01s/it]
100%|██████████| 30/30 [00:07<00:00,  4.03it/s]<00:24,  8.02s/it]
100%|██████████| 30/30 [00:07<00:00,  4.02it/s]<00:16,  8.03s/it]
100%|██████████| 30/30 [00:07<00:00,  4.03it/s]<00:08,  8.03s/it]
Generating images: 100%|██████████| 10/10 [01:20<00:00,  8.01s/it]


In [17]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18


def get_cifar10_resnet18(pretrained=False):
    model = resnet18(pretrained=pretrained)

    # Modify first conv layer for CIFAR-10 (3x32x32 images)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    model.maxpool = nn.Identity()  # Remove maxpool to preserve spatial dims

    # Adjust the classifier head
    model.fc = nn.Linear(model.fc.in_features, 1)

    return model

In [20]:
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
# testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_dataset = CIFAKEDataset(root="./cifake", split='train', transform=transform)
test_dataset = CIFAKEDataset(root="./cifake", split='test', transform=transform)

trainloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
testloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)
model = get_cifar10_resnet18().to(device)

# criterion = nn.CrossEntropyLoss()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training loop
for epoch in range(10):
    model.train()
    running_loss = 0.0
    for inputs, labels in trainloader:
        inputs, labels = inputs.to(device), labels.to(device).unsqueeze(1)  # Ensure labels are of shape [batch_size, 1]

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        print(loss.item())

    print(f"Epoch {epoch+1}, Loss: {running_loss / len(trainloader):.4f}")

cuda




0.6944637165870517
1.308629118998164
1.2532340542238671
0.6802178724339711
0.7665213337750174
0.685993341329322
0.5310521216888446
0.5712127753940877
0.6962785335635999
0.4546670992640429
0.4985916171644931
0.5945665165272658
0.5004427712992765
0.5026514723510616
0.5642008900243027
0.472436145940037
0.3619069795749965
0.5855000329538598
0.4008216590009397
0.4520472006761338
0.44582374513629475
0.4592725640432036
0.45641244127909886
0.48097450894601934
0.496886498935055
0.54875304219604
0.44979505700757727
0.35535492436611094
0.38276074075110955
0.39296519897834514
0.3394628198511782
0.4314740957979666
0.3899199337529353
0.2978153595323647
0.34406026990563987
0.40180102901757664
0.3010148691066661
0.39254595261988356
0.41340322479536695
0.46671170952504326
0.3878223739311579
0.42399148195249836
0.3298147538757803
0.4146134170669029
0.39987199002075613
0.42787874913301494
0.3795781960579916
0.42677120859752904
0.31784006890484306
0.3636873925161126
0.32083583642247504
0.429817040364469
0

KeyboardInterrupt: 