## Convert test data to tfrecords

In [None]:
import random
import nobrainer
import os, sys
sys.path.append("..")
import numpy as np
import nibabel as nb
from glob import glob
from pathlib import Path
from shutil import *
import subprocess
from operator import itemgetter
import pandas as pd

test_root_dir = "/tf/shank/HDDLinux/Stanford/data/mriqc-shared/test_ixi"
csv_path = os.path.join(test_root_dir, "csv")
tf_records_dir = os.path.join(test_root_dir, "tfrecords")

os.makedirs(tf_records_dir, exist_ok=True)

test_csv_path = os.path.join(csv_path, "testing.csv")
test_paths = pd.read_csv(test_csv_path)["X"].values
test_labels = pd.read_csv(test_csv_path)["Y"].values
test_D = list(zip(test_paths, test_labels))
test_write_path = os.path.join(tf_records_dir, 'data-test_shard-{shard:03d}.tfrec')

nobrainer.tfrecord.write(
    features_labels=test_D,
    filename_template=test_write_path,
    examples_per_shard=3)

In [None]:
test_root_dir = '/tf/shank/HDDLinux/Stanford/data/mriqc-shared/test_ixi'
model_save_path = os.path.join(ROOTDIR_B, "model_save_dir_full")
tfrecords_path = os.path.join(test_root_dir, "tfrecords")
plane = "axial"
dataset_plane = get_dataset(
        file_pattern=os.path.join(tfrecords_path, "data-test_*"),
        n_classes=2,
        batch_size=16,
        volume_shape=(128, 128, 128),
        plane=plane,
        mode='test'
    )

print(dataset_plane)

## Inference

In [None]:
import sys, os
sys.path.append('..')
from models.modelN import CombinedClassifier
from dataloaders.dataset import get_dataset


# Tf packages
import tensorflow as tf

def inference(tfrecords_path, weights_path):
    
    model = CombinedClassifier(
        input_shape=(128, 128), dropout=0.4, wts_root=None, trainable=True)
    
    model.load_weights(os.path.abspath(weights_path))
    model.trainable = False
    
    dataset_test = get_dataset(
        file_pattern=os.path.join(tfrecords_path, "data-test_*"),
        n_classes=2,
        batch_size=16,
        volume_shape=(128, 128, 128),
        plane='combined',
        mode='test'
    )

    METRICS = [
        metrics.BinaryAccuracy(name="accuracy"),
        metrics.Precision(name="precision"),
        metrics.Recall(name="recall"),
        metrics.AUC(name="auc"),
    ]
    
    model.compile(
        loss=tf.keras.losses.binary_crossentropy,
        optimizer=Adam(learning_rate=1e-3),
        metrics=METRICS,
    )
    
    results = model.evaluate(dataset_test, batch_size=16)
    predictions = (model.predict(dataset_test) > 0.5).astype(int)
    
    
ROOTDIR_B = '/tf/shank/HDDLinux/Stanford/data/mriqc-shared/experiments/experiment_B/128'
ROOTDIR_A = '/tf/shank/HDDLinux/Stanford/data/mriqc-shared/experiments/experiment_A/128'
test_root_dir = '/tf/shank/HDDLinux/Stanford/data/mriqc-shared/test_ixi'

model_save_path = os.path.join(ROOTDIR_B, "model_save_dir_full")
tfrecords_path = os.path.join(test_root_dir, "tfrecords")
print("TFRECORDS: ", tfrecords_path)
weights_path = os.path.join(model_save_path, 'weights/combined/best-wts.h5')
    
model = CombinedClassifier(
    input_shape=(128, 128), dropout=0.4, wts_root=None, trainable=True
)
model.load_weights(os.path.abspath(weights_path))

print(os.path.join(tfrecords_path, "data-test_*"))

dataset_test = get_dataset(
    file_pattern=os.path.join(tfrecords_path, "data-test_*"),
    n_classes=2,
#     n_slices = 24,
    batch_size=16,
    volume_shape=(128, 128, 128),
    plane='combined',
    mode='test'
)

