In [None]:
import argparse
import copy
import sys
import seaborn as sns
from pathlib import Path
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from datetime import datetime
import pandas as pd

sys.path.append("/vol/biomedic3/mb121/causal-contrastive")
from causal_models.train_setup import (
    setup_dataloaders,
    load_finetuned_vae,
    load_vae_and_module,
)

from sklearn.metrics import roc_auc_score, balanced_accuracy_score

from causal_models.hps import Hparams
from causal_models.hvae import HVAE2
import torch
import numpy as np
import matplotlib.pyplot as plt
from causal_models.trainer import preprocess_batch
from data_handling.mammo import modelname_map

rev_model_map = {v: k for k, v in modelname_map.items()}
from classification.classification_module import ClassificationModule

from causal_models.plotting_utils import (
    MidpointNormalize,
    plot_cxr_grid,
    plot_counterfactual_viz_embed,
)
from pytorch_lightning import seed_everything

from sklearn.pipeline import Pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.decomposition import PCA
from sklearn.model_selection import cross_val_score
import numpy as np


def plot_counterfactual_grid(gt, cf, rec):
    n_images = gt.shape[0]
    f, ax = plt.subplots(5, n_images, figsize=(25, 10))
    plot_cxr_grid(gt, ax=ax[0])
    plot_cxr_grid(rec, ax=ax[1])
    plot_cxr_grid(
        rec - gt,
        ax=ax[2],
        fig=f,
        cmap="RdBu_r",
        cbar=True,
        norm=MidpointNormalize(midpoint=0),
    )
    plot_cxr_grid(cf, ax=ax[3])
    plot_cxr_grid(
        cf - gt,
        ax=ax[4],
        fig=f,
        cmap="RdBu_r",
        cbar=True,
        norm=MidpointNormalize(midpoint=0),
    )
    plt.show()


def get_mae(ground_truth, counterfactual, indices=None):
    if indices is None:
        return (
            np.abs(counterfactual - ground_truth)
            .reshape(counterfactual.shape[0], -1)
            .mean(1)
        )
    return (
        np.abs(counterfactual[indices] - ground_truth[indices])
        .reshape(indices.shape[0], -1)
        .mean(1)
    )


def get_cxr_counterfactuals(
    model,
    batch,
    args,
    s=0.0,
    col="sex",
    cf_col_values=None,
    partial_abduct=None,
    plot=True,
    t=0.1,
    u_t=1.0,
):
    model.cuda()
    results = {}
    _pa = torch.cat([batch[k] for k in args.parents_x], dim=1)
    _pa = (
        _pa[..., None, None]
        .repeat(1, 1, *(args.input_res,) * 2)
        .to(args.device)
        .float()
    )

    zs = model.abduct(x=batch["x"].cuda(), parents=_pa.cuda(), t=1e-5)
    if model.cond_prior:
        zs = [zs[j]["z"] for j in range(len(zs))]
    if partial_abduct is not None:
        zs = zs[:partial_abduct]

    px_loc, px_scale = model.forward_latents(zs, parents=_pa, t=t)

    cf_pa = {k: batch[k] for k in args.parents_x}
    if cf_col_values is None:
        if col == "scanner":
            cf_pa[col] = torch.nn.functional.one_hot(
                torch.randint_like(torch.argmax(cf_pa[col], 1), high=5), num_classes=5
            )
            do = {"scanner": cf_pa[col]}
        elif col == "view":
            cf_pa["view"] = 1 - cf_pa["view"]
            do = {"view": cf_pa[col]}
        elif col == "cview":
            cf_pa["cview"] = 1 - cf_pa["cview"]
            do = {"cview": cf_pa[col]}
        elif col == "density":
            cf_pa[col] = torch.nn.functional.one_hot(
                torch.randint_like(torch.argmax(cf_pa[col], 1), high=4), num_classes=4
            )
            do = {"density": cf_pa[col]}
        elif col == "size":
            cf_pa[col] = cf_pa[col] * 1.3
            do = {"size": cf_pa[col]}
        else:
            NotImplementedError
    else:
        NotImplementedError
        if isinstance(col, str):
            if cf_col_values.ndim == 1:
                cf_pa[col] = cf_col_values.unsqueeze(1)
            else:
                cf_pa[col] = cf_col_values
        else:
            assert len(cf_col_values) == len(col)
            for c, v in zip(col, cf_col_values):
                cf_pa[c] = v
    _cf_pa = torch.cat([cf_pa[k] for k in args.parents_x], dim=1)
    _cf_pa = (
        _cf_pa[..., None, None]
        .repeat(1, 1, *(args.input_res,) * 2)
        .to(args.device)
        .float()
    )
    cf_loc, cf_scale = model.forward_latents(zs, parents=_cf_pa.cuda(), t=t)

    diff = cf_loc - px_loc
    diff = diff - diff.mean(dim=[-1, -2], keepdim=True)
    cf_loc = cf_loc + s * diff
    u = (batch["x"] - px_loc) / px_scale.clamp(min=1e-12)
    cf_scale = cf_scale * u_t
    cf_x = torch.clamp(cf_loc + cf_scale * u, min=-1, max=1)
    results["DE"] = cf_x.cpu()
    results["REC"] = px_loc.cpu()
    results["GT"] = batch["x"].cpu()
    if isinstance(col, list):
        results["cf_col_values"] = torch.cat([cf_pa[c] for c in col], 1)
    else:
        results["cf_col_values"] = cf_pa[col]

    if plot:
        f = plot_counterfactual_viz_embed(
            args,
            batch["x"].cpu(),
            cf_x.cpu(),
            {k: batch[k] for k in args.parents_x},
            cf_pa,
            do,
            results["REC"],
            save=False,
        )
        f.savefig("save.pdf")
        f.show()

    return results


