In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("jangedoo/utkface-new")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/jangedoo/utkface-new?dataset_version_number=1...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 331M/331M [00:05<00:00, 67.7MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/jangedoo/utkface-new/versions/1


In [2]:
# Import Libraries
import os
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from transformers import ViTFeatureExtractor, ViTForImageClassification
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using Device:", device)

# Dataset Paths
train_path = "/root/.cache/kagglehub/datasets/jangedoo/utkface-new/versions/1/UTKFace"
val_path = "/root/.cache/kagglehub/datasets/jangedoo/utkface-new/versions/1/crop_part1"

# Define Transformer for Preprocessing with Data Augmentation
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),  # Random cropping
    transforms.RandomHorizontalFlip(),  # Random flipping
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize for ViT
])

# Custom Dataset Class for UTKFace
class UTKFaceDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_files = os.listdir(root_dir)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.root_dir, img_name)
        image = Image.open(img_path).convert("RGB")

        # Extract Age & Gender from Filename (Format: age_gender_race.jpg)
        age, gender, _ = img_name.split("_")[:3]
        age = int(age) / 100.0  # Normalize age (0-1)
        gender = int(gender)  # 0: Male, 1: Female

        if self.transform:
            image = self.transform(image)

        return image, torch.tensor([age, gender], dtype=torch.float32)

# Load Datasets
train_dataset = UTKFaceDataset(train_path, transform)
val_dataset = UTKFaceDataset(val_path, transform)

# Data Loaders
batch_size = 64  # Adjust batch size for memory efficiency
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# Load Pretrained ViT Model
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=2  # Age & Gender Prediction
)
model.to(device)

# Define Loss Functions
criterion_age = nn.MSELoss()  # Age regression loss
criterion_gender = nn.BCEWithLogitsLoss()  # Gender classification loss

# Loss Weighting
alpha = 0.7  # Weight for age loss
beta = 0.3   # Weight for gender loss

# Combined Loss Function
def combined_loss(outputs, labels):
    age_output, gender_output = outputs[:, 0], outputs[:, 1]
    age_label, gender_label = labels[:, 0], labels[:, 1]

    loss_age = criterion_age(age_output, age_label)
    loss_gender = criterion_gender(gender_output, gender_label)

    return alpha * loss_age + beta * loss_gender

# Define Optimizer with Lower Learning Rate for Stability
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-4)


Using Device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:

# Training Loop
epochs = 10
for epoch in range(epochs):
    model.train()
    total_loss, age_mae, gender_correct, total_samples = 0, 0, 0, 0

    for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images).logits

        # Compute loss
        loss = combined_loss(outputs, labels)
        loss.backward()
        optimizer.step()

        # Extract predictions
        age_output = outputs[:, 0] * 100  # De-normalize age
        gender_output = outputs[:, 1]
        age_label = labels[:, 0] * 100  # De-normalize age
        gender_label = labels[:, 1].float()

        # Compute Metrics
        age_mae += torch.abs(age_output - age_label).sum().item()
        gender_pred = (torch.sigmoid(gender_output) > 0.5).float()
        gender_correct += (gender_pred == gender_label).sum().item()
        total_loss += loss.item()
        total_samples += labels.size(0)

    # Compute epoch metrics
    epoch_loss = total_loss / len(train_loader)
    epoch_age_mae = age_mae / total_samples
    epoch_gender_accuracy = (gender_correct / total_samples) * 100

    print(f"Epoch {epoch+1}: Loss = {epoch_loss:.4f}, Age MAE = {epoch_age_mae:.2f}, Gender Acc = {epoch_gender_accuracy:.2f}%")

# Validation Phase
model.eval()
total_loss, age_mae, gender_correct, total_samples = 0, 0, 0, 0

with torch.no_grad():
    for images, labels in tqdm(val_loader, desc="Validating"):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images).logits

        # Compute loss
        loss = combined_loss(outputs, labels)
        total_loss += loss.item()

        # Extract predictions
        age_output = outputs[:, 0] * 100  # De-normalize age
        gender_output = outputs[:, 1]
        age_label = labels[:, 0] * 100  # De-normalize
        gender_label = labels[:, 1].float()

        # Compute Metrics
        age_mae += torch.abs(age_output - age_label).sum().item()
        gender_pred = (torch.sigmoid(gender_output) > 0.5).float()
        gender_correct += (gender_pred == gender_label).sum().item()
        total_samples += labels.size(0)

