In [None]:
import argparse
import glob
import gluoncv
import mxnet as mx
import os
import pandas as pd

from matplotlib import pyplot as plt
from mxnet import gluon
from mxnet.gluon.data.vision import datasets
from tqdm import tqdm_notebook as tqdm

import rsna_heme as rsna

In [None]:
args = argparse.Namespace()

# args.params_path = ['../params/resnet50_v2_fold0_2019-10-08_193926/resnet50_v2_fold0_2019-10-08_193926_best.params']
args.param_paths = glob.glob(os.path.join('../params/test_bagged', '*.params'))
args.n_tta = 3
args.data_dir = '/mnt/Data2/datasets/rsna_heme/normalized'
args.dcm_dir = '/mnt/Data2/datasets/rsna_heme/stage_1_test_images'
args.pred_dir = '../predictions'

args.model_name = 'resnet50_V2'
args.pretrained = False
args.classes = 6

args.ctx = [mx.gpu(3)]
args.batch_size = 60
args.num_workers = 4

time_str = rsna.util.get_time()
args.save_prefix = os.path.join(args.pred_dir, time_str)

In [None]:
test_dataset = datasets.ImageRecordDataset(os.path.join(args.data_dir, 'test.rec'), flag=1, transform=rsna.transforms.common_transform)
test_data = gluon.data.DataLoader(test_dataset.transform_first(rsna.transforms.train_transform), batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

In [None]:
# sampler = gluon.data.SequentialSampler(10)
# test_data = gluon.data.DataLoader(test_dataset.transform_first(rsna.transforms.train_transform), batch_size=args.batch_size, sampler=sampler, num_workers=args.num_workers)
# rsna.util.plt_tensor(test_dataset.transform_first(rsna.transforms.train_transform)[1][0])

In [None]:
plt.imshow(test_dataset[33733][0].asnumpy())

In [None]:
net = rsna.cnn.get_model(args)

In [None]:
logger = rsna.logger.Logger(args.save_prefix, 'predictions')

probs_all = []
for param_path in tqdm(args.param_paths, desc='params'):
    net.load_parameters(param_path)
    for i in tqdm(range(args.n_tta), desc='tta'):
        probs = []
        for batch in tqdm(test_data, 'batch'):
            data    = gluon.utils.split_and_load(batch[0], ctx_list=args.ctx, batch_axis=0, even_split=False)
            outputs = [net(X) for X in data]
            probs.append(pd.DataFrame(mx.nd.sigmoid(outputs[0]).asnumpy(), columns=rsna.labels.heme_types))
        probs = pd.concat(probs, ignore_index=True)
        
        probs_all.append(probs)
probs_all = pd.concat(probs_all)

In [None]:
probs = probs_all.groupby(probs_all.index).mean().round(4)

In [None]:
ids = rsna.labels.ids_from_dir(args.dcm_dir)

In [None]:
probs = pd.concat([probs, ids.reset_index()], axis=1)

In [None]:
probs.sort_values(by=['any'], ascending=False)

In [None]:
probs_long = pd.melt(probs, id_vars='ID', value_name='Label').sort_values(['ID', 'variable'])
probs_long['ID'] = probs_long.loc[:, ['ID', 'variable']].apply(lambda x: '_'.join(x), axis=1)
probs_long.drop(['variable'], axis=1, inplace=True)

In [None]:
probs_long

In [None]:
if not os.path.exists(args.save_prefix):
    os.makedirs(args.save_prefix)
probs_long.to_csv(os.path.join(args.save_prefix, 'predictions.csv'), index=False)
logger.close()