In [1]:
import glob
import os
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import cv2
import scipy.ndimage as ndimage
import torch.optim as optim
import time
import shutil
from sklearn.metrics import roc_curve, auc
from argparse import ArgumentParser, Namespace

import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
import math
from functools import partial
from torch.utils.tensorboard import SummaryWriter

import torchio as tio
from tqdm.auto import tqdm

import sys
sys.path.append('../src')
from seg_model_utils.torchio_transforms import *
from seg_model_utils.brats2021_dataset import BraTS2021
from seg_model_utils.augmentations3d import *
from seg_model_utils.visualization import *
from seg_model_utils.seg_model import UNet3D_v2

In [3]:
model = UNet3D_v2(out_channels=1).cuda()
model.load_state_dict(torch.load(f'../output/seg_model_256--1/seg_model_256--1_last_epoch.pth')['model_state_dict'])
_ = model.eval()

In [4]:
fold = 0
df = pd.read_csv('../input/train_labels_folds-v1.csv')

missing_ids = [169,197,245,308,408,564,794,998]

npy_dir = '../input/registered_cases/train/'
sample_fns_test = [os.path.join(npy_dir, str(_id).zfill(5) + '.npy') for _id in missing_ids]

In [7]:
for fn in sample_fns_test:
    print(f'{fn} exists: {os.path.exists(fn)}')

../input/registered_cases/train/00169.npy exists: True
../input/registered_cases/train/00197.npy exists: True
../input/registered_cases/train/00245.npy exists: True
../input/registered_cases/train/00308.npy exists: True
../input/registered_cases/train/00408.npy exists: True
../input/registered_cases/train/00564.npy exists: True
../input/registered_cases/train/00794.npy exists: True
../input/registered_cases/train/00998.npy exists: True


In [10]:
test_ds = BraTS2021(
    mode='test', 
    npy_fns_list=sample_fns_test, 
    label_list=None,
    augmentations=None,
    volume_normalize=True,
    max_out_size=(256,256,96)
)

In [19]:
def inference_sample_tta(model, image, batch_size=1):
    """Inference 3d image with tta and average predictions. Image shape: 3xWxHxD"""
    model.eval()
    
    def _flip(im, index=0):
        if index == 0:
            return im
        elif index == 1:
            return torch.flip(im, [1])
        elif index == 2:
            return torch.flip(im, [2])
        elif index == 3:
            return torch.flip(im, [3])
        elif index == 4:
            return torch.flip(im, [1,2])
        elif index == 5:
            return torch.flip(im, [1,3])
        elif index == 6:
            return torch.flip(im, [1,2,3])
        elif index == 7:
            return torch.flip(im, [2,3])
        
    def _predict(batch):
        batch.requires_grad=False
        seg_batch_flipped, clf_batch = model(batch.cuda())
        seg_batch_flipped, clf_batch = seg_batch_flipped.detach().cpu(), clf_batch.detach().cpu()
        # logits to preds
        clf_batch = torch.sigmoid(clf_batch)
        return seg_batch_flipped, clf_batch
    
    batch = torch.stack([_flip(image.clone(), index) for index in range(4)], dim=0)
    seg_batch_flipped_list, clf_batch_list = [],[]
    
    with torch.no_grad():    
        for start in range(0, 4, batch_size):
            seg_batch_flipped, clf_batch = _predict(batch[start:start + batch_size])
            
            seg_batch_flipped_list = seg_batch_flipped_list + [seg for seg in seg_batch_flipped]
            clf_batch_list = clf_batch_list + [clf for clf in clf_batch]
    
    # flip masks back
    seg_batch = torch.stack([_flip(seg, index) for index, seg in enumerate(seg_batch_flipped_list)], dim=0)
    
    # average results
    seg = torch.mean(seg_batch, dim=0)
    clf = torch.mean(torch.stack(clf_batch_list, dim=0), dim=0)
    return seg, clf

def create_segs(out_dir, val_ds, model):
    # create dirs
    if not os.path.exists(out_dir): os.mkdir(out_dir)
    pred_dir = os.path.join(out_dir, 'oof_preds')
    if not os.path.exists(pred_dir): os.mkdir(pred_dir)
    vis_dir = os.path.join(out_dir, 'vis')
    if not os.path.exists(vis_dir): os.mkdir(vis_dir)
    
    for val_index in tqdm(range(len(val_ds))):
        sample = val_ds.__getitem__(val_index)
        bratsid = f'{int(sample["BraTSID"]):05d}'
        
        seg, clf = inference_sample_tta(model, sample['image'])
        
        # save oof preds
        seg_fn = os.path.join(pred_dir, f'{bratsid}_seg.npy')
        np.save(seg_fn, seg.cpu().numpy())
        

In [11]:
out_dir = os.path.join(f'../output/', 'seg_model_256_missing_cases')
if not os.path.exists(out_dir): os.mkdir(out_dir)

In [20]:
create_segs(out_dir, test_ds, model)

  0%|          | 0/8 [00:00<?, ?it/s]

In [23]:
seg_dir = '../output/seg_model_256_missing_cases/segmentations/'
binary_dir = '../output/seg_model_256_missing_cases/binary_segmentations/'
if not os.path.exists(binary_dir): os.mkdir(binary_dir)
for fn in os.listdir(seg_dir):
    seg = np.load(os.path.join(seg_dir, fn))
    seg = (seg > 0.5).astype(np.bool)
    np.save(os.path.join(binary_dir, fn), seg)