# Compute final validation results
val_loss = total_loss / len(val_loader)
val_age_mae = age_mae / total_samples
val_gender_accuracy = (gender_correct / total_samples) * 100

print(f"Validation Loss: {val_loss:.4f}")
print(f"Validation Age MAE: {val_age_mae:.2f}, Validation Gender Accuracy: {val_gender_accuracy:.2f}%")

# Save the model
torch.save(model.state_dict(), "vit_utkface.pth")
print("Model Saved Successfully! ‚úÖ")

Training Epoch 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [13:15<00:00,  2.14s/it]


Epoch 1: Loss = 0.1249, Age MAE = 8.81, Gender Acc = 86.35%


Training Epoch 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [13:26<00:00,  2.17s/it]


Epoch 2: Loss = 0.0779, Age MAE = 6.96, Gender Acc = 90.45%


Training Epoch 3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [13:27<00:00,  2.18s/it]


Epoch 3: Loss = 0.0698, Age MAE = 6.69, Gender Acc = 91.30%


Training Epoch 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [13:26<00:00,  2.17s/it]


Epoch 4: Loss = 0.0657, Age MAE = 6.45, Gender Acc = 91.95%


Training Epoch 5: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [13:25<00:00,  2.17s/it]


Epoch 5: Loss = 0.0629, Age MAE = 6.28, Gender Acc = 92.24%


Training Epoch 6: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [13:25<00:00,  2.17s/it]


Epoch 6: Loss = 0.0600, Age MAE = 6.19, Gender Acc = 92.69%


Training Epoch 7: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [13:25<00:00,  2.17s/it]


Epoch 7: Loss = 0.0585, Age MAE = 6.05, Gender Acc = 92.82%


Training Epoch 8: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [13:24<00:00,  2.17s/it]


Epoch 8: Loss = 0.0562, Age MAE = 6.02, Gender Acc = 93.07%


Training Epoch 9: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [13:24<00:00,  2.17s/it]


Epoch 9: Loss = 0.0547, Age MAE = 5.99, Gender Acc = 93.36%


Training Epoch 10: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 371/371 [13:25<00:00,  2.17s/it]


Epoch 10: Loss = 0.0522, Age MAE = 5.98, Gender Acc = 93.79%


Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 153/153 [02:00<00:00,  1.27it/s]


Validation Loss: 0.0661
Validation Age MAE: 5.86, Validation Gender Accuracy: 90.78%
Model Saved Successfully! ‚úÖ


In [5]:
import torch
import numpy as np
import cv2
import os
from PIL import Image
from torchvision import transforms
from google.colab import files
from google.colab.output import eval_js
from IPython.display import display, Javascript
from base64 import b64decode

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using Device:", device)

# Load the saved model
model_path = "vit_utkface.pth"  # Path to the saved model
model = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=2  # Age & Gender Prediction
)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
print("‚úÖ Model loaded successfully!")

# Define the same transformations used during training
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to ViT input size
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize as used during training
])

# üîπ Function to preprocess image
def preprocess_image(img_path):
    """Loads and preprocesses an image for age & gender prediction."""
    img = Image.open(img_path).convert("RGB")
    img = transform(img).unsqueeze(0).to(device)  # Add batch dimension and move to device
    return img

# üîπ Function to predict age & gender
def predict_age_gender(img_path):
    """Predicts age & gender from an input image."""
    img_array = preprocess_image(img_path)

    with torch.no_grad():
        outputs = model(img_array).logits  # Perform inference

    # Extract predictions
    age_output = outputs[0, 0].item() * 100  # De-normalize age
    gender_output = outputs[0, 1].item()
    gender_pred = "Male" if gender_output < 0.5 else "Female"  # Gender prediction

    print(f"Predicted Age: {age_output:.2f}")
    print(f"Predicted Gender: {gender_pred}")

