# Generate Template

In [None]:
import jinja2
import json
import jsonlines

from copy import copy

In [None]:
TYPES_FILE = "./data/datasets/docred/types.json"
GT_TRAIN_FILE = "./data/datasets/docred/train_annotated.json"
PERT_TRAIN_FILE = "./data/datasets/docred/train_annotated_perturbed.json"
REMOVED_RELS_FILE = "./data/datasets/docred/train_annotated_removed_rels.json"
SCORES_FILE = "./influence_analysis/scores/rel_loss_perturbed_ckpt15.jsonl"

TEMPLATE_FILE_DIR = "./influence_analysis/templates"
TEMPLATE_FILENAME = "analyse.html"

OUTPUT_FILE = "./influence_analysis/analyse_rel_loss.html"

In [None]:
with open(TYPES_FILE, "r") as fd:
    rel_types = json.load(fd)
    rel_types = rel_types["relations"]

with open(GT_TRAIN_FILE, "r") as fd:
    gt_train_docs = json.load(fd)

with open(PERT_TRAIN_FILE, "r") as fd:
    pert_train_docs = json.load(fd)

with open(REMOVED_RELS_FILE, "r") as fd:
    removed_rels = json.load(fd)

In [None]:
def load_scores(file):
    scores = []
    
    with jsonlines.open(file) as rel_loss_reader:
        for doc_id, data in enumerate(rel_loss_reader):
            for rel in data:
                entry = copy(rel)
                entry['doc_id'] = doc_id
                scores.append(entry)
    
    return scores

In [None]:
def is_perturbed(rel):
    doc_id = rel['doc_id']
    head_id, tail_id = rel['entity_pair']
    doc = removed_rels[doc_id]
    
    for r in doc:
        if r['h'] == head_id and r['t'] == tail_id:
            return True
    
    return False

# get the gt rel (from the perturbed dataset)
def get_gt_rel(rel):
    doc_id = rel["doc_id"]
    head_id, tail_id = rel["entity_pair"]
    doc = gt_train_docs[doc_id]

    for r in doc["labels"]:
        if r["h"] == head_id and r["t"] == tail_id:
            return r["r"]

    return None

In [None]:
def build_text_html(doc, rel):
    sents = doc["sents"]
    
    mentions = []
    head_entity_id, tail_entity_id = rel["entity_pair"]
    
    h_or_t = "h"
    head_entity_type = doc["vertexSet"][head_entity_id][0]["type"]
    for mention in doc["vertexSet"][head_entity_id]:
        start, end = mention["pos"]
        sent_id = mention["sent_id"]
        mentions.append([start, end, h_or_t, sent_id])

    h_or_t = "t"
    tail_entity_type = doc["vertexSet"][tail_entity_id][0]["type"]
    for mention in doc["vertexSet"][tail_entity_id]:
        start, end = mention["pos"]
        sent_id = mention["sent_id"]
        mentions.append([start, end, h_or_t, sent_id])


    head_tag = '<span class="head"><span class="type">%s</span>' % head_entity_type
    tail_tag = '<span class="tail"><span class="type">%s</span>' % tail_entity_type

    for mention in mentions:
        start, end, h_or_t, sent_id = mention
        tokens = sents[sent_id]
        tag = head_tag if h_or_t == "h" else tail_tag
        tokens[start] = tag + tokens[start]
        tokens[end - 1] = tokens[end - 1] + "</span>"

    html = [" ".join(t) for t in sents]
    html = " ".join(html)
    
    return html

In [None]:
def gen_template(results, template_dir=TEMPLATE_FILE_DIR, template_name=TEMPLATE_FILENAME, output_file=OUTPUT_FILE):
    templateLoader = jinja2.FileSystemLoader(template_dir)
    templateEnv = jinja2.Environment(loader=templateLoader)
    template = templateEnv.get_template(template_name)
    output = template.render(results=results)
    
    with open(output_file, "w") as fd:
        fd.write(output)

In [None]:
scores = load_scores(SCORES_FILE)
scores = sorted(scores, key=lambda d: d["loss"], reverse=True)

In [None]:
results = []

for rel in scores[:50]:
    doc = pert_train_docs[rel["doc_id"]]
    
    text_html = build_text_html(doc, rel)
    is_perturbed_rel = is_perturbed(rel)
    gt_rel = get_gt_rel(rel)

    pred_rels = [(rel, "%.4f" % score ) for rel, score in rel["preds"].items()]

    c = ""
    if is_perturbed_rel and len(pred_rels) != 0:
        c = "pert-rel-found"
    elif is_perturbed_rel and len(pred_rels) == 0:
        c = "pert-rel-not-found"
    elif len(pred_rels) != 0 and gt_rel is None:
        c = "pred-rel-but-not-gt"

    pred_rels = "<br>".join({f"{p[0]} (confidence: {p[1]})" for p in pred_rels})

    results.append({
        "doc_id": rel["doc_id"],
        "text": text_html,
        "is_perturbed": is_perturbed_rel,
        "gt_rel": "None" if gt_rel is None else rel_types[gt_rel]["verbose"],
        "pred_rels": "None" if pred_rels == "" else pred_rels,
        "c": c,
        "score": rel["loss"]
    })

results

In [None]:
gen_template(results=results)