def get_nochange_counterfactuals(
    model,
    batch,
    args,
    col="sex",
    partial_abduct=None,
    plot=True,
    t=1.0,
    u_t=1.0,
):
    model.cuda()
    results = {}
    _pa = torch.cat([batch[k] for k in args.parents_x], dim=1)
    _pa = (
        _pa[..., None, None]
        .repeat(1, 1, *(args.input_res,) * 2)
        .to(args.device)
        .float()
    )

    zs = model.abduct(x=batch["x"].cuda(), parents=_pa.cuda(), t=1e-5)
    if model.cond_prior:
        zs = [zs[j]["z"] for j in range(len(zs))]
    if partial_abduct is not None:
        zs = zs[:partial_abduct]

    px_loc, px_scale = model.forward_latents(zs, parents=_pa, t=t)

    cf_pa = {k: batch[k] for k in args.parents_x}

    _cf_pa = torch.cat([cf_pa[k] for k in args.parents_x], dim=1)
    _cf_pa = (
        _cf_pa[..., None, None]
        .repeat(1, 1, *(args.input_res,) * 2)
        .to(args.device)
        .float()
    )
    cf_loc, cf_scale = model.forward_latents(zs, parents=_cf_pa.cuda(), t=t)

    diff = cf_loc - px_loc
    diff = diff - diff.mean(dim=[-1, -2], keepdim=True)
    u = (batch["x"] - px_loc) / px_scale.clamp(min=1e-12)
    cf_scale = cf_scale * u_t
    cf_x = torch.clamp(cf_loc + cf_scale * u, min=-1, max=1)
    results["DE"] = cf_x.cpu()
    results["REC"] = px_loc.cpu()
    results["GT"] = batch["x"].cpu()
    if isinstance(col, list):
        results["cf_col_values"] = torch.cat([cf_pa[c] for c in col], 1)
    else:
        results["cf_col_values"] = cf_pa[col]

    if plot:
        f = plot_counterfactual_viz_embed(
            args,
            batch["x"].cpu(),
            cf_x.cpu(),
            {k: batch[k] for k in args.parents_x},
            cf_pa,
            do,
            results["REC"],
            save=False,
        )
        f.show()

    return results


