In [1]:
checkpoint = 'axial_segmentation_effseg_212625-epoch-50.pth'

In [2]:
import os
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
import math
from tqdm import tqdm

from PIL import Image, ImageOps

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.io import read_image
import torchvision.transforms as T
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop
import torchvision.transforms.functional as TF
import torchvision.models as models

import albumentations as A
from albumentations.pytorch import ToTensorV2

from sklearn.model_selection import train_test_split

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
DATA_DIR = "/root/autodl-tmp/cervical_spine/"
IMAGES_DIR = os.path.join(DATA_DIR, f"train_axial_images_jpeg95")
LABEL_DIR = os.path.join(DATA_DIR, f"segmentation_axial_labels")

In [5]:
from efficientunet import *

def get_axial_segmentation_model(checkpoint):
    model = get_efficientunet_b5(out_channels=2, concat_input=True, pretrained=True)
    
    state = torch.load(os.path.join(DATA_DIR, 'checkpoint', checkpoint))
    model.load_state_dict(state["model"])
    model.eval()
    return model.to(device)
    
model = get_axial_segmentation_model(checkpoint)

In [6]:
def get_axial_boundary(seg, pixel_spacing, throw=100, tol=0.2, max_mm=100):
    """
    512x512 -> throw=100, tol=0.2, max_mm=100
    256x256 -> throw=50, tol=0.2, max_mm=50
    """
    
    image_size = seg.shape[0]
    min_size = min(image_size, max_mm / pixel_spacing)
    
    rows, columns = seg.nonzero()
    rows.sort()
    columns.sort()
    
    throw = min(len(rows) // 2, throw)
    
    if(len(rows)) == 0:
        return 0, 0, image_size, image_size
    
    xmin, xmax = columns[throw], columns[-throw]
    ymin, ymax = rows[throw], rows[-throw]
    
    w = (xmax - xmin) * (1 + tol)
    h = (ymax - ymax) * (1 + tol)
    new_size = max(w, h, min_size)
    new_size = min(image_size, new_size)
    
    xcenter, ycenter = (xmax + xmin) / 2, (ymax + ymin) / 2
    
    xmin = min(image_size - new_size, xcenter - new_size / 2)
    xmin = max(0, xmin)
    
    ymin = min(image_size - new_size, ycenter - new_size / 2)
    ymin = max(0, ymin)
    
    return xmin, ymin, xmin + new_size, ymin + new_size

In [75]:
train_3d_df = pd.read_csv(os.path.join(DATA_DIR, 'meta_train_3d.csv')).set_index('UID')
print(len(train_3d_df))
train_3d_df.head()

1983


Unnamed: 0_level_0,z_spacing,pixel_spacing,aspect,is_flip,z_height
UID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
1.2.826.0.1.3680043.10001,0.625,0.253906,2.461540885209487,0,268
1.2.826.0.1.3680043.10005,0.625,0.298828,2.0915041428514063,0,259
1.2.826.0.1.3680043.10014,0.8000000000000114,0.234,3.418803418803467,0,258
1.2.826.0.1.3680043.10016,0.3130000000000024,0.275391,1.136565828222427,0,645
1.2.826.0.1.3680043.10032,0.625,0.320313,1.9512164663938083,0,321


In [76]:
bbox_df = pd.read_csv(os.path.join(DATA_DIR, 'bbox_clean.csv')).set_index('UID')
bbox_df = bbox_df.loc[bbox_df.index.isin(train_3d_df.index)]
print(len(bbox_df))
bbox_df.head()

7160


Unnamed: 0_level_0,start_slice_number,axial_index,coronal_index,sagittal_index,aspect,pixel_spacing,z_spacing,is_flip,num_slices,x,y,width,height,slice_number
UID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
1.2.826.0.1.3680043.10051,1,132,226,227,2.461541,0.253906,0.625,0.0,272,219.27715,216.71419,17.3044,20.38517,133
1.2.826.0.1.3680043.10051,1,133,229,230,2.461541,0.253906,0.625,0.0,272,221.5646,216.71419,17.87844,25.24362,134
1.2.826.0.1.3680043.10051,1,134,234,230,2.461541,0.253906,0.625,0.0,272,216.82151,221.62546,27.00959,26.37454,135
1.2.826.0.1.3680043.10051,1,135,234,228,2.461541,0.253906,0.625,0.0,272,214.49455,215.48637,27.92726,37.51363,136
1.2.826.0.1.3680043.10051,1,136,237,227,2.461541,0.253906,0.625,0.0,272,214.0,215.48637,27.0,43.51363,137


In [77]:
bbox_df = bbox_df.reset_index().drop_duplicates(subset=['UID']).set_index('UID')
bbox_df

Unnamed: 0_level_0,start_slice_number,axial_index,coronal_index,sagittal_index,aspect,pixel_spacing,z_spacing,is_flip,num_slices,x,y,width,height,slice_number
UID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
1.2.826.0.1.3680043.10051,1,132,226,227,2.461541,0.253906,0.62500,0.0,272,219.27715,216.71419,17.30440,20.38517,133
1.2.826.0.1.3680043.10579,1,77,261,267,3.047619,0.328125,1.00000,0.0,184,250.18182,244.96552,35.52992,33.79122,78
1.2.826.0.1.3680043.10678,1,117,245,256,3.065134,0.261000,0.80000,0.0,269,231.00000,227.00000,50.00000,36.00000,118
1.2.826.0.1.3680043.10697,1,125,127,223,3.737226,0.267578,1.00000,0.0,179,201.32881,112.81356,45.12543,29.50509,126
1.2.826.0.1.3680043.10732,1,59,242,172,4.162602,0.240234,1.00000,0.0,218,133.31759,203.92651,78.61418,77.94226,60
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1.2.826.0.1.3680043.8519,1,51,158,191,2.519684,0.248047,0.62500,0.0,249,174.89558,147.19207,32.49683,23.08180,52
1.2.826.0.1.3680043.8693,1,42,298,226,5.378133,0.488281,2.62604,0.0,69,166.69202,268.00000,120.30798,60.00000,43
1.2.826.0.1.3680043.9447,1,121,224,227,1.383784,0.289062,0.40000,0.0,467,154.95929,149.52212,146.04638,150.42831,122
1.2.826.0.1.3680043.9926,1,50,148,212,2.844444,0.351562,1.00000,0.0,170,199.36283,135.92920,27.23940,25.85519,51


In [78]:

df = pd.DataFrame()
for UID in bbox_df.index.unique().values:
    row = bbox_df.loc[UID]
    new_df = bbox_df.loc[[UID] * int(row.num_slices)]
    new_df["axial_index"] = np.arange(int(row.num_slices))
    
    df = pd.concat((df, new_df), axis=0)
    
print(len(df))
df.head()

76908


Unnamed: 0_level_0,start_slice_number,axial_index,coronal_index,sagittal_index,aspect,pixel_spacing,z_spacing,is_flip,num_slices,x,y,width,height,slice_number
UID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
1.2.826.0.1.3680043.10051,1,0,226,227,2.461541,0.253906,0.625,0.0,272,219.27715,216.71419,17.3044,20.38517,133
1.2.826.0.1.3680043.10051,1,1,226,227,2.461541,0.253906,0.625,0.0,272,219.27715,216.71419,17.3044,20.38517,133
1.2.826.0.1.3680043.10051,1,2,226,227,2.461541,0.253906,0.625,0.0,272,219.27715,216.71419,17.3044,20.38517,133
1.2.826.0.1.3680043.10051,1,3,226,227,2.461541,0.253906,0.625,0.0,272,219.27715,216.71419,17.3044,20.38517,133
1.2.826.0.1.3680043.10051,1,4,226,227,2.461541,0.253906,0.625,0.0,272,219.27715,216.71419,17.3044,20.38517,133


In [79]:
class DataTransform(nn.Module):
    def __init__(self, image_size, train=True):
        super().__init__()

        self.image_size = image_size

        if train:
            self.transform = A.Compose([
                A.HorizontalFlip(p=0.5),
                A.ShiftScaleRotate(p=0.5, rotate_limit=45),
                A.RandomBrightnessContrast(p=0.5),
                A.Resize(image_size, image_size),
                ToTensorV2(p=1),
            ])
        else:
            self.transform = A.Compose([
                A.Resize(image_size, image_size),
                ToTensorV2(p=1),
            ])

        self.normalize = T.Normalize(255 * 0.5, 255 * 0.5)

    def forward(self, x):
        augmented = self.transform(image=np.asarray(x))
        x= augmented['image']

        x = self.normalize(x.float())

        return torch.cat((x, x, x), dim=0)

In [80]:
class SegDataset(Dataset):
    def __init__(self, df, image_dir, transform=None):
        self.df = df
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)


    def __getitem__(self, idx):
        s = self.df.iloc[idx]
        UID = s.name

        index = int(s[f"axial_index"])

        slice_img = Image.open(os.path.join(self.image_dir, UID, f"{index}.jpeg"))

        if self.transform:
            slice_img= self.transform(slice_img)

        return slice_img, UID, index

