In [None]:
import os
import shutil
import re
import pandas as pd
import numpy as np
from dotenv import load_dotenv
import nibabel as nib

load_dotenv()

# Data Organization

In [80]:
from IPython.display import clear_output
def printUpdate(count, total):
  clear_output(wait=True)
  percent = 100.*count/total
  print(f"{count}/{total} ({percent:.2f}%)", end="")
  print(f'[{"██"*int(percent//10)}{"--"*int((100-percent)//10)}]')



for i in range(10):
  printUpdate(i, 10)

9/10 (90.00%)[██████████████████--]


In [None]:
def moveRename(source, dest, num_files = 291):
    """
    Moves all .hdr and .img images from source location to a single directory
    """
    rawDirPat = r"(?:[\W\S]+?)OAS2_([0-9]{4})_MR([0-9]{1})/RAW"
    i = 0
    for root, dir, files in os.walk(source):
        r_match = re.findall(rawDirPat, root)
        if len(r_match) > 0:
            subID = r_match[0][0]
            session = r_match[0][1]
            new_name = f"{subID}_{session}"
            for f in files:
                fname, fext = os.path.splitext(f)
                if fext == ".img" or fext == ".hdr":
                    f_match = re.findall(r"mpr-([0-9]{1}).nifti", fname)
                    if len(f_match) > 0:
                        f_num = f_match[0]
                        old_name = os.path.join(root, f)
                        new_name = os.path.join(
                            dest, f"OAS2_{subID}_MR{f_num}_V{session}.nifti{fext}"
                        )

                        printUpdate(i, num_files)
                        i += 1
                        shutil.copy2(old_name, new_name)
    print()

In [82]:
def convertToNii(source, num_files = 291):
    """
    Convert images from .img to .nii format and get rid of .img and .hdr files.
    """

    i = 0
    for root, dir, files in os.walk(source):
        for f in files:
            fbase, fext = os.path.splitext(f)
            if fext == ".img":
                # print(f"Converting {f}")
                printUpdate(i, num_files)
                i += 1
                fname = os.path.join(root, f)
                img = nib.load(fname)
                nib.save(img, fname.replace(".img", ".nii"))
                os.remove(os.path.join(root, fbase + ".hdr"))
                os.remove(os.path.join(root, fbase + ".img"))
    print()

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy import misc

def toRGB(data):
    x, y = data.shape[:2]
    data = (data-data.min())/(data.max()-data.min())

    img_arr = np.empty(shape=(x,y,4))
    img_arr[:, :, :3] = data
    img_arr[:, :, 3] = 1.
    return img_arr

