In [1]:
import cv2
import os, shutil
import numpy as np
import pandas as pd
from glob import glob
from tqdm import tqdm
tqdm.pandas()
from scipy.ndimage import binary_closing, binary_opening, measurements
import cupy as cp
import sys
import gc
import matplotlib.pyplot as plt
from mmseg.apis import init_segmentor, inference_segmentor
from mmcv.utils import config

In [2]:
# df = pd.read_csv("/root/autodl-tmp/train.csv")
# df.iloc[74650:74670]

In [3]:
THRESHOD = [[0.9,0.5,0.5], [0.8,0.7,0.5]]
SIZE = 360
configs = [
    "/root/custom_cfg/Swin_Small_Uper.py",
]

ckpts = [
    "/root/swin_small_25000_anno_only_832.pth"
]

models = []
img_norm_cfg = dict(mean=[0,0,0], std=[1,1,1], to_rgb=False)
for cfg, ckpt, thrs in zip(configs, ckpts, THRESHOD):
    cfg = config.Config.fromfile(cfg)
    cfg.model.test_cfg.THRESHOD = thrs
    cfg.model.test_cfg.logits = False
    cfg.model.test_cfg.nature = False
    cfg.model.test_cfg.train_test = True
    cfg.model.test_cfg.multi_label = True
    cfg.data.test.pipeline[1].img_ratios=[1.0]
    cfg.data.test.pipeline[1].flip=False
    model = init_segmentor(cfg, ckpt, device='cuda:0')
    models.append(model)

load checkpoint from local path: /root/swin_small_25000_anno_only_832.pth


In [4]:
# df = pd.read_csv("/root/autodl-tmp/train.csv")
# for i in range(len(df)):
#     if df.loc[i,'id'].rsplit('_',2)[0]=='case43_day18' and df.loc[i,'class']=='large_bowel' and pd.isnull(df.loc[i,'segmentation'])!=True:
#         df.loc[i,'segmentation']=np.nan
#         break
# df.to_csv('/root/autodl-tmp/train.csv', index=False)

In [5]:
# df = pd.read_csv("/root/autodl-tmp/train.csv")
# df.drop(columns=['Unnamed: 0'],inplace=True)
# df.to_csv("/root/autodl-tmp/train.csv", index=False)

In [6]:
df = pd.read_csv("/root/autodl-tmp/train_fix.csv")
df_train = df.copy()
df_train = df_train.sort_values(["id", "class"]).reset_index(drop = True)
df_train["patient"] = df_train.id.apply(lambda x: x.split("_")[0])
df_train["days"] = df_train.id.apply(lambda x: "_".join(x.split("_")[:2]))
num_slices = len(np.unique(df_train.id))
num_empty_slices = df_train.groupby("id").apply(lambda x: x.segmentation.isna().all()).sum()
num_patients = len(np.unique(df_train.patient))
num_days = len(np.unique(df_train.days))
print({
    "#slices:": num_slices,
    "#empty slices:": num_empty_slices,
    "#patients": num_patients,
    "#days": num_days
})

{'#slices:': 38496, '#empty slices:': 21900, '#patients': 85, '#days': 274}


In [7]:
all_image_files = sorted(glob("/root/autodl-tmp/train/*/*/scans/*.png"), key = lambda x: x.split("/")[5] + "_" + x.split("/")[7])
size_x = [int(os.path.basename(_)[:-4].split("_")[-4]) for _ in all_image_files]
size_y = [int(os.path.basename(_)[:-4].split("_")[-3]) for _ in all_image_files]
spacing_x = [float(os.path.basename(_)[:-4].split("_")[-2]) for _ in all_image_files]
spacing_y = [float(os.path.basename(_)[:-4].split("_")[-1]) for _ in all_image_files]
df_train["image_files"] = np.repeat(all_image_files, 3)
df_train["spacing_x"] = np.repeat(spacing_x, 3)
df_train["spacing_y"] = np.repeat(spacing_y, 3)
df_train["size_x"] = np.repeat(size_x, 3)
df_train["size_y"] = np.repeat(size_y, 3)
df_train["slice"] = np.repeat([int(os.path.basename(_)[:-4].split("_")[-5]) for _ in all_image_files], 3)
df_train