In [81]:
@torch.no_grad()
def predict(x, model):
    x = x.to(device)
    logits = model(x)

    classification_score, mse_score = logits.sigmoid().chunk(2, dim=1)
    classification_pred = classification_score.gt(0.5).float()
    pred = (classification_pred * mse_score).cpu().numpy()
    
    return pred

In [82]:
def get_axial_boundary(seg, pixel_spacing, throw=100, tol=0.2, max_mm=100):
    image_size = seg.shape[0]
    min_size = min(image_size, max_mm / pixel_spacing)
    
    rows, columns = seg.nonzero()
    rows.sort()
    columns.sort()
    
    throw = min(len(rows) // 2, throw)
    
    if(len(rows)) == 0:
        return 0, 0, image_size, image_size
    
    xmin, xmax = columns[throw], columns[-throw]
    ymin, ymax = rows[throw], rows[-throw]
    
    w = (xmax - xmin) * (1 + tol)
    h = (ymax - ymax) * (1 + tol)
    new_size = max(w, h, min_size)
    new_size = min(image_size, new_size)
    
    xcenter, ycenter = (xmax + xmin) / 2, (ymax + ymin) / 2
    
    xmin = min(image_size - new_size, xcenter - new_size / 2)
    xmin = max(0, xmin)
    
    ymin = min(image_size - new_size, ycenter - new_size / 2)
    ymin = max(0, ymin)
    
    return xmin, ymin, xmin + new_size, ymin + new_size



In [71]:
tf = DataTransform(256, train=False)
dataset = SegDataset(df, IMAGES_DIR, tf)

x, UID, axial_index = dataset[0]
print(x.min(), x.max())
print(x.shape)
print(axial_index)

tensor(-1.) tensor(0.9843)
torch.Size([3, 256, 256])
0


In [84]:
result_df = pd.DataFrame()


for i in tqdm(range(len(dataset))):
    x, UID, axial_index = dataset[i]
    pixel_spacing = float(train_3d_df.loc[UID, 'pixel_spacing'])
    x = x.unsqueeze(0)
    label = predict(x, model)
    # label = np.round(label / 0.125) * 0.125
    label = label.squeeze()
    # class_label = np.mean(label[label > 0])
    # class_label_list.append(class_label)
    xmin, ymin, xmax, ymax = get_axial_boundary(label, pixel_spacing, throw=50, tol=0.2, max_mm=50)
    
    new_df = pd.DataFrame(data = {
        'UID': UID,
        'xmin' : xmin,
        'ymin' : ymin,
        'xmax' : xmax,
        'ymax' : ymax, 
        'pixel_spacing' : pixel_spacing,
        'axial_index' : axial_index,
    }, index=[i])
    result_df = pd.concat((result_df, new_df), axis=0)

print(len(result_df))
result_df.head()

100%|██████████| 76908/76908 [58:12<00:00, 22.02it/s]   

76908





Unnamed: 0,UID,xmin,ymin,xmax,ymax,pixel_spacing,axial_index
0,1.2.826.0.1.3680043.10051,0.0,0.0,256.0,256.0,0.253906,0
1,1.2.826.0.1.3680043.10051,0.0,0.0,256.0,256.0,0.253906,1
2,1.2.826.0.1.3680043.10051,0.0,0.0,256.0,256.0,0.253906,2
3,1.2.826.0.1.3680043.10051,0.0,0.0,256.0,256.0,0.253906,3
4,1.2.826.0.1.3680043.10051,0.0,0.0,256.0,256.0,0.253906,4


In [85]:
result_df.set_index('UID').to_csv(os.path.join(DATA_DIR, 'infer_axial_boundary_from_axial_seg.csv'))