## Make predictions on the ImageCLEF test dataset (421 samples)

In [1]:
colab = False
if colab:
    # Mount drive
    from google.colab import drive
    drive.mount('/content/gdrive')
    # Set path to working directory
    import sys
    sys.path.append('/content/gdrive/My Drive/ImageCLEF2021/')
    %cd /content/gdrive/My\ Drive/ImageCLEF2021/

In [3]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import numpy as np
import pickle
import tensorflow as tf
from tqdm import tqdm
from tensorflow.keras.models import load_model


CFG_NAME = "augmented"

ROOT_DIR = os.path.abspath("../")
LOG_PATH = os.path.join(ROOT_DIR, "logs", CFG_NAME)

DATASET_PATH = os.path.join(ROOT_DIR, "dataset/")
TEST_VOLUMES_PATH = os.path.join(DATASET_PATH, "test_volumes_numpy")
PATIENT_NAMES_PATH = os.path.join(DATASET_PATH, "patient_names.txt")

with open(PATIENT_NAMES_PATH, "rb") as fp:
  PATIENT_NAMES_ = pickle.load(fp)

PATIENT_NAMES = [x+".nii.gz" for x in PATIENT_NAMES_]
print(PATIENT_NAMES[:3])

['TST_0001.nii.gz', 'TST_0002.nii.gz', 'TST_0003.nii.gz']


## Helpers

In [4]:
def get_number(filename):
  return int(filename[:filename.find('.')])
        
def sort_paths(paths):
  paths.sort(key = get_number)
  return paths

### Get test data

In [5]:
# 1.npy, 2.npy, 3.npy etc.
volume_path_ = sort_paths(os.listdir(TEST_VOLUMES_PATH))
# Sorted patient names in folder
volume_paths = [os.path.join(TEST_VOLUMES_PATH, image_id) for image_id in volume_path_]

print(volume_paths[:3])

['/home/hz/tbt-classification/dataset/test_volumes_numpy/1.npy', '/home/hz/tbt-classification/dataset/test_volumes_numpy/2.npy', '/home/hz/tbt-classification/dataset/test_volumes_numpy/3.npy']


### Check images

In [6]:
ct = 200
image = np.load(volume_paths[ct])
image.shape

(70, 224, 224, 1)

In [7]:
np.min(image), np.max(image)

(0.0, 1.0)

In [8]:
np.unique(image)

array([0.0000000e+00, 1.4012985e-45, 2.8025969e-45, ..., 9.9999970e-01,
       9.9999976e-01, 1.0000000e+00], dtype=float32)

### Load model

In [11]:
from tensorflow.keras.models import load_model
from keras.utils.data_utils import get_file

def get_model():
    URL = "https://github.com/hasibzunair/ViPTT-Net/releases/tag/v0.0.1/download/ViPTT-Net-CLEF-TBT.h5"
    weights_path = get_file("ViPTT-Net-CLEF-TBT.h5", URL)
    model = load_model(weights_path, compile = False)
    return model

In [12]:
model = None
model = get_model()
model.summary()

Downloading data from https://github.com/hasibzunair/ViPTT-Net/releases/tag/v0.0.1/download/ViPTT-Net-CLEF-TBT.h5


Exception: URL fetch failure on https://github.com/hasibzunair/ViPTT-Net/releases/tag/v0.0.1/download/ViPTT-Net-CLEF-TBT.h5 : 404 -- Not Found

### Make predictions

In [12]:
class_names = [1,2,3,4,5]
class_dict = {i: cat for (i, cat) in enumerate(class_names)}

predictions = []

for path in tqdm(volume_paths):
    features = np.load(path)
    features = np.expand_dims(features, axis=0)
    pred = model.predict(features)
    pred = np.argmax(pred)
    predictions.append(class_dict[pred])

print(predictions[:3])
print(len(predictions))

100%|██████████| 421/421 [10:35<00:00,  1.51s/it]

[2, 1, 5]
421





In [13]:
with open('{}/{}_submission.txt'.format(LOG_PATH, CFG_NAME), 'w') as f:
    for n, p in zip(PATIENT_NAMES, predictions):
        print(n,",", p)
        f.write(str(n))
        f.write(",")
        f.write(str(p))
        f.write("\n")
        
f.close()

TST_0001.nii.gz , 2
TST_0002.nii.gz , 1
TST_0003.nii.gz , 5
TST_0004.nii.gz , 1
TST_0005.nii.gz , 5
TST_0006.nii.gz , 4
TST_0007.nii.gz , 2
TST_0008.nii.gz , 1
TST_0009.nii.gz , 1
TST_0010.nii.gz , 2
TST_0011.nii.gz , 3
TST_0012.nii.gz , 1
TST_0013.nii.gz , 2
TST_0014.nii.gz , 3
TST_0015.nii.gz , 1
TST_0016.nii.gz , 1
TST_0017.nii.gz , 4
TST_0018.nii.gz , 2
TST_0019.nii.gz , 2
TST_0020.nii.gz , 2
TST_0021.nii.gz , 1
TST_0022.nii.gz , 2
TST_0023.nii.gz , 1
TST_0024.nii.gz , 5
TST_0025.nii.gz , 1
TST_0026.nii.gz , 2
TST_0027.nii.gz , 1
TST_0028.nii.gz , 1
TST_0029.nii.gz , 4
TST_0030.nii.gz , 1
TST_0031.nii.gz , 5
TST_0032.nii.gz , 5
TST_0033.nii.gz , 1
TST_0034.nii.gz , 1
TST_0035.nii.gz , 1
TST_0036.nii.gz , 2
TST_0037.nii.gz , 1
TST_0038.nii.gz , 1
TST_0039.nii.gz , 5
TST_0040.nii.gz , 4
TST_0041.nii.gz , 1
TST_0042.nii.gz , 2
TST_0043.nii.gz , 1
TST_0044.nii.gz , 4
TST_0045.nii.gz , 1
TST_0046.nii.gz , 5
TST_0047.nii.gz , 1
TST_0048.nii.gz , 1
TST_0049.nii.gz , 2
TST_0050.nii.gz , 1


In [14]:
print("Done!")

Done!
