In [None]:
html_template = """
<HTML>
<head> <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/bulma/0.9.2/css/bulma.min.css">
</head>
<body>{0}</body>
</HTML>"""

def build_label_col(info, max_len, color, model):
    cells = []
    for i in range(max_len):
        if i+1 in info["top_ranked_indices"]:
            cells.append("<td bgcolor={0}>{1}</td>".format(color, model))
        else:
            cells.append("<td></td>")

    return cells

def build_text_col(text, max_len, highlighted_cols, color):
    cells = []
    for i in range(max_len):
        if i < len(text):
            this_cell_text = text[i]
        else:
            this_cell_text = ""
        if i in highlighted_cols:
            cells.append("<td bgcolor={0}>{1}</td>".format(color, this_cell_text))
        else:
            cells.append("<td>{0}</td>".format(this_cell_text))

    return cells

def build_table(rank_info, bm25_info, review_sentences, rebuttal_sentences):
    total_lines = max([len(review_sentences), len(rebuttal_sentences)])
    columns = "review_text true_label bm25_label rank_label rebuttal_sentence".split()
    header = "<tr>" + " ".join(["<td>{0}</td>".format(col) for col in columns]) + "</td>"
    table_lines = [header]
    rank_col = build_label_col(rank_info, total_lines, '#9FE2BF', "rank")
    bm25_col = build_label_col(bm25_info, total_lines, '#40E0D0', "bm25")
    assert rank_info["actual_labels"] == bm25_info["actual_labels"]
    assert rank_info["rebuttal_idx"] == bm25_info["rebuttal_idx"]
    review_col = build_text_col(review_sentences, total_lines, rank_info["actual_labels"], '#6495ED')
    rebuttal_col = build_text_col(rebuttal_sentences, total_lines, [rank_info["rebuttal_idx"]], '#CCCCFF')
                           
    rows = [" ".join(["<tr>"] + [b,c,a,d] + ["</tr>"] ) for a,b,c,d in zip(review_col, rank_col, bm25_col, rebuttal_col)]
    
    
    preceding = rank_info["review_id"] + " " + str(rank_info["rebuttal_idx"]) + "<br/><br/>"
    
    table_text = preceding + '<table border="1px grey">' + "\n".join(rows) + "</table>"
    return table_text

In [None]:
import collections
import glob
import json

relevant_info = collections.defaultdict(dict)

with open("ir_errors/mrr_errors_all.csv", 'r') as f:
    for line in f:
        obj = json.loads(line)
        key = (obj["review_id"], obj["rebuttal_idx"])
        assert obj["model"] not in relevant_info[key]
        relevant_info[key][obj["model"]] = obj
        
text_map = {}
for dev_file in glob.glob(
    "/Users/nnayak/Downloads/0517_split_2/dev/*"):
    with open(dev_file, 'r') as f:
        obj = json.load(f)
        review_text = [x["sentence"] for x in obj["review"]]
        rebuttal_text = [x["sentence"] for x in obj["rebuttal"]]
        text_map[obj["metadata"]["review"]] = (review_text, rebuttal_text)

table_texts = collections.defaultdict(dict)
for (review_id, rebuttal_index), objs in relevant_info.items():
    review_text, rebuttal_text = text_map[review_id]
    table_texts[review_id][rebuttal_index] = build_table(objs["rank"], objs["bm25"], review_text, rebuttal_text)

In [None]:
for review_id, tables in table_texts.items():
    ordered_tables = "<br/> <br/>".join([tables[i] for i in sorted(tables.keys())])
    with open("ir_output/"+review_id + "_tables.html", 'w') as f:
        f.write(html_template.format(ordered_tables))
            

In [None]:
def filter_none_mean(l):
    ll = [i for i in l if i is not None]
    return sum(ll)/len(ll)

rank_results = []
bm25_results = []
with open("ir_errors/mrr_results_all.csv", 'r') as f:
    for line in f:
        obj = json.loads(line)
        rank_results.append(obj["rank_mrr"])
        bm25_results.append(obj["bm25_mrr"])
        
        
print("rank", filter_none_mean(rank_results))
print("bm25", filter_none_mean(bm25_results))
        