In [None]:
import argparse
import mxnet as mx
import os
import pandas as pd

from matplotlib import pyplot as plt
from mxnet import gluon
from mxnet.gluon.data.vision import datasets
from tqdm import tqdm_notebook as tqdm

import rsna_heme as rsna

In [None]:
args = argparse.Namespace()

args.params_path = '../params/resnet50_v2_fold0_2019-10-08_193926/resnet50_v2_fold0_2019-10-08_193926_best.params'
args.data_dir = '/mnt/Data2/datasets/rsna_heme/normalized'
args.dcm_dir = '/mnt/Data2/datasets/rsna_heme/stage_1_test_images'
args.pred_dir = '../predictions'

args.model_name = 'resnet50_V2'
args.classes = 6

args.ctx = [mx.gpu(3)]
args.batch_size = 60
args.num_workers = 4

time_str = rsna.util.get_time()
args.save_prefix = os.path.join(args.pred_dir, time_str)

In [None]:
if not os.path.exists(args.save_prefix):
    os.makedirs(args.save_prefix)

In [None]:
test_dataset = datasets.ImageRecordDataset(os.path.join(args.data_dir, 'test.rec'), flag=1)
test_data = gluon.data.DataLoader(test_dataset.transform_first(rsna.transforms.val_transform), batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

In [None]:
plt.imshow(test_dataset[19][0].asnumpy())

In [None]:
net = gluon.model_zoo.vision.get_model(args.model_name, pretrained = True)
with net.name_scope():
    net.output = gluon.nn.Dense(args.classes)
net.load_parameters(args.params_path)
net.collect_params().reset_ctx(args.ctx)
net.hybridize()

In [None]:
probs = pd.DataFrame(columns=rsna.labels.heme_types)

In [None]:
for i, batch in enumerate(tqdm(test_data)):
    data    = gluon.utils.split_and_load(batch[0], ctx_list=args.ctx, batch_axis=0, even_split=False)
    outputs = [net(X) for X in data]
    probs = probs.append(pd.DataFrame(mx.nd.sigmoid(outputs[0]).asnumpy().round(3), columns=probs.columns))

In [None]:
ids = rsna.labels.ids_from_dir(args.dcm_dir)

In [None]:
probs = pd.concat([probs.reset_index(drop=True), ids.reset_index()], axis=1)

In [None]:
probs

In [None]:
probs_long = pd.melt(probs, id_vars='ID', value_name='Label').sort_values(['ID', 'variable'])
probs_long['ID'] = probs_long.loc[:, ['ID', 'variable']].apply(lambda x: '_'.join(x), axis=1)
probs_long.drop(['variable'], axis=1, inplace=True)

In [None]:
probs_long.to_csv(os.path.join(args.save_prefix, 'predictions.csv'), index=False)