In [11]:
import os
import torch
import json
import torch.nn as nn
import os
from utils import get_dataloader, load_test_dataset
from dataset import BratsDataset
import numpy as np
from tqdm import tqdm
from eval_utils import compute_dice
from skimage.transform import resize

# Arrange GPU devices starting from 0
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"  # Set the GPUs 2 and 3 to use

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print('Device:', device)
print('Current cuda device:', torch.cuda.current_device())
print('Count of using GPUs:', torch.cuda.device_count())

with open('models/model_subscriptions.json', 'r') as f:
    models_info = json.load(f)
model_name = 'ch3_32_interval_3_240'
model_info = next(
    (model for model in models_info if model['model_name'] == model_name), None)
print(model_info)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
_model = torch.load(os.path.join('models', model_name, f"{model_name}.pth"))
_model = _model.to(device)
model = nn.DataParallel(_model).to(device)
model.eval()

test_dataloader = get_dataloader(dataset=BratsDataset, phase="test",
                                 resize_info=model_info['resize_info'], img_width=model_info['img_size'], data_type=model_info['used_channel'], batch_size=4)
# Check dataloader
# test_batch = next(iter(test_dataloader))
# batch_id, images, targets = test_batch['Id'], test_batch['image'], test_batch['mask']
# images = images.to(device)
# targets = targets.to(device)
# print('batch id', batch_id)
# print('loaded image, target shape', images.shape, targets.shape)


Device: cuda
Current cuda device: 0
Count of using GPUs: 2
{'model_name': 'ch3_32_interval_3_240', 'depth / in_channel / n_channel': [52, 3, 32], 'img_size': 240, 'used_channel': ['-t1n.nii.gz', '-t1c.nii.gz', '-t2f.nii.gz'], 'val score(dice/jaccard)': [88, 81], 'batch/total epoch/best epoch': [1, 50, 33], 'resize_info': [0, 155, 3], 'run time(m)': 1360.0}


In [17]:
dice_scores = []
csv = []
with torch.no_grad():
    for itr, data_batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader), desc=f'Test Set'):
        if(itr==3): break
        batch_id, images, targets = data_batch['Id'], data_batch['image'], data_batch['mask']
        images = images.to(device)
        logits = model(images)
        pred = torch.sigmoid(logits).detach().cpu().numpy()
        threshold = 0.33
        pred = (pred >= threshold).astype(int)
        pred = np.array([resize(_pred, (3, 155, 240, 240), preserve_range=True)
                         for _pred in pred])
        if (pred.shape == targets.shape):
            dice_score = compute_dice(targets, pred)
            print(f"Dice score({itr}): {dice_score}")
            csv.append({'batch_id':batch_id, 'score':dice_score})
            dice_scores.append(dice_score)
        else:
            print(pred.shape, targets.shape)
        del images, targets, logits, pred
        torch.cuda.empty_cache()
dice_scores = np.array(dice_scores)
print("Total Dice score:", dice_scores.mean())

Test Set:   3%|▎         | 1/32 [00:10<05:21, 10.36s/it]

Dice score(0): 0.9082398872684248


Test Set:   6%|▋         | 2/32 [00:18<04:25,  8.86s/it]

Dice score(1): 0.8739172389505194


Test Set:   9%|▉         | 3/32 [00:25<04:11,  8.66s/it]

Dice score(2): 0.8748600830269028





Total Dice score: 0.8856724030819491


In [18]:
csv

[{'batch_id': ['BraTS-GLI-00744-000',
   'BraTS-GLI-01205-000',
   'BraTS-GLI-01161-000',
   'BraTS-GLI-00714-001'],
  'score': 0.9082398872684248},
 {'batch_id': ['BraTS-GLI-01419-000',
   'BraTS-GLI-00120-000',
   'BraTS-GLI-01314-000',
   'BraTS-GLI-00231-000'],
  'score': 0.8739172389505194},
 {'batch_id': ['BraTS-GLI-01510-000',
   'BraTS-GLI-00715-001',
   'BraTS-GLI-00322-000',
   'BraTS-GLI-01476-000'],
  'score': 0.8748600830269028}]

In [21]:
import pandas as pd
df = pd.DataFrame(csv)
df.to_excel(os.path.join('models',model_name,'logs','dice_score_inference.xlsx'))

In [15]:
import pandas as pd

# Data to be written into the Excel file
data = {
    'batch_id': ['John', 'Alice', 'Bob'],
    'Age': [30, 25, 35],
    'Country': ['USA', 'UK', 'Canada']
}

# Create a DataFrame from the data
df = pd.DataFrame(data)

df

Unnamed: 0,Name,Age,Country
0,John,30,USA
1,Alice,25,UK
2,Bob,35,Canada