def get_reverse_counterfactuals(
    model,
    batch,
    args,
    col="sex",
    partial_abduct=None,
    t=0.1,
    u_t=1.0,
):
    model.cuda()
    results = {}
    _pa = torch.cat([batch[k] for k in args.parents_x], dim=1)
    _pa = (
        _pa[..., None, None]
        .repeat(1, 1, *(args.input_res,) * 2)
        .to(args.device)
        .float()
    )

    zs = model.abduct(x=batch["x"].cuda(), parents=_pa.cuda(), t=1e-5)
    if model.cond_prior:
        zs = [zs[j]["z"] for j in range(len(zs))]
    if partial_abduct is not None:
        zs = zs[:partial_abduct]

    px_loc, px_scale = model.forward_latents(zs, parents=_pa, t=t)

    cf_pa = {k: batch[k] for k in args.parents_x}

    if col == "scanner":
        cf_pa[col] = torch.nn.functional.one_hot(
            torch.randint_like(torch.argmax(cf_pa[col], 1), high=5), num_classes=5
        )
        do = {"scanner": cf_pa[col]}
    elif col == "view":
        cf_pa["view"] = 1 - cf_pa["view"]
        do = {"view": cf_pa[col]}
    elif col == "cview":
        cf_pa["cview"] = 1 - cf_pa["cview"]
        do = {"cview": cf_pa[col]}
    elif col == "density":
        cf_pa[col] = torch.nn.functional.one_hot(
            torch.randint_like(torch.argmax(cf_pa[col], 1), high=4), num_classes=4
        )
        do = {"density": cf_pa[col]}
    elif col == "size":
        cf_pa[col] = cf_pa[col] * 1.3
        do = {"size": cf_pa[col]}
    else:
        NotImplementedError

    _cf_pa = torch.cat([cf_pa[k] for k in args.parents_x], dim=1)
    _cf_pa = (
        _cf_pa[..., None, None]
        .repeat(1, 1, *(args.input_res,) * 2)
        .to(args.device)
        .float()
    )
    cf_loc, cf_scale = model.forward_latents(zs, parents=_cf_pa.cuda(), t=t)

    diff = cf_loc - px_loc
    diff = diff - diff.mean(dim=[-1, -2], keepdim=True)
    u = (batch["x"] - px_loc) / px_scale.clamp(min=1e-12)
    cf_scale = cf_scale * u_t
    cf_x = torch.clamp(cf_loc + cf_scale * u, min=-1, max=1)
    cf_x_o = cf_x.clone()

    zs = model.abduct(x=cf_x, parents=_cf_pa.cuda(), t=1e-5)
    if model.cond_prior:
        zs = [zs[j]["z"] for j in range(len(zs))]
    if partial_abduct is not None:
        zs = zs[:partial_abduct]
    px_loc, px_scale = model.forward_latents(zs, parents=_cf_pa.cuda(), t=t)
    cf_loc, cf_scale = model.forward_latents(zs, parents=_pa.cuda(), t=t)
    diff = cf_loc - px_loc
    diff = diff - diff.mean(dim=[-1, -2], keepdim=True)
    u = (cf_x_o - px_loc) / px_scale.clamp(min=1e-12)
    cf_scale = cf_scale * u_t
    cf_x = torch.clamp(cf_loc + cf_scale * u, min=-1, max=1)
    results["DE"] = cf_x.cpu()
    results["GT"] = batch["x"].cpu()
    return results

# Effectiveness

In [None]:
seed_everything(33)

# For ablation study

# BaseCF
model_path = (
    "/vol/biomedic3/mb121/causal-contrastive/outputs/scanner/beta1balanced/last_19.pt"
)
vae, dataloader, args = load_vae_and_module(model_path)

# CF-
# model_path = '/vol/biomedic3/mb121/causal-contrastive/outputs/scanner/bad/checkpoint.pt'
# vae, dataloader, args = load_vae_and_module(model_path)

# CF+
# model_path = 'cf_finetune.ckpt'
# vae, args = load_finetuned_vae(model_path)
# dataloader = setup_dataloaders(args, cache=False)


test_loader = dataloader["valid"]
run_name = (
    Path(model_path).parent.parent.stem
    + Path(model_path).parent.stem
    + datetime.now().strftime("%d%m/%H:%M")
)
print(run_name)

