In [1]:
import numpy as np
import pandas as pd
from os.path import join
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomHorizontalFlip
from torchvision import models
import torch.optim as optim
from torch.utils.data import DataLoader

class AgeDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, annot_path, train=True):
        super(AgeDataset, self).__init__()
        self.data_path = data_path
        self.train = train
        self.ann = pd.read_csv(annot_path)
        self.files = self.ann['file_id']
        self.ages = self.ann['age'] if train else None
        self.transform = self._transform(224, train)

    @staticmethod
    def _convert_image_to_rgb(image):
        return image.convert("RGB")

    def _transform(self, n_px, train):
        transforms = [
            Resize(n_px),
            self._convert_image_to_rgb,
            ToTensor(),
            Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]
        if train:
            transforms.insert(0, RandomHorizontalFlip())
        return Compose(transforms)

    def read_img(self, file_name):
        im_path = join(self.data_path, file_name)   
        img = Image.open(im_path)
        img = self.transform(img)
        return img

    def __getitem__(self, index):
        file_name = self.files[index]
        img = self.read_img(file_name)
        if self.train:
            age = self.ages[index]
            return img, age
        else:
            return img

    def __len__(self):
        return len(self.files)

# Paths
train_path = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/train'
train_ann = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/train.csv'
train_dataset = AgeDataset(train_path, train_ann, train=True)

test_path = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/test'
test_ann = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/submission.csv'
test_dataset = AgeDataset(test_path, test_ann, train=False)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Using ResNet-50 model
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)
model = model.to(device)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)

def train_model(model, train_loader, criterion, optimizer, num_epochs=40):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, ages in tqdm(train_loader):
            images = images.to(device)
            ages = ages.to(device).float().unsqueeze(1)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, ages)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
        epoch_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

train_model(model, train_loader, criterion, optimizer)

@torch.no_grad()
def predict(loader, model):
    model.eval()
    predictions = []
    for img in tqdm(loader):
        img = img.to(device)
        pred = model(img)
        predictions.extend(pred.flatten().detach().cpu().numpy())
    return predictions

preds = predict(test_loader, model)

submit = pd.read_csv(test_ann)
submit['age'] = preds
submit.to_csv('/kaggle/working/submission.csv', index=False)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 169MB/s]
100%|██████████| 667/667 [05:10<00:00,  2.15it/s]


Epoch 1/40, Loss: 134.8246


100%|██████████| 667/667 [02:50<00:00,  3.92it/s]


Epoch 2/40, Loss: 76.6526


100%|██████████| 667/667 [02:48<00:00,  3.97it/s]


Epoch 3/40, Loss: 66.1359


100%|██████████| 667/667 [02:48<00:00,  3.96it/s]


Epoch 4/40, Loss: 60.7405


100%|██████████| 667/667 [02:47<00:00,  3.97it/s]


Epoch 5/40, Loss: 55.0343


100%|██████████| 667/667 [02:56<00:00,  3.78it/s]


Epoch 6/40, Loss: 50.1860


100%|██████████| 667/667 [02:48<00:00,  3.95it/s]


Epoch 7/40, Loss: 46.5915


100%|██████████| 667/667 [02:48<00:00,  3.95it/s]


Epoch 8/40, Loss: 42.6036


100%|██████████| 667/667 [02:48<00:00,  3.96it/s]


Epoch 9/40, Loss: 37.9126


100%|██████████| 667/667 [02:48<00:00,  3.95it/s]


Epoch 10/40, Loss: 34.7590


100%|██████████| 667/667 [02:48<00:00,  3.95it/s]


Epoch 11/40, Loss: 32.3691


100%|██████████| 667/667 [02:48<00:00,  3.97it/s]


Epoch 12/40, Loss: 29.9168


100%|██████████| 667/667 [02:48<00:00,  3.97it/s]


Epoch 13/40, Loss: 27.1000


100%|██████████| 667/667 [02:48<00:00,  3.97it/s]


Epoch 14/40, Loss: 25.5135


100%|██████████| 667/667 [02:48<00:00,  3.96it/s]


Epoch 15/40, Loss: 22.9679


100%|██████████| 667/667 [02:49<00:00,  3.94it/s]


Epoch 16/40, Loss: 20.5791


100%|██████████| 667/667 [02:48<00:00,  3.97it/s]


Epoch 17/40, Loss: 19.7139


100%|██████████| 667/667 [02:49<00:00,  3.95it/s]


Epoch 18/40, Loss: 18.0203


100%|██████████| 667/667 [02:50<00:00,  3.91it/s]


Epoch 19/40, Loss: 16.2756


100%|██████████| 667/667 [02:48<00:00,  3.96it/s]


Epoch 20/40, Loss: 14.9222


100%|██████████| 667/667 [02:47<00:00,  3.99it/s]


Epoch 21/40, Loss: 13.7116


100%|██████████| 667/667 [02:47<00:00,  3.99it/s]


Epoch 22/40, Loss: 13.3002


100%|██████████| 667/667 [02:47<00:00,  3.99it/s]


Epoch 23/40, Loss: 12.3528


100%|██████████| 667/667 [02:47<00:00,  3.98it/s]


Epoch 24/40, Loss: 11.3644


100%|██████████| 667/667 [02:48<00:00,  3.96it/s]


Epoch 25/40, Loss: 11.1597


100%|██████████| 667/667 [02:48<00:00,  3.96it/s]


Epoch 26/40, Loss: 10.9036


100%|██████████| 667/667 [02:47<00:00,  3.97it/s]


Epoch 27/40, Loss: 10.7838


100%|██████████| 667/667 [02:49<00:00,  3.95it/s]


Epoch 28/40, Loss: 10.1026


100%|██████████| 667/667 [02:50<00:00,  3.92it/s]


Epoch 29/40, Loss: 8.7055


100%|██████████| 667/667 [02:48<00:00,  3.95it/s]


Epoch 30/40, Loss: 8.2000


100%|██████████| 667/667 [02:49<00:00,  3.93it/s]


Epoch 31/40, Loss: 8.5891


100%|██████████| 667/667 [02:49<00:00,  3.92it/s]


Epoch 32/40, Loss: 8.9071


100%|██████████| 667/667 [02:50<00:00,  3.92it/s]


Epoch 33/40, Loss: 8.1000


100%|██████████| 667/667 [02:49<00:00,  3.94it/s]


Epoch 34/40, Loss: 7.3260


100%|██████████| 667/667 [02:49<00:00,  3.95it/s]


Epoch 35/40, Loss: 6.5990


100%|██████████| 667/667 [02:49<00:00,  3.92it/s]


Epoch 36/40, Loss: 6.9436


100%|██████████| 667/667 [02:49<00:00,  3.93it/s]


Epoch 37/40, Loss: 6.3946


100%|██████████| 667/667 [02:49<00:00,  3.94it/s]


Epoch 38/40, Loss: 6.4934


100%|██████████| 667/667 [02:49<00:00,  3.94it/s]


Epoch 39/40, Loss: 6.0386


100%|██████████| 667/667 [02:49<00:00,  3.93it/s]


Epoch 40/40, Loss: 6.1048


100%|██████████| 61/61 [00:19<00:00,  3.07it/s]