Unnamed: 0,id,class,segmentation,patient,days,image_files,spacing_x,spacing_y,size_x,size_y,slice
0,case101_day20_slice_0001,large_bowel,,case101,case101_day20,/root/autodl-tmp/train/case101/case101_day20/s...,1.5,1.5,266,266,1
1,case101_day20_slice_0001,small_bowel,,case101,case101_day20,/root/autodl-tmp/train/case101/case101_day20/s...,1.5,1.5,266,266,1
2,case101_day20_slice_0001,stomach,,case101,case101_day20,/root/autodl-tmp/train/case101/case101_day20/s...,1.5,1.5,266,266,1
3,case101_day20_slice_0002,large_bowel,,case101,case101_day20,/root/autodl-tmp/train/case101/case101_day20/s...,1.5,1.5,266,266,2
4,case101_day20_slice_0002,small_bowel,,case101,case101_day20,/root/autodl-tmp/train/case101/case101_day20/s...,1.5,1.5,266,266,2
...,...,...,...,...,...,...,...,...,...,...,...
115483,case9_day22_slice_0143,small_bowel,,case9,case9_day22,/root/autodl-tmp/train/case9/case9_day22/scans...,1.5,1.5,360,310,143
115484,case9_day22_slice_0143,stomach,,case9,case9_day22,/root/autodl-tmp/train/case9/case9_day22/scans...,1.5,1.5,360,310,143
115485,case9_day22_slice_0144,large_bowel,,case9,case9_day22,/root/autodl-tmp/train/case9/case9_day22/scans...,1.5,1.5,360,310,144
115486,case9_day22_slice_0144,small_bowel,,case9,case9_day22,/root/autodl-tmp/train/case9/case9_day22/scans...,1.5,1.5,360,310,144


# Convert data to 2.5D

In [8]:
from tqdm.notebook import tqdm
channels=3
strides=[6,0,-6]
for i in range(channels):
    df_train[f'image_path_{i:02}'] = df_train.groupby(['days'])['image_files'].shift(strides[i])#.fillna(method="ffill")
df_train['image_paths'] = df_train[[f'image_path_{i:02d}' for i in range(channels)]].values.tolist()
print(len(df_train))
for i,row in tqdm(df_train.iterrows(), total=len(df_train)):
    if type(row.image_paths[0]) == float:
        df_train.iloc[i].image_paths[0] = df_train.iloc[i].image_paths[1]
    if type(row.image_paths[-1]) == float:
        df_train.iloc[i].image_paths[-1] = df_train.iloc[i].image_paths[1]
df_train.image_paths[434]

115488


  0%|          | 0/115488 [00:00<?, ?it/s]

['/root/autodl-tmp/train/case101/case101_day22/scans/slice_0001_266_266_1.50_1.50.png',
 '/root/autodl-tmp/train/case101/case101_day22/scans/slice_0001_266_266_1.50_1.50.png',
 '/root/autodl-tmp/train/case101/case101_day22/scans/slice_0003_266_266_1.50_1.50.png']

### Remark for 3 more slices

* ID, whose slices end up in 80
* Case89:  day21
* Case131: day15
* Case117: day13 | day15 | day16 | day17
* Case35:  day12 | day13 | day15 | day18
* Case118: day0
* Case146: day25
* Case34:  day0 | day15 | day16

In [9]:
id_ban = []
list_80 = ['case89_day21', 'case131_day15', 'case117_day13','case117_day15', 'case117_day16',
           'case117_day17', 'case35_day12', 'case35_day13', 'case35_day15', 'case35_day18', 
           'case118_day0', 'case146_day25', 'case34_day0', 'case34_day15', 'case34_day16']
df_train['empty'] = False
a = df_train.groupby("id").apply(lambda x: x.segmentation.isna().all())
for i in range(len(df_train)):
    if df_train.loc[i,'slice']>50 and df_train.loc[i,'days'] in list_80:
        df_train.loc[i,'empty'] = a[df_train.loc[i,'id']]
    elif df_train.loc[i,'slice']>90:
        df_train.loc[i,'empty'] = a[df_train.loc[i,'id']]
rm_df = df_train[df_train['empty']==True].reset_index(drop=True)
rm_df = rm_df.groupby('days', as_index=False).apply(lambda x: x.iloc[:9,:]).reset_index().drop(columns=['level_0','level_1'])
re_mask_list = rm_df.id.unique().tolist()
rm_df

