In [1]:
model_ver = '095521'
epoch_ver = 21
checkpoint = f'axial_segmentation_effseg_{model_ver}-epoch-{epoch_ver}.pth'

In [2]:
import os
import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import seaborn as sns
import math
from tqdm import tqdm
import re
from PIL import Image, ImageOps

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision.io import read_image
import torchvision.transforms as T
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop
import torchvision.transforms.functional as TF
import torchvision.models as models

import albumentations as A
from albumentations.pytorch import ToTensorV2

from sklearn.model_selection import train_test_split

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [4]:
DATA_DIR = "/root/autodl-tmp/cervical_spine/"
IMAGES_DIR = os.path.join(DATA_DIR, f"train_axial_images_jpeg95")
TARGET_DIR = os.path.join(DATA_DIR, f"segmentation_axial_results_{model_ver}")
LABEL_DIR = os.path.join(DATA_DIR, f"segmentation_axial_labels")

In [5]:
# train_df = pd.read_csv('../input/rsna-2022-cervical-spine-fracture-detection/train.csv').set_index('StudyInstanceUID')
# train_df.head()

test_slices = glob.glob(f'{IMAGES_DIR}/*/*')
test_slices = [re.findall(f'{IMAGES_DIR}/(.*)/(.*).jpeg', s)[0] for s in test_slices]
df_test_slices = pd.DataFrame(data=test_slices, columns=['StudyInstanceUID', 'Slice']).astype({'Slice': int})

df_test_slices

Unnamed: 0,StudyInstanceUID,Slice
0,1.2.826.0.1.3680043.10001,153
1,1.2.826.0.1.3680043.10001,0
2,1.2.826.0.1.3680043.10001,154
3,1.2.826.0.1.3680043.10001,1
4,1.2.826.0.1.3680043.10001,155
...,...,...
708774,1.2.826.0.1.3680043.9997,95
708775,1.2.826.0.1.3680043.9997,96
708776,1.2.826.0.1.3680043.9997,97
708777,1.2.826.0.1.3680043.9997,98


In [6]:
# test_slices = glob.glob(f'{IMAGES_DIR}/*/*')
# test_slices = [re.findall(f'{IMAGES_DIR}/(.*)/(.*).jpeg', s)[0] for s in test_slices]
# df_test_slices = pd.DataFrame(data=test_slices, columns=['StudyInstanceUID', 'Slice']).astype({'Slice': int}).sort_values(['StudyInstanceUID', 'Slice']).reset_index(drop=True)
# df_test_slices

In [7]:
df_test_slices = df_test_slices.set_index('StudyInstanceUID')
df_test_slices["Start"] = df_test_slices.groupby('StudyInstanceUID').apply(lambda df: df.Slice.min())
# df_test_slices["End"] = df_test_slices.groupby('StudyInstanceUID').apply(lambda df: df.Slice.max())
df_test_slices = df_test_slices.sort_values(['StudyInstanceUID', 'Slice']).reset_index(drop=False)
df_test_slices.head()

Unnamed: 0,StudyInstanceUID,Slice,Start
0,1.2.826.0.1.3680043.10001,0,0
1,1.2.826.0.1.3680043.10001,1,0
2,1.2.826.0.1.3680043.10001,2,0
3,1.2.826.0.1.3680043.10001,3,0
4,1.2.826.0.1.3680043.10001,4,0


In [8]:
len(df_test_slices.StudyInstanceUID.unique())

2012

In [9]:
train_3d_df = pd.read_csv(os.path.join(DATA_DIR, 'meta_train_3d.csv')).set_index('UID')
print(len(train_3d_df))
train_3d_df.head()

2012


Unnamed: 0_level_0,z_spacing,pixel_spacing,aspect,is_flip
UID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1.2.826.0.1.3680043.10001,0.625,0.253906,2.461541,0
1.2.826.0.1.3680043.10005,0.625,0.298828,2.091504,0
1.2.826.0.1.3680043.10014,0.8,0.234,3.418803,0
1.2.826.0.1.3680043.10016,0.313,0.275391,1.136566,0
1.2.826.0.1.3680043.10032,0.625,0.320313,1.951216,0


