In [None]:
import sys
import os

# Allow import of dicom_preprocess
module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import glob
import pydicom
import numpy as np
from rt_utils import RTStructBuilder
from dicom_preprocess import (
    filter_ptv_name,
    filter_junction_name,
    get_ptv_mask_3d,
    get_dicom_field_geometry,
    transform_field_geometry,
)
from src.config.constants import MAP_ID_PTV, MAP_ID_JUNCTION, DICOM_TEST_PATH

In [None]:
# Load existing RT Struct. Requires the series path and existing RT Struct path
rt_struct_path = glob.glob(os.path.join(DICOM_TEST_PATH, "RTSTRUCT*"))[0]
rtstruct = RTStructBuilder.create_from(
    dicom_series_path=DICOM_TEST_PATH,
    rt_struct_path=rt_struct_path,
)

ptv_name = (
    MAP_ID_PTV[rtstruct.ds.PatientID]
    if rtstruct.ds.PatientID in MAP_ID_PTV
    else next(filter(filter_ptv_name, rtstruct.get_roi_names()))
)

# Loading the 3D Mask from within the RT Struct
mask_3d_ptv = rtstruct.get_roi_mask_by_name(ptv_name)

# Retrieve PTV junction names
junction_names = (
    MAP_ID_JUNCTION[rtstruct.ds.PatientID]
    if rtstruct.ds.PatientID in MAP_ID_JUNCTION
    else list(filter(filter_junction_name, rtstruct.get_roi_names()))
)

mask_3d = get_ptv_mask_3d(
    rtstruct, ptv_name, junction_names
)  # axis0=y, axis1=x, axis2=z

rt_plan_path = glob.glob(os.path.join(DICOM_TEST_PATH, "RTPLAN*"))[0]
ds = pydicom.read_file(rt_plan_path)
isocenters, jaw_X, jaw_Y, coll_angles = get_dicom_field_geometry(
    rtstruct.series_data, ds
)

# Create 3D array
series_data = rtstruct.series_data
img_shape = list(series_data[0].pixel_array.shape)
img_shape.append(len(series_data))
img_3d = np.zeros(img_shape)

for i, s in enumerate(series_data):
    img_2d = s.pixel_array
    img_3d[:, :, i] = img_2d

iso_pixel, jaw_X_pix, jaw_Y_pix = transform_field_geometry(
    rtstruct.series_data, isocenters, jaw_X, jaw_Y
)

## Plot

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

In [None]:
slice_thickness = series_data[0].SliceThickness
pix_spacing = series_data[0].PixelSpacing[0]
aspect_ratio = series_data[0].SliceThickness / series_data[0].PixelSpacing[0]


def add_rectangle_patch(
    ax: plt.Axes,
    anchor: tuple[float, float],
    width: float,
    height: float,
    rotation_point: tuple[float, float],
    angle: float,
) -> None:
    ax.add_patch(
        Rectangle(
            anchor,
            width,
            height,
            angle=angle,
            rotation_point=rotation_point,
            linewidth=1,
            edgecolor="r",
            facecolor="none",
        )
    )


def plot_fields(
    ax: plt.Axes,
    iso_pixel: np.ndarray,
    jaw_X: np.ndarray,
    jaw_Y: np.ndarray,
    coll_angles: np.ndarray,
) -> None:
    for i, (iso, X, Y, angle) in enumerate(
        zip(
            iso_pixel,
            jaw_X,
            jaw_Y,
            coll_angles,
        )
    ):
        if all(iso == 0):
            continue  # Isocenter not present
        iso_pixel_col, iso_pixel_row = iso[2], iso[0]
        offset_col = Y[0] / slice_thickness
        offset_row = X[1] / pix_spacing
        width = (Y[1] - Y[0]) / slice_thickness
        height = (X[1] - X[0]) / pix_spacing
        if angle != 90:
            print(
                f"Collimator angle for field {i + 1} was {angle}°. Plotting with angle=0° for visualization"
            )
            angle = 0
        elif angle == 90:
            offset_col *= aspect_ratio
            offset_row /= aspect_ratio
            width *= aspect_ratio
            height /= aspect_ratio
        add_rectangle_patch(
            ax,
            (iso_pixel_col + offset_col, iso_pixel_row - offset_row),
            width,
            height,
            (iso_pixel_col, iso_pixel_row),
            angle,
        )

In [None]:
# Display one slice of the region
plt.imshow(img_3d[mask_3d.shape[0] // 2, :, :], cmap="gray", aspect=1 / aspect_ratio)
plt.contourf(mask_3d[mask_3d.shape[0] // 2, :, :], alpha=0.25)
plt.scatter(iso_pixel[:, 2], iso_pixel[:, 0], color="red", s=10)
plot_fields(plt.gca(), iso_pixel, jaw_X, jaw_Y, coll_angles)

plt.show()