Unnamed: 0,id,class,segmentation,patient,days,image_files,spacing_x,spacing_y,size_x,size_y,slice,image_path_00,image_path_01,image_path_02,image_paths,empty
0,case101_day20_slice_0112,large_bowel,,case101,case101_day20,/root/autodl-tmp/train/case101/case101_day20/s...,1.5,1.5,266,266,112,/root/autodl-tmp/train/case101/case101_day20/s...,/root/autodl-tmp/train/case101/case101_day20/s...,/root/autodl-tmp/train/case101/case101_day20/s...,[/root/autodl-tmp/train/case101/case101_day20/...,True
1,case101_day20_slice_0112,small_bowel,,case101,case101_day20,/root/autodl-tmp/train/case101/case101_day20/s...,1.5,1.5,266,266,112,/root/autodl-tmp/train/case101/case101_day20/s...,/root/autodl-tmp/train/case101/case101_day20/s...,/root/autodl-tmp/train/case101/case101_day20/s...,[/root/autodl-tmp/train/case101/case101_day20/...,True
2,case101_day20_slice_0112,stomach,,case101,case101_day20,/root/autodl-tmp/train/case101/case101_day20/s...,1.5,1.5,266,266,112,/root/autodl-tmp/train/case101/case101_day20/s...,/root/autodl-tmp/train/case101/case101_day20/s...,/root/autodl-tmp/train/case101/case101_day20/s...,[/root/autodl-tmp/train/case101/case101_day20/...,True
3,case101_day20_slice_0113,large_bowel,,case101,case101_day20,/root/autodl-tmp/train/case101/case101_day20/s...,1.5,1.5,266,266,113,/root/autodl-tmp/train/case101/case101_day20/s...,/root/autodl-tmp/train/case101/case101_day20/s...,/root/autodl-tmp/train/case101/case101_day20/s...,[/root/autodl-tmp/train/case101/case101_day20/...,True
4,case101_day20_slice_0113,small_bowel,,case101,case101_day20,/root/autodl-tmp/train/case101/case101_day20/s...,1.5,1.5,266,266,113,/root/autodl-tmp/train/case101/case101_day20/s...,/root/autodl-tmp/train/case101/case101_day20/s...,/root/autodl-tmp/train/case101/case101_day20/s...,[/root/autodl-tmp/train/case101/case101_day20/...,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2461,case9_day22_slice_0121,small_bowel,,case9,case9_day22,/root/autodl-tmp/train/case9/case9_day22/scans...,1.5,1.5,360,310,121,/root/autodl-tmp/train/case9/case9_day22/scans...,/root/autodl-tmp/train/case9/case9_day22/scans...,/root/autodl-tmp/train/case9/case9_day22/scans...,[/root/autodl-tmp/train/case9/case9_day22/scan...,True
2462,case9_day22_slice_0121,stomach,,case9,case9_day22,/root/autodl-tmp/train/case9/case9_day22/scans...,1.5,1.5,360,310,121,/root/autodl-tmp/train/case9/case9_day22/scans...,/root/autodl-tmp/train/case9/case9_day22/scans...,/root/autodl-tmp/train/case9/case9_day22/scans...,[/root/autodl-tmp/train/case9/case9_day22/scan...,True
2463,case9_day22_slice_0122,large_bowel,,case9,case9_day22,/root/autodl-tmp/train/case9/case9_day22/scans...,1.5,1.5,360,310,122,/root/autodl-tmp/train/case9/case9_day22/scans...,/root/autodl-tmp/train/case9/case9_day22/scans...,/root/autodl-tmp/train/case9/case9_day22/scans...,[/root/autodl-tmp/train/case9/case9_day22/scan...,True
2464,case9_day22_slice_0122,small_bowel,,case9,case9_day22,/root/autodl-tmp/train/case9/case9_day22/scans...,1.5,1.5,360,310,122,/root/autodl-tmp/train/case9/case9_day22/scans...,/root/autodl-tmp/train/case9/case9_day22/scans...,/root/autodl-tmp/train/case9/case9_day22/scans...,[/root/autodl-tmp/train/case9/case9_day22/scan...,True


### Load images and masks

In [10]:
def load_img(path, size=360):
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    img = img.astype(np.float32)
    ori_shape = img.shape[:2]
    new_shape = (size, size)
    if np.any(ori_shape!=new_shape):
        if ori_shape==(234,234):
            img = cv2.resize(img,new_shape)
        elif ori_shape==(266,266):
            img = cv2.resize(img,new_shape)
        elif ori_shape==(276,276):
            img = cv2.resize(img,new_shape)
        elif ori_shape==(310,360):
            img = np.pad(img,((25,25),(0,0)),'constant')
            img = cv2.resize(img,new_shape)
    return img, ori_shape

def load_imgs(img_paths, size=360):
    imgs = np.zeros((size, size, len(img_paths)), dtype=np.float32)
    for i, img_path in enumerate(img_paths):
        img, ori_shape  = load_img(img_path, size=size)
        imgs[..., i] = img
    imgs = imgs / (imgs.max()+1e-7)
    return imgs, ori_shape

def mask2rle(msk):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    msk    = cp.array(msk)
    pixels = msk.flatten()
    pad    = cp.array([0])
    pixels = cp.concatenate([pad, pixels, pad])
    runs   = cp.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

def masks2rles(msks, ids, heights, widths):
    pred_strings = []; pred_ids = []; pred_classes = [];
    for idx in range(msks.shape[0]):
        msk = msks[idx]
        height = heights[idx].item()
        width = widths[idx].item()
        shape0 = (height, width)
        resize = (360, 360)
        if shape0!=resize:
            if shape0==(234,234):
                msk = cv2.resize(msk,shape0)
            if shape0==(266,266):
                msk = cv2.resize(msk,shape0)
            if shape0==(276,276):
                msk = cv2.resize(msk,shape0)
            if shape0==(310,360):
                msk = cv2.resize(msk,(shape0[0]+50,shape0[1]))
                msk = msk[25:335,:,:]
        rle = [None]*3
        for midx in [0, 1, 2]:
            rle[midx] = mask2rle(msk[...,midx])
        pred_strings.extend(rle)
        pred_ids.extend([ids[idx]]*len(rle))
        pred_classes.extend(['large_bowel', 'small_bowel', 'stomach'])
    return pred_strings, pred_ids, pred_classes

def rle_decode(mask_rle, ori_shape, new_shape):
    s = np.array(mask_rle.split(), dtype=int)
    starts, lengths = s[0::2] - 1, s[1::2]
    ends = starts + lengths
    h, w = ori_shape
    mask = np.zeros((h * w,), dtype = np.uint8)
    for lo, hi in zip(starts, ends):
        mask[lo : hi] = 1
    mask = mask.reshape(ori_shape)
    if np.any(ori_shape!=new_shape):
        if ori_shape==(234,234):
            mask = cv2.resize(mask,new_shape)
        elif ori_shape==(266,266):
            mask = cv2.resize(mask,new_shape)
        elif ori_shape==(276,276):
            mask = cv2.resize(mask,new_shape)
        elif ori_shape==(310,360):
            mask = np.pad(mask,((25,25),(0,0)),'constant')
            mask = cv2.resize(mask,new_shape)
    return mask

* Build Dataset

In [11]:
import torch
import torch.nn as F
from torch.utils.data import Dataset, DataLoader
import mmcv
from mmcv.parallel import collate, scatter
from mmseg.datasets.pipelines import Compose

class BuildDataset(torch.utils.data.Dataset):
    def __init__(self, df, size):
        self.df = df
        self.size = size
        self.img_paths = df['image_paths']
        self.ids       = df['id']
        self.rle       = df.groupby('id').apply(lambda x: x.segmentation.tolist())
        
    def __len__(self):
        return len(self.rle)
    
    def __getitem__(self, index):
        img_paths = self.img_paths[index*3]
        id_       = self.ids[index*3]
        mask_rle  = self.rle.loc[id_]
        
        img, ori_shape = load_imgs(img_paths, self.size)
        mask = np.zeros((self.size, self.size, 3))
        for i in range(3):
            rle = mask_rle[i]
            if not pd.isna(rle): #np.isnan(rle),pd.isnull(rle)
                mask[...,i] = rle_decode(rle, ori_shape, (self.size, self.size))
        h, w = ori_shape
        return img, mask, id_, h, w

* Inference

In [12]:
free_list = ['case7_day0_slice_0065', 'case7_day0_slice_0066', 'case7_day0_slice_0067',
             'case7_day0_slice_0068', 'case7_day0_slice_0069', 'case81_day30_slice_0063',
             'case81_day30_slice_0064', 'case81_day30_slice_0065']
ban_list = ['case7_day0_slice_0049']
def infer(model_paths, id_ban, test_loader):
    msks = []; imgs = []; ori_masks = []
    pred_strings = []; pred_ids = []; pred_classes = [];
    for (img, mask, ids, heights, widths) in tqdm(test_loader, total=len(test_loader), desc='Infer'):
        size = img.shape
        img = img.numpy()
        mask = mask.numpy()
        msk = np.zeros((size[0], size[1], size[2], 3))
        for model in model_paths:
            outs = np.zeros_like(msk)
            for im_idx in range(len(img)):
                outs[im_idx] = inference_segmentor(model, img[im_idx])[0]
            msk += outs/len(model_paths)
        msk[...,0] = msk[...,0] > 0.90
        msk[...,1] = msk[...,1] > 0.70
        msk[...,2] = msk[...,2] > 0.60
        msk = msk.astype(np.uint8)
        
        imgs.append((img*255).astype(np.uint8))
        msks.append(msk*255)
        ori_masks.append((mask*255).astype(np.uint8))
        result = masks2rles(msk, ids, heights, widths)
        pred_strings.extend(result[0])
        pred_ids.extend(result[1])
        pred_classes.extend(result[2])
        del img, msk, outs, model, result, mask
        gc.collect()
    return pred_strings, pred_ids, pred_classes, imgs, msks, ori_masks


In [13]:
from tqdm import tqdm
test_dataset = BuildDataset(rm_df, SIZE)
test_loader  = DataLoader(test_dataset, batch_size=64, num_workers=1, shuffle=False, pin_memory=False)
pred_strings, pred_ids, pred_classes, imgs, msks, ori_masks = infer(models, id_ban, test_loader)

Infer: 100%|██████████| 13/13 [00:53<00:00,  4.10s/it]


In [14]:
!mkdir -p /root/save_pred/imgs
!mkdir -p /root/save_pred/msks
!mkdir -p /root/save_pred/ori_msks

In [15]:
imgs = np.concatenate(imgs)
msks = np.concatenate(msks)
ori_masks = np.concatenate(ori_masks)
idx = 0
imgs = imgs[...,::-1]
msks = msks[...,::-1]
ori_masks = ori_masks[...,::-1]
for img, msk, ori_msk in zip(imgs, msks, ori_masks):
    msk = 0.4*img+0.6*msk
    ori_msk = 0.4*img+0.6*ori_msk
    cv2.imwrite(f'/root/save_pred/imgs/{pred_ids[idx]}.png', img)
    cv2.imwrite(f'/root/save_pred/msks/{pred_ids[idx]}.png', msk)
    cv2.imwrite(f'/root/save_pred/ori_msks/{pred_ids[idx]}.png', ori_msk)
    idx += 3

In [16]:
# for img, msk, ori_msk in zip(imgs, msks, ori_masks):
#     for i in range(len(img)):
#         if msk[i].sum()>0:
#             plt.figure(figsize=(12, 7))
#             plt.subplot(1, 3, 1); plt.imshow(img[i], cmap='bone'); plt.axis('OFF'); plt.title('image')
#             plt.subplot(1, 3, 2); plt.imshow(img[i], cmap='bone'); plt.imshow(msk[i], alpha=0.6); plt.axis('OFF'); plt.title('pre_mask')
#             plt.subplot(1, 3, 3); plt.imshow(img[i], cmap='bone'); plt.imshow(ori_msk[i], alpha=0.6); plt.axis('OFF'); plt.title('ori_mask')
#             plt.tight_layout()
#             plt.show()

### Csv for 3 more slices

In [17]:

def fix_mask(row):
    if row.id in re_mask_list:
        for i in range(len(pred_strings)):
            if row.id == pred_ids[i] and row['class'] == pred_classes[i]:
                row.segmentation = pred_strings[i]
                print(f'add_mask_for_{row.id}')
                break
    return row
fix_df = df.apply(fix_mask, axis=1)
fix_df.to_csv('/root/autodl-tmp/train_fix_more_mask.csv',index=False)
fix_df.head(5)

add_mask_for_case123_day20_slice_0119
add_mask_for_case123_day20_slice_0119
add_mask_for_case123_day20_slice_0119
add_mask_for_case123_day20_slice_0120
add_mask_for_case123_day20_slice_0120
add_mask_for_case123_day20_slice_0120
add_mask_for_case123_day20_slice_0121
add_mask_for_case123_day20_slice_0121
add_mask_for_case123_day20_slice_0121
add_mask_for_case123_day22_slice_0125
add_mask_for_case123_day22_slice_0125
add_mask_for_case123_day22_slice_0125
add_mask_for_case123_day22_slice_0126
add_mask_for_case123_day22_slice_0126
add_mask_for_case123_day22_slice_0126
add_mask_for_case123_day22_slice_0127
add_mask_for_case123_day22_slice_0127
add_mask_for_case123_day22_slice_0127
add_mask_for_case123_day0_slice_0114
add_mask_for_case123_day0_slice_0114
add_mask_for_case123_day0_slice_0114
add_mask_for_case123_day0_slice_0115
add_mask_for_case123_day0_slice_0115
add_mask_for_case123_day0_slice_0115
add_mask_for_case123_day0_slice_0116
add_mask_for_case123_day0_slice_0116
add_mask_for_case123

Unnamed: 0,id,class,segmentation
0,case123_day20_slice_0001,large_bowel,
1,case123_day20_slice_0001,small_bowel,
2,case123_day20_slice_0001,stomach,
3,case123_day20_slice_0002,large_bowel,
4,case123_day20_slice_0002,small_bowel,


In [None]:
import shutil
shutil.make_archive('/root/save','zip','/root/save_pred')