# Import Library

In [1]:
from transformers import ViTModel, ViTFeatureExtractor
from transformers import ViTFeatureExtractor

import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
import torch
from tqdm import tqdm
from sklearn.metrics import mean_absolute_error, r2_score
import torch.nn as nn
from sklearn.preprocessing import StandardScaler
import pandas as pd
import numpy as np
from scipy.stats import pearsonr
import joblib

In [2]:
if torch.cuda.is_available():
    device = "cuda"

elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print(f"Using {device} as device.")

Using mps as device.


# Function Library

In [3]:
def train_model(model, train_loader, val_loader, y_scaler, epochs=5):
    model.to(device)

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0

        for imgs, _, idxs in tqdm(train_loader):
            imgs = imgs.to(device)
            labels = torch.tensor([bmis_scaled[idx] for idx in idxs], dtype=torch.float32).unsqueeze(1).to(device)

            preds = model(imgs)
            loss = criterion(preds, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        print(f"\nEpoch {epoch+1}: Train Loss = {train_loss/len(train_loader):.4f}")

        # Validation
        model.eval()
        val_preds, val_labels = [], []

        with torch.no_grad():
            for imgs, _, idxs in val_loader:
                imgs = imgs.to(device)
                labels = torch.tensor(bmis_scaled[idxs], dtype=torch.float32).unsqueeze(1).to(device)

                preds = model(imgs)
                val_preds.extend(preds.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())

        val_preds_real = y_scaler.inverse_transform(np.array(val_preds).reshape(-1, 1)).ravel()
        val_labels_real = y_scaler.inverse_transform(np.array(val_labels).reshape(-1, 1)).ravel()

        mae = mean_absolute_error(val_labels_real, val_preds_real)
        r2 = r2_score(val_labels_real, val_preds_real)
        r, _ = pearsonr(val_labels_real, val_preds_real)

        print(f"Val MAE: {mae:.2f} | Val R²: {r2:.3f} | Pearson r: {r:.3f}")

        # Step the scheduler
        scheduler.step(mae)

In [4]:
class ViTBMIRegressor(nn.Module):
    def __init__(self):
        super().__init__()
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        self.regressor = nn.Sequential(
            nn.Linear(self.vit.config.hidden_size, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        outputs = self.vit(pixel_values=x)
        cls_token = outputs.last_hidden_state[:, 0]
        return self.regressor(cls_token)

In [5]:
class BMIDataset(Dataset):
    def __init__(self, csv_path, image_dir, transform=None):
        self.df = pd.read_csv(csv_path)
        self.image_dir = image_dir
        self.transform = transform
        
        # Remove any rows without image files
        self.df['full_path'] = self.df['name'].apply(lambda x: os.path.join(image_dir, x))
        self.df = self.df[self.df['full_path'].apply(os.path.exists)].reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = Image.open(row['full_path']).convert('RGB')

        if self.transform:
            img = self.transform(img)

        label = row['bmi']
        return img, label, idx 

# Data Loading

In [6]:
# Basic transform
img_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)  # Normalized to [-1, 1]
])

# Dataset
dataset = BMIDataset(
    csv_path='../Data/data.csv',
    image_dir='../Data/Images',
    transform=img_transform
)

# Split dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [9]:
# Load BMI column and scale
df = pd.read_csv('../Data/data.csv')
bmis = df['bmi'].values
y_scaler = StandardScaler()
bmis_scaled = y_scaler.fit_transform(bmis.reshape(-1, 1)).ravel()

joblib.dump(y_scaler, "y_scaler.pkl")

['y_scaler.pkl']

In [10]:
extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

vit_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=extractor.image_mean, std=extractor.image_std)
])



In [11]:
model = ViTBMIRegressor().to(device)

In [12]:
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)

# Optional learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)

In [13]:
train_model(model, train_loader, val_loader, y_scaler, epochs=10)

100%|██████████| 100/100 [02:28<00:00,  1.48s/it]



Epoch 1: Train Loss = 0.9870
Val MAE: 6.22 | Val R²: -0.003 | Pearson r: -0.015


100%|██████████| 100/100 [01:58<00:00,  1.19s/it]



Epoch 2: Train Loss = 0.9602
Val MAE: 6.22 | Val R²: -0.006 | Pearson r: -0.025


100%|██████████| 100/100 [01:53<00:00,  1.14s/it]



Epoch 3: Train Loss = 0.9357
Val MAE: 6.22 | Val R²: -0.011 | Pearson r: -0.024


100%|██████████| 100/100 [01:31<00:00,  1.10it/s]



Epoch 4: Train Loss = 0.8969
Val MAE: 6.26 | Val R²: -0.024 | Pearson r: -0.026


100%|██████████| 100/100 [01:36<00:00,  1.03it/s]



Epoch 5: Train Loss = 0.8195
Val MAE: 6.26 | Val R²: -0.050 | Pearson r: -0.019


100%|██████████| 100/100 [01:39<00:00,  1.01it/s]



Epoch 6: Train Loss = 0.6986
Val MAE: 6.43 | Val R²: -0.104 | Pearson r: -0.013


100%|██████████| 100/100 [02:05<00:00,  1.26s/it]



Epoch 7: Train Loss = 0.5187
Val MAE: 6.62 | Val R²: -0.159 | Pearson r: -0.000


100%|██████████| 100/100 [02:31<00:00,  1.52s/it]



Epoch 8: Train Loss = 0.4100
Val MAE: 6.77 | Val R²: -0.222 | Pearson r: 0.006


100%|██████████| 100/100 [02:24<00:00,  1.44s/it]



Epoch 9: Train Loss = 0.3227
Val MAE: 7.02 | Val R²: -0.291 | Pearson r: 0.002


100%|██████████| 100/100 [02:29<00:00,  1.50s/it]



Epoch 10: Train Loss = 0.2598
Val MAE: 6.97 | Val R²: -0.301 | Pearson r: 0.005


In [14]:
torch.save(model.state_dict(), "vit_model_final.pt")