In [57]:
import os
import pandas as pd

from pyment.models import RegressionSFCN


IMAGE_FOLDER = os.path.join(os.path.expanduser('~'), 'data', 'IXI', 'freesurfer+fsl')
LABELS_FILE = os.path.join(os.path.expanduser('~'), 'data', 'IXI', 'IXI.xls')
CSV_FILE = os.path.join(os.path.expanduser('~'), 'data', 'IXI', 'IXI.csv')

# Creates a simplified CSV with only image path and age to ease the interaction with tensorflow below
if not os.path.isfile(CSV_FILE):
    subjects = os.listdir(IMAGE_FOLDER)
    ids = {int(subject[3:6]): subject for subject in subjects}
    labels = pd.read_excel(LABELS_FILE)
    labels = labels[~pd.isna(labels['AGE'])]
    labels['age'] = labels['AGE']
    labels['id'] = labels['IXI_ID'].apply(lambda x: ids[x] if x in ids else None)
    labels = labels[~pd.isna(labels['id'])]
    labels['path'] = labels['id'].apply(lambda x: os.path.join(IMAGE_FOLDER, x, 'mri', 'cropped.nii.gz'))
    labels = labels[labels['path'].apply(lambda x: os.path.isfile(x))]
    labels = labels[['path', 'age']]
    labels.to_csv(CSV_FILE, index=False)
    
MODEL = RegressionSFCN
WEIGHTS = 'brain-age-2022'
# Min and max age should match that of the model (e.g. 3-95 for brain-age-2022),
# not that of the dataset
MIN_AGE = 3
MAX_AGE = 95

BATCH_SIZE = 4
NUM_THREADS = 8

model = MODEL(weights=WEIGHTS, prediction_range=(MIN_AGE, MAX_AGE))

In [58]:
import nibabel as nib
import numpy as np
import tensorflow as tf

from typing import Dict, Tuple


dataset = tf.data.experimental.make_csv_dataset(CSV_FILE, batch_size=BATCH_SIZE, shuffle=True, select_columns=['path', 'age'])

n_rows = len(pd.read_csv(CSV_FILE))
train_len = int(n_rows * 0.8)
train = dataset.take(train_len)
validation = dataset.skip(train_len)

def load_niftis(paths: str):
    images = np.asarray([nib.load(path.numpy().decode()).get_fdata() for path in paths])
    images = np.expand_dims(images, axis=-1)

    return images

def load_row(row: Dict[str, tf.Tensor]) -> Tuple[tf.Tensor]:
    image = tf.py_function(load_niftis, [row['path']], [tf.float32])[0]

    return image, row['age']

def configure_nifti_dataset(dataset: tf.Tensor, shuffle: bool = False) -> tf.Tensor:
    dataset = dataset.map(load_row, num_parallel_calls=NUM_THREADS)
    dataset = dataset.shuffle(buffer_size=4 * BATCH_SIZE, reshuffle_each_iteration=True) \
              if shuffle else dataset
    dataset = dataset.prefetch(BATCH_SIZE)

    return dataset

train = configure_nifti_dataset(train, shuffle=True)
validation = configure_nifti_dataset(validation)

In [59]:
model.compile(loss='mean_squared_error', optimizer='adam', metrics=['mae'])
model.fit(train, validation_data=validation)

  1/420 [..............................] - ETA: 4:23:58 - loss: 13.6904 - mae: 2.8566

KeyboardInterrupt: 