In [None]:
%env CUDA_VISIBLE_DEVICES=""

import os
import socket
import pickle
from typing import cast

import captum.attr
import pandas as pd
import savethat
import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm.auto import tqdm
 
from lrp_relations import sanity_checks, utils, train_clevr 
from lrp_relations import data, lrp, gt_eval, figures
from relation_network import model as rel_model
import savethat.log 

savethat.log.setup_logger()

print(f"Running on {socket.gethostname()}")

In [None]:
storage = utils.get_storage()

In [None]:

key = "SanityChecksForRelationNetworks_2022-06-14T"
runs = pd.DataFrame(storage.find_runs(key))   # type: ignore

display(runs)

In [None]:
run = runs.iloc[-1]
key = run.run_key
print(key)
print(run)

In [None]:
with open(storage / key / "results.pickle", 'rb') as f:
    result = cast(sanity_checks.SanityChecksForRelationNetworksResults,
         pickle.load(f))

In [None]:
result.saliency_0.image_idx, result.saliency_0.question_index
result.saliency_0_rand_questions.image_idx, result.saliency_0_rand_questions.question_index


In [None]:
args = sanity_checks.SanityChecksForRelationNetworksArgs.from_json(
    storage / key / "args.json"
)

dataset = data.CLEVR_XAI(
    question_type=args.question_type,
    ground_truth=args.ground_truth,
    reverse_question=True,
    use_preprocessed=False,
)

display(dataset.get_image(0, preprocessed=False, resize=True))

In [None]:
dataset.answer_dict()

In [None]:
nrows = 2
ncols = 8
# for saliency_result in [
#     result.saliency_0,
#     result.saliency_1,
#     result.saliency_0_rand_questions
# ]:

answer_dict = dataset.answer_dict()
with figures.latexify():
    figsize = figures.get_figure_size(fraction=1.0, ratio=0.38)
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)

    for i, (ax1, ax2, ax3, ax4) in enumerate(
        zip(
            axes.flatten()[::4],
            axes.flatten()[1::4],
            axes.flatten()[2::4],
            axes.flatten()[3::4],
        )
    ):

        saliency_result = result.saliency_0
        image_idx: int = int(saliency_result.image_idx[i])
        question_index: int = int(saliency_result.question_index[i])

        quest, answer0 = dataset.get_question_and_answer(question_index)
        answer1 = dataset.get_question_and_answer(
            int(result.saliency_1.question_index[i])
        )
        img = dataset.get_image(image_idx, preprocessed=False, resize=True)
        ax1.imshow(img)
        ax2.set_title(utils.insert_newlines(quest, every=40), fontsize=6)

        saliency = lrp.normalize_saliency(saliency_result.saliency[i])
        im = ax2.imshow(saliency.mean(0), cmap="Reds")

        ax3.imshow(
            lrp.normalize_saliency(result.saliency_1.saliency[i]).mean(0),
            cmap="Reds",
        )
        ax4.imshow(
            lrp.normalize_saliency(
                result.saliency_0_rand_questions.saliency[i]
            ).mean(0),
            cmap="Reds",
        )

        rand_q_index = int(result.saliency_0_rand_questions.question_index[i])
        rand_quest, answer_rand = dataset.get_question_and_answer(rand_q_index)
        ax4.set_title(utils.insert_newlines(rand_quest, every=16), fontsize=6)
        # plt.colorbar(im, ax=ax2)

        ax1.set_xlabel("Input")

        for ax, sal_res in [
            (ax2, saliency_result),
            (ax3, result.saliency_1),
            (ax4, result.saliency_0_rand_questions),
        ]:
            answer = answer_dict[int(sal_res.target[i].item())]
            ax.set_xlabel(f"{answer}", fontsize=8, fontname="monospace")

    for ax in axes.flatten():
        ax.set_xticks([])
        ax.set_yticks([])

    fig.set_dpi(120)
    fig.subplots_adjust(wspace=0.15, hspace=0.75, left=0.10, right=0.90)
    fig_path = storage / key / "saliency" / "saliency.pgf"
    fig_path.parent.mkdir(parents=True, exist_ok=True)
    print(f"scp -r {socket.gethostname()}:{fig_path.parent} ./figures")
    figures.savefig_pgf(fig, fig_path, pad_inches=0.15)
    plt.show()


In [None]:
df = result.statistics(
    lambda x: lrp.normalize_saliency(
        x, clip_percentile_min=0.5, clip_percentile_max=99.5
    )
)
display(df)

print(df.abs_mean.iloc[-2:])

In [None]:
assert args.model is not None

with open(storage / args.model / "results.pickle", "rb") as f:
    model_ckpts = cast(train_clevr.TrainedModel, pickle.load(f))

model_args = train_clevr.TrainArgs.from_json(storage / args.model / "args.json")


In [None]:
acc = model_ckpts.get_checkpoint(args.checkpoint).accuracy
print(f"Model accuracy [%]: {acc:.2%}")