In [None]:
import sys

sys.path.append("/vol/biomedic3/mb121/causal-contrastive")
from torchvision.transforms import (
    Resize,
    CenterCrop,
    ToTensor,
    RandomAffine,
    ColorJitter,
    RandomResizedCrop,
)
from skimage import io
import matplotlib.pyplot as plt
from data_handling.mammo import preprocess_breast, get_embed_csv, modelname_map
from pathlib import Path
import matplotlib
import matplotlib.patches as patches

matplotlib.rcParams["font.family"] = "serif"
import seaborn as sns
from data_handling.xray import prepare_padchest_csv

# Some random images for figure 1

In [None]:
path = "/vol/biomedic3/data/EMBED/images/png/1024x768/10000879/1.2.826.0.1.3680043.8.498.10392068038916878965464813474172245832.png"
path_cf = "/vol/biomedic3/mb121/causal-contrastive/cf_beta1balanced_scanner/10000879/1.2.826.0.1.3680043.8.498.10392068038916878965464813474172245832_s3.png"
img1 = CenterCrop((224, 192))(preprocess_breast(str(path), (256, 192))[0])
plt.imshow(img1, cmap="gray")
plt.axis("off");

In [None]:
path_cf = "/vol/biomedic3/mb121/causal-contrastive/cf_beta1balanced_scanner/10000879/1.2.826.0.1.3680043.8.498.10392068038916878965464813474172245832_s2.png"
img1 = CenterCrop((224, 192))(preprocess_breast(str(path_cf), (256, 192))[0])
plt.imshow(img1, cmap="gray")
plt.axis("off")

In [None]:
from lightning import seed_everything

seed_everything(33)
img1 = ColorJitter(contrast=0.5, brightness=0.5)(
    RandomResizedCrop((224, 192))(
        RandomAffine(30)((preprocess_breast(str(path), (256, 192))))
    )
)
plt.imshow(img1[0], cmap="gray", vmin=0, vmax=1)
plt.axis("off")

In [None]:
seed_everything(55)
img1 = ColorJitter(contrast=0.5, brightness=0.5)(
    CenterCrop((224, 192))(RandomAffine(30)((preprocess_breast(str(path), (256, 192)))))
)
plt.imshow(img1[0], cmap="gray", vmin=0, vmax=1)
plt.axis("off")

# Visualisation of counterfactuals

In [None]:
df = get_embed_csv()

In [None]:
df = get_embed_csv()

rev_model_map = {v: k for k, v in modelname_map.items()}
rev_model_map[2] = "Senograph 2000D"
f, ax = plt.subplots(3, 3, figsize=(15, 15))

for i, s in enumerate([0, 2, 4]):
    shortpath = df.loc[df.SimpleModelLabel == s, "shortimgpath"].values[1]
    path = Path("/vol/biomedic3/data/EMBED/images/png/1024x768") / shortpath

    for j, cf in enumerate([0, 2, 4]):
        path_cf = (
            Path("/vol/biomedic3/mb121/causal-contrastive/cf_beta1balanced_scanner")
            / f"{shortpath[:-4]}_s{cf}.png"
        )
        if Path(path_cf).exists():
            img1 = CenterCrop((224, 192))(preprocess_breast(str(path_cf), (256, 192)))
            if i == 0:
                t = rev_model_map[cf].replace(" ", "\ ")
                ax[i, j].set_title(
                    r"$\bf{" + t + "}$" + "\nCOUNTERFACTUAL", fontsize=20
                )
            else:
                ax[i, j].set_title(f"COUNTERFACTUAL", fontsize=20)
        else:
            img1 = CenterCrop((224, 192))(preprocess_breast(str(path), (256, 192)))
            if i == 0:
                t = rev_model_map[cf].replace(" ", "\ ")
                ax[i, j].set_title(r"$\bf{" + t + "}$" + "\nREAL IMAGE", fontsize=20)
            else:
                ax[i, j].set_title(f"REAL IMAGE", fontsize=20)

        ax[i, j].imshow(img1[0], cmap="gray")
        ax[i, j].set_xticks([]), ax[i, j].set_yticks([])

plt.savefig("cf_viz.pdf", bbox_inches="tight")

In [None]:
from sklearn.model_selection import train_test_split

df = prepare_padchest_csv()
train_val_id, _ = train_test_split(
    df.PatientID.unique(),
    test_size=0.20,
    random_state=33,
)

train_id, _ = train_test_split(
    train_val_id,
    test_size=0.10,
    random_state=33,
)
df = df.loc[df.PatientID.isin(train_id)]
f, ax = plt.subplots(3, 2, figsize=(10, 15))

for i, s in enumerate(["Phillips", "Imaging", "Phillips"]):
    shortpath = df.loc[df.Manufacturer == s, "ImageID"].values[i + 1]
    path = Path("/vol/biodata/data/chest_xray/BIMCV-PADCHEST") / "images" / shortpath

    for j, cf in enumerate(["Phillips", "Imaging"]):
        if cf != s:
            path_cf = Path("../padchest_cf_images_v0") / f"{shortpath[:-4]}_sc_cf.png"
            img = io.imread(str(path_cf), as_gray=True)
            if i == 0:
                t = cf.replace(" ", "\ ")
                ax[i, j].set_title(
                    r"$\bf{" + t + "}$" + "\nCOUNTERFACTUAL", fontsize=20
                )
            else:
                ax[i, j].set_title("COUNTERFACTUAL", fontsize=20)
        else:
            img = io.imread(str(path), as_gray=True)
            if i == 0:
                t = cf.replace(" ", "\ ")
                ax[i, j].set_title(r"$\bf{" + t + "}$" + "\nREAL IMAGE", fontsize=20)
            else:
                ax[i, j].set_title(f"REAL IMAGE", fontsize=20)
        img = img / (img.max() + 1e-12)
        img = CenterCrop(224)(Resize(224, antialias=True)(ToTensor()(img)))[0]
        ax[i, j].imshow(img, cmap="gray", vmin=0, vmax=1)
        ax[i, j].set_xticks([]), ax[i, j].set_yticks([])

plt.savefig("cf_viz_cxr.pdf", bbox_inches="tight")