In [1]:
import numpy as np
import pandas as pd
from os.path import join
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision import models
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split

# Dataset class
class AgeDataset(Dataset):
    def __init__(self, data_path, annot_path, train=True):
        super(AgeDataset, self).__init__()
        self.data_path = data_path
        self.train = train
        self.ann = pd.read_csv(annot_path)
        self.files = self.ann['file_id']
        if train:
            self.ages = self.ann['age']
        self.transform = self._transform(224)

    @staticmethod
    def _convert_image_to_rgb(image):
        return image.convert("RGB")

    def _transform(self, n_px):
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        return Compose([
            Resize(n_px),
            self._convert_image_to_rgb,
            ToTensor(),
            Normalize(mean, std),
        ])

    def read_img(self, file_name):
        img = Image.open(join(self.data_path, file_name))
        return self.transform(img)

    def __getitem__(self, index):
        file_name = self.files[index]
        img = self.read_img(file_name)
        if self.train:
            age = self.ages[index]
            return img, age
        else:
            return img

    def __len__(self):
        return len(self.files)

# Paths
train_path = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/train'
train_ann = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/train.csv'
test_path = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/test'
test_ann = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/submission.csv'

# Dataset and DataLoader
dataset = AgeDataset(train_path, train_ann, train=True)
train_dataset, val_dataset = random_split(dataset, [int(0.8 * len(dataset)), len(dataset) - int(0.8 * len(dataset))])
test_dataset = AgeDataset(test_path, test_ann, train=False)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)

# Prediction function
@torch.no_grad
def predict(loader, model, device):
    model.eval()
    predictions = []
    for img in tqdm(loader):
        img = img.to(device)
        pred = model(img)
        predictions.extend(pred.cpu().flatten().tolist())
    return predictions

# Validation loop
def validation_loop(val_loader, model, criterion, device):
    model.eval()
    running_val_loss = 0.0
    with torch.no_grad():
        for images, ages in tqdm(val_loader):
            images, ages = images.to(device), ages.to(device).float().unsqueeze(1)
            outputs = model(images)
            loss = criterion(outputs, ages)
            running_val_loss += loss.item() * images.size(0)
    return running_val_loss / len(val_loader.dataset)

# Training loop
def train_model(model, train_loader, val_loader, test_loader, criterion, optimizer, device, num_epochs=5):
    for epoch in range(num_epochs):
        model.train()
        train_running_loss = 0.0
        for images, ages in tqdm(train_loader):
            images, ages = images.to(device), ages.to(device).float().unsqueeze(1)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, ages)
            loss.backward()
            optimizer.step()
            train_running_loss += loss.item() * images.size(0)

        epoch_loss = train_running_loss / len(train_loader.dataset)
        val_loss = validation_loop(val_loader, model, criterion, device)
        print(f"Epoch {epoch + 1}, Training Loss: {epoch_loss:.4f}, Validation Loss: {val_loss:.4f}")

        preds = predict(test_loader, model, device)
        submit = pd.read_csv('/kaggle/input/smai-24-age-prediction/content/faces_dataset/submission.csv')
        submit['age'] = preds
        submit.to_csv(f'/kaggle/working/submission_epoch_{epoch + 1}.csv', index=False)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model
model = models.resnet34(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)

for name, param in model.named_parameters():
    if "layer1" in name:
        param.requires_grad = False

model = model.to(device)

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Train the model
train_model(model, train_loader, val_loader, test_loader, criterion, optimizer, device, num_epochs=40)

# Create submission CSV file
preds = predict(test_loader, model, device)
submit = pd.read_csv('/kaggle/input/smai-24-age-prediction/content/faces_dataset/submission.csv')
submit['age'] = preds
submit.to_csv('/kaggle/working/submission.csv', index=False)


Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 90.3MB/s]
100%|██████████| 267/267 [00:41<00:00,  6.47it/s]
100%|██████████| 67/67 [00:09<00:00,  7.14it/s]


Epoch 1, Training Loss: 539.2432, Validation Loss: 275.6130


