In [None]:
%env CUDA_VISIBLE_DEVICES=""

import os
import socket
import pickle
from typing import cast

import captum.attr
import pandas as pd
import savethat
import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm.auto import tqdm
 
from lrp_relations import sanity_checks, utils, train_clevr 
from lrp_relations import data, lrp, gt_eval, figures
from relation_network import model as rel_model
 
savethat.log.setup_logger()

print(f"Running on {socket.gethostname()}")

In [None]:
storage = utils.get_storage()

In [None]:

key = "SanityChecksForRelationNetworks_2022-06-14T"
runs = pd.DataFrame(storage.find_runs(key))

runs

In [None]:
run = runs.iloc[-1]
key = run.run_key
print(key)
run

In [None]:
with open(storage / key / "results.pickle", 'rb') as f:
    result = cast(sanity_checks.SanityChecksForRelationNetworksResults,
         pickle.load(f))

In [None]:
result.saliency_0.image_idx, result.saliency_0.question_index
result.saliency_0_rand_questions.image_idx, result.saliency_0_rand_questions.question_index


In [None]:
args = sanity_checks.SanityChecksForRelationNetworksArgs.from_json(
    storage / key / "args.json"
)

dataset = data.CLEVR_XAI(
    question_type=args.question_type,
    ground_truth=args.ground_truth,
    reverse_question=True,
    use_preprocessed=False,
)

display(dataset.get_image(0, preprocessed=False, resize=True))

In [None]:
dataset.answer_dict()

In [None]:
nrows = 2
ncols = 8
# for saliency_result in [
#     result.saliency_0,
#     result.saliency_1,
#     result.saliency_0_rand_questions
# ]:

answer_dict = dataset.answer_dict()
with figures.latexify():
    figsize = figures.get_figure_size(fraction=1.0, ratio=0.38)
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)

    for i, (ax1, ax2, ax3, ax4) in enumerate(
        zip(
            axes.flatten()[::4],
            axes.flatten()[1::4],
            axes.flatten()[2::4],
            axes.flatten()[3::4],
        )
    ):

        saliency_result = result.saliency_0
        image_idx: int = saliency_result.image_idx[i].item()
        question_index: int = saliency_result.question_index[i].item()

        quest, answer0 = dataset.get_question_and_answer(question_index)
        answer1 = dataset.get_question_and_answer(
            result.saliency_1.question_index[i].item()
        )
        img = dataset.get_image(image_idx, preprocessed=False, resize=True)
        ax1.imshow(img)
        ax2.set_title(utils.insert_newlines(quest, every=40), fontsize=6)

        saliency = lrp.normalize_saliency(saliency_result.saliency[i])
        im = ax2.imshow(saliency.mean(0), cmap="Reds")

        ax3.imshow(
            lrp.normalize_saliency(result.saliency_1.saliency[i]).mean(0),
            cmap="Reds",
        )
        ax4.imshow(
            lrp.normalize_saliency(
                result.saliency_0_rand_questions.saliency[i]
            ).mean(0),
            cmap="Reds",
        )

        rand_q_index = result.saliency_0_rand_questions.question_index[i].item()
        rand_quest, answer_rand = dataset.get_question_and_answer(rand_q_index)
        ax4.set_title(utils.insert_newlines(rand_quest, every=20), fontsize=6)
        # plt.colorbar(im, ax=ax2)

        ax1.set_xlabel("Input")

        for ax, sal_res in [
            (ax2, saliency_result),
            (ax3, result.saliency_1),
            (ax4, result.saliency_0_rand_questions),
        ]:
            answer = answer_dict[sal_res.target[i].item()]
            ax.set_xlabel(f"{answer}", fontsize=8, fontname="monospace")

    for ax in axes.flatten():
        ax.set_xticks([])
        ax.set_yticks([])

    fig.set_dpi(120)
    fig.subplots_adjust(wspace=0.15, hspace=0.75, left=0.10, right=0.90)
    fig_path = storage / key / "saliency" / "saliency.pgf"
    fig_path.parent.mkdir(parents=True, exist_ok=True)
    print(f"scp -r {socket.gethostname()}:{fig_path.parent} ./figures")
    figures.savefig_pgf(fig, fig_path)
    plt.show()


In [None]:
result.statistics(
    lambda x: lrp.normalize_saliency(
        x, clip_percentile_min=0, clip_percentile_max=99.5
    )
)


In [None]:
with open(storage / args.model / "results.pickle", "rb") as f:
    model_ckpts = cast(train_clevr.TrainedModel, pickle.load(f))

model_args = train_clevr.TrainArgs.from_json(storage / args.model / "args.json")


In [None]:
if args.checkpoint is None:
    acc = model_ckpts.get_best_checkpoint().accuracy
print(f"Model accuracy [%]: {acc:.2%}")

In [None]:
dataloader


pbar = tqdm(dataloader)
indices = []
for i, (image, question, q_len, answer, idx) in enumerate(pbar):
    image, question, q_len, answer = (
        image.to(device),
        question.to(device),
        torch.tensor(q_len),
        answer.to(device),
    )
    saliency = lrp_model.get_lrp_saliency(image, question, q_len, target=answer, normalize=False)

    indices.extend(idx)
    break

In [None]:
def insert_newlines(string, every=64):
    lines = []
    for i in range(0, len(string), every):
        lines.append(string[i : i + every])
    return "\n".join(lines)


r = 6
c = 6
fig, axes = plt.subplots(r, c, figsize=(2 * c, 2 * r))

ax_flat = axes.flatten()
for i, (ax1, ax2, ax3) in enumerate(
    zip(ax_flat[::3], ax_flat[1::3], ax_flat[2::3])
):
    dset_idx = indices[i]
    question, answer = dataset.get_question_and_answer(dset_idx)
    gt = dataset.get_ground_truth(dset_idx)

    ax1.imshow(image[i].permute(1, 2, 0).cpu().detach().numpy() / 2 + 0.5)
    ax1.set_title(
        insert_newlines(question, every=40) + "\n" + f"Answer: {answer}",
        fontsize=6,
    )
    sal = saliency[i].cpu().detach().abs().numpy().sum(0)
    q = np.percentile(sal, [99])
    sal[sal > q] = q
    im = ax2.imshow(sal, alpha=1.0)
    fig.colorbar(im, ax=ax2)

    gt_mask = gt.cpu().detach().numpy()
    ax3.imshow(gt_mask, alpha=1.0)

    sal_l2 = gt_eval.l2_norm_sq(saliency[i], dim=0).detach().cpu()
    sal_max = gt_eval.max_norm(saliency[i], dim=0).cpu().detach().numpy()[0]

    rel_mass = gt_eval.relevance_mass(sal_l2[0], gt, reduce=(0, 1))
    ax2.set_title(
        f"rel. mass: {rel_mass.item():.3f}\n"
        f"rel. rank acc.: {gt_eval.get_ration_in_mask(sal_max, gt_mask):0.3f}",
        fontsize=6,
    )


for ax in ax_flat:
    ax.set_xticks([])
    ax.set_yticks([])

fig.suptitle(f"{key}@{ckpt.name}\nGT: {question_type}@{ground_truth}")
fig.subplots_adjust(wspace=0.5, hspace=0.75)
fig.set_dpi(90)


In [None]:
image.min()

In [None]:
import json

with open(storage / key / "checkpoints" / "log.jsonl", "r") as f:
    log = pd.DataFrame([json.loads(line) for line in f.readlines()])
    print(log.columns)
    log.drop("count", axis=1, inplace=True)
    log.set_index("epoch", inplace=True)


In [None]:
log.plot()