print(dataset_test)

METRICS = [
            metrics.BinaryAccuracy(name="accuracy"),
            metrics.Precision(name="precision"),
            metrics.Recall(name="recall"),
            metrics.AUC(name="auc"),
        ]

model.compile(
    loss=tf.keras.losses.binary_crossentropy,
    optimizer=Adam(learning_rate=1e-3),
    metrics=METRICS,
)

    
results = model.evaluate(dataset_test, batch_size=16)
predictions = (model.predict(dataset_test) > 0.5).astype(int)

In [None]:
planes = ['coronal'] #, 'coronal', 'sagittal']

for plane in planes:
    
    model = modelN.Submodel(
        input_shape=(128, 128),
        dropout=0.2,
        name=plane,
        include_top=True,
        weights=None,
        trainable=False,
    )
    
    print(os.path.join(model_save_path, plane, 'best-wts.h5'))
    
    model.load_weights(os.path.join(model_save_path, 'weights', plane, 'best-wts.h5'))
    
    dataset_plane = get_dataset(
        file_pattern=os.path.join(tfrecords_path, "data-test_*"),
        n_classes=2,
        batch_size=16,
        volume_shape=(128, 128, 128),
        plane=plane,
        mode='test',)
    
    METRICS = [
        metrics.BinaryAccuracy(name="accuracy"),
        metrics.Precision(name="precision"),
        metrics.Recall(name="recall"),
        metrics.AUC(name="auc"),
    ]
    
    model.summary()
    
    model.compile(
        loss=tf.keras.losses.binary_crossentropy,
        optimizer=Adam(learning_rate=1e-3),
        metrics=METRICS,
    )
    
#     results = model.evaluate(dataset_plane, batch_size=16)
    predictions = (model.predict(dataset_plane) > 0.5).astype(int)

In [None]:
print(len(predictions.flatten()))

In [2]:
import csv

path = '/home/shank/Stanford/nondefaced-detector/examples/sample_vols/example.csv'

              
if path.endswith('csv'):
    filepaths = []
    skip_header =True
    with open(path, newline="") as csvfile:
        reader = csv.reader(csvfile, delimiter=",")
        if skip_header:
            next(reader)
            
        for row in reader:
            filepaths.append(row[0])

from nondefaced_detector.preprocess import preprocess, cleanup_files
from nondefaced_detector.preprocess import preprocess_parallel

num_parallel_calls = None
if num_parallel_calls is None:
    # Get number of processes allocated to the current process.
    # Note the difference from `os.cpu_count()`.
    num_parallel_calls = len(os.sched_getaffinity(0))

outputs = preprocess_parallel(
        filepaths,
        num_parallel_calls=num_parallel_calls,
        with_label=False,
)

print(outputs)

# cleanup_files(outputs)

Preprocessing 6 examples


100%|██████████| 6/6 [00:05<00:00,  1.11it/s]

['/home/shank/Stanford/nondefaced-detector/examples/sample_vols/faced/preprocessed/example1.nii.gz', '/home/shank/Stanford/nondefaced-detector/examples/sample_vols/faced/preprocessed/example2.nii.gz', '/home/shank/Stanford/nondefaced-detector/examples/sample_vols/faced/preprocessed/example3.nii.gz', '/home/shank/Stanford/nondefaced-detector/examples/sample_vols/defaced/preprocessed/example1.nii.gz', '/home/shank/Stanford/nondefaced-detector/examples/sample_vols/defaced/preprocessed/example2.nii.gz', '/home/shank/Stanford/nondefaced-detector/examples/sample_vols/defaced/preprocessed/example3.nii.gz']





In [9]:
"""Methods to predict using trained models"""

import functools
import os

import numpy as np
import tensorflow as tf
import multiprocessing as mp

from pathlib import Path
from tqdm import tqdm

from nondefaced_detector.helpers       import utils
from nondefaced_detector.models.modelN import CombinedClassifier


