In [1]:
import numpy as np
import pandas as pd
from glob import glob
from os.path import join
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torchvision
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision import models
import torch.optim as optim
from torch.utils.data import DataLoader

class AgeDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, annot_path, train=True):
        super(AgeDataset, self).__init__()
        self.data_path = data_path
        self.train = train
        self.ann = pd.read_csv(annot_path)
        self.files = self.ann['file_id']
        self.ages = self.ann['age'] if train else None
        self.transform = self._transform(224)

    @staticmethod
    def _convert_image_to_rgb(image):
        return image.convert("RGB")

    def _transform(self, n_px):
        return Compose([
            Resize(n_px),
            self._convert_image_to_rgb,
            ToTensor(),
            Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

    def read_img(self, file_name):
        im_path = join(self.data_path, file_name)   
        img = Image.open(im_path)
        img = self.transform(img)
        return img

    def __getitem__(self, index):
        file_name = self.files[index]
        img = self.read_img(file_name)
        if self.train:
            age = self.ages[index]
            return img, age
        else:
            return img

    def __len__(self):
        return len(self.files)

train_path = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/train'
train_ann = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/train.csv'
train_dataset = AgeDataset(train_path, train_ann, train=True)

test_path = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/test'
test_ann = '/kaggle/input/smai-24-age-prediction/content/faces_dataset/submission.csv'
test_dataset = AgeDataset(test_path, test_ann, train=False)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Use ResNet-18 model
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)  # Modify the fully connected layer for regression output
model = model.to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_model(model, train_loader, criterion, optimizer, num_epochs=30):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, ages in tqdm(train_loader):
            images = images.to(device)
            ages = ages.to(device).float().unsqueeze(1)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, ages)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
        epoch_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")

train_model(model, train_loader, criterion, optimizer)

@torch.no_grad()
def predict(loader, model):
    model.eval()
    predictions = []
    for img in tqdm(loader):
        img = img.to(device)
        pred = model(img)
        predictions.extend(pred.flatten().detach().cpu().numpy())
    return predictions

preds = predict(test_loader, model)

submit = pd.read_csv(test_ann)
submit['age'] = preds
submit.to_csv('/kaggle/working/submission.csv', index=False)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 142MB/s]
100%|██████████| 334/334 [02:42<00:00,  2.05it/s]


Epoch 1/30, Loss: 155.6160


100%|██████████| 334/334 [01:40<00:00,  3.32it/s]


Epoch 2/30, Loss: 61.5524


100%|██████████| 334/334 [01:38<00:00,  3.39it/s]


Epoch 3/30, Loss: 51.5176


100%|██████████| 334/334 [01:38<00:00,  3.40it/s]


Epoch 4/30, Loss: 41.5515


100%|██████████| 334/334 [01:40<00:00,  3.34it/s]


Epoch 5/30, Loss: 34.6815


100%|██████████| 334/334 [01:40<00:00,  3.31it/s]


Epoch 6/30, Loss: 29.3736


100%|██████████| 334/334 [01:40<00:00,  3.31it/s]


Epoch 7/30, Loss: 24.5885


100%|██████████| 334/334 [01:41<00:00,  3.29it/s]


Epoch 8/30, Loss: 18.8987


100%|██████████| 334/334 [01:38<00:00,  3.38it/s]


Epoch 9/30, Loss: 15.7060


100%|██████████| 334/334 [01:38<00:00,  3.39it/s]


Epoch 10/30, Loss: 12.5119


100%|██████████| 334/334 [01:41<00:00,  3.29it/s]


Epoch 11/30, Loss: 10.3566


100%|██████████| 334/334 [01:40<00:00,  3.32it/s]


Epoch 12/30, Loss: 9.1963


100%|██████████| 334/334 [01:40<00:00,  3.34it/s]


Epoch 13/30, Loss: 8.5416


100%|██████████| 334/334 [01:41<00:00,  3.29it/s]


Epoch 14/30, Loss: 7.4713


100%|██████████| 334/334 [01:41<00:00,  3.30it/s]


Epoch 15/30, Loss: 6.9215


100%|██████████| 334/334 [01:42<00:00,  3.25it/s]


Epoch 16/30, Loss: 6.5279


100%|██████████| 334/334 [01:42<00:00,  3.27it/s]


Epoch 17/30, Loss: 6.2593


100%|██████████| 334/334 [01:38<00:00,  3.38it/s]


Epoch 18/30, Loss: 5.6176


100%|██████████| 334/334 [01:38<00:00,  3.38it/s]


Epoch 19/30, Loss: 5.4627


100%|██████████| 334/334 [01:41<00:00,  3.30it/s]


Epoch 20/30, Loss: 5.5008


100%|██████████| 334/334 [01:40<00:00,  3.32it/s]


Epoch 21/30, Loss: 5.0898


100%|██████████| 334/334 [01:39<00:00,  3.35it/s]


Epoch 22/30, Loss: 4.9166


100%|██████████| 334/334 [01:40<00:00,  3.34it/s]


Epoch 23/30, Loss: 6.0575


100%|██████████| 334/334 [01:42<00:00,  3.27it/s]


Epoch 24/30, Loss: 5.3520


100%|██████████| 334/334 [01:38<00:00,  3.38it/s]


Epoch 25/30, Loss: 4.2426


100%|██████████| 334/334 [01:38<00:00,  3.40it/s]


Epoch 26/30, Loss: 3.7723


100%|██████████| 334/334 [01:37<00:00,  3.42it/s]


Epoch 27/30, Loss: 3.4209


100%|██████████| 334/334 [01:38<00:00,  3.38it/s]


Epoch 28/30, Loss: 3.3434


100%|██████████| 334/334 [01:39<00:00,  3.35it/s]


Epoch 29/30, Loss: 4.0559


100%|██████████| 334/334 [01:38<00:00,  3.38it/s]


Epoch 30/30, Loss: 3.6187


100%|██████████| 31/31 [00:13<00:00,  2.30it/s]
