In [1]:
import os
import torch
import json
import torch.nn as nn
import os
from utils import get_dataloader, load_test_dataset
from dataset import BratsDataset
import numpy as np
from tqdm import tqdm
from eval_utils import compute_dice
from skimage.transform import resize

# Arrange GPU devices starting from 0
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"  # Set the GPUs 2 and 3 to use

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print('Device:', device)
print('Current cuda device:', torch.cuda.current_device())
print('Count of using GPUs:', torch.cuda.device_count())

with open('models/model_subscriptions.json', 'r') as f:
    models_info = json.load(f)
model_name = 'ch3_32_interval_3_240'
model_info = next(
    (model for model in models_info if model['model_name'] == model_name), None)
print(model_info)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
_model = torch.load(os.path.join('models', model_name, f"{model_name}.pth"))
_model = _model.to(device)
model = nn.DataParallel(_model).to(device)
model.eval()

test_dataloader = get_dataloader(dataset=BratsDataset, phase="test",
                                 resize_info=model_info['resize_info'], img_width=model_info['img_size'], data_type=model_info['used_channel'], batch_size=4)
# Check dataloader
# test_batch = next(iter(test_dataloader))
# batch_id, images, targets = test_batch['Id'], test_batch['image'], test_batch['mask']
# images = images.to(device)
# targets = targets.to(device)
# print('batch id', batch_id)
# print('loaded image, target shape', images.shape, targets.shape)

predictions_arr = []
batch_ids= []
cnt = 0
with torch.no_grad():
    for itr, data_batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc=f'Test Set'):
        if cnt == 5:
            break
        batch_id, images, targets = data_batch['Id'], data_batch['image'], data_batch['mask']
        batch_ids.append(batch_id)
        images = images.to(device)
        logits = model(images)
        pred = torch.sigmoid(logits).detach().cpu().numpy()
        threshold = 0.33
        pred = (pred >= threshold).astype(int)
        pred = np.array([resize(_pred, (3, 155, 240, 240), preserve_range=True)
                         for _pred in pred])
        predictions_arr.append(pred)
        del images, targets, logits, pred
        torch.cuda.empty_cache()
        cnt += 1

Device: cuda
Current cuda device: 0
Count of using GPUs: 2
{'model_name': 'ch3_32_interval_3_240', 'depth / in_channel / n_channel': [52, 3, 32], 'img_size': 240, 'used_channel': ['-t1n.nii.gz', '-t1c.nii.gz', '-t2f.nii.gz'], 'val score(dice/jaccard)': [88, 81], 'batch/total epoch/best epoch': [1, 50, 33], 'resize_info': [0, 155, 3], 'run time(m)': 1360.0}


Test Set:  16%|█▌        | 5/32 [00:44<03:58,  8.84s/it]


In [2]:
len(predictions_arr), predictions_arr[0].shape

(5, (4, 3, 155, 240, 240))

: 

In [22]:
d = []
for batch_id in batch_ids:
    d.append(load_test_dataset(batch_id))
targets = np.stack(d)
targets.shape

(20, 3, 155, 240, 240)

In [25]:
if(targets.shape==predictions.shape):
    print(compute_dice(targets, predictions))

0.8764808824034401
