## Convert test data to tfrecords

In [None]:
import random
import nobrainer
import os, sys
sys.path.append("..")
import numpy as np
import nibabel as nb
from glob import glob
from pathlib import Path
from shutil import *
import subprocess
from operator import itemgetter
import pandas as pd

test_root_dir = "/tf/shank/HDDLinux/Stanford/data/mriqc-shared/test_ixi"
csv_path = os.path.join(test_root_dir, "csv")
tf_records_dir = os.path.join(test_root_dir, "tfrecords")

os.makedirs(tf_records_dir, exist_ok=True)

test_csv_path = os.path.join(csv_path, "testing.csv")
test_paths = pd.read_csv(test_csv_path)["X"].values
test_labels = pd.read_csv(test_csv_path)["Y"].values
test_D = list(zip(test_paths, test_labels))
test_write_path = os.path.join(tf_records_dir, 'data-test_shard-{shard:03d}.tfrec')

nobrainer.tfrecord.write(
    features_labels=test_D,
    filename_template=test_write_path,
    examples_per_shard=3)

In [None]:
test_root_dir = '/tf/shank/HDDLinux/Stanford/data/mriqc-shared/test_ixi'
model_save_path = os.path.join(ROOTDIR_B, "model_save_dir_full")
tfrecords_path = os.path.join(test_root_dir, "tfrecords")
plane = "axial"
dataset_plane = get_dataset(
        file_pattern=os.path.join(tfrecords_path, "data-test_*"),
        n_classes=2,
        batch_size=16,
        volume_shape=(128, 128, 128),
        plane=plane,
        mode='test'
    )

print(dataset_plane)

## Inference

In [None]:
import sys, os
sys.path.append('..')
from models.modelN import CombinedClassifier
from dataloaders.dataset import get_dataset


# Tf packages
import tensorflow as tf

def inference(tfrecords_path, weights_path):
    
    model = CombinedClassifier(
        input_shape=(128, 128), dropout=0.4, wts_root=None, trainable=True)
    
    model.load_weights(os.path.abspath(weights_path))
    model.trainable = False
    
    dataset_test = get_dataset(
        file_pattern=os.path.join(tfrecords_path, "data-test_*"),
        n_classes=2,
        batch_size=16,
        volume_shape=(128, 128, 128),
        plane='combined',
        mode='test'
    )

    METRICS = [
        metrics.BinaryAccuracy(name="accuracy"),
        metrics.Precision(name="precision"),
        metrics.Recall(name="recall"),
        metrics.AUC(name="auc"),
    ]
    
    model.compile(
        loss=tf.keras.losses.binary_crossentropy,
        optimizer=Adam(learning_rate=1e-3),
        metrics=METRICS,
    )
    
    results = model.evaluate(dataset_test, batch_size=16)
    predictions = (model.predict(dataset_test) > 0.5).astype(int)
    
    
ROOTDIR_B = '/tf/shank/HDDLinux/Stanford/data/mriqc-shared/experiments/experiment_B/128'
ROOTDIR_A = '/tf/shank/HDDLinux/Stanford/data/mriqc-shared/experiments/experiment_A/128'
test_root_dir = '/tf/shank/HDDLinux/Stanford/data/mriqc-shared/test_ixi'

model_save_path = os.path.join(ROOTDIR_B, "model_save_dir_full")
tfrecords_path = os.path.join(test_root_dir, "tfrecords")
print("TFRECORDS: ", tfrecords_path)
weights_path = os.path.join(model_save_path, 'weights/combined/best-wts.h5')
    
model = CombinedClassifier(
    input_shape=(128, 128), dropout=0.4, wts_root=None, trainable=True
)
model.load_weights(os.path.abspath(weights_path))

print(os.path.join(tfrecords_path, "data-test_*"))

dataset_test = get_dataset(
    file_pattern=os.path.join(tfrecords_path, "data-test_*"),
    n_classes=2,
#     n_slices = 24,
    batch_size=16,
    volume_shape=(128, 128, 128),
    plane='combined',
    mode='test'
)

print(dataset_test)

METRICS = [
            metrics.BinaryAccuracy(name="accuracy"),
            metrics.Precision(name="precision"),
            metrics.Recall(name="recall"),
            metrics.AUC(name="auc"),
        ]

model.compile(
    loss=tf.keras.losses.binary_crossentropy,
    optimizer=Adam(learning_rate=1e-3),
    metrics=METRICS,
)

    
results = model.evaluate(dataset_test, batch_size=16)
predictions = (model.predict(dataset_test) > 0.5).astype(int)

In [None]:
planes = ['coronal'] #, 'coronal', 'sagittal']

for plane in planes:
    
    model = modelN.Submodel(
        input_shape=(128, 128),
        dropout=0.2,
        name=plane,
        include_top=True,
        weights=None,
        trainable=False,
    )
    
    print(os.path.join(model_save_path, plane, 'best-wts.h5'))
    
    model.load_weights(os.path.join(model_save_path, 'weights', plane, 'best-wts.h5'))
    
    dataset_plane = get_dataset(
        file_pattern=os.path.join(tfrecords_path, "data-test_*"),
        n_classes=2,
        batch_size=16,
        volume_shape=(128, 128, 128),
        plane=plane,
        mode='test',)
    
    METRICS = [
        metrics.BinaryAccuracy(name="accuracy"),
        metrics.Precision(name="precision"),
        metrics.Recall(name="recall"),
        metrics.AUC(name="auc"),
    ]
    
    model.summary()
    
    model.compile(
        loss=tf.keras.losses.binary_crossentropy,
        optimizer=Adam(learning_rate=1e-3),
        metrics=METRICS,
    )
    
#     results = model.evaluate(dataset_plane, batch_size=16)
    predictions = (model.predict(dataset_plane) > 0.5).astype(int)

In [None]:
print(len(predictions.flatten()))

In [None]:
import matplotlib.pyplot as plt
corr_pred_map = {}
corr = 0
incorr = 0
for x, y in dataset_plane.as_numpy_iterator():
    
    batch_predictions = (model.predict(x) > 0.5).astype(int)
    all_imgs = []
    for i in range(len(batch_predictions)):
        if batch_predictions.flatten()[i] != y.flatten()[i].astype(int):
            incorr += 1
            print("Predicted: ",batch_predictions.flatten()[i], "Actual: ", y.flatten()[i].astype(int))
        else:
            corr += 1
            
#             fig = plt.figure(figsize=(25, 8))
#             rows, cols = 3, 16
            
#             for i in range(1, cols*rows + 1):
# #                 if i/cols == 1:
# #                     use = x['coronal']
# #                 if i/cols == 2:
# #                     use = x['sagittal']
                    
#                 fig.add_subplot(rows, cols, i)
                
#                 plt.imshow(use[(i-1)%cols,:,:, 0])


#             plt.show()

In [None]:
print(corr, incorr)