100%|██████████| 31/31 [00:04<00:00,  7.02it/s]
100%|██████████| 267/267 [00:39<00:00,  6.74it/s]
100%|██████████| 67/67 [00:05<00:00, 11.39it/s]


Epoch 2, Training Loss: 113.7068, Validation Loss: 71.8692


100%|██████████| 31/31 [00:02<00:00, 10.89it/s]
100%|██████████| 267/267 [00:39<00:00,  6.75it/s]
100%|██████████| 67/67 [00:06<00:00, 10.57it/s]


Epoch 3, Training Loss: 37.1229, Validation Loss: 52.3290


100%|██████████| 31/31 [00:02<00:00, 11.16it/s]
100%|██████████| 267/267 [00:39<00:00,  6.76it/s]
100%|██████████| 67/67 [00:06<00:00, 10.68it/s]


Epoch 4, Training Loss: 25.4616, Validation Loss: 56.3968


100%|██████████| 31/31 [00:02<00:00, 10.66it/s]
100%|██████████| 267/267 [00:39<00:00,  6.76it/s]
100%|██████████| 67/67 [00:05<00:00, 11.20it/s]


Epoch 5, Training Loss: 17.3835, Validation Loss: 53.2633


100%|██████████| 31/31 [00:02<00:00, 10.55it/s]
100%|██████████| 267/267 [00:39<00:00,  6.76it/s]
100%|██████████| 67/67 [00:05<00:00, 11.30it/s]


Epoch 6, Training Loss: 12.9492, Validation Loss: 51.1041


100%|██████████| 31/31 [00:02<00:00, 10.41it/s]
100%|██████████| 267/267 [00:39<00:00,  6.75it/s]
100%|██████████| 67/67 [00:06<00:00, 10.82it/s]


Epoch 7, Training Loss: 10.7487, Validation Loss: 54.4057


100%|██████████| 31/31 [00:03<00:00, 10.14it/s]
100%|██████████| 267/267 [00:39<00:00,  6.76it/s]
100%|██████████| 67/67 [00:05<00:00, 11.55it/s]


Epoch 8, Training Loss: 9.3706, Validation Loss: 51.8366


100%|██████████| 31/31 [00:02<00:00, 11.14it/s]
100%|██████████| 267/267 [00:39<00:00,  6.76it/s]
100%|██████████| 67/67 [00:05<00:00, 11.37it/s]


Epoch 9, Training Loss: 8.0108, Validation Loss: 61.4762


100%|██████████| 31/31 [00:02<00:00, 11.26it/s]
100%|██████████| 267/267 [00:39<00:00,  6.75it/s]
100%|██████████| 67/67 [00:06<00:00, 10.47it/s]


Epoch 10, Training Loss: 7.4363, Validation Loss: 48.6431


100%|██████████| 31/31 [00:02<00:00, 10.61it/s]
100%|██████████| 267/267 [00:39<00:00,  6.73it/s]
100%|██████████| 67/67 [00:06<00:00, 10.90it/s]


Epoch 11, Training Loss: 6.2144, Validation Loss: 49.5937


100%|██████████| 31/31 [00:02<00:00, 10.75it/s]
100%|██████████| 267/267 [00:39<00:00,  6.75it/s]
100%|██████████| 67/67 [00:05<00:00, 11.46it/s]


Epoch 12, Training Loss: 5.6485, Validation Loss: 47.8275


100%|██████████| 31/31 [00:03<00:00,  9.87it/s]
100%|██████████| 267/267 [00:39<00:00,  6.76it/s]
100%|██████████| 67/67 [00:06<00:00, 10.81it/s]


Epoch 13, Training Loss: 4.8489, Validation Loss: 47.1076


100%|██████████| 31/31 [00:02<00:00, 10.51it/s]
100%|██████████| 267/267 [00:39<00:00,  6.76it/s]
100%|██████████| 67/67 [00:06<00:00, 11.14it/s]


Epoch 14, Training Loss: 4.4900, Validation Loss: 46.9475


