In [None]:
import glob
import pydicom
import os
import numpy as np
import pandas as pd
from rt_utils import RTStructBuilder
from src.data.dicom_preprocess import (
    filter_ptv_name,
    filter_junction_name,
    get_ptv_mask_3d,
    get_dicom_field_geometry,
    transform_field_geometry,
)
from src.config.constants import (
    MAP_ID_PTV,
    MAP_ID_JUNCTION,
    DICOM_DIR_TEST_PATH,
    CLASSIFICATION,
)

In [None]:
# Load existing RT Struct. Requires the series path and existing RT Struct path
rt_struct_path = glob.glob(os.path.join(DICOM_DIR_TEST_PATH, "RTSTRUCT*"))[0]
rtstruct = RTStructBuilder.create_from(
    dicom_series_path=DICOM_DIR_TEST_PATH,
    rt_struct_path=rt_struct_path,
)

ptv_name = (
    MAP_ID_PTV[rtstruct.ds.PatientID]
    if rtstruct.ds.PatientID in MAP_ID_PTV
    else next(filter(filter_ptv_name, rtstruct.get_roi_names()))
)

# Loading the 3D Mask from within the RT Struct
mask_3d_ptv = rtstruct.get_roi_mask_by_name(ptv_name)

# Retrieve PTV junction names
junction_names = (
    MAP_ID_JUNCTION[rtstruct.ds.PatientID]
    if rtstruct.ds.PatientID in MAP_ID_JUNCTION
    else list(filter(filter_junction_name, rtstruct.get_roi_names()))
)

mask_3d = get_ptv_mask_3d(
    rtstruct, ptv_name, junction_names
)  # axis0=y, axis1=x, axis2=z

rt_plan_path = glob.glob(os.path.join(DICOM_DIR_TEST_PATH, "RTPLAN*"))[0]
ds = pydicom.read_file(rt_plan_path)
isocenters, jaw_X, jaw_Y, coll_angles = get_dicom_field_geometry(
    rtstruct.series_data, ds
)

# Create 3D array
series_data = rtstruct.series_data
img_shape = list(series_data[0].pixel_array.shape)
img_shape.append(len(series_data))
img_3d = np.zeros(img_shape)

for i, s in enumerate(series_data):
    img_2d = s.pixel_array
    img_3d[:, :, i] = img_2d

iso_pixel, jaw_X_pix, jaw_Y_pix = transform_field_geometry(
    rtstruct.series_data, isocenters, jaw_X, jaw_Y
)

## Plot

In [None]:
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

In [None]:
slice_thickness = series_data[0].SliceThickness
pix_spacing = series_data[0].PixelSpacing[0]
aspect_ratio = series_data[0].SliceThickness / series_data[0].PixelSpacing[0]


def add_rectangle_patch(
    ax: plt.Axes,
    anchor: tuple[float, float],
    width: float,
    height: float,
    rotation_point: tuple[float, float],
    angle: float,
) -> None:
    ax.add_patch(
        Rectangle(
            anchor,
            width,
            height,
            angle=angle,
            rotation_point=rotation_point,
            linewidth=1,
            edgecolor="r",
            facecolor="none",
        )
    )


def plot_fields(
    ax: plt.Axes,
    iso_pixel: np.ndarray,
    jaw_X_pixel: np.ndarray,
    jaw_Y_pixel: np.ndarray,
    coll_angles: np.ndarray,
) -> None:
    """
        Plots rectangular fields on a given Matplotlib Axes object and their relative isocenters.

    Args:
        ax (matplotlib.axes.Axes): The Axes object to plot the fields on.
        iso_pixel (numpy.ndarray): An array of shape (n_fields, 3) representing the pixel coordinates of the isocenter for each field.
        jaw_X_pixel (numpy.ndarray): An array of shape (n_fields, 2) representing the jaw apertures along X (mm)
        jaw_Y_pixel (numpy.ndarray): An array of shape (n_fields, 2) representing the jaw apertures along Y (mm)
        coll_angles (numpy.ndarray): An array of shape (n_fields,) containing the collimator angles in degrees for each field.

    Returns:
        None

    Notes:
        The function plots a rectangle for each field on the given Axes object, with the isocenter pixel as the center of the rectangle.
        The position and size of the rectangle are determined by the jaw positions and collimator angle.
        If the collimator angle is not 90 degrees, the rectangle is rotated to 0 degrees for visualization.
        If the collimator angle is 90 degrees, the rectangle is scaled by the aspect ratio to account for the non-square pixel aspect ratio.
    """
    for i, (iso, X, Y, angle) in enumerate(
        zip(
            iso_pixel,
            jaw_X_pixel,
            jaw_Y_pixel,
            coll_angles,
        )
    ):
        if all(iso == 0):
            continue  # isocenter not present, skip field
        iso_pixel_col, iso_pixel_row = iso[2], iso[0]
        offset_col = Y[0]
        offset_row = X[1]
        width = Y[1] - Y[0]
        height = X[1] - X[0]

        if angle != 90:
            print(
                f"Collimator angle for field {i + 1} was {angle}°. Plotting with angle=0° for visualization"
            )
            angle = 0
        elif angle == 90:
            offset_col *= aspect_ratio
            offset_row /= aspect_ratio
            width *= aspect_ratio
            height /= aspect_ratio
        add_rectangle_patch(
            ax,
            (iso_pixel_col + offset_col, iso_pixel_row - offset_row),
            width,
            height,
            (iso_pixel_col, iso_pixel_row),
            angle,
        )

In [None]:
# Display one slice of the region
plt.imshow(img_3d[mask_3d.shape[0] // 2, :, :], cmap="gray", aspect=1 / aspect_ratio)
plt.contourf(mask_3d[mask_3d.shape[0] // 2, :, :], alpha=0.25)
plt.scatter(iso_pixel[:, 2], iso_pixel[:, 0], color="red", s=10)
plot_fields(plt.gca(), iso_pixel, jaw_X_pix, jaw_Y_pix, coll_angles)

plt.show()

### Visualize all PTV's ###

In [None]:
path = os.getcwd()
path = os.path.join(path, os.pardir)

if CLASSIFICATION:
    dir_path = path + r"\data\5_355"
else:
    dir_path = path + r"\data\90"

In [None]:
df_patient_info = pd.read_csv(path + r"\data\patient_info.csv")
slice_tickness_col_idx = df_patient_info.columns.get_loc("SliceThickness")
pixel_spacing_col_idx = df_patient_info.columns.get_loc("PixelSpacing")

In [None]:
with np.load(dir_path + r"\raw\ptv_masks2D.npz") as npz_mask:
    ptv_masks2D = list(npz_mask.values())

In [None]:
with np.load(dir_path + r"\raw\ptv_imgs2D.npz") as npz_ptv_imgs2d:
    ptv_imgs2D = list(npz_ptv_imgs2d.values())

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(11, 10, figsize=(20, 20))

for idx, (ptv_img, ax) in enumerate(zip(ptv_imgs2D, axes.flat)):
    aspect_ratio = (
        df_patient_info.iloc[idx, slice_tickness_col_idx]
        / df_patient_info.iloc[idx, pixel_spacing_col_idx]
    )
    ax.imshow(ptv_img, cmap="gray", aspect=1 / aspect_ratio)

plt.tight_layout()