In [None]:
from XrayTo3DShape import MODEL_LABEL_COLOR, load_json

In [None]:
model_color_label = MODEL_LABEL_COLOR

In [None]:
hip_indomain_fp = "../../domainshift_results/hip_indomain.json"
hip_indomain = load_json(hip_indomain_fp)

hip_outdomain_kits_fp = "../../domainshift_results/hip_outdomain_kits.json"
hip_outdomain_kits = load_json(hip_outdomain_kits_fp)

hip_outdomain_clinic_fp = "../../domainshift_results/hip_outdomain_clinic.json"
hip_outdomain_clinic = load_json(hip_outdomain_clinic_fp)

hip_outdomain_clinic_metal_fp = "../../domainshift_results/hip_outdomain_clinic_metal.json"
hip_outdomain_clinic_metal = load_json(hip_outdomain_clinic_metal_fp)

vertebra_indomain_fp = "../../domainshift_results/vertebra_indomain.json"
vertebra_indomain = load_json(vertebra_indomain_fp)

vertebra_outdomain_rsna_fp = "../../domainshift_results/vertebra_outdomain_rsna.json"
vertebra_outdomain_rsna = load_json(vertebra_outdomain_rsna_fp)

In [None]:
hip_outdomain_kits, hip_outdomain_clinic, hip_outdomain_clinic_metal, hip_indomain, vertebra_indomain, vertebra_outdomain_rsna

In [None]:
KEYS = [
    "SwinUNETR",
    "AttentionUnet",
    "TwoDPermuteConcat",
    "UNet",
    "MultiScale2DPermuteConcat",
    "UNETR",
    "TLPredictor",
    "OneDConcat",
]

In [None]:
import matplotlib.pyplot as plt
import scienceplots
import tueplots
from tueplots import figsizes, fontsizes

plt.style.use(["science", "no-latex"])

In [None]:
# Increase the resolution of all the plots below
plt.rcParams.update({"figure.dpi": 150})
plt.rcParams.update(figsizes.neurips2021(nrows=1, ncols=4))
plt.rcParams.update(fontsizes.neurips2021())


fig, ax = plt.subplots(1, 4, sharey="all")
for model in KEYS:
    ax[0].plot(
        range(1, 3),
        [hip_indomain[model], hip_outdomain_kits[model]],
        "-o",
        label=model,
        markersize=2,
        c=model_color_label[model],
    )
ax[0].set_xticks(range(1, 3), ["in-domain\n[TOTALSEG]", "domain-shift\n[KITS19]"], fontsize=6)
ax[0].set_ylabel("DSC")
ax[0].set_xticks([], minor=True)
ax[0].set_title("cohort shift")
ax[0].yaxis.set_tick_params(labelleft=True)

for model in KEYS:
    ax[1].plot(
        range(1, 3),
        [hip_indomain[model], hip_outdomain_clinic[model]],
        "-o",
        label=model,
        markersize=2,
        c=model_color_label[model],
    )
ax[1].set_xticks(range(1, 3), ["in-domain\n[TOTALSEG]", "domain-shift\n[CLINIC]"], fontsize=6)
# ax[1].set_ylabel('DSC')
ax[1].set_xticks([], minor=True)
ax[1].set_title("Fractured Bone")
ax[1].yaxis.set_tick_params(labelleft=True)

for model in KEYS:
    ax[2].plot(
        range(1, 3),
        [vertebra_indomain[model], vertebra_outdomain_rsna[model]],
        "-o",
        label=model,
        markersize=2,
        c=model_color_label[model],
    )
ax[2].set_xticks(range(1, 3), ["in-domain\n[Verse19]", "domain-shift\n[RSNACervical]"], fontsize=6)
# ax[2].set_ylabel('DSC')
ax[2].set_xticks([], minor=True)
ax[2].yaxis.set_tick_params(labelleft=True)

# ax[2].set_title('X-ray with Bone Implants')

for model in KEYS:
    ax[3].plot(
        range(1, 3),
        [hip_indomain[model], hip_outdomain_clinic_metal[model]],
        "-o",
        label=model,
        markersize=2,
        c=model_color_label[model],
    )
ax[3].set_xticks(
    range(1, 3), ["in-domain\n[TOTALSEG]", "domain-shift\n[CLINIC-METAL]"], fontsize=5
)
# ax[3].set_ylabel('DSC')
ax[3].set_title("X-ray with Bone Implants")
ax[3].set_xticks([], minor=True)
ax[3].yaxis.set_tick_params(labelleft=True)


# plt.legend(loc=[3,0])
plt.tight_layout()
plt.savefig("../../visualizations/domain_shift_with_legend.pdf")

In [None]:
latex_table = ""
for model in KEYS:
    print(
        f"{model} & {hip_indomain[model]*100:.2f} & {hip_outdomain_kits[model]*100:.2f} & {hip_indomain[model]*100 - hip_outdomain_kits[model]*100:.2f} & {hip_outdomain_clinic[model]*100:.2f} & {hip_indomain[model]*100 - hip_outdomain_clinic[model]*100:.2f} & {hip_outdomain_clinic_metal[model]*100:.2f} & {hip_indomain[model]*100 - hip_outdomain_clinic_metal[model]*100:.2f} & {vertebra_indomain[model]*100:.2f} & {vertebra_outdomain_rsna[model]*100:.2f} & {vertebra_indomain[model]*100 - vertebra_outdomain_rsna[model]*100:.2f}\\\\"
    )
print(latex_table)