def convertToJPG(source, num_files = 291, transverse_slice = 116):
    """
    Convert images from .nii to .jpg format

    i = 0
    shape = (0,0,0,0)
    for root, dir, files in os.walk(source):
        for f in files:
            fbase, fext = os.path.splitext(f)
            if fext == ".nii":
                # print(f"Converting {f}")
                printUpdate(i, num_files)
                i += 1
                fname = os.path.join(root, f)
                img = nib.load(fname)
                data = img.get_fdata()[transverse_slice,:,:]

                img_arr = toRGB(data)
                shape = img_arr.shape
                plt.imsave(fname.replace(".nifti.nii", ".jpg"), img_arr)
                # os.remove(os.path.join(root, fbase + ".nii"))
    printUpdate(num_files, num_files)

    return shape


In [84]:
def removeVisits3to5(source):
    """
    Removes all images taken at visits 3-5
    """
    for root, dir, files in os.walk(source):
        for f in files:
            m = re.match(r"OAS2_[0-9]{4}_MR[0-9]{1}_V([0-9]{1})", f)
            session_num = int(m.groups()[0])
            if session_num > 2:
                print(f"Removing {f}")
                os.remove(os.path.join(root, f))

def removeSubjects(subjectIDs:list, source):
    for root, dir, files in os.walk(source):
        for f in files:
            for id in subjectIDs:
                m = re.match(re.compile("("+id+")"), f)
                if m!= None:
                    subjectIDs.remove(id)
                    os.remove(os.path.join(root, f))


In [85]:
# Only run once, to move data
# Set True to move files
move_files = False
convert_files = False

In [86]:
if move_files:
  moveRename("datasets/OAS2", "datasets/OAS2_nii", 2733)

In [87]:
if move_files:
  removeVisits3to5("datasets/OAS2_nii")


In [88]:
if move_files:
  convertToNii("datasets/OAS2_nii", 1366)

In [89]:
# Remove subject MRIs who had an age > 95
if move_files:
  removeSubjects(["OAS2_0051_MR3", "OAS2_0087_MR1", "OAS2_0087_MR2"], os.getenv("OAS2NII"))

# Load Data

In [90]:
def moveByClass(df, num_files = 291, removeOriginal=True):
  train_path = os.path.join(os.getenv("OAS2NII"), "train")
  test_path = os.path.join(os.getenv("OAS2NII"), "test")
  validate_path = os.path.join(os.getenv("OAS2NII"), "validate")

  data_paths = [train_path, test_path, validate_path]
  for path in data_paths:
    if not os.path.exists(path):
      os.makedirs(path)
      if not os.path.exists(os.path.join(path, "class_1")):
        os.makedirs(os.path.join(path, "class_1"))
      if not os.path.exists(os.path.join(path, "class_0")):
        os.makedirs(os.path.join(path, "class_0"))

  new_paths = []
  i = 0
  for index, row in df.iterrows():
    old_name = row["file"]
    fname = os.path.split(old_name)[1]
    if row["Group"] == 0:
      new_name = os.path.join(os.getenv("OAS2NII"), row["Split"], "class_0", fname)
    else:
      new_name = os.path.join(os.getenv("OAS2NII"), row["Split"], "class_1", fname)
    new_paths += [new_name]
    printUpdate(i, num_files)
    i += 1
    if os.path.exists(old_name):
      os.rename(old_name, new_name)
    print(f'Old: {old_name}\nNew:{new_name}\n')

  files = [f for f in os.listdir(os.getenv("OAS2NII")) if os.path.isfile(os.path.join(os.getenv("OAS2NII"), f))]
  if removeOriginal:
    for f in files:
      if os.path.exists(os.path.join(os.getenv("OAS2NII"), f)):
        os.remove(os.path.join(os.getenv("OAS2NII"),f))

  df["file"] = new_paths
  return df.copy(deep=True)

def moveBackToSource(df, deleteInstead=False):
  for root, dir, files in os.walk(os.getenv("OAS2NII")):
    if root!=os.getenv("OAS2NII"):
      for f in files:
        if deleteInstead:
          os.remove(os.path.join(root, f))
        else:
          os.rename(os.path.join(root, f), os.path.join(os.getenv("OAS2NII"), f))

  for index, row in df.iterrows():
    old_name = row["file"]
    fname = os.path.split(old_name)[-1]
    new_name = os.path.join(os.getenv("OAS2NII"), fname)
    df.loc[index, "file"] = new_name

  return df.copy(deep=True)


In [91]:
load_dotenv()
df = pd.read_excel("OAS2-normalized.xlsx")

In [92]:
fnames = []

df.set_index("MRI ID", inplace=True)

for index, row in df.iterrows():
  fname = index + "_V" + str(row["Visit"]+1) +  ".jpg"
  fnames += [os.path.join(os.getenv("OAS2NII"),fname)]

df["file"] = fnames


In [93]:
if convert_files:
  data_shape = convertToJPG("datasets/OAS2_nii", 1366, transverse_slice=134)
else:
  data_shape = (256, 128, 4)

In [None]:
from sklearn.model_selection import train_test_split

def makeTVTSplit(df):
  x_train, x_temp, y_train, y_temp = train_test_split(df.drop(columns=["Group"]), df["Group"], test_size=0.2, stratify=df[['Group',"Sex_F"]])

  strat = pd.DataFrame(x_temp)
  strat["Group"] = y_temp

  x_test, x_val, y_test, y_val = train_test_split(x_temp, y_temp, stratify=strat[['Group',"Sex_F"]],test_size=0.2)

  # df_new = moveBackToSource(df, deleteInstead=True)
  train = x_train.copy(deep=True)
  train["Split"] = ["train"]*train.shape[0]
  train["Group"] = y_train.values

  validate = x_val.copy(deep=True)
  validate["Split"] = ["validate"]*validate.shape[0]
  validate["Group"] = y_val.values

  test = x_test.copy(deep=True)
  test["Split"] = ["test"]*test.shape[0]
  test["Group"] = y_test.values

  print(f'\tTest: {len(test)}')
  print(f'\tTrain: {len(train)}')
  print(f'\tValidate: {len(validate)}')

  df_new = pd.merge(train, test, how="outer")
  df_new = pd.merge(df_new, validate, how="outer")

  # df_new = moveByClass(df_new, removeOriginal=True)

  return df_new


In [95]:
df = makeTVTSplit(df)

290/291 (99.66%)[██████████████████]
Old: datasets/OAS2_nii/OAS2_0186_MR2_V2.jpg
New:datasets/OAS2_nii/train/class_0/OAS2_0186_MR2_V2.jpg



# Keras

In [96]:
# Needed to add cuda to path for GPU utilization
# !source ~/.profile

In [97]:
import keras
from keras import layers
from keras import ops
import SimpleITK as sitk
from keras import Sequential
import tensorflow as tf


In [98]:
resnet_base = keras.applications.ResNet50(
  include_top=False,
  weights='imagenet',
  input_shape=(data_shape[0], data_shape[1], 3)
)

resnet_base.trainable = False

model = Sequential([
  resnet_base,
  layers.Flatten(),
  layers.Dense(384, activation='relu'),
  layers.Dense(2, activation='softmax')
])

model.compile(
  optimizer=keras.optimizers.Adam(),
  loss='sparse_categorical_crossentropy',
  metrics=['accuracy']
)

In [99]:
def makeKerasTVTDatasets(num_epochs, df, resplit=True):
  if resplit:
    df_new = makeTVTSplit(df)
  else:
    df_new = df.copy(deep=True)

  dset_train = keras.preprocessing.image_dataset_from_directory(
    directory=os.path.join(os.getenv("OAS2NII"),"train"),
    # validation_split=0.2,
    # subset="training",
    seed=73,
    image_size=data_shape[:2],
    batch_size=32,
    label_mode='binary')

  dset_validate = keras.preprocessing.image_dataset_from_directory(
    directory=os.path.join(os.getenv("OAS2NII"),"validate"),
    # validation_split=0.5,
    # subset="validation",
    seed=73,
    image_size=data_shape[:2],
    batch_size=32,
    label_mode='binary')

  return dset_train, dset_validate, df_new


In [100]:
dset_train, dset_validate, df = makeKerasTVTDatasets(50, df, resplit=True)

290/291 (99.66%)[██████████████████]
Old: datasets/OAS2_nii/train/class_0/OAS2_0186_MR2_V2.jpg
New:datasets/OAS2_nii/train/class_0/OAS2_0186_MR2_V2.jpg

Found 204 files belonging to 2 classes.
Found 20 files belonging to 2 classes.


In [101]:
history = model.fit(dset_train, validation_data=dset_validate, epochs=50)

# Save the model
model.save("oas2-model-t134.keras")

Epoch 1/50
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 518ms/step - accuracy: 0.5183 - loss: 69.7995





[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 1s/step - accuracy: 0.5179 - loss: 69.5989 - val_accuracy: 0.6000 - val_loss: 24.9426
Epoch 2/50
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.5230 - loss: 17.4938 - val_accuracy: 0.4000 - val_loss: 12.9499
Epoch 3/50
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.5364 - loss: 5.9287 - val_accuracy: 0.6000 - val_loss: 5.9390
Epoch 4/50
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.6747 - loss: 2.6385 - val_accuracy: 0.5500 - val_loss: 3.8717
Epoch 5/50
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.8172 - loss: 1.0606 - val_accuracy: 0.6000 - val_loss: 2.8120
Epoch 6/50
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.8428 - loss: 0.7944 - val_accuracy: 0.6000 - val_loss: 2.2469
Epoch 7/50
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[

## Predictions

In [None]:
def getModelPredictions(model, df, columnName="CM_Val"):
  prediction = []
  df_test = df[df["Split"]=="test"].copy(deep=True)

  for index, row in df_test.iterrows():
    img = keras.preprocessing.image.load_img(row["file"], target_size=data_shape[:2])
    img_arr = keras.preprocessing.image.img_to_array(img)
    img_arr = tf.expand_dims(img_arr, 0)

    prediction += [np.argmax(model.predict(img_arr))]

  df_test.loc[:,"Prediction"] = prediction
  cm_val = []
  for index, row in df_test.iterrows():
    g = int(row["Group"])
    p = int(row["Prediction"])
    if g == p:
      if g == 0:
        cm_val += ["TN"]
      else:
        cm_val += ["TP"]
    else:
      if g == 0:
        cm_val += ["FP"]
      else:
        cm_val += ["FN"]

  df_test[columnName] = cm_val
  return df_test


In [103]:
def getCMV(prediction):
  cmv = prediction.value_counts()
  for m in ["TP", "TN", "FP", "FN"]:
    if m not in cmv.index:
      cmv[m] = 0
  return cmv

In [104]:
def getPrecision(cm_val:pd.Series):
  vals = cm_val.value_counts()
  for m in ["TP", "TN", "FP", "FN"]:
    if m not in vals.index:
      vals[m] = 0
  return float(vals["TP"])/(vals["TP"]+vals["FP"])

def getAccuracy(cm_val:pd.Series):
  vals = cm_val.value_counts()
  for m in ["TP", "TN", "FP", "FN"]:
    if m not in vals.index:
      vals[m] = 0
  return float(vals["TP"] + vals["TN"])/(vals["TP"]+vals["FP"]+vals["TN"]+vals["FN"])


def getRecall(cm_val:pd.Series):
  vals = cm_val.value_counts()
  for m in ["TP", "TN", "FP", "FN"]:
    if m not in vals.index:
      vals[m] = 0
  return float(vals["TP"])/(vals["TP"]+vals["TN"])


In [None]:
def printModelMetrics(df_test, modelName=""):
  width = 64
  liner = "-"*width
  print(liner)
  print(f'{"Precision:":<20}{getPrecision(df_test["CM_Val"]):.2f}')
  print(f'{"Precision, Male:":<20}{getPrecision(df_test[df_test["Sex_F"]==0]["CM_Val"]):.2f}')
  print(f'{"Precision, Female:":<20}{getPrecision(df_test[df_test["Sex_F"]==1]["CM_Val"]):.2f}')
  print()
  print(f'{"Accuracy:":<20}{getAccuracy(df_test["CM_Val"]):.2f}')
  print(f'{"Accuracy, Male:":<20}{getAccuracy(df_test[df_test["Sex_F"]==0]["CM_Val"]):.2f}')
  print(f'{"Accuracy, Female:":<20}{getAccuracy(df_test[df_test["Sex_F"]==1]["CM_Val"]):.2f}')
  print()
  print(f'{"Recall:":<20}{getRecall(df_test["CM_Val"]):.2f}')
  print(f'{"Recall, Male:":<20}{getRecall(df_test[df_test["Sex_F"]==0]["CM_Val"]):.2f}')
  print(f'{"Recall, Female:":<20}{getRecall(df_test[df_test["Sex_F"]==1]["CM_Val"]):.2f}')

  print(liner)


In [107]:
df_p = getModelPredictions(model, df)

FileNotFoundError: [Errno 2] No such file or directory: 'datasets/OAS2_nii/test/class_1/OAS2_0002_MR2_V2.jpg'

In [None]:
printModelMetrics(df_p)

----------------------------------------------------------------
Precision:          0.79
Precision, Male:    0.80
Precision, Female:  0.75

Accuracy:           0.74
Accuracy, Male:     0.74
Accuracy, Female:   0.75

Recall:             0.31
Recall, Male:       0.57
Recall, Female:     0.14
----------------------------------------------------------------
