In [1]:
!pip install datasets torchvision torch



In [2]:
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader, random_split, ConcatDataset
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.nn.functional as functional
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import matplotlib.pyplot as plt
import numpy as np

In [3]:
# Validation transforms (NO random augmentation)
test_transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])
class DataSetWrapper(Dataset):
  def __init__(self, dataset, label, transform = None):
    self.dataset = dataset
    self.label = label
    self.transform = transform
  def __len__(self):
    return len(self.dataset)
  def __getitem__(self, i):
    sample = self.dataset[i]
    image = sample["image"]
    if self.transform:
      image = self.transform(image)
    label = torch.tensor(self.label, dtype=torch.long)
    return image, label


In [4]:
class Hybrid_CNN_Vit(nn.Module):
  def __init__(self, image_size = 64,
               num_classes = 2,
               cnn_channels = 32,
               num_heads = 4,
               num_layers = 1,
               dropout = 0.4):
    super().__init__()
    self.cnn = nn.Sequential( ## CNN Feature Extractor
        #32x32
        nn.Conv2d(3, 32, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Dropout(0.4),
        nn.MaxPool2d(2),
        #16x16
        nn.Conv2d(32, 64, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Dropout(0.4),
        nn.MaxPool2d(2),
        #8x8
        nn.Conv2d(64, cnn_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Dropout(0.4),
        nn.MaxPool2d(2),
    )
    # Post CNN
    self.feature_h = image_size // 8
    self.feature_w = image_size //8
    self.seq_len = self.feature_h * self.feature_w
    self.embedded_dim = cnn_channels
    # Positional embeddings for tokens
    self.positional_embeddings = nn.Parameter(
        torch.randn(1, self.seq_len, self.embedded_dim)
    )
    #Transformer Encoder
    encoder_layer = nn.TransformerEncoderLayer(
        d_model = self.embedded_dim,
        nhead=num_heads,
        dim_feedforward=self.embedded_dim * 4,
        dropout=dropout,
        activation="gelu",
        batch_first=True
    )
    self.transformer = nn.TransformerEncoder(
        encoder_layer,
        num_layers=num_layers
    )
    # Classification Head
    self.clasifier = nn.Sequential(
        nn.LayerNorm(self.embedded_dim),
        nn.Dropout(0.6),
        nn.Linear(self.embedded_dim, num_classes)
    )
  def forward(self, x):
    x = self.cnn(x)
    b, c, h, w = x.shape
    x = x.view(b, c, h * w).permute(0, 2, 1)
    x += self.positional_embeddings
    x = self.transformer(x)
    x = x.mean(dim=1)
    logits = self.clasifier(x)
    return logits

In [5]:
def accuracy_from_logits(logits, labels):
  preds = logits.argmax(dim=1)
  return (preds==labels).float().mean().item()

In [6]:
class TransformSubset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
    def __len__(self):
        return len(self.subset)
    def __getitem__(self, idx):
        image, label = self.subset[idx]
        if self.transform:
            image = self.transform(image)

        return image, label

Loading the Model from Hugging Face Hub for Prediction

First, we need to import `HfApi` and `hf_hub_download` from `huggingface_hub` to programmatically download your model. Then, we'll initialize your `Hybrid_CNN_Vit` model and load the downloaded state dictionary into it.

In [7]:
def predict_image(model, image_tensor):
    model.eval() # Ensure model is in evaluation mode
    with torch.no_grad():
        image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
        image_tensor = image_tensor.to(device)
        logits = model(image_tensor)
        probabilities = torch.softmax(logits, dim=1)
        _, predicted_class = torch.max(probabilities, 1)
    return predicted_class.item(), probabilities.cpu().numpy()[0]


In [8]:
from huggingface_hub import hf_hub_download

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Your Hugging Face repository ID and the filename of your model
repo_id = "Tomisin05/anime-ai-human-detector"  # Use your actual repo_id
filename = "best_model.pth"

# Download the model file
try:
    downloaded_model_path = hf_hub_download(repo_id=repo_id, filename=filename)
    print(f"Model downloaded to: {downloaded_model_path}")
except Exception as e:
    print(f"Failed to download model from Hugging Face Hub: {e}")
    print("Please ensure the repo_id and filename are correct, and your token (if private repo) is configured.")

# Initialize the model architecture
loaded_model_from_hub = Hybrid_CNN_Vit(dropout=0.3).to(device)

# Load the state dictionary
try:
    loaded_model_from_hub.load_state_dict(torch.load(downloaded_model_path, map_location=device))
    loaded_model_from_hub.eval() # Set the model to evaluation mode
    print("Model loaded successfully from Hugging Face Hub and set to evaluation mode.")
except Exception as e:
    print(f"Failed to load model state dictionary: {e}")

best_model.pth:   0%|          | 0.00/220k [00:00<?, ?B/s]

Model downloaded to: /root/.cache/huggingface/hub/models--Tomisin05--anime-ai-human-detector/snapshots/b6908e6e5c53b46f9aa80a2bfe01adfe55f2a11e/best_model.pth
Model loaded successfully from Hugging Face Hub and set to evaluation mode.


### Making Predictions with the Loaded Model

Now that the model is loaded, you can use it to make predictions on new images. We'll define a helper function to predict the class of a single image and then demonstrate it using an image from your `testing_loader`.

### Loading the best model for reuse

To load the saved model, you first need to create an instance of your `Hybrid_CNN_Vit` class with the same architecture parameters (e.g., `dropout`, `image_size`, `num_classes`, etc., which are currently using their default values, so this is straightforward). Then, you load the saved state dictionary into this new model instance. Finally, it's good practice to set the model to evaluation mode using `model.eval()` if you plan to use it for inference, as this disables dropout and batch normalization updates.

### Uploading an image from your device and making a prediction

First, we'll use `google.colab.files` to allow you to upload an image file from your local machine. Make sure to upload an image file (e.g., `.jpg`, `.png`).

In [None]:
from google.colab import files
from PIL import Image
import io

# Upload a file from your local machine
uploaded = files.upload()

# Get the filename of the uploaded file
if uploaded:
    for fn in uploaded.keys():
        print(f'User uploaded file "{fn}"')
        uploaded_file_name = fn
        break
else:
    print("No file uploaded. Please run the cell again and upload an image.")
    uploaded_file_name = None

if uploaded_file_name:
    # Open the image using PIL
    img_bytes = uploaded[uploaded_file_name]
    uploaded_image = Image.open(io.BytesIO(img_bytes)).convert("RGB") # Ensure RGB format

    # Apply the test transformations
    transformed_uploaded_image = test_transform(uploaded_image)

    # Make a prediction
    predicted_class, probabilities = predict_image(loaded_model_from_hub, transformed_uploaded_image)

    class_names = {0: "Human-drawn", 1: "AI-generated"}

    print(f"\nPrediction for uploaded image ('{uploaded_file_name}'):")
    print(f"Predicted Class: {class_names[predicted_class]}")
    print(f"Probabilities: Human-drawn: {probabilities[0]:.4f}, AI-generated: {probabilities[1]:.4f}")

    # Display the uploaded image
    img_display = transformed_uploaded_image.cpu().numpy().transpose((1, 2, 0))
    img_display = 0.5 * img_display + 0.5 # Undo normalization
    img_display = np.clip(img_display, 0, 1)

    plt.imshow(img_display)
    plt.title(f"Uploaded Image | Predicted: {class_names[predicted_class]}")
    plt.axis('off')
    plt.show()
else:
    print("Cannot proceed with prediction without an uploaded file.")