In [None]:
!pip install transformers==4.25.1

In [None]:
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor
from google.colab import drive
import os
from PIL import Image, UnidentifiedImageError

**Data Preparation & Cleaning:**
- Here, it includes unzipping the dataset and removing corrupted images.
- We have applied data augmentation (random horizontal flips) for training and normalization for both training and evaluation.

In [None]:
drive.mount('/content/drive')

zip_path = "/content/drive/MyDrive/catsanddogs.zip"
extract_dir = "./"

if not os.path.exists("./PetImages"):
    !unzip -q "{zip_path}" -d "{extract_dir}"
    print("Dataset unzipped.")
else:
    print("PetImages folder already exists, skipping unzip.")


In [None]:
def remove_corrupted_images(directory):
    removed_count = 0
    for root, dirs, files in os.walk(directory):
        for file in files:
            file_path = os.path.join(root, file)
            try:
                with Image.open(file_path) as img:
                    img.verify()
            except (UnidentifiedImageError, IOError, SyntaxError) as e:
                print(f"Removing corrupted file: {file_path}")
                os.remove(file_path)
                removed_count += 1
    print(f"Removed {removed_count} corrupted image files.")
remove_corrupted_images("./kagglecatsanddogs_5340/PetImages")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
data_dir = "./kagglecatsanddogs_5340/PetImages"

In [None]:
def is_valid_image(filename):
    valid_extensions = ('.jpg', '.jpeg', '.png', '.bmp')
    return filename.lower().endswith(valid_extensions)

In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
])

val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
])

In [None]:
full_dataset = datasets.ImageFolder(
    root=data_dir,
    transform=train_transform,
    is_valid_file=is_valid_image
)
print("Total images found:", len(full_dataset))

**Dataset Splitting:**
- The data is divided into 70% training, 15% validation, and 15% test sets.

In [None]:
dataset_size = len(full_dataset)
train_size = int(0.7 * dataset_size)
val_size   = int(0.15 * dataset_size)
test_size  = dataset_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(full_dataset, [train_size, val_size, test_size])
print(f"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}, Test samples: {len(test_dataset)}")

In [None]:
val_dataset.dataset.transform = val_test_transform
test_dataset.dataset.transform = val_test_transform

In [None]:
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
print(f"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}, Testing samples: {len(test_dataset)}")

**Model Setup & Training:**
- We have fine-tuned a pre-trained Vision Transformer (ViT) using mixed precision to handle a larger batch size of 64 efficiently on a 15GB GPU.
- We are using AdamW as the optimizer with a learning rate of 5e-5 and and cross-entropy as the loss function.

In [None]:
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    num_labels=2,
    ignore_mismatched_sizes=True
)
model.to(device)

optimizer = AdamW(model.parameters(), lr=5e-5)
criterion = nn.CrossEntropyLoss()
num_epochs = 5

In [None]:
print("Starting training...")
scaler = torch.cuda.amp.GradScaler()

for epoch in range(num_epochs):
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    # Training Phase.
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(images).logits
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    train_loss = running_loss / total
    train_acc  = correct / total

    # Validation Phase.
    model.eval()
    val_loss, val_correct, val_total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            with torch.cuda.amp.autocast():
                outputs = model(images).logits
                loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    val_loss /= val_total
    val_acc  = val_correct / val_total

    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

print("Training complete!")

**Training Metrics:**
- In epoch 1 Train Loss = 0.0302 (99.05% accuracy), Val Loss = 0.0166 (99.49% accuracy).
- But in others we received almost perfect training and validation accuracies with minor fluctuations in loss.

In [None]:
model_save_path = "vit_cats_dogs.pth"
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

In [None]:
model.eval()
test_loss, test_correct, test_total = 0.0, 0, 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        with torch.cuda.amp.autocast():
            outputs = model(images).logits
            loss = criterion(outputs, labels)
        test_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()

test_loss /= test_total
test_acc  = test_correct / test_total
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

Here we received excellent accuracy on test set as well. So, The high accuracy in training, validation, and test sets menas that our fine-tuned ViT model is highly effective at classifying cat vs. dog images with excellent generalization.

The below code is app.py file that I used to deploy model on hugging face using gradio.

In [None]:
import torch
import torchvision.transforms as transforms
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image
import gradio as gr

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
])

model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224", 
    num_labels=2,
    ignore_mismatched_sizes=True 
)
model.load_state_dict(torch.load("vit_cats_dogs.pth", map_location=device))
model.to(device)
model.eval()

def classify_image(image: Image.Image):
    
    image = image.convert("RGB")
    img_tensor = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model(img_tensor).logits
        probs = torch.nn.functional.softmax(logits, dim=1)
    
    prob_cat = probs[0][0].item()
    prob_dog = probs[0][1].item()
    label = "Cat" if prob_cat > prob_dog else "Dog"
    return label, {"Cat": prob_cat, "Dog": prob_dog}

interface = gr.Interface(
    fn=classify_image,
    inputs=gr.Image(type="pil", label="Upload an image"),
    outputs=[
        gr.Label(label="Predicted Label"),
        gr.JSON(label="Probabilities")
    ],
    title="Cat vs. Dog Classifier",
    description="Upload an image of a cat or a dog and get the prediction."
)

if __name__ == "__main__":
    interface.launch()

**References**

- PyTorch: https://pytorch.org/docs/stable/index.html
- Dataset: https://www.microsoft.com/en-us/download/details.aspx?id=54765
- Torchvision: https://pytorch.org/vision/stable/index.html
- Transformers: https://huggingface.co/docs/transformers
- Gradio: https://www.gradio.app/docs
- Pillow: https://pillow.readthedocs.io/en/stable/
- Hugging face spaces: https://huggingface.co/docs/hub/spaces
- ViT pretrained model: https://huggingface.co/google/vit-base-patch16-224
- ViT paper: https://arxiv.org/pdf/2010.11929