In [None]:
import cv2 
import glob
import matplotlib.pyplot as plt 
import numpy as np 
import os
import torch 

from transformers import AutoModel

In [None]:
cropper = AutoModel.from_pretrained("ianpan/ct-crop", trust_remote_code=True, token=True).eval().cuda()

In [None]:
# Test single slice
files = glob.glob("/mnt/stor/datasets/kaggle/rsna-intracranial-hemorrhage-detection/stage_2_train_png/*/*/*/*.png")
files[0]

In [None]:
f = np.random.choice(files)
img = cv2.imread(f)
print(img.shape)
with torch.inference_mode():
    cropped_img = cropper.crop(img[..., 2], mode="2d", add_buffer=0.025)
print(cropped_img.shape)

plt.imshow(cropped_img, cmap="gray")

In [None]:
empty = np.zeros((256, 256))
with torch.inference_mode():
    cropped_empty = cropper.crop(empty, mode="2d", add_buffer=0.025)
print(cropped_empty.shape)

In [None]:
# Test series
series = glob.glob("/mnt/stor/datasets/kaggle/rsna-intracranial-hemorrhage-detection/stage_2_train_png/*/*/*")
s = np.random.choice(series)
png_files = np.sort(glob.glob(os.path.join(s, "*.png")))
stack = np.stack([cv2.imread(f) for f in png_files], axis=0)
cropped_stack = cropper.crop(stack, add_buffer=0.025, mode="3d")
print(cropped_stack.shape)

# for i in range(cropped_stack.shape[0]):
#     plt.imshow(cropped_stack[i], cmap="gray")
#     plt.show()

In [None]:
# Test DICOM series
dicom_dirs = glob.glob("/mnt/stor/datasets/kaggle/rsna-intracranial-hemorrhage-detection/stage_2_train/*/*/*")
d = np.random.choice(dicom_dirs)
stack = cropper.load_stack_from_dicom_folder(d)
print(stack.shape, np.unique(stack))

# Raw HU
cropped_stack = cropper.crop(stack, add_buffer=0.025, mode="3d", raw_hu=True)
print(cropped_stack.shape)

for i in range(cropped_stack.shape[0]):
    plt.imshow(cropped_stack[i], cmap="gray")
    plt.show()

In [None]:
# Test DICOM series
dicom_dirs = glob.glob("/mnt/stor/datasets/kaggle/rsna-intracranial-hemorrhage-detection/stage_2_train/*/*/*")
d = np.random.choice(dicom_dirs)
stack, dicom_files = cropper.load_stack_from_dicom_folder(d, windows=[(40, 80), (400, 1800), (200, 10)], return_sorted_dicom_files=True)
print(stack.shape, np.unique(stack), dicom_files)

# Window
cropped_stack = cropper.crop(stack, add_buffer=0.025, mode="3d", raw_hu=False)
print(cropped_stack.shape)

for i in range(cropped_stack.shape[0]):
    plt.imshow(cropped_stack[i], cmap="gray")
    plt.show()

In [None]:
# Test DICOM series
dicom_dirs = glob.glob("/mnt/stor/datasets/kaggle/rsna-intracranial-hemorrhage-detection/stage_2_train/*/*/*")
d = np.random.choice(dicom_dirs)
stack, dicom_files = cropper.load_stack_from_dicom_folder(d, windows=[(40, 80), (400, 1800), (200, 10)], return_sorted_dicom_files=True)
print(stack.shape)

# Add empty slices
empty = np.zeros_like(stack[0])
empty = empty[np.newaxis]
stack = np.concatenate([empty, stack, empty], axis=0)

# Window
cropped_stack, empty_indices = cropper.crop(stack, add_buffer=0.025, mode="3d", raw_hu=False, remove_empty_slices=True)
print(empty_indices)
print(cropped_stack.shape)

for i in range(cropped_stack.shape[0]):
    plt.imshow(cropped_stack[i], cmap="gray")
    plt.show()