In [None]:
run_name = (
    Path(model_path).parent.parent.stem
    + Path(model_path).parent.stem
    + datetime.now().strftime("%d%m/%H:%M")
)
print(run_name)

model_paths = {
    "scanner": "../../outputs2/run_vrn2dmeo/best.ckpt",
}

for column in ["scanner"]:
    classification_model = ClassificationModule.load_from_checkpoint(
        model_paths[column]
    ).model.eval()
    classification_model.cuda()
    confs = []
    labels = []
    confs_true = []
    true_label = []
    all_feats = []
    true_feats = []
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            if column == "cview":
                valid_indices = torch.where(batch["scanner"] == 0)[0]
            else:
                valid_indices = torch.where(batch["cview"] >= 0)[0]
            batch = preprocess_batch(args, batch, expand_pa=False)
            out_d = get_cxr_counterfactuals(
                model=vae,
                batch=batch,
                args=args,
                col=column,
                plot=i == 0,
                s=0.0,
                t=0.8,
                u_t=0.8,
            )

            label = out_d["cf_col_values"].numpy()
            inputs = (out_d["DE"].to(args.device) + 1) / 2
            feats = classification_model.get_features(inputs.cuda())
            logits = classification_model.classify_features(feats).cpu()
            feats = feats.cpu()
            all_feats.append(feats)
            probas = torch.softmax(logits, 1).numpy()
            if column == "2density":
                binproba = np.zeros((label.shape[0], 2))
                binproba[:, 0] = probas[:, 0] + probas[:, 1]
                binproba[:, 1] = probas[:, 2] + probas[:, 3]
                confs.append(binproba)
            else:
                confs.append(probas[valid_indices])
            labels.append(label[valid_indices])
            inputs = (out_d["GT"].to(args.device) + 1) / 2
            feats = classification_model.get_features(inputs.cuda())
            logits = classification_model.classify_features(feats).cpu()
            feats = feats.cpu()
            true_feats.append(feats)
            probas = torch.softmax(logits, 1).numpy()
            if column == "2density":
                binproba = np.zeros((label.shape[0], 2))
                binproba[:, 0] = probas[:, 0] + probas[:, 1]
                binproba[:, 1] = probas[:, 2] + probas[:, 3]
                confs_true.append(binproba)
            else:
                confs_true.append(probas[valid_indices])
            true_label.append(batch[column][valid_indices].numpy())
            if i % 50 == 0 and i > 0:
                print(i)
                print(np.concatenate(true_label).mean())

    print(f"Variable CF {column}")
    targets = np.concatenate(labels)
    targets_true = np.concatenate(true_label)
    all_feats = np.concatenate(all_feats)
    true_feats = np.concatenate(true_feats)

    if targets.shape[1] == 1:
        targets = targets.reshape(-1)
        targets_true = targets_true.reshape(-1)
    elif column == "2density":
        targets = np.argmax(targets, 1) > 1
        targets_true = np.argmax(targets_true, 1) > 1
    else:
        targets = np.argmax(targets, 1)
        targets_true = np.argmax(targets_true, 1)

    confs = np.concatenate(confs)
    confs_true = np.concatenate(confs_true)

    res = {}

    bal_true = balanced_accuracy_score(targets_true, np.argmax(confs_true, 1))
    bal_cf = balanced_accuracy_score(targets, np.argmax(confs, 1))

    print(f"Balanced accuracy GT {bal_true:.3f}")
    print(f"Balanced accuracy CF {bal_cf:.3f}")

    res["run_name"] = run_name
    res["bal_true"] = bal_true
    res["bal_cf"] = bal_cf

    cm_true = confusion_matrix(targets_true, np.argmax(confs_true, 1))
    cm_cf = confusion_matrix(targets, np.argmax(confs, 1))

    res["cm_true"] = cm_true.reshape(-1)
    res["cm_cf"] = cm_cf.reshape(-1)
    res["cm_true"] = ",".join([str(i) for i in res["cm_true"]])
    res["cm_cf"] = ",".join([str(i) for i in res["cm_cf"]])
    res["cf_var"] = column

    # ConfusionMatrixDisplay.from_predictions(
    #     targets,
    #     np.argmax(confs, 1),
    #     display_labels=[rev_model_map[i] for i in range(5)],
    # )

    ConfusionMatrixDisplay.from_predictions(
        targets,
        np.argmax(confs, 1),
        display_labels=[rev_model_map[i] for i in range(5)],
        xticks_rotation="vertical",
        normalize="true",
        cmap="Blues",
    )
    plt.savefig("cf.pdf", bbox_inches="tight")

    # ConfusionMatrixDisplay.from_predictions(
    #     targets_true,
    #     np.argmax(confs_true, 1),
    #     display_labels=[rev_model_map[i] for i in range(5)],
    #     normalize="true",
    # )

    if confs.shape[1] == 2:
        roc_gt = roc_auc_score(targets_true, confs_true[:, 1])
        roc_cf = roc_auc_score(targets, confs[:, 1])
        print(f"ROC-AUC GT {roc_gt:.3f}")
        print(f"ROC-AUC CF {roc_cf:.3f}")
        res["roc_gt"] = roc_gt
        res["roc_cf"] = roc_cf
    else:
        try:
            roc_gt = roc_auc_score(targets_true, confs_true, multi_class="ovr")
            roc_cf = roc_auc_score(targets, confs, multi_class="ovr")
            print(f"ROC-AUC GT {roc_gt:.3f}")
            print(f"ROC-AUC  CF {roc_cf:.3f}")
            res["roc_gt"] = roc_gt
            res["roc_cf"] = roc_cf
        except ValueError:
            pass