In [10]:
df_test_slices["pixel_spacing"] = train_3d_df.loc[df_test_slices.StudyInstanceUID, 'pixel_spacing'].values
df_test_slices.head()

Unnamed: 0,StudyInstanceUID,Slice,Start,pixel_spacing
0,1.2.826.0.1.3680043.10001,0,0,0.253906
1,1.2.826.0.1.3680043.10001,1,0,0.253906
2,1.2.826.0.1.3680043.10001,2,0,0.253906
3,1.2.826.0.1.3680043.10001,3,0,0.253906
4,1.2.826.0.1.3680043.10001,4,0,0.253906


In [11]:
class ImageDataSet(torch.utils.data.Dataset):
    def __init__(self, df, path, transforms=None):
        super().__init__()
        self.df = df
        self.path = path
        self.transforms = transforms

        self.len = len(self.df)

    def __getitem__(self, i):

        try:
            s = self.df.iloc[i]
            img = Image.open(os.path.join(self.path, s.StudyInstanceUID, f'{s.Slice}.jpeg'))

            if self.transforms is not None:
                img = self.transforms(img)
        except Exception as ex:
            print(ex)
            return None, None

        return img, s.pixel_spacing, s.Slice == s.Start, s.StudyInstanceUID, s.Slice

    def __len__(self):
        return self.len

class DataTransform(nn.Module):
    def __init__(self, image_size=512):
        super().__init__()

        self.image_size = image_size

        self.transform = T.Compose([
            T.Resize((image_size, image_size)),
            T.ToTensor(),
            T.Normalize(0.5, 0.5),
        ])

    def forward(self, x):
        x = self.transform(x)

        return x
    
ds = ImageDataSet(df_test_slices, IMAGES_DIR, DataTransform())
img, pixel_spacing, is_end, UID, Slice = ds[150]
print(img.min(), img.max())
print(img.shape)
print(pixel_spacing, is_end, UID, Slice)

tensor(-1.) tensor(0.8196)
torch.Size([1, 512, 512])
0.253906 False 1.2.826.0.1.3680043.10001 150


In [12]:
batch_size = 32
dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=min(16, batch_size))

# x, pixel_spacings, is_start = next(iter(dl))
# print(x.min(), x.max())
# print(x.shape)
# print(pixel_spacings)

In [13]:
from efficientunet import *

def get_axial_segmentation_model(checkpoint):
    model = get_efficientunet_b5(out_channels=2, concat_input=True, pretrained=True)

    state = torch.load(os.path.join(DATA_DIR, 'checkpoint', checkpoint))
    model.load_state_dict(state["model"])
    model.eval()
    return model.to(device)

seg_model = get_axial_segmentation_model(checkpoint)

