# Hybrid (ResNet50 + SwinV2 Small) â€” Multiclass by Platform (Fake Images Only)

This notebook fine-tunes a hybrid model that fuses torchvision ResNet18 and Hugging Face SwinV2 Small to classify fake images by their 'platform' field in metadata.json.

Notes:
- Filters the dataset to status == 'fake' and uses 'platform' as label.
- Uses AutoImageProcessor for SwinV2 preprocessing and ImageNet normalization for ResNet.
- Windows-friendly data loading (num_workers=0) and step logging enabled.

In [1]:
import os
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import transforms
import torchvision.models as models

from sklearn.model_selection import train_test_split
import evaluate

from transformers import (
    AutoImageProcessor,
    Swinv2Model,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Environment check
import transformers
print('torch:', torch.__version__, 'cuda available:', torch.cuda.is_available())
print('transformers:', transformers.__version__)

torch: 2.9.0+cu126 cuda available: True
transformers: 4.57.1


In [2]:
# --- Load and prepare metadata ---
BASE_DIR = './data'
METADATA_FILE = os.path.join(BASE_DIR, 'metadata.json')
MODEL_CHECKPOINT = 'microsoft/swinv2-small-patch4-window8-256'

print(f'Loading metadata from: {METADATA_FILE}')
df = pd.read_json(METADATA_FILE)

# Flatten image_file if nested lists are present
if isinstance(df['image_file'].iloc[0], list):
    df['image_file'] = df['image_file'].str[0]

# Build absolute/full image paths
df['full_path'] = df['image_file'].apply(lambda x: os.path.join(BASE_DIR, x))

# Filter to fake images and require platform
df = df[df['status'] == 'fake'].copy()
if 'platform' not in df.columns:
    raise ValueError("metadata.json must contain a 'platform' key for fake images.")
df = df[~df['platform'].isna()].copy()

# Build label vocab from platform values
label_names = sorted(df['platform'].unique().tolist())
label2id = {l: i for i, l in enumerate(label_names)}
id2label = {i: l for l, i in label2id.items()}
df['label_id'] = df['platform'].map(label2id).astype(int)

# Split
train_df, val_df = train_test_split(
    df, test_size=0.2, random_state=42, stratify=df['label_id']
)
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)

print('Classes (platforms):', label_names)
print('Total:', len(df), 'Train:', len(train_df), 'Val:', len(val_df))
train_df.head()

Loading metadata from: ./data\metadata.json
Classes (platforms): ['dall-E3', 'im', 'sd']
Total: 7109 Train: 5687 Val: 1422


Unnamed: 0,id,image_file,prompts,platform,status,full_path,label_id
0,5970,fake_DALLE/DALL_E11.webp,A small orange airplane prepares to take off.,dall-E3,fake,./data\fake_DALLE/DALL_E11.webp,0
1,4757,fake_IMAGEN/806c023ce8b7e2b0bc95ebc9ae7739ed.png,hyperrealism woman wearing a black robe holdin...,im,fake,./data\fake_IMAGEN/806c023ce8b7e2b0bc95ebc9ae7...,1
2,4450,fake_IMAGEN/c51379554ca11eafb0d8d7f6ad6eb0c5.png,The square coaster was next to the circular gl...,im,fake,./data\fake_IMAGEN/c51379554ca11eafb0d8d7f6ad6...,1
3,1709,fake_SD/Image_sd2 34.jpg,A man brushing his child's teeth while the ch...,sd,fake,./data\fake_SD/Image_sd2 34.jpg,2
4,2698,fake_IMAGEN/image_fx_a_man_is_sitting_at_a_tab...,A man is sitting at a table with a drink,im,fake,./data\fake_IMAGEN/image_fx_a_man_is_sitting_a...,1


In [3]:
# --- Processor and Datasets (Swin + ResNet branches) ---
processor = AutoImageProcessor.from_pretrained(MODEL_CHECKPOINT)
target = processor.size['height'] if isinstance(processor.size, dict) and 'height' in processor.size else 256

# Train-time augmentations on PIL images
train_augs = transforms.Compose([
    transforms.RandomResizedCrop(target),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ColorJitter(0.1, 0.1, 0.1),
    transforms.RandomRotation(10),
])

# ResNet preprocessing (ImageNet normalization)
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
resnet_base_transforms = transforms.Compose([
    transforms.Resize(target),
    transforms.CenterCrop(target),
    transforms.ToTensor(),
    transforms.Normalize(mean=imagenet_mean, std=imagenet_std),
])

