In [None]:
import sys
sys.path.append("..")

In [None]:
from pathlib import Path

MODEL_NAME = "gpt2-xl"
EXPERIMENT_NAMES = [
#     "biaffine_no_frills",
#     "linear_no_frills",
#     "linear_with_loss_terms",
    "linear_last_entity",
    "linear_last_entity_with_loss_terms",
]
EDITOR_TYPES = ["linear", "biaffine"]
SPLITS = ["train", "test"]
BASE_URL = "https://unitname.csail.mit.edu/context-mediation"
RESULTS_ROOT = Path("../results").resolve()
OUTPUT_ROOT = Path("/raid/lingo/dez/web/unitname/context-mediation")

In [None]:
import json

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


def load_results_by_layer(experiment_name, editor_type, split):
    results_dir = RESULTS_ROOT / experiment_name / editor_type
    layers = sorted([int(path.name) for path in results_dir.glob("*")])

    results_by_layer = {}
    for layer in layers:
        file = results_dir / str(layer) / f"{split}-eval.json"
        with file.open("r") as handle:
            data = json.load(handle)
        results_by_layer[layer] = data["results"]
    return results_by_layer


def load_args(experiment_name):
    args_file = RESULTS_ROOT / experiment_name / "args.json"
    with args_file.open("r") as handle:
        return json.load(handle)


def heatmap(values_by_layer, title=None, save_to=None):
    plt.figure(figsize=(15, 2))
    if title is not None:
        plt.title(title)
    data = np.array([[values_by_layer[layer] for layer in sorted(values_by_layer)]])
    sns.heatmap(data, annot=True, annot_kws=dict(fontsize="xx-small"))
    plt.yticks([]);
    plt.xlabel(f"{MODEL_NAME} layer")
    plt.tight_layout()
    if save_to is not None:
        plt.savefig(save_to)


def plot_accuracy_heatmap(results_by_layer, **kwargs):
    accuracy_by_layer = {}
    for layer, results in results_by_layer.items():
        correct = 0
        for result in results:
            if result["after_target_mediated_score"] > result["after_target_unmediated_score"]:
                correct += 1
        accuracy_by_layer[layer] = correct / len(results)
    heatmap(accuracy_by_layer, **kwargs)


def plot_total_effect_heatmap(results_by_layer, **kwargs):
    average_delta_by_layer = {
        layer: sum(
            result["after_target_mediated_score"] - result["before_target_mediated_score"]
            for result in results
        ) / len(results)
        for layer, results in results_by_layer.items()
        if results
    }
    heatmap(average_delta_by_layer, **kwargs)
    

REPORT_CSS = """\
img {
    margin-top: 20px;
    display: block;
    margin-left: auto;
    margin-right: auto;
}

h1, h2 {
    text-align: center;
}

h2 {
    font-size: 28px;
    font-weight: normal;
    text-decoration: underline;
    margin-top: 25px;
}

h3 {
    font-size: 20px;
}

table {
  font-size: 16px;
  border-collapse: collapse;
  margin-left: auto;
  margin-right: auto;
}

th, td {
  text-align: center;
  border-top: 1px solid #ccc;
  padding: 0 4px;
}

/* Tranpose the table... */
tr { display: block; float: left; }
th, td { display: block; }
"""
    

def generate_experiment_report(experiment_name,
                               editor_type,
                               splits=SPLITS,
                               base_url=BASE_URL,
                               output_dir=OUTPUT_ROOT):
    args = load_args(experiment_name)
    for split in splits:
        results_by_layer = load_results_by_layer(experiment_name, editor_type, split)
        if len(results_by_layer) == 0:
            print(f"no results for {experiment_name}/{editor_type}/{split}")
            continue

        # Create visualizations.
        subdir = Path(experiment_name, editor_type)
        Path(output_dir, subdir).mkdir(exist_ok=True, parents=True)
        
        accuracy_heatmap_file_rel = subdir / f"{split}_accuracy.png"
        accuracy_heatmap_file = output_dir / accuracy_heatmap_file_rel
        plot_accuracy_heatmap(results_by_layer,
                              title="Accuracy of P(Mediated) > P(Unmediated)",
                              save_to=accuracy_heatmap_file)
        
        effect_heatmap_file_rel = subdir / f"{split}_total_effect.png"
        effect_heatmap_file = output_dir / effect_heatmap_file_rel
        plot_total_effect_heatmap(results_by_layer,
                                  title="Average Effect, Log[P_after(mediated)] - Log[P_before(mediated)]",
                                  save_to=effect_heatmap_file)

        # Make HTML page with all the experiment info.
        html = [
            "<!doctype html>",
            "<html>",
            "<head>",
            "<style>",
            REPORT_CSS,
            "</style>",
            "</head>",
            "<body>",
            f"<h1>Results: {experiment_name} ({split} set)</h1>",
        ]

        html += [
            "<div>",
            "<h2>Config</h2>",
            "<table>",
                "<tr>",
                    "<th>arg</th>"
                    "<th>value</th>",
                "</tr>",
        ]
        for key, value in args.items():
            html += [
                "<tr>",
                    f"<td>{key}</td>",
                    f"<td>{value}</td>",
                "</tr>",
            ]
        html += [
            "</table>",
            "</div>"
        ]

        html += [
            "<div class='images'>"
            "<h2>Plots</h2>",
            f'<img alt="effect" src="{base_url}/{effect_heatmap_file_rel}">',
            f'<img alt="accuracy" src="{base_url}/{accuracy_heatmap_file_rel}">',
            "</div>"
        ]
        
        html += [
            "<div>",
            "<h2>Samples</h2>",
        ]
        example_layer = next(iter(results_by_layer))
        results = [
            [results_by_layer[layer][index] for layer in results_by_layer]
            for index in range(len(results_by_layer[example_layer]))
        ]
        if split == "train":
            results = results[:200]
        for sample_num, rs in enumerate(results):
            kv = lambda key: f"<b>{key}: </b>{rs[0]['sample'][key]}"
            html += [
                "<div>",
                f"<h3>Sample {sample_num}</h2>",
                "<div>Inputs:</div>",
                "<ul>",
                f"<li>{kv('entity')}</li>",
                f"<li>{kv('context')}</li>",
                f"<li>{kv('attribute')}</li>",
                f"<li>{kv('prompt')}</li>",
                f"<li>{kv('target_mediated')}</li>",
                f"<li>{kv('target_unmediated')}</li>",
                "</ul>",
            ]
            html += [
                "<div>Model generations:</div>",
                "<ul>",
            ]
            html += [
                "<li>",
                "<b>original: </b>",
                f"{rs[0]['before_generations']}"
                "</li>",
            ]
            for layer, result in zip(sorted(map(int, results_by_layer.keys())), rs):
                html += [
                    f"<li>after edit <b>layer {layer}</b>: {result['after_generations']}</li>"
                ]
            html += [
                "</ul>",
                "<hr>",
            ]

        html += [
            "</div>",
            "</body>",
            "</html>",
        ]

        html_file = output_dir / subdir / f"{split}_results.html"
        html_file.parent.mkdir(exist_ok=True, parents=True)
        print(f"saving report to {html_file}")
        with html_file.open("w") as handle:
            handle.write("\n".join(html))

In [None]:
for experiment_name in EXPERIMENT_NAMES:
    for editor_type in EDITOR_TYPES:
        generate_experiment_report(
            experiment_name,
            editor_type,
            splits=SPLITS,
            base_url=BASE_URL,
            output_dir=OUTPUT_ROOT)