In [1]:
import os
import torch 
from models_scripts import i3_res50, i3_res50_nl, disable_bn, enable_bn
from utilities_scripts import SAM, LR_Scheduler, get_criterion, LoadingBar, Log, initialize, RandAugment
from dataset_scripts import CTDataset, CTDatasetTestSimple
import json

from torch.utils.data import DataLoader
import torchvision

batch_size = 1
cuda_device_index = 0
n_class = 2 # extend number of classes
root = "/home/sentic/storage2/iccv_madu/fold_1"
num_workers = 2 # workers for dataloader
fold_test_path = "./test_folding.json"
fold_id = "1"
checkpoint_dir = "/home/sentic/storage2/iccv_madu/checkpoints/model1_basicAUG_fold1_0"
model_path = "checkpoint_model1_1_39.pth"
# checkpoint_dir = "/home/sentic/Documents/data/storage2/LEUKEMIA/C-NMC_Leukemia/checkpoints/"
device = torch.device("cuda:" + str(cuda_device_index) if torch.cuda.is_available() else "cpu")
prepath = ""
# replacer = "/home/sentic/Documents/data/storage2/LEUKEMIA/C-NMC_Leukemia"
replacer = ""
clip_len = 128
threshold = 0.41

In [2]:
with open(fold_test_path) as fhandle:
    fold_splitter_test = json.load(fhandle)
    
def find_frames_by_name(fname, paths=fold_splitter_test["1"]["paths"],
                        frames=fold_splitter_test["1"]["frames"]):
    for ix, (p, f) in enumerate(zip(paths, frames)):
        if p == fname:
            return f
    return 0

def choose_best_option(predictions_for_input, method="highest"):
    if method == "highest":
        return sorted(predictions_for_input, key=lambda x: x[1], reverse=True)[0]
    elif method == "frequent":
        list_positive = [x for x in predictions_for_input if x[0] == 1]
        list_negative = [x for x in predictions_for_input if x[0] == 0]
        if len(list_positive) >= len(list_negative):
            return (1, 1.0)
        else:
            return (0, 0.0)
    elif method == "weight":
        list_positive = [x for x in predictions_for_input if x[0] == 1]
        list_negative = [x for x in predictions_for_input if x[0] == 0]
        score_positive, score_negative = 0, 0
        if len(list_positive) > 0:
            score_positive = sum([x[1] for x in list_positive]) / len(list_positive)
        if len(list_negative) > 0:
            score_negative = sum([x[1] for x in list_negative]) / len(list_negative)
        if score_positive >= score_negative:
            return (1, 1.0)
        else:
            return (0, 1.0)

In [3]:
path_checkpoint = os.path.join(checkpoint_dir, model_path)
dict_checkpoint = torch.load(path_checkpoint, map_location='cuda:' + str(cuda_device_index))
net_state_dict = dict_checkpoint['model_state_dict']

model = i3_res50_nl(n_class)


model.load_state_dict(net_state_dict)
######################
model.to(device)

