In [26]:
import os

from get_config import get_config_dict

from pathlib import Path

import tifffile as tiff

from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.saving import load_model

import numpy as np
import pandas as pd

config = get_config_dict()

### Load the data


In [27]:
# Load images and labels


def load_data(data_dir: Path):
    X = []
    ids = {}
    for subject_idx, subject_path in enumerate(data_dir.glob('*')):
        scan_ids = {}
        subject_scans = []
        subject_id = subject_path.stem
        
        for scan_idx, scan_path in enumerate(subject_path.glob('**/I*')):
            scan_imgs = []
            scan_id = scan_path.stem
            scan_ids[scan_id] = []
            for img_idx, img_path in enumerate(scan_path.glob('**/*.tiff')):
                img_name = img_path.stem
                
                img = tiff.imread(img_path)
                img_array = img_to_array(img)
                
                # shift the range from (-1, 1) to (0, 1)
                img_array = (img_array + 1) / 2
                
                # convert to 3-channel image
                img_array = np.repeat(img_array, 3, axis=-1)
                
                scan_imgs.append(img_array)
                scan_ids[scan_id].append(img_name)

            subject_scans.append(scan_imgs)
            
        ids[subject_id] = scan_ids
        X.append(subject_scans)
    return X, ids


In [28]:
preprocessed_data_path = config["preprocessed_data_path"]

# Define the directory containing the extracted dataset
data_dir = Path(preprocessed_data_path)
X, ids = load_data(data_dir)
for (idx, subject_imgs), subject_id in zip(enumerate(X), ids.keys()):
    subject_imgs = np.array(subject_imgs)
    print(f"Shape of {subject_id} images:", subject_imgs.shape)
    X[idx] = subject_imgs
    

Shape of 002_S_0413 images: (4, 30, 180, 180, 3)
Shape of 005_S_0221 images: (1, 30, 180, 180, 3)


### Inference

In [29]:
classes = ['MCI', 'AD', 'CN']
classes

['MCI', 'AD', 'CN']

In [30]:
# Load the model
model_path = config['model_path']
model = load_model(model_path)

In [31]:
# Inference loop
y_pred = []
for subject_imgs in X:
    subject_preds = []
    
    for scan_imgs in subject_imgs:
        # Make prediction
        scan_preds = model.predict(scan_imgs, verbose=0)
        
        # Reverse one-hot predictions
        # scan_preds = scan_preds.argmax(axis=-1)
        
        subject_preds.append(scan_preds)
    
    y_pred.append(np.array(subject_preds))
    
for subject_preds, subject_id in zip(y_pred, ids.keys()):
    print(f"Shape of {subject_id} preds:", subject_preds.shape)

Shape of 002_S_0413 preds: (4, 30, 3)
Shape of 005_S_0221 preds: (1, 30, 3)


In [32]:
class_mapper = {n:c for n, c in zip(range(3), classes)}

df = []
for subject_id, subject_preds in zip(ids.keys(), y_pred):
    for scan_id, scan_preds in zip(ids[subject_id].keys(), subject_preds):
        for slice_name, slice_pred in zip(ids[subject_id][scan_id], scan_preds):
            d = {
                "subject_id":subject_id,
                "scan_id":scan_id,
                "slice_name": slice_name,
                # "raw_pred": slice_pred
            } | {c:p for c, p in zip(classes, slice_pred)}
            df.append(d)
            
df = pd.DataFrame(df)\
    .set_index(['subject_id', 'scan_id'])

df

Unnamed: 0_level_0,Unnamed: 1_level_0,slice_name,MCI,AD,CN
subject_id,scan_id,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
002_S_0413,I120917,ADNI_002_S_0413_MR_MPR__GradWarp__B1_Correctio...,5.611733e-05,0.000025,9.999189e-01
002_S_0413,I120917,ADNI_002_S_0413_MR_MPR__GradWarp__B1_Correctio...,2.274170e-04,0.011541,9.882320e-01
002_S_0413,I120917,ADNI_002_S_0413_MR_MPR__GradWarp__B1_Correctio...,1.261345e-05,0.000004,9.999835e-01
002_S_0413,I120917,ADNI_002_S_0413_MR_MPR__GradWarp__B1_Correctio...,3.925137e-05,0.001247,9.987140e-01
002_S_0413,I120917,ADNI_002_S_0413_MR_MPR__GradWarp__B1_Correctio...,1.089373e-05,0.011412,9.885775e-01
...,...,...,...,...,...
005_S_0221,I102054,ADNI_005_S_0221_MR_MPR__GradWarp__B1_Correctio...,5.100378e-06,0.999975,1.951559e-05
005_S_0221,I102054,ADNI_005_S_0221_MR_MPR__GradWarp__B1_Correctio...,4.096058e-06,0.999996,5.617907e-08
005_S_0221,I102054,ADNI_005_S_0221_MR_MPR__GradWarp__B1_Correctio...,6.112429e-09,0.994901,5.099166e-03
005_S_0221,I102054,ADNI_005_S_0221_MR_MPR__GradWarp__B1_Correctio...,6.548838e-06,0.999965,2.852597e-05


### Saving the predictions

In [33]:
pred_path = config['pred_path']
if not pred_path.exists():
    pred_path.mkdir(parents=True, exist_ok=True)

In [34]:
# Save slice-level predictions to `slice_predictions.json`
slice_pred_path = pred_path.joinpath('slice_predictions.json')

df.set_index('slice_name', append=True)\
  .groupby('subject_id')\
  .apply(lambda x:
      x\
      # drop subject_id index level
      .droplevel('subject_id')\
      # group by scan_id index level
      .groupby('scan_id')\
          .apply(lambda y:
              y\
              # drop scan_id index level
              .droplevel('scan_id')\
              .to_dict(orient='index')
          )
      .to_dict()
  )\
  .to_json(slice_pred_path, orient='index', indent=4)

In [35]:
# Save scan-level predictions to `scan_predictions.json`
scan_pred_path = pred_path.joinpath('scan_predictions.json')

scan_level_preds = df[['MCI', 'AD', 'CN']].groupby(['subject_id', 'scan_id']).aggregate('mean')
scan_level_preds.groupby(level=0)\
    .apply(lambda x:
        x\
        .droplevel(0)\
        .to_dict(orient='index')
    )\
    .to_json(scan_pred_path, orient='index', indent=4)

In [36]:
# Save subject-level predicitons to `subject_predictions.json`
subject_pred_path = pred_path.joinpath('subject_predictions.json')

scan_level_preds.groupby('subject_id')\
    .mean()\
    .to_json(subject_pred_path, orient='index', indent=4)

In [37]:
# # Plot some sample images from each class
# num_images_to_plot = 4
# # random_scan = 
# plt.figure(figsize=(15, 1 * num_images_to_plot))

# for i in range(num_images_to_plot):
#     plt.subplot(num_images_to_plot, i)
#     cls_indices = np.where(y == j)[0]
#     plt.imshow(X[cls_indices[i]])
#     plt.axis('off')
#     if i == 0:
#         plt.title(classes[j])

# plt.tight_layout()
# plt.show()