# üîπ Function to capture image from webcam
def take_photo():
    """Captures an image using Colab's JavaScript webcam interface."""
    js = Javascript('''
        async function takePhoto() {
            const video = document.createElement('video');
            const stream = await navigator.mediaDevices.getUserMedia({ video: true });
            document.body.appendChild(video);
            video.srcObject = stream;
            await new Promise((resolve) => (video.onloadedmetadata = resolve));
            video.play();

            return new Promise((resolve) => {
                document.addEventListener('keydown', async (event) => {
                    if (event.key === 'Enter') {
                        event.preventDefault();  // Prevent default Enter key behavior
                        const canvas = document.createElement('canvas');
                        canvas.width = video.videoWidth;
                        canvas.height = video.videoHeight;
                        canvas.getContext('2d').drawImage(video, 0, 0);
                        stream.getTracks().forEach(track => track.stop());
                        document.body.removeChild(video);
                        resolve(canvas.toDataURL('image/jpeg'));
                    }
                });
            });
        }
    ''')

    display(js)
    image_data = eval_js('takePhoto()')
    image_bytes = b64decode(image_data.split(',')[1])
    image_path = "/content/captured_image.jpg"
    with open(image_path, "wb") as f:
        f.write(image_bytes)
    print("üì∏ Image Captured!")
    return image_path

# üîπ Function to upload image
def upload_image():
    """Uploads an image and returns its filename."""
    uploaded = files.upload()
    for filename in uploaded.keys():
        print(f"üì∏ Uploaded image: {filename}")
        return filename
    return None

# üîπ Main function to choose input method and predict
def main():
    while True:
        choice = input("Choose input method: (1) Upload Image (2) Capture Photo (3) Exit: ")
        if choice == '1':
            image_path = upload_image()
        elif choice == '2':
            image_path = take_photo()
        elif choice == '3':
            print("üö™ Exiting...")
            break
        else:
            print("‚ùå Invalid choice. Please enter 1, 2, or 3.")
            continue

        if image_path:
            predict_age_gender(image_path)
            os.remove(image_path)  # Cleanup after prediction

# üîπ Run the program
main()

Using Device: cuda


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  model.load_state_dict(torch.load(model_path, map_location=device))


‚úÖ Model loaded successfully!
Choose input method: (1) Upload Image (2) Capture Photo (3) Exit: 2


<IPython.core.display.Javascript object>

üì∏ Image Captured!
Predicted Age: 28.69
Predicted Gender: Male
Choose input method: (1) Upload Image (2) Capture Photo (3) Exit: 2


<IPython.core.display.Javascript object>

üì∏ Image Captured!
Predicted Age: 28.69
Predicted Gender: Male
Choose input method: (1) Upload Image (2) Capture Photo (3) Exit: 2


<IPython.core.display.Javascript object>

üì∏ Image Captured!
Predicted Age: 27.40
Predicted Gender: Male
Choose input method: (1) Upload Image (2) Capture Photo (3) Exit: 2


<IPython.core.display.Javascript object>

üì∏ Image Captured!
Predicted Age: 23.04
Predicted Gender: Male
Choose input method: (1) Upload Image (2) Capture Photo (3) Exit: 2


<IPython.core.display.Javascript object>

üì∏ Image Captured!
Predicted Age: 20.01
Predicted Gender: Male
Choose input method: (1) Upload Image (2) Capture Photo (3) Exit: 2


<IPython.core.display.Javascript object>

üì∏ Image Captured!
Predicted Age: 25.37
Predicted Gender: Male
Choose input method: (1) Upload Image (2) Capture Photo (3) Exit: 2


<IPython.core.display.Javascript object>

üì∏ Image Captured!
Predicted Age: 36.87
Predicted Gender: Male
Choose input method: (1) Upload Image (2) Capture Photo (3) Exit: 2


<IPython.core.display.Javascript object>

üì∏ Image Captured!
Predicted Age: 26.84
Predicted Gender: Male
Choose input method: (1) Upload Image (2) Capture Photo (3) Exit: 2


<IPython.core.display.Javascript object>

üì∏ Image Captured!
Predicted Age: 28.49
Predicted Gender: Male
Choose input method: (1) Upload Image (2) Capture Photo (3) Exit: 1


Saving sun.jpg to sun.jpg
üì∏ Uploaded image: sun.jpg
Predicted Age: 27.52
Predicted Gender: Male
Choose input method: (1) Upload Image (2) Capture Photo (3) Exit: 1


Saving ib.jpg to ib.jpg
üì∏ Uploaded image: ib.jpg
Predicted Age: 33.07
Predicted Gender: Male
Choose input method: (1) Upload Image (2) Capture Photo (3) Exit: 3
üö™ Exiting...