I3Res50(
  (conv1): Conv3d(3, 64, kernel_size=(5, 7, 7), stride=(2, 2, 2), padding=(2, 3, 3), bias=False)
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool1): MaxPool3d(kernel_size=(2, 3, 3), stride=(2, 2, 2), padding=(0, 0, 0), dilation=1, ceil_mode=False)
  (maxpool2): MaxPool3d(kernel_size=(2, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0), dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv3d(64, 64, kernel_size=(3, 1, 1), stride=(1, 1, 1), padding=(1, 0, 0), bias=False)
      (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
      (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
      (bn3): BatchNorm3d(256, e

In [4]:
dataset_test = CTDatasetTestSimple(root=root, 
                      fold_id=fold_id, 
                      fold_splitter=fold_splitter_test,
                      transforms=None,
                      replacer="",
                      prepath="",
                      clip_len=clip_len,
                      split="test"
                      )

dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [5]:
# import torch.nn as nn
# import torch
# model.eval()

# d = {"fname": [],
#     "label": [],
#     "score": []}

# trues = []
# predicted = []
# scores_list = []

# with torch.no_grad():
#     for batch in dataloader_test:
#         inputs, targets = (b.to(device) for b in batch[:2])
#         fname = batch[2][0]
#         T = inputs.shape[2]
#         predictions_for_input = []
#         if T <= clip_len:
#             original_num_frames = find_frames_by_name(fname)
#             #################################################
#             predictions = model(inputs) # forward input
#             scores = torch.nn.functional.softmax(predictions, dim=1)
#             if scores[0][1] >= threshold:
#                 prediction = 1
#             else:
#                 prediction = 0
#             score = scores[0][prediction].item()
#             predictions_for_input.append((prediction, score))
#             ##################################################
#             inputs1 = inputs[:, :, :original_num_frames, :, :]
#             inputs1 = torch.flip(inputs1, (2,))
#             inputs2 = inputs[:, :, original_num_frames:, :, :]
#             inputs = torch.cat([inputs1, inputs2], axis=2)
#             predictions = model(inputs) # forward input
#             del inputs1
#             del inputs2
#             scores = torch.nn.functional.softmax(predictions, dim=1)
#             if scores[0][1] >= threshold:
#                 prediction = 1
#             else:
#                 prediction = 0
#             score = scores[0][prediction].item()
#             predictions_for_input.append((prediction, score))
#             ##################################################
#         elif T > clip_len and T <= 2 * clip_len:
#             predictions = model(inputs) # forward input
#             scores = torch.nn.functional.softmax(predictions, dim=1)
#             if scores[0][1] >= threshold:
#                 prediction = 1
#             else:
#                 prediction = 0
#             score = scores[0][prediction].item()
#             predictions_for_input.append((prediction, score))
#             ##################################################
#             inputs = torch.flip(inputs, (2,))
#             predictions = model(inputs) # forward input
#             scores = torch.nn.functional.softmax(predictions, dim=1)
#             if scores[0][1] >= threshold:
#                 prediction = 1
#             else:
#                 prediction = 0
#             score = scores[0][prediction].item()
#             predictions_for_input.append((prediction, score))
#             ###################################################
#         elif T > 2 * clip_len and T <= 4 * clip_len:
#             leap = 2
#             offset = 0
#             inputs1 = inputs[:, :, offset::leap, :, :]
#             predictions = model(inputs1) # forward input
#             del inputs1
#             scores = torch.nn.functional.softmax(predictions, dim=1)
#             if scores[0][1] >= threshold:
#                 prediction = 1
#             else:
#                 prediction = 0
#             score = scores[0][prediction].item()
#             predictions_for_input.append((prediction, score))
#             ##################################################
#             inputs2 = torch.flip(inputs, (2,))
#             inputs2 = inputs2[:, :, offset::leap, :, :]
#             predictions = model(inputs2) # forward input
#             del inputs2
#             scores = torch.nn.functional.softmax(predictions, dim=1)
#             if scores[0][1] >= threshold:
#                 prediction = 1
#             else:
#                 prediction = 0
#             score = scores[0][prediction].item()
#             predictions_for_input.append((prediction, score))
#             ##################################################
#             leap = 2
#             offset = 1
#             inputs1 = inputs[:, :, offset::leap, :, :]
#             predictions = model(inputs1) # forward input
#             del inputs1
#             scores = torch.nn.functional.softmax(predictions, dim=1)
#             if scores[0][1] >= threshold:
#                 prediction = 1
#             else:
#                 prediction = 0
#             score = scores[0][prediction].item()
#             predictions_for_input.append((prediction, score))
#             ##################################################
#             inputs2 = torch.flip(inputs, (2,))
#             inputs2 = inputs2[:, :, offset::leap, :, :]
#             predictions = model(inputs2) # forward input
#             del inputs2
#             scores = torch.nn.functional.softmax(predictions, dim=1)
#             if scores[0][1] >= threshold:
#                 prediction = 1
#             else:
#                 prediction = 0
#             score = scores[0][prediction].item()
#             predictions_for_input.append((prediction, score))
#             ##################################################
#         elif T > 4 * clip_len:
#             leap = 3
#             offset = 0
#             inputs1 = inputs[:, :, offset::leap, :, :]
#             predictions = model(inputs1) # forward input
#             del inputs1
#             scores = torch.nn.functional.softmax(predictions, dim=1)
#             if scores[0][1] >= threshold:
#                 prediction = 1
#             else:
#                 prediction = 0
#             score = scores[0][prediction].item()
#             predictions_for_input.append((prediction, score))
#             ##################################################
#             inputs2 = torch.flip(inputs, (2,))
#             inputs2 = inputs2[:, :, offset::leap, :, :]
#             predictions = model(inputs2) # forward input
#             del inputs2
#             scores = torch.nn.functional.softmax(predictions, dim=1)
#             if scores[0][1] >= threshold:
#                 prediction = 1
#             else:
#                 prediction = 0
#             score = scores[0][prediction].item()
#             predictions_for_input.append((prediction, score))
#             ##################################################
#             leap = 3
#             offset = 1
#             inputs1 = inputs[:, :, offset::leap, :, :]
#             predictions = model(inputs1) # forward input
#             del inputs1
#             scores = torch.nn.functional.softmax(predictions, dim=1)
#             if scores[0][1] >= threshold:
#                 prediction = 1
#             else:
#                 prediction = 0
#             score = scores[0][prediction].item()
#             predictions_for_input.append((prediction, score))
#             ##################################################
#             inputs2 = torch.flip(inputs, (2,))
#             inputs2 = inputs2[:, :, offset::leap, :, :]
#             predictions = model(inputs2) # forward input
#             del inputs2
#             scores = torch.nn.functional.softmax(predictions, dim=1)
#             if scores[0][1] >= threshold:
#                 prediction = 1
#             else:
#                 prediction = 0
#             score = scores[0][prediction].item()
#             predictions_for_input.append((prediction, score))
#             ##################################################
#             leap = 3
#             offset = 2
#             inputs1 = inputs[:, :, offset::leap, :, :]
#             predictions = model(inputs1) # forward input
#             del inputs1
#             scores = torch.nn.functional.softmax(predictions, dim=1)
#             if scores[0][1] >= threshold:
#                 prediction = 1
#             else:
#                 prediction = 0
#             score = scores[0][prediction].item()
#             predictions_for_input.append((prediction, score))
#             ##################################################
#             inputs2 = torch.flip(inputs, (2,))
#             inputs2 = inputs2[:, :, offset::leap, :, :]
#             predictions = model(inputs2) # forward input
#             del inputs2
#             scores = torch.nn.functional.softmax(predictions, dim=1)
#             if scores[0][1] >= threshold:
#                 prediction = 1
#             else:
#                 prediction = 0
#             score = scores[0][prediction].item()
#             predictions_for_input.append((prediction, score))
#             ##################################################
#         else:
#             print("NOT GOOD")
#         best_option = choose_best_option(predictions_for_input)
#         trues.append(targets.item())
#         predicted.append(best_option[0])

In [6]:
# from sklearn.metrics import classification_report

In [7]:
# print(classification_report(trues, predicted, target_names=['non-covid', 'covid']))

In [8]:
# from sklearn.metrics import f1_score

# print(f1_score(trues, predicted, average='macro'))

In [9]:
def simple_inference(inputs, leap=0, backward=False, 
                     flipx=False, flipy=False, 
                     original_num_frames=0, 
                     offset=0):
    inputs_clone = inputs.clone().detach()
    if backward == True:
        inputs1 = inputs_clone[:, :, :original_num_frames, :, :]
        inputs1 = torch.flip(inputs1, (2,))
        inputs2 = inputs_clone[:, :, original_num_frames:, :, :]
        inputs_clone = torch.cat([inputs1, inputs2], axis=2)
        del inputs1
        del inputs2
        
    if flipx == True:
        inputs_clone = torch.flip(inputs_clone, (3,))
        
    if flipy == True:
        inputs_clone = torch.flip(inputs_clone, (4,))
        
    if leap != 0:
        inputs_clone = inputs_clone[:, :, offset::leap, :, :]
    return inputs_clone

def decide_score(predictions, threshold, predictions_for_input):
    scores = torch.nn.functional.softmax(predictions, dim=1)
    if scores[0][1] >= threshold:
        prediction = 1
    else:
        prediction = 0
    score = scores[0][prediction].item()
    predictions_for_input.append((prediction, score))
    return predictions_for_input

In [10]:
import torch.nn as nn
import torch
model.eval()

d = {"fname": [],
    "label": [],
    "score": []}

trues = []
predicted = []
scores_list = []

with torch.no_grad():
    for batch in dataloader_test:
        inputs, targets = (b.to(device) for b in batch[:2])
        fname = batch[2][0]
        T = inputs.shape[2]
        predictions_for_input = []
        original_num_frames = inputs.shape[2]
        if T <= clip_len:
            original_num_frames = find_frames_by_name(fname)
            #################################################
            inputs1 = simple_inference(inputs, leap=0, backward=False, flipx=False, flipy=False, 
                     original_num_frames=original_num_frames, offset=0)
            predictions = model(inputs1) # forward input
            del inputs1
            predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
            ##################################################
            inputs1 = simple_inference(inputs, leap=0, backward=True, flipx=False, flipy=False, 
                     original_num_frames=original_num_frames, offset=0)
            predictions = model(inputs1) # forward input
            del inputs1
            predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
            ##################################################
            
            inputs1 = simple_inference(inputs, leap=0, backward=False, flipx=True, flipy=False, 
                     original_num_frames=original_num_frames, offset=0)
            predictions = model(inputs1) # forward input
            del inputs1
            predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
            ##################################################
            
            inputs1 = simple_inference(inputs, leap=0, backward=False, flipx=False, flipy=True, 
                     original_num_frames=original_num_frames, offset=0)
            predictions = model(inputs1) # forward input
            del inputs1
            predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
            ##################################################
            
            inputs1 = simple_inference(inputs, leap=0, backward=False, flipx=True, flipy=True, 
                     original_num_frames=original_num_frames, offset=0)
            predictions = model(inputs1) # forward input
            del inputs1
            predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
            ##################################################
            
            inputs1 = simple_inference(inputs, leap=0, backward=True, flipx=True, flipy=False, 
                     original_num_frames=original_num_frames, offset=0)
            predictions = model(inputs1) # forward input
            del inputs1
            predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
            ##################################################
            
            inputs1 = simple_inference(inputs, leap=0, backward=True, flipx=False, flipy=True, 
                     original_num_frames=original_num_frames, offset=0)
            predictions = model(inputs1) # forward input
            del inputs1
            predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
            ##################################################
            
            inputs1 = simple_inference(inputs, leap=0, backward=True, flipx=True, flipy=True, 
                     original_num_frames=original_num_frames, offset=0)
            predictions = model(inputs1) # forward input
            del inputs1
            predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
            ##################################################
            
        elif T > clip_len and T <= 2 * clip_len:
            inputs1 = simple_inference(inputs, leap=0, backward=False, flipx=False, flipy=False, 
                     original_num_frames=original_num_frames, offset=0)
            predictions = model(inputs1) # forward input
            del inputs1
            predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
            ##################################################
            inputs1 = simple_inference(inputs, leap=0, backward=True, flipx=False, flipy=False, 
                     original_num_frames=original_num_frames, offset=0)
            predictions = model(inputs1) # forward input
            del inputs1
            predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
            ###################################################
            
            
            inputs1 = simple_inference(inputs, leap=0, backward=False, flipx=True, flipy=False, 
                     original_num_frames=original_num_frames, offset=0)
            predictions = model(inputs1) # forward input
            del inputs1
            predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
            ##################################################
            
            inputs1 = simple_inference(inputs, leap=0, backward=False, flipx=False, flipy=True, 
                     original_num_frames=original_num_frames, offset=0)
            predictions = model(inputs1) # forward input
            del inputs1
            predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
            ##################################################
            
            inputs1 = simple_inference(inputs, leap=0, backward=False, flipx=True, flipy=True, 
                     original_num_frames=original_num_frames, offset=0)
            predictions = model(inputs1) # forward input
            del inputs1
            predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
            ##################################################
            
            inputs1 = simple_inference(inputs, leap=0, backward=True, flipx=True, flipy=False, 
                     original_num_frames=original_num_frames, offset=0)
            predictions = model(inputs1) # forward input
            del inputs1
            predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
            ##################################################
            
            inputs1 = simple_inference(inputs, leap=0, backward=True, flipx=False, flipy=True, 
                     original_num_frames=original_num_frames, offset=0)
            predictions = model(inputs1) # forward input
            del inputs1
            predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
            ##################################################
            
            inputs1 = simple_inference(inputs, leap=0, backward=True, flipx=True, flipy=True, 
                     original_num_frames=original_num_frames, offset=0)
            predictions = model(inputs1) # forward input
            del inputs1
            predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
            ##################################################
        elif T > 2 * clip_len and T <= 4 * clip_len:
            leap = 2
            for offset in range(leap):
                inputs1 = simple_inference(inputs, leap=leap, backward=False, flipx=False, flipy=False, 
                                           original_num_frames=original_num_frames, offset=offset)
                predictions = model(inputs1)
                del inputs1
                predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
                
                inputs1 = simple_inference(inputs, leap=leap, backward=True, flipx=False, flipy=False, 
                                           original_num_frames=original_num_frames, offset=offset)
                predictions = model(inputs1)
                del inputs1
                predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
                ### 
                inputs1 = simple_inference(inputs, leap=leap, backward=False, flipx=True, flipy=False, 
                         original_num_frames=original_num_frames, offset=offset)
                predictions = model(inputs1) # forward input
                del inputs1
                predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
                ##################################################

                inputs1 = simple_inference(inputs, leap=leap, backward=False, flipx=False, flipy=True, 
                         original_num_frames=original_num_frames, offset=offset)
                predictions = model(inputs1) # forward input
                del inputs1
                predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
                ##################################################

                inputs1 = simple_inference(inputs, leap=leap, backward=False, flipx=True, flipy=True, 
                         original_num_frames=original_num_frames, offset=offset)
                predictions = model(inputs1) # forward input
                del inputs1
                predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
                ##################################################

                inputs1 = simple_inference(inputs, leap=leap, backward=True, flipx=True, flipy=False, 
                         original_num_frames=original_num_frames, offset=offset)
                predictions = model(inputs1) # forward input
                del inputs1
                predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
                ##################################################

                inputs1 = simple_inference(inputs, leap=leap, backward=True, flipx=False, flipy=True, 
                         original_num_frames=original_num_frames, offset=offset)
                predictions = model(inputs1) # forward input
                del inputs1
                predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
                ##################################################

                inputs1 = simple_inference(inputs, leap=leap, backward=True, flipx=True, flipy=True, 
                         original_num_frames=original_num_frames, offset=offset)
                predictions = model(inputs1) # forward input
                del inputs1
                predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
                ##################################################
        elif T > 4 * clip_len:
            leap = 3
            for offset in range(leap):
                inputs1 = simple_inference(inputs, leap=leap, backward=False, flipx=False, flipy=False, 
                                           original_num_frames=original_num_frames, offset=offset)
                predictions = model(inputs1)
                del inputs1
                predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
                
                inputs1 = simple_inference(inputs, leap=leap, backward=True, flipx=False, flipy=False, 
                                           original_num_frames=original_num_frames, offset=offset)
                predictions = model(inputs1)
                del inputs1
                predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
                
                ### 
                inputs1 = simple_inference(inputs, leap=leap, backward=False, flipx=True, flipy=False, 
                         original_num_frames=original_num_frames, offset=offset)
                predictions = model(inputs1) # forward input
                del inputs1
                predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
                ##################################################

                inputs1 = simple_inference(inputs, leap=leap, backward=False, flipx=False, flipy=True, 
                         original_num_frames=original_num_frames, offset=offset)
                predictions = model(inputs1) # forward input
                del inputs1
                predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
                ##################################################

                inputs1 = simple_inference(inputs, leap=leap, backward=False, flipx=True, flipy=True, 
                         original_num_frames=original_num_frames, offset=offset)
                predictions = model(inputs1) # forward input
                del inputs1
                predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
                ##################################################

                inputs1 = simple_inference(inputs, leap=leap, backward=True, flipx=True, flipy=False, 
                         original_num_frames=original_num_frames, offset=offset)
                predictions = model(inputs1) # forward input
                del inputs1
                predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
                ##################################################

                inputs1 = simple_inference(inputs, leap=leap, backward=True, flipx=False, flipy=True, 
                         original_num_frames=original_num_frames, offset=offset)
                predictions = model(inputs1) # forward input
                del inputs1
                predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
                ##################################################

                inputs1 = simple_inference(inputs, leap=leap, backward=True, flipx=True, flipy=True, 
                         original_num_frames=original_num_frames, offset=offset)
                predictions = model(inputs1) # forward input
                del inputs1
                predictions_for_input = decide_score(predictions, threshold, predictions_for_input)
                ##################################################
        else:
            print("NOT GOOD")
        best_option = choose_best_option(predictions_for_input, method="heighest")
        trues.append(targets.item())
        predicted.append(best_option[0])

In [11]:
from sklearn.metrics import classification_report
print(classification_report(trues, predicted, target_names=['non-covid', 'covid']))

              precision    recall  f1-score   support

   non-covid       0.89      0.97      0.93       209
       covid       0.95      0.85      0.90       165

    accuracy                           0.92       374
   macro avg       0.92      0.91      0.91       374
weighted avg       0.92      0.92      0.92       374



In [12]:
from sklearn.metrics import f1_score
print(f1_score(trues, predicted, average='macro'))

0.914847049318791