100%|██████████| 31/31 [00:02<00:00, 11.03it/s]
100%|██████████| 267/267 [00:39<00:00,  6.76it/s]
100%|██████████| 67/67 [00:05<00:00, 11.53it/s]


Epoch 15, Training Loss: 4.5048, Validation Loss: 46.5547


100%|██████████| 31/31 [00:02<00:00, 11.16it/s]
100%|██████████| 267/267 [00:39<00:00,  6.74it/s]
100%|██████████| 67/67 [00:06<00:00,  9.96it/s]


Epoch 16, Training Loss: 4.3013, Validation Loss: 51.2486


100%|██████████| 31/31 [00:03<00:00, 10.05it/s]
100%|██████████| 267/267 [00:39<00:00,  6.76it/s]
100%|██████████| 67/67 [00:05<00:00, 11.53it/s]


Epoch 17, Training Loss: 4.0951, Validation Loss: 46.4079


100%|██████████| 31/31 [00:02<00:00, 10.70it/s]
100%|██████████| 267/267 [00:39<00:00,  6.74it/s]
100%|██████████| 67/67 [00:06<00:00, 10.52it/s]


Epoch 18, Training Loss: 3.8350, Validation Loss: 46.4486


100%|██████████| 31/31 [00:02<00:00, 10.89it/s]
100%|██████████| 267/267 [00:39<00:00,  6.75it/s]
100%|██████████| 67/67 [00:06<00:00, 10.58it/s]


Epoch 19, Training Loss: 3.6229, Validation Loss: 46.5452


100%|██████████| 31/31 [00:02<00:00, 10.59it/s]
100%|██████████| 267/267 [00:39<00:00,  6.75it/s]
100%|██████████| 67/67 [00:05<00:00, 11.52it/s]


Epoch 20, Training Loss: 3.3552, Validation Loss: 47.3383


100%|██████████| 31/31 [00:02<00:00, 10.95it/s]
100%|██████████| 267/267 [00:39<00:00,  6.76it/s]
100%|██████████| 67/67 [00:05<00:00, 11.25it/s]


Epoch 21, Training Loss: 3.2921, Validation Loss: 46.2328


100%|██████████| 31/31 [00:02<00:00, 10.51it/s]
100%|██████████| 267/267 [00:39<00:00,  6.73it/s]
100%|██████████| 67/67 [00:06<00:00, 10.67it/s]


Epoch 22, Training Loss: 3.2904, Validation Loss: 46.5107


100%|██████████| 31/31 [00:02<00:00, 10.34it/s]
100%|██████████| 267/267 [00:39<00:00,  6.75it/s]
100%|██████████| 67/67 [00:05<00:00, 11.46it/s]


Epoch 23, Training Loss: 3.1519, Validation Loss: 50.1269


100%|██████████| 31/31 [00:02<00:00, 10.96it/s]
100%|██████████| 267/267 [00:39<00:00,  6.76it/s]
100%|██████████| 67/67 [00:05<00:00, 11.57it/s]


Epoch 24, Training Loss: 3.4135, Validation Loss: 47.1188


100%|██████████| 31/31 [00:02<00:00, 10.90it/s]
100%|██████████| 267/267 [00:39<00:00,  6.75it/s]
100%|██████████| 67/67 [00:06<00:00, 10.44it/s]


Epoch 25, Training Loss: 3.2868, Validation Loss: 47.3061


100%|██████████| 31/31 [00:03<00:00,  9.74it/s]
100%|██████████| 267/267 [00:39<00:00,  6.74it/s]
100%|██████████| 67/67 [00:05<00:00, 11.40it/s]


Epoch 26, Training Loss: 3.8093, Validation Loss: 50.6216


100%|██████████| 31/31 [00:02<00:00, 11.10it/s]
100%|██████████| 267/267 [00:39<00:00,  6.75it/s]
100%|██████████| 67/67 [00:05<00:00, 11.43it/s]


Epoch 27, Training Loss: 3.3910, Validation Loss: 51.5660


