# SwinV2 Small â€” Multiclass by Platform (Fake Images)

This notebook fine-tunes a SwinV2 classifier to categorize fake images by their 'platform' in metadata.json.
- It filters to status=='fake' and uses the 'platform' field as the class label.
- If you want to include real images as an extra class, see notes in the data cell.

In [1]:
import os
import numpy as np
import pandas as pd
from PIL import Image

import torch
from torch.utils.data import Dataset
from torchvision import transforms

from sklearn.model_selection import train_test_split
import evaluate

from transformers import (
    AutoImageProcessor,
    AutoModelForImageClassification,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback,
)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Quick environment check
import transformers
print('torch:', torch.__version__, 'cuda available:', torch.cuda.is_available())
print('transformers:', transformers.__version__)

torch: 2.9.0+cu126 cuda available: True
transformers: 4.57.1


In [15]:
# --- Load and prepare metadata ---
BASE_DIR = './data'
METADATA_FILE = os.path.join(BASE_DIR, 'metadata.json')
MODEL_CHECKPOINT = 'microsoft/swinv2-small-patch4-window8-256'

print(f'Loading metadata from: {METADATA_FILE}')
df = pd.read_json(METADATA_FILE)

# Flatten image_file if nested lists are present
if isinstance(df['image_file'].iloc[0], list):
    df['image_file'] = df['image_file'].str[0]

# Build absolute/full image paths
df['full_path'] = df['image_file'].apply(lambda x: os.path.join(BASE_DIR, x))


df = df[~df['platform'].isna()].copy()

# Build label vocab from platform values
label_names = sorted(df['platform'].unique().tolist())
label2id = {l: i for i, l in enumerate(label_names)}
id2label = {i: l for l, i in label2id.items()}
df['label_id'] = df['platform'].map(label2id).astype(int)

# Split the data stratified by label
train_df, val_df = train_test_split(
    df, test_size=0.2, random_state=42, stratify=df['label_id']
)
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)

print('Classes (platforms):', label_names)
print('Total:', len(df), 'Train:', len(train_df), 'Val:', len(val_df))
train_df.head()

Loading metadata from: ./data\metadata.json
Classes (platforms): ['None', 'dall-E3', 'im', 'sd']
Total: 14062 Train: 11249 Val: 2813


Unnamed: 0,id,image_file,prompts,platform,status,full_path,label_id
0,1389,fake_SD/image_214.jpg,A woman in costume is marching with a large d...,sd,fake,./data\fake_SD/image_214.jpg,3
1,791,fake_SD/6219a249-4462-4f76-ac23-37621c17751b.jpg,A shrub that has been shaped to look like a dog.,sd,fake,./data\fake_SD/6219a249-4462-4f76-ac23-37621c1...,3
2,1200,fake_SD/image_25.jpg,A older bearded man wearing a sports jacket m...,sd,fake,./data\fake_SD/image_25.jpg,3
3,8971,real/SD_dataset_000000252659.jpg,real,,real,./data\real/SD_dataset_000000252659.jpg,0
4,12412,real/hf_unsplash_24713.jpg,real,,real,./data\real/hf_unsplash_24713.jpg,0


In [16]:
# --- Processor and Dataset ---
processor = AutoImageProcessor.from_pretrained(MODEL_CHECKPOINT)
size = processor.size
if isinstance(size, dict):
    target = size.get('height') or size.get('shortest_edge') or 256
else:
    target = int(size) if size is not None else 256

train_augs = transforms.Compose([
    transforms.RandomResizedCrop(target),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ColorJitter(0.1, 0.1, 0.1),
    transforms.RandomRotation(10),
])

class PlatformDataset(Dataset):
    def __init__(self, df, processor, transforms=None):
        self.df = df
        self.processor = processor
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row['full_path']
        label_id = int(row['label_id'])
        try:
            image = Image.open(img_path).convert('RGB')
        except FileNotFoundError:
            return self.__getitem__((idx + 1) % len(self))

        img = self.transforms(image) if self.transforms is not None else image
        inputs = self.processor(images=img, return_tensors='pt')
        pixel_values = inputs['pixel_values'].squeeze(0)
        return {'pixel_values': pixel_values, 'labels': torch.tensor(label_id, dtype=torch.long)}

train_dataset = PlatformDataset(train_df, processor, transforms=train_augs)
val_dataset = PlatformDataset(val_df, processor, transforms=None)
print('Datasets -> train:', len(train_dataset), 'val:', len(val_dataset))

Datasets -> train: 11249 val: 2813


In [17]:
# --- Metrics ---
metric = evaluate.load('accuracy')
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return metric.compute(predictions=preds, references=labels)

In [20]:
# --- TrainingArguments ---
use_fp16 = torch.cuda.is_available()
training_args = TrainingArguments(
    output_dir='./swinv2-multiclass-realtrue',
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    eval_strategy='epoch',
    save_strategy='epoch',
    num_train_epochs=12,
    fp16=use_fp16,
    learning_rate=2e-5,
    logging_dir='./logs',
    logging_strategy='steps',
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model='accuracy',
    remove_unused_columns=False,
    report_to='none',
    dataloader_pin_memory=torch.cuda.is_available(),
    dataloader_num_workers=0,
    disable_tqdm=False,
)

In [21]:
# --- Model + Trainer ---
model = AutoModelForImageClassification.from_pretrained(
    MODEL_CHECKPOINT,
    num_labels=len(label_names),
    ignore_mismatched_sizes=True,
    label2id=label2id,
    id2label=id2label,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
)

Some weights of Swinv2ForImageClassification were not initialized from the model checkpoint at microsoft/swinv2-small-patch4-window8-256 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([4, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([4]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


In [22]:
print('Classes:', id2label)
print('ðŸš€ Starting training...')
_ = trainer.train()
print('âœ… Training finished!')

metrics = trainer.evaluate()
print('Validation metrics:', metrics)

best_ckpt = getattr(trainer.state, 'best_model_checkpoint', None)
save_dir = './best-model-swinv2-multiclass-realtrue'
if best_ckpt:
    print('Best checkpoint:', best_ckpt)
trainer.save_model(save_dir)
print('Saved model to', save_dir)

Classes: {0: 'None', 1: 'dall-E3', 2: 'im', 3: 'sd'}
ðŸš€ Starting training...


Epoch,Training Loss,Validation Loss,Accuracy
1,0.5362,0.689419,0.734092
2,0.4493,0.674576,0.80839
3,0.2816,0.489135,0.869534
4,0.2874,0.54442,0.8738
5,0.4047,0.470495,0.902595
6,0.208,1.24485,0.795592
7,0.2112,0.780988,0.860647


âœ… Training finished!


Validation metrics: {'eval_loss': 0.470495343208313, 'eval_accuracy': 0.9025950942054746, 'eval_runtime': 87.5298, 'eval_samples_per_second': 32.138, 'eval_steps_per_second': 2.011, 'epoch': 7.0}
Best checkpoint: ./swinv2-multiclass-realtrue\checkpoint-7035
Saved model to ./best-model-swinv2-multiclass-realtrue
Saved model to ./best-model-swinv2-multiclass-realtrue
