In [None]:
%env CUDA_VISIBLE_DEVICES=""

import os
import socket

import captum.attr
import pandas as pd
import savethat
import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm.auto import tqdm
 
from lrp_relations import sanity_checks, utils, train_clevr, data, lrp, gt_eval
from relation_network import model as rel_model
 
savethat.log.setup_logger()

print(f"Running on {socket.gethostname()}")

In [None]:
storage = utils.get_storage()

In [None]:
runs = pd.DataFrame(storage.find_runs("Train"))
runs

runs[runs.warmup_and_decay == True]

In [None]:
# key = "Train_2022-05-18T19-17-45"
key = "Train_2022-05-19T09-19-31"  # trained on > 240 epochs

print(list((storage / key / "checkpoints").glob("*.model")))
ckpt = list((storage / key / "checkpoints").glob("*.model"))[-1]

args = train_clevr.TrainArgs.from_json(storage / key / 'args.json')
device = 'cpu'

ckpt

In [None]:
model = rel_model.RelationNetworks(data.get_n_words())
model.load_state_dict(torch.load(ckpt, map_location=device))

In [None]:
lrp_model = lrp.LRPViewOfRelationNetwork(model)

In [None]:
# dataloader = data.get_clevr_dataloader(utils.clevr_path(), split='val')

# question_type, ground_truth = "complex", "unique"
question_type, ground_truth = "simple", "single_object"
dataloader = data.get_clevr_xai_loader(
    question_type=question_type,
    ground_truth=ground_truth,
    batch_size=30,
    n_worker=2,
)

dataset = dataloader.dataset

In [None]:
dataloader


pbar = tqdm(dataloader)
indices = []
for i, (image, question, q_len, answer, idx) in enumerate(pbar):
    image, question, q_len, answer = (
        image.to(device),
        question.to(device),
        torch.tensor(q_len),
        answer.to(device),
    )
    saliency = lrp_model.get_lrp_saliency(image, question, q_len, target=answer, normalize=False)

    indices.extend(idx)
    break

In [None]:
def insert_newlines(string, every=64):
    lines = []
    for i in range(0, len(string), every):
        lines.append(string[i : i + every])
    return "\n".join(lines)


r = 6
c = 6
fig, axes = plt.subplots(r, c, figsize=(2 * c, 2 * r))

ax_flat = axes.flatten()
for i, (ax1, ax2, ax3) in enumerate(
    zip(ax_flat[::3], ax_flat[1::3], ax_flat[2::3])
):
    dset_idx = indices[i]
    question, answer = dataset.get_question_and_answer(dset_idx)
    gt = dataset.get_ground_truth(dset_idx)

    ax1.imshow(image[i].permute(1, 2, 0).cpu().detach().numpy() / 2 + 0.5)
    ax1.set_title(
        insert_newlines(question, every=40) + "\n" + f"Answer: {answer}",
        fontsize=6,
    )
    sal = saliency[i].cpu().detach().abs().numpy().sum(0)
    q = np.percentile(sal, [99])
    sal[sal > q] = q
    im = ax2.imshow(sal, alpha=1.0)
    fig.colorbar(im, ax=ax2)

    gt_mask = gt.cpu().detach().numpy()
    ax3.imshow(gt_mask, alpha=1.0)

    sal_l2 = gt_eval.l2_norm_sq(saliency[i], dim=0).detach().cpu()
    sal_max = gt_eval.max_norm(saliency[i], dim=0).cpu().detach().numpy()[0]

    rel_mass = gt_eval.relevance_mass(sal_l2[0], gt, reduce=(0, 1))
    ax2.set_title(
        f"rel. mass: {rel_mass.item():.3f}\n"
        f"rel. rank acc.: {gt_eval.get_ration_in_mask(sal_max, gt_mask):0.3f}",
        fontsize=6,
    )


for ax in ax_flat:
    ax.set_xticks([])
    ax.set_yticks([])

fig.suptitle(f"{key}@{ckpt.name}\nGT: {question_type}@{ground_truth}")
fig.subplots_adjust(wspace=0.5, hspace=0.75)
fig.set_dpi(90)


In [None]:
image.min()

In [None]:
import json

with open(storage / key / "checkpoints" / "log.jsonl", "r") as f:
    log = pd.DataFrame([json.loads(line) for line in f.readlines()])
    print(log.columns)
    log.drop("count", axis=1, inplace=True)
    log.set_index("epoch", inplace=True)


In [None]:
log.plot()