# Evaluate model

This notebook is intended for running model, measuring dice metric and creating predictions to `.nii.gz` files.

NB: CT data is in [LPS orientation](https://www.slicer.org/wiki/Coordinate_systems) format.

### Downloading data


Data is available on Yandex.disk - https://disk.yandex.ru/d/pWEKt6D3qi3-aw

In [16]:
import requests
from urllib.parse import urlencode

base_url = 'https://cloud-api.yandex.net/v1/disk/public/resources/download?'
public_key = 'https://disk.yandex.ru/d/pWEKt6D3qi3-aw'

final_url = base_url + urlencode(dict(public_key=public_key))
response = requests.get(final_url)
download_url = response.json()['href']
response = requests.get(download_url)

dist_path = 'AVUCTK_cases.zip'
with open(dist_path, 'wb') as f:
    f.write(response.content)
    
import zipfile
with zipfile.ZipFile(dist_path, 'r') as zip_ref:
    zip_ref.extractall()

### Preparing variables

In [38]:
import json
import os
import copy
import torch
import monai 

MMAR_ROOT = 'kidney-mmar'

import sys
sys.path.append(f'{MMAR_ROOT}/custom')
from my_transforms import *

KIDNEY_DATASET = f'dataset.json'
DATASET_ROOT = os.getcwd()


with open(KIDNEY_DATASET, 'r') as f:
    dataset_json = json.load(f)

model = torch.load(os.path.join(MMAR_ROOT,'models','model.pt'), map_location=torch.device('cpu'))
train_config = model['train_conf']

validation_with_root = [{
                         'artery':obj['artery'],
                         'vein':obj['vein'],
                         'excret':obj['excret'],
                         'label':obj['label']
                        } for obj in dataset_json['validation']]

### Creating preprocessing transforms pipeline

In [26]:
pre_transforms = []
transform_map = {}
for trfm in  train_config['train']['pre_transforms']:
    if 'name' in trfm.keys():
        key = trfm['name']
        if '#' in key:
            key = key.split('#')[1]
        transform_map[key] = trfm
for transform_json in train_config['validate']['pre_transforms']:
    if 'ref' in transform_json.keys():
        transform_name, transform_args = transform_map[transform_json['ref']]['name'], transform_map[transform_json['ref']]['args'] 
    elif 'name' in transform_json.keys():
        transform_name, transform_args = transform_json['name'], transform_json['args']
    else:
        transform_name, transform_args = transform_json['path'], transform_json['args']
    transform_name = transform_name.split('#')[0]
    if '.' not in transform_name:
        pre_transforms.append(getattr(monai.transforms,transform_name)(**transform_args))
    else:
        pre_transforms.append(ConcatImages(['artery','vein','excret'],'image'))


### Creating postprocessing transforms pipeline

In [27]:
post_transforms = []
for transform_json in train_config['validate']['post_transforms']:
    transform_name, transform_args = transform_json['name'], transform_json['args']
    if monai.__version__ == '1.0.0':
        if transform_name in ['AsDiscreteD','AsDiscreted']:
            if 'threshold_values' in transform_args.keys():
                if 'logit_thresh' in transform_args.keys():
                    transform_args['threshold'] = [transform_args['logit_thresh'] if _ else None for _ in transform_args['threshold_values'] ]
                else:
                    transform_args['threshold'] = [_ if _ else None for _ in transform_args['threshold_values'] ]
                del transform_args['threshold_values']
                
            if 'logit_thresh' in transform_args.keys():
                del transform_args['logit_thresh']
            if 'n_classes' in transform_args.keys():
                transform_args['to_onehot'] = [transform_args['n_classes'] if _  else None for _ in transform_args['to_onehot'] ]
                del transform_args['n_classes']
    if transform_name == 'InvertD':
        transform_args['transform'] = pre_transforms[3]
        transform_args['device'] = 'cpu'
    post_transforms.append(getattr(monai.transforms,transform_name)(**transform_args))

### Loading model and inferer

In [36]:
device = torch.device('cpu')  if not torch.cuda.is_available() else torch.device('cuda:0')

model = torch.load(os.path.join(MMAR_ROOT,'models','model.pt'), map_location=torch.device('cpu'))
model_name, model_args = train_config['train']['model']['name'], train_config['train']['model']['args']
model_arch = getattr(monai.networks.nets, model_name)(**model_args)
model_arch.load_state_dict(model['model'])
model_arch = model_arch.eval().to(device)

inferer_name, inferer_args = train_config['validate']['inferer']['name'], train_config['validate']['inferer']['args']
inferer = getattr(monai.inferers, inferer_name)(**inferer_args)

### Inference

Running all pipelines and saving prediction to directory

In [39]:
save_transform = monai.transforms.SaveImaged(keys=['pred'],
                                             separate_folder=False,
                                             output_dtype=np.uint8,
                                             meta_keys=['pred_meta_dict'],
                                             data_root_dir='data',
                                             output_dir=f'eval'
                                            )
dices = []
for out in validation_with_root:
    case = out['artery'].split('/')[-2]
    for trfm in pre_transforms:
        out = trfm(out)
    with torch.no_grad():
        out['pred'] = inferer(out['image'].unsqueeze(0).to(device), model_arch)
    out['label'] = out['label'].unsqueeze(0)
    out['pred'] = out['pred'].squeeze()
    for trfm in post_transforms:
        out = trfm(out)
    out['pred_meta_dict'] = copy.copy(out['image_meta_dict'])
    new_pred = out['pred_meta_dict']['filename_or_obj'].split('/')
    new_pred[-1] = 'pred_' + new_pred[-1]
    out['pred_meta_dict']['filename_or_obj'] = '/'.join(new_pred)
    out['label'] = out['label'].permute(1,0,2,3,4)
    out['pred'] = out['pred'].unsqueeze(0)
    if out['pred'].shape == out['label'].shape:
        dice = monai.metrics.compute_meandice(out['pred'].cpu(),out['label'].cpu(), False)
        dices.append(dice)
    out['pred_meta_dict'] = copy.copy(out['image_meta_dict'])
    new_pred = out['pred_meta_dict']['filename_or_obj'].split('/')
    new_pred[-1] = 'pred_' + new_pred[-1]
    out['pred_meta_dict']['filename_or_obj'] = '/'.join(new_pred)
    merged = None
    for idx, channel in enumerate(out['pred'].squeeze()  * torch.tensor([0,1,2,3,4,5]).view(-1,1,1,1)):
        imgvol = channel
        if idx != 6:
            if merged is not None:
                merged = merged + imgvol * ~((merged != 0) & (imgvol != 0))
            else:
                merged = imgvol
    out['pred'] = merged.unsqueeze(0).cpu().numpy()
    save_transform(out)

header = ['Artery', 'Vein', 'Urethra', 'Neoplasm', 'Kidney']

print('| {:^43} |'.format('Dice metric statistics'))
print('='*47)
print('| {:} | {:} | {:} | {:} | {:} |'.format(*header))
print('='*47)
print('| {:^6.2f} | {:^4.2f} | {:^7.2f} | {:^8.2f} | {:^6.2f} |'.format(*torch.cat(dices).mean(0).cpu().numpy()))
print('-'*47)
print('| Mean = {:^36.2f} |'.format(torch.cat(dices).mean().cpu().item()))

file written: eval/case_1/pred_12_trans.nii.gz.
file written: eval/case_2/pred_12_trans.nii.gz.
file written: eval/case_3/pred_12_trans.nii.gz.
file written: eval/case_4/pred_12_trans.nii.gz.
file written: eval/case_5/pred_12_trans.nii.gz.
file written: eval/case_6/pred_12_trans.nii.gz.
file written: eval/case_7/pred_12_trans.nii.gz.
file written: eval/case_8/pred_12_trans.nii.gz.
file written: eval/case_9/pred_12_trans.nii.gz.
|           Dice metric statistics            |
| Artery | Vein | Urethra | Neoplasm | Kidney |
|  0.86  | 0.80 |  0.80   |   0.58   |  0.89  |
-----------------------------------------------
| Mean =                 0.79                 |