def _predict(volume, model):
    """Return predictions from `inputs`.

    This is a general prediction method.

    Parameters
    ---------

    Returns
    ------
    """
    
    if not isinstance(volume, (np.ndarray)):
        raise ValueError("volume is not a numpy ndarray")
        
    ds = _structural_slice(volume, plane="combined", n_slices=n_slices)
    ds = tf.data.Dataset.from_tensor_slices(ds)
    ds = ds.batch(batch_size=1, drop_remainder=False)

    predicted = model.predict(ds)

    return predicted


def predict(volumes, model_path, n_slices=32):
    
    if not isinstance(volumes, list):
        raise ValueError('Volumes need to be a list of paths to preprocessed MRI volumes.')
    
    outputs = []
    model = _get_model(model_path)
    
    for path in tqdm(volumes, total=len(volumes)):
        vol,_,_ = utils.load_vol(path)
        predicted = _predict(vol, model)
        
        outputs.append((path, predicted[0][0]))
        
    return outputs
    
    
def _structural_slice(x, plane, n_slices=16):

    """Transpose dataset based on the plane

    Parameters
    ----------
    x:

    plane:

    n_slices:

    Returns
    -------
    """

    options = ["sagittal", "coronal", "axial", "combined"]

    if isinstance(plane, str) and plane in options:
        idxs = np.random.randint(x.shape[0], size=(n_slices, 3))
        if plane == "sagittal":
            midx = idxs[:, 0]
            x = x

        if plane == "coronal":
            midx = idxs[:, 1]
            x = tf.transpose(x, perm=[1, 2, 0])

        if plane == "axial":
            midx = idxs[:, 2]
            x = tf.transpose(x, perm=[2, 0, 1])

        if plane == "combined":
            temp = {}
            for op in options[:-1]:
                temp[op] = _structural_slice(x, op, n_slices)
            x = temp

        if not plane == "combined":
            x = tf.squeeze(tf.gather_nd(x, midx.reshape(n_slices, 1, 1)), axis=1)
            x = tf.math.reduce_mean(x, axis=0, keepdims=True)
            x = tf.expand_dims(x, axis=-1)
            x = tf.convert_to_tensor(x)

        return x
    else:
        raise ValueError(
            "Expected plane to be one of [sagittal, coronal, axial, combined]"
        )


def _get_model(model_path):

    """Return `tf.keras.Model` object from a filepath.

    Parameters
    ----------
    path: str, path to HDF5 or SavedModel file.

    Returns
    -------
    Instance of `tf.keras.Model`.

    Raises
    ------
    `ValueError` if cannot load model.
    """

    try:
        p = Path(model_path).resolve()

        model = CombinedClassifier(input_shape=(128, 128), wts_root=p, trainable=False)

        combined_weights = list(Path(os.path.join(p, "combined")).glob("*.h5"))[
            0
        ].resolve()

        model.load_weights(combined_weights)
        model.trainable = False

        return model

    except Exception as e:
        print(e)
        pass

    raise ValueError("Failed to load model.")
    
preds = predict(outputs, model_path='/home/shank/Stanford/nondefaced-detector/nondefaced_detector/models/pretrained_weights')


100%|██████████| 6/6 [00:02<00:00,  2.81it/s]


In [10]:
print(preds)

[('/home/shank/Stanford/nondefaced-detector/examples/sample_vols/faced/preprocessed/example1.nii.gz', 0.99998486), ('/home/shank/Stanford/nondefaced-detector/examples/sample_vols/faced/preprocessed/example2.nii.gz', 0.9999981), ('/home/shank/Stanford/nondefaced-detector/examples/sample_vols/faced/preprocessed/example3.nii.gz', 0.9970654), ('/home/shank/Stanford/nondefaced-detector/examples/sample_vols/defaced/preprocessed/example1.nii.gz', 0.016103715), ('/home/shank/Stanford/nondefaced-detector/examples/sample_vols/defaced/preprocessed/example2.nii.gz', 0.9974597), ('/home/shank/Stanford/nondefaced-detector/examples/sample_vols/defaced/preprocessed/example3.nii.gz', 0.0201056)]