100%|██████████| 31/31 [00:03<00:00,  9.41it/s]
100%|██████████| 267/267 [00:39<00:00,  6.73it/s]
100%|██████████| 67/67 [00:06<00:00, 10.75it/s]


Epoch 28, Training Loss: 2.9339, Validation Loss: 45.7873


100%|██████████| 31/31 [00:03<00:00, 10.31it/s]
100%|██████████| 267/267 [00:39<00:00,  6.76it/s]
100%|██████████| 67/67 [00:06<00:00, 10.81it/s]


Epoch 29, Training Loss: 2.4113, Validation Loss: 45.0852


100%|██████████| 31/31 [00:02<00:00, 11.03it/s]
100%|██████████| 267/267 [00:39<00:00,  6.75it/s]
100%|██████████| 67/67 [00:05<00:00, 11.25it/s]


Epoch 30, Training Loss: 2.2392, Validation Loss: 45.1585


100%|██████████| 31/31 [00:03<00:00,  9.96it/s]
100%|██████████| 267/267 [00:39<00:00,  6.74it/s]
100%|██████████| 67/67 [00:06<00:00, 10.02it/s]


Epoch 31, Training Loss: 2.1684, Validation Loss: 45.8086


100%|██████████| 31/31 [00:02<00:00, 10.55it/s]
100%|██████████| 267/267 [00:39<00:00,  6.75it/s]
100%|██████████| 67/67 [00:06<00:00, 10.49it/s]


Epoch 32, Training Loss: 2.1722, Validation Loss: 46.1755


100%|██████████| 31/31 [00:02<00:00, 10.42it/s]
100%|██████████| 267/267 [00:39<00:00,  6.76it/s]
100%|██████████| 67/67 [00:05<00:00, 11.42it/s]


Epoch 33, Training Loss: 2.3109, Validation Loss: 46.1239


100%|██████████| 31/31 [00:02<00:00, 10.92it/s]
100%|██████████| 267/267 [00:39<00:00,  6.74it/s]
100%|██████████| 67/67 [00:06<00:00, 10.26it/s]


Epoch 34, Training Loss: 2.1714, Validation Loss: 45.2755


100%|██████████| 31/31 [00:02<00:00, 10.37it/s]
100%|██████████| 267/267 [00:39<00:00,  6.74it/s]
100%|██████████| 67/67 [00:06<00:00, 11.10it/s]


Epoch 35, Training Loss: 2.0935, Validation Loss: 46.4906


100%|██████████| 31/31 [00:02<00:00, 10.77it/s]
100%|██████████| 267/267 [00:39<00:00,  6.75it/s]
100%|██████████| 67/67 [00:05<00:00, 11.21it/s]


Epoch 36, Training Loss: 2.1381, Validation Loss: 47.3651


100%|██████████| 31/31 [00:02<00:00, 10.58it/s]
100%|██████████| 267/267 [00:39<00:00,  6.73it/s]
100%|██████████| 67/67 [00:06<00:00, 10.61it/s]


Epoch 37, Training Loss: 2.0884, Validation Loss: 46.2333


100%|██████████| 31/31 [00:02<00:00, 10.39it/s]
100%|██████████| 267/267 [00:39<00:00,  6.75it/s]
100%|██████████| 67/67 [00:05<00:00, 11.46it/s]


Epoch 38, Training Loss: 2.4905, Validation Loss: 47.6073


100%|██████████| 31/31 [00:03<00:00, 10.31it/s]
100%|██████████| 267/267 [00:39<00:00,  6.74it/s]
100%|██████████| 67/67 [00:05<00:00, 11.44it/s]


Epoch 39, Training Loss: 3.0177, Validation Loss: 49.0025


100%|██████████| 31/31 [00:02<00:00, 11.09it/s]
100%|██████████| 267/267 [00:39<00:00,  6.72it/s]
100%|██████████| 67/67 [00:06<00:00, 10.18it/s]


Epoch 40, Training Loss: 4.9771, Validation Loss: 53.3894


100%|██████████| 31/31 [00:03<00:00, 10.10it/s]
100%|██████████| 31/31 [00:02<00:00, 11.17it/s]