In [14]:
def get_axial_boundary_from_segmentation(seg, pixel_spacing, throw=100, tol=0.2, max_mm=100):
    """
    seg : H x W
    """
    image_size = seg.shape[0]
    min_size = min(image_size, max_mm / pixel_spacing)

    rows, columns = seg.nonzero(as_tuple=True)
    rows.sort()
    columns.sort()

    throw = min(len(rows) // 2, throw)

    if(len(rows)) == 0:
        return torch.tensor([0, 0, image_size, image_size]).to(device)

    xmin, xmax = columns[throw], columns[-throw]
    ymin, ymax = rows[throw], rows[-throw]

    w = (xmax - xmin) * (1 + tol)
    h = (ymax - ymax) * (1 + tol)
    new_size = max(w, h, min_size)
    new_size = min(image_size, new_size)

    xcenter, ycenter = (xmax + xmin) / 2, (ymax + ymin) / 2

    xmin = torch.min(torch.tensor(image_size - new_size), xcenter - new_size / 2)
    xmin = xmin.clip(min=0)

    ymin = torch.min(torch.tensor(image_size - new_size), ycenter - new_size / 2)
    ymin = ymin.clip(min=0)

    return torch.stack([xmin, ymin, xmin + new_size, ymin + new_size])

In [15]:
def get_axial_boundary(segs, pixel_spacings, seg_img_size=256):
    boundary_list = []
    for i in range(segs.shape[0]):
        seg = segs[i, 0, :, :]
        
        boundary = get_axial_boundary_from_segmentation(seg, pixel_spacings[i], throw=int(100 / 512 * seg_img_size), tol=0.2, max_mm=100 / 512 * seg_img_size)
        boundary_list.append(boundary)
    boundary_list = torch.stack(boundary_list, axis=0) * (512. / seg_img_size)
    return boundary_list

In [16]:
def predict_seg(x, model, img_size=256):
    """
    return: N x 1 x H x W
    """
    x = TF.resize(x, (img_size, img_size))
    logits = model(x)

    classification_score, mse_score = logits.sigmoid().chunk(2, dim=1)
    classification_pred = classification_score.gt(0.5).float()
    pred = (classification_pred * mse_score)

    return pred

In [17]:
def save_pred_img(pred, UIDs, axial_indices):
    for i in range(pred.shape[0]):
        
        save_dir = os.path.join(TARGET_DIR, UIDs[i])
        if os.path.exists(save_dir) is False:
            os.mkdir(save_dir)
        
        img = pred[i, 0, :, :]
        
        Image.fromarray(np.uint8(img * 256), 'L').save(os.path.join(save_dir, f'{int(axial_indices[i])}.png'))
    

In [18]:

def infer():

    with torch.no_grad():

        boundary_list = []

        x0, _, _, _, _ = ds[0]
        x1, _, _, _, _ = ds[1]
        x0, x1 = x0.to(device), x1.to(device)
        prev2 = torch.stack((x0, x1))

        for x, pixel_spacings, is_starts, UIDs, axial_indices in tqdm(dl):
            x = x.to(device)

            # x : N x 1 x 512 x 512
            x = x.to(device)

            # (N+2), 1, 512, 512
            x = torch.cat((prev2, x), dim=0)

            r = x[:-2, :, :, :]
            g = x[1:-1, :, :, :]
            b = x[2:, :, :, :]

            start_indices = torch.argwhere(is_starts)
            r[start_indices, :, :, :] = b[start_indices, :, :, :]
            g[start_indices, :, :, :] = b[start_indices, :, :, :]

            prev2 = b[-2:, :, :, :]

            x = torch.cat((r, g, b), dim=1)

            seg_result = predict_seg(x, seg_model)  # N x 1 x 256 x 256

            save_pred_img(seg_result.cpu().numpy(), UIDs, axial_indices)

            axial_boundary = get_axial_boundary(seg_result, pixel_spacings, seg_img_size=256)

            boundary_list.append(axial_boundary.cpu().numpy())
        
        return boundary_list
    
boundarys = infer()

  xmin = torch.min(torch.tensor(image_size - new_size), xcenter - new_size / 2)
  ymin = torch.min(torch.tensor(image_size - new_size), ycenter - new_size / 2)
100%|██████████| 22150/22150 [1:18:45<00:00,  4.69it/s]


In [19]:
boundary_df = pd.DataFrame(data=np.concatenate(boundarys), columns=['xmin','ymin','xmax','ymax'])
boundary_df = pd.concat((df_test_slices, boundary_df), axis=1).set_index('StudyInstanceUID')
boundary_df.head()

Unnamed: 0_level_0,Slice,Start,pixel_spacing,xmin,ymin,xmax,ymax
StudyInstanceUID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1.2.826.0.1.3680043.10001,0,0,0.253906,54.076729,0.0,447.923271,393.846542
1.2.826.0.1.3680043.10001,1,0,0.253906,0.0,0.0,512.0,512.0
1.2.826.0.1.3680043.10001,2,0,0.253906,0.0,0.0,512.0,512.0
1.2.826.0.1.3680043.10001,3,0,0.253906,0.0,0.0,512.0,512.0
1.2.826.0.1.3680043.10001,4,0,0.253906,0.0,0.0,512.0,512.0


In [20]:
boundary_df.to_csv(os.path.join(DATA_DIR, f'infered_boundary_{model_ver}_2.csv'))