In [None]:
run_name = (
    Path(model_path).parent.parent.stem
    + Path(model_path).parent.stem
    + datetime.now().strftime("%d%m/%H:%M")
)
print(run_name)

model_paths = {
    "scanner": "../../outputs2/run_vrn2dmeo/best.ckpt",
}

for column in ["scanner"]:
    classification_model = ClassificationModule.load_from_checkpoint(
        model_paths[column]
    ).model.eval()
    classification_model.cuda()
    mae = []
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            if column == "cview":
                valid_indices = torch.where(batch["scanner"] == 0)[0]
            else:
                valid_indices = torch.where(batch["cview"] >= 0)[0]
            batch = preprocess_batch(args, batch, expand_pa=False)
            out_d = get_nochange_counterfactuals(
                model=vae,
                batch=batch,
                args=args,
                col=column,
                plot=False,
                t=0.8,
                u_t=1,
            )

            mae.append(torch.abs(out_d["DE"] - out_d["GT"]).mean())

            if i % 50 == 0 and i > 0:
                print(i)

    print(f"Variable CF {column}")
    print(torch.stack(mae).mean())

In [None]:
run_name = (
    Path(model_path).parent.parent.stem
    + Path(model_path).parent.stem
    + datetime.now().strftime("%d%m/%H:%M")
)
print(run_name)

model_paths = {
    "scanner": "../../outputs2/run_vrn2dmeo/best.ckpt",
}

for column in ["scanner"]:
    classification_model = ClassificationModule.load_from_checkpoint(
        model_paths[column]
    ).model.eval()
    classification_model.cuda()
    mae = []
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            if column == "cview":
                valid_indices = torch.where(batch["scanner"] == 0)[0]
            else:
                valid_indices = torch.where(batch["cview"] >= 0)[0]
            batch = preprocess_batch(args, batch, expand_pa=False)
            gt = batch["x"].clone().cpu()
            for _ in range(1):
                out_d = get_reverse_counterfactuals(
                    model=vae,
                    batch=batch,
                    args=args,
                    col=column,
                    t=0.8,
                    u_t=1,
                )
                batch["x"] = out_d["DE"].cuda()

            mae.append(torch.abs(out_d["DE"] - gt).mean())

            if i % 50 == 0 and i > 0:
                print(i)

    print(f"Variable CF {column}")
    print(torch.stack(mae).mean())