In [39]:
# Dependencies 
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification
from datasets import load_dataset
from tqdm import tqdm
import gradio as gr
from PIL import Image


In [40]:
# Load the DiffusionDB dataset
dataset = load_dataset("poloclub/diffusiondb", 'large_random_1k')

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])




In [44]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['image', 'prompt', 'seed', 'step', 'cfg', 'sampler', 'width', 'height', 'user_name', 'timestamp', 'image_nsfw', 'prompt_nsfw'],
        num_rows: 1000
    })
})


In [45]:
# Custom Dataset Class
class DiffusionDBDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = Image.open(self.dataset[idx]['image']).convert("RGB")
        label = self.dataset[idx]['label']
        if self.transform:
            image = self.transform(image)
        return image, label


In [50]:
# Create DataLoader
train_dataset = DiffusionDBDataset(dataset['train'], transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Load the DiffusionDB dataset
dataset = load_dataset("poloclub/diffusiondb", 'large_random_1k')

# Determine the number of unique labels
num_labels = len(set(example['prompt'] for example in dataset['train']))

# Load the pre-trained ViT model with the correct number of labels
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(
    model_name, 
    num_labels=num_labels,
    ignore_mismatched_sizes=True  # Add this line to ignore size mismatches
)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)

# Move model to the appropriate device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)




Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([124]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([124, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]



ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe