In [None]:
from XrayTo3DShape import MODEL_LABEL_COLOR, load_json

In [None]:
KEYS = [
    "SwinUNETR",
    "AttentionUnet",
    "TwoDPermuteConcat",
    "UNet",
    "MultiScale2DPermuteConcat",
    "UNETR",
    "TLPredictor",
    "OneDConcat",
]

In [None]:
import matplotlib.pyplot as plt
import scienceplots
import tueplots
from tueplots import figsizes, fontsizes

plt.style.use(["science", "no-latex"])
plt.rcParams.update({"figure.dpi": 150})
plt.rcParams.update(fontsizes.neurips2021())
plt.rcParams.update(figsizes.neurips2021(nrows=1, ncols=3))

In [None]:
ANGLE_PERTURBATIONS = [1, 2, 5, 10]

In [None]:
anatomies_file = {
    "femur": "../../angle_perturbation_results/femur_angle_perturbation.json",
    "rib": "../../angle_perturbation_results/rib_angle_perturbation.json",
    "hip": "../../angle_perturbation_results/hip_angle_perturbation.json",
    "vertebra": "../../angle_perturbation_results/vertebra_angle_perturbation.json",
}
fig, ax = plt.subplots(1, 4, sharey="all")
for idx, (anatomy, metrics_json) in enumerate(anatomies_file.items()):
    metrics_dict = load_json(metrics_json)
    metric = "DSC"
    for model in KEYS:
        try:
            angle_perturbed_dsc = [metrics_dict[model][metric][str(0)]]
            angle_perturbed_dsc.extend(
                [metrics_dict[model][metric][str(angle)][0] for angle in ANGLE_PERTURBATIONS]
            )
            ax[idx].plot(
                [0, 1, 2, 5, 10],
                angle_perturbed_dsc,
                "-o",
                label=model,
                markersize=2,
                c=MODEL_LABEL_COLOR[model],
            )
        except KeyError as e:
            print(e)
    if idx == 0:
        ax[idx].set_ylabel(metric)
    ax[idx].set_xlabel("angle(deg)")
    ax[idx].set_title(anatomy)
# plt.suptitle('Performance Degradation due to Misalignment of LAT view w.r.t to AP view')
plt.tight_layout()
plt.savefig("../../angle_perturbation_results/dsc_anatomywise.pdf")
plt.show()

In [None]:
hip_angle_perturbation_file = "../../angle_perturbation_results/hip_angle_perturbation.json"
hip_angle_perturbation = load_json(hip_angle_perturbation_file)
fig, ax = plt.subplots(1, 4, figsize=(8, 2))
for i, metric in enumerate(["DSC", "ASD", "HD95", "NSD"]):
    for model in KEYS:
        try:
            angle_perturbed_dsc = [hip_angle_perturbation[model][metric][str(0)]]
            angle_perturbed_dsc.extend(
                [
                    hip_angle_perturbation[model][metric][str(angle)][0]
                    for angle in ANGLE_PERTURBATIONS
                ]
            )
            ax[i].plot(
                [0, 1, 2, 5, 10],
                angle_perturbed_dsc,
                "-o",
                label=model,
                markersize=2,
                c=MODEL_LABEL_COLOR[model],
            )
        except KeyError as e:
            print(e)
    ax[i].set_ylabel(metric)
    ax[i].set_xlabel("angle(deg)")
    # ax[i].legend()

plt.suptitle("Hip Angle Perturbation")
plt.tight_layout()
plt.savefig("../../angle_perturbation_results/hip.pdf")
plt.show()

In [None]:
hip_angle_perturbation_file = "../../angle_perturbation_results/rib_angle_perturbation.json"
hip_angle_perturbation = load_json(hip_angle_perturbation_file)
fig, ax = plt.subplots(1, 4, figsize=(8, 2))
for i, metric in enumerate(["DSC", "ASD", "HD95", "NSD"]):
    for model in KEYS:
        try:
            angle_perturbed_dsc = [hip_angle_perturbation[model][metric][str(0)]]
            angle_perturbed_dsc.extend(
                [
                    hip_angle_perturbation[model][metric][str(angle)][0]
                    for angle in ANGLE_PERTURBATIONS
                ]
            )
            ax[i].plot(
                [0, 1, 2, 5, 10],
                angle_perturbed_dsc,
                "-o",
                label=model,
                markersize=2,
                c=MODEL_LABEL_COLOR[model],
            )
        except KeyError as e:
            print(e)
    ax[i].set_ylabel(metric)
    ax[i].set_xlabel("angle(deg)")
    # ax[i].legend()

plt.suptitle(" Rib Angle Perturbation")
plt.tight_layout()
plt.savefig("../../angle_perturbation_results/rib.pdf")
plt.show()

In [None]:
hip_angle_perturbation_file = "../../angle_perturbation_results/femur_angle_perturbation.json"
hip_angle_perturbation = load_json(hip_angle_perturbation_file)
fig, ax = plt.subplots(1, 4, figsize=(8, 2))
for i, metric in enumerate(["DSC", "ASD", "HD95", "NSD"]):
    for model in KEYS:
        try:
            angle_perturbed_dsc = [hip_angle_perturbation[model][metric][str(0)]]
            angle_perturbed_dsc.extend(
                [
                    hip_angle_perturbation[model][metric][str(angle)][0]
                    for angle in ANGLE_PERTURBATIONS
                ]
            )
            ax[i].plot(
                [0, 1, 2, 5, 10],
                angle_perturbed_dsc,
                "-o",
                label=model,
                markersize=2,
                c=MODEL_LABEL_COLOR[model],
            )
        except KeyError as e:
            print(e)
    ax[i].set_ylabel(metric)
    ax[i].set_xlabel("angle(deg)")
    # ax[i].legend()

plt.suptitle(" Femur Angle Perturbation")
plt.tight_layout()
plt.savefig("../../angle_perturbation_results/femur.pdf")
plt.show()

In [None]:
hip_angle_perturbation_file = "../../angle_perturbation_results/vertebra_angle_perturbation.json"
hip_angle_perturbation = load_json(hip_angle_perturbation_file)
fig, ax = plt.subplots(1, 4, figsize=(8, 2))
for i, metric in enumerate(["DSC", "ASD", "HD95", "NSD"]):
    for model in KEYS:
        try:
            angle_perturbed_dsc = [hip_angle_perturbation[model][metric][str(0)]]
            angle_perturbed_dsc.extend(
                [
                    hip_angle_perturbation[model][metric][str(angle)][0]
                    for angle in ANGLE_PERTURBATIONS
                ]
            )
            ax[i].plot(
                [0, 1, 2, 5, 10],
                angle_perturbed_dsc,
                "-o",
                label=model,
                markersize=2,
                c=MODEL_LABEL_COLOR[model],
            )
        except KeyError as e:
            print(e)
    ax[i].set_ylabel(metric)
    ax[i].set_xlabel("angle(deg)")
    # ax[i].legend()

plt.suptitle(" Vertebra Angle Perturbation")
plt.tight_layout()
plt.savefig("../../angle_perturbation_results/vertebra.pdf")
plt.show()