In [None]:
# Imports
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import ViTImageProcessor

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
img_size = 224  # ViT use 224x224 images
train_transforms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
val_transforms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

path = r"/content/drive/MyDrive/testing - Copy"

# Load dataset
train_data = datasets.ImageFolder(root=path + "/train", transform=train_transforms)
val_data   = datasets.ImageFolder(root=path + "/test", transform=val_transforms)

# Create data loader
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_data, batch_size=32, shuffle=False)

class_names = train_data.classes
print(class_names)  # e.g., ['real', 'fake']

In [None]:
from transformers import ViTForImageClassification, AutoImageProcessor

# Load pre-trained ViT model and image processor
model_name = "google/vit-base-patch16-224"  # ViT Base, patch size 16, 224x224 images
model = ViTForImageClassification.from_pretrained(model_name, num_labels=2, ignore_mismatched_sizes=True)
model.to(device)

processor = AutoImageProcessor.from_pretrained(model_name)
print("Expected image mean:", processor.image_mean, "std:", processor.image_std)

for param in model.vit.parameters():
    param.requires_grad = False

In [None]:
import torch.optim as optim

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss()

# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)          
        logits = outputs.logits          
        loss = criterion(logits, labels) 
        loss.backward()                  
        optimizer.step()                 

        running_loss += loss.item() * images.size(0)
    epoch_loss = running_loss / len(train_data)

    # Validation phase
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            logits = outputs.logits
            _, preds = torch.max(logits, dim=1)    # Predicted class indices
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    val_acc = correct / total

    print(f"Epoch {epoch+1}/{num_epochs} - Training Loss: {epoch_loss:.4f} - Val Accuracy: {val_acc:.4f}")

In [None]:
model.save_pretrained("vit_model")

In [None]:
from huggingface_hub import HfApi

repo_name = "O-ww-O/custom-vit"  # Change this to your desired repo name

api = HfApi()
api.create_repo(repo_name, exist_ok=True)

model.push_to_hub(repo_name)

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

model_name = "O-ww-O/custom-vit"
model = ViTForImageClassification.from_pretrained(model_name, num_labels=2, ignore_mismatched_sizes=True)

In [None]:
from PIL import Image

# Ensure model is in eval mode and on CPU or GPU as available
model.eval()
model = model.to(device)

def predict_image(image_path):
    """Predicts whether an image is real or AI-generated."""
    img = Image.open(image_path).convert("RGB")
    # Apply the same transforms as validation (resize, tensor, normalize)
    img_tensor = val_transforms(img).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(img_tensor)
        logits = outputs.logits
        pred = torch.argmax(logits, dim=1).item()
    label = class_names[pred]
    return label

print(predict_image("/content/drive/MyDrive/testing - Copy/test/REAL/0003 (5).jpg"))  # prints "real" or "fake"