class HybridMultiDataset(Dataset):
    def __init__(self, df, processor, transforms=None):
        self.df = df
        self.processor = processor
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row['full_path']
        label_id = int(row['label_id'])
        try:
            image = Image.open(img_path).convert('RGB')
        except FileNotFoundError:
            return self.__getitem__((idx + 1) % len(self))

        pil_img = self.transforms(image) if self.transforms is not None else image

        # SwinV2 branch (processor handles normalization)
        swin_inputs = self.processor(images=pil_img, return_tensors='pt')
        pixel_values = swin_inputs['pixel_values'].squeeze(0)

        # ResNet branch (ImageNet normalized tensor)
        resnet_pixel_values = resnet_base_transforms(pil_img)

        return {
            'pixel_values': pixel_values,
            'resnet_pixel_values': resnet_pixel_values,
            'labels': torch.tensor(label_id, dtype=torch.long),
        }

train_dataset = HybridMultiDataset(train_df, processor, transforms=train_augs)
val_dataset = HybridMultiDataset(val_df, processor, transforms=None)
print('Datasets -> train:', len(train_dataset), 'val:', len(val_dataset))

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Datasets -> train: 5687 val: 1422


In [4]:
# --- Hybrid Model (ResNet18 + SwinV2 Small) ---
class ResNetSwinClassifier(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        # ResNet50
        try:
            weights = models.ResNet50_Weights.DEFAULT
            self.resnet = models.resnet50(weights=weights)
        except AttributeError:
            self.resnet = models.resnet50(pretrained=True)
        self.resnet.fc = nn.Identity()
        resnet_out_dim = 2048

        # SwinV2 (feature extractor)
        self.swin = Swinv2Model.from_pretrained('microsoft/swinv2-small-patch4-window8-256')
        swin_out_dim = self.swin.config.hidden_size

        # Fusion head
        self.classifier = nn.Sequential(
            nn.Linear(resnet_out_dim + swin_out_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_labels),
        )

    def forward(self, pixel_values=None, resnet_pixel_values=None, labels=None):
        # Swin features
        swin_outputs = self.swin(pixel_values=pixel_values)
        swin_feat = getattr(swin_outputs, 'pooler_output', None)
        if swin_feat is None:
            swin_feat = swin_outputs.last_hidden_state.mean(dim=1)

        # ResNet features
        resnet_feat = self.resnet(resnet_pixel_values)

        # Fuse
        combined = torch.cat([resnet_feat, swin_feat], dim=1)
        logits = self.classifier(combined)

        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)
        return {'loss': loss, 'logits': logits}

In [5]:
# --- Metrics ---
metric = evaluate.load('accuracy')
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return metric.compute(predictions=preds, references=labels)

In [8]:
# --- TrainingArguments ---
use_fp16 = torch.cuda.is_available()
training_args = TrainingArguments(
    output_dir='./hybrid-multiclass50',
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    eval_strategy='epoch',
    save_strategy='epoch',
    num_train_epochs=15,
    fp16=use_fp16,
    learning_rate=2e-5,
    logging_dir='./logs',
    logging_strategy='steps',
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    remove_unused_columns=False,
    report_to='none',
    dataloader_pin_memory=torch.cuda.is_available(),
    dataloader_num_workers=0,
    disable_tqdm=False,
)

In [9]:
# --- Model + Trainer ---
model = ResNetSwinClassifier(num_labels=len(label_names))

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5)],
)

  trainer = Trainer(


In [10]:
print('Classes:', id2label)
print('ðŸš€ Starting training...')
_ = trainer.train()
print('âœ… Training finished!')

metrics = trainer.evaluate()
print('Validation metrics:', metrics)

best_ckpt = getattr(trainer.state, 'best_model_checkpoint', None)
save_dir = './best-model-hybrid-multiclass50'
if best_ckpt:
    print('Best checkpoint:', best_ckpt)
trainer.save_model(save_dir)
print('Saved model to', save_dir)

Classes: {0: 'dall-E3', 1: 'im', 2: 'sd'}
ðŸš€ Starting training...


Epoch,Training Loss,Validation Loss,Accuracy
1,0.3866,0.543783,0.797468
2,0.1719,0.344063,0.881857
3,0.368,0.457172,0.876231
4,0.2946,0.37414,0.892405
5,0.2427,0.361702,0.912096
6,0.1863,0.957632,0.811533
7,0.3107,0.678007,0.875527
8,0.0739,0.425549,0.921238
9,0.3023,0.697281,0.883966
10,0.3466,0.708805,0.889592


âœ… Training finished!


Validation metrics: {'eval_loss': 0.4255494773387909, 'eval_accuracy': 0.9212376933895922, 'eval_runtime': 64.4528, 'eval_samples_per_second': 22.063, 'eval_steps_per_second': 1.381, 'epoch': 13.0}
Best checkpoint: ./hybrid-multiclass50\checkpoint-5688
Saved model to ./best-model-hybrid-multiclass50
