In [None]:
import sys
sys.path.append("..")

In [None]:
import transformers

config = "EleutherAI/gpt-j-6B"
device = "cuda:1"
model = transformers.AutoModelForCausalLM.from_pretrained(
    config,
    low_cpu_mem_usage=True,
    revision="float16").to(device)
tokenizer = transformers.AutoTokenizer.from_pretrained(config)
tokenizer.pad_token = tokenizer.eos_token

# Model Confidence?

In [None]:
import torch

# subject = "surgeon"
# prompt = "farmer: barn\ncar mechanic: garage\nchef: kitchen\nteacher: school\n{}:"

# subject = "Saudi Arabia"
# delim = " shares its northern border with"
# prompt = f"USA{delim} Canada\Mexico{delim} USA\nSudan{delim} Egypt\n" + "{}"+ delim

# subject = "The actor Neil Patrick Harris"
# prompt = "{} is married to a man named"


# subject = "Gengar"
# prompt = "Pikachu: electric\nSquirtle: water\nCharizard: fire\nShroomish: grass\n{}:"

subject = "Bagon"
prompt = "Pikachu: Raichu\nCharmander: Charmeleon\nShroomish: Breloom\n{}:"

inputs = tokenizer(prompt.format(subject), return_tensors="pt").to(device)
with torch.inference_mode():
    outputs = model(**inputs)
topk = torch.softmax(outputs.logits[:, -1].float(), dim=-1).topk(dim=-1, k=5)
words = [tokenizer.decode(token_id) for token_id in topk.indices.squeeze()]
probs = topk.values.squeeze().tolist()

print(prompt)
for word, prob in zip(words, probs):
    print(f"{word} ({prob:.2f})")

In [None]:
from src import corner, estimate
import importlib
importlib.reload(corner)
importlib.reload(estimate)

operator, _ = estimate.relation_operator_from_sample(
    model,
    tokenizer,
    subject,
    prompt,
    device=device,
)

In [None]:
def logit_lens(h, k=10):
    h = h.view(1, model.config.hidden_size)
    dist = torch.softmax(model.lm_head(model.transformer.ln_f(h)), dim=-1)
    topk = dist.topk(dim=-1, k=k)
    words = [
        tokenizer.decode(token_id)
        for token_id in topk.indices.squeeze()
    ]
    probs = topk.values.squeeze().tolist()
    return tuple(zip(words, probs))

logit_lens(operator.bias)

In [None]:
operator("blueberries", device=device)

In [None]:
corner_estimator = corner.CornerEstimator(model, tokenizer)
c = corner_estimator.estimate_simple_corner([
    "black",
    "white",
    "brown",
    "green",
    "blue",
    "orange",
    "yellow",
    "purple",
    "red",
    "pink",
    "grey",
])
corner_operator = operator.overwrite(bias=c)

In [None]:
# corner_operator = operator.overwrite(bias=c, weight=torch.eye(4096).to(device))
corner_operator("sweet potatoes", device=device)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.heatmap(corner_operator.weight.data.cpu()[:100, :100], vmin=-.02, vmax=.02)

# Vignette

In [None]:
import json
from itertools import chain
from pathlib import Path
from typing import NamedTuple

from src import corner, estimate

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

import importlib
importlib.reload(corner)
importlib.reload(estimate)

RESULTS_DIR = Path("vignette_results/")


class Setting(NamedTuple):
    
    relation: str
    train: tuple[tuple[str, str], ...]
    test: tuple[tuple[str, str], ...]
    rng: tuple[str, ...] | None = None
    k: int = 2
    prompt: str | None = None

        
corner_estimator = corner.CornerEstimator(model, tokenizer)


def pprint_predictions(subject, predictions):
    print(f"* {subject}")
    for word, prob in predictions:
        if isinstance(word, list):
            word = word[0]
        if isinstance(prob, list):
            prob = prob[0]
        print(f"{word} ({prob:.2f})")


def test(operator, samples, k=2, pprint=False):
    total_correct = 0
    sample_summaries = []
    for subject, expected in samples:
        predictions = operator(subject, device=device)
        predictions = [
            (w, p)
            for w, p in predictions
            if w.strip()
            and (
                len(w.strip()) <= 2
                or (
                    w.strip().lower() != subject.strip().lower()
                    and not subject.strip().lower().startswith(w.strip().lower())
                )
            )
        ]
        if pprint:
            pprint_predictions(subject, predictions)
        is_correct = any(expected.lower().startswith(x[0].strip().lower()) for x in predictions[:k])
        
        sample_summaries.append({
            "subject": subject,
            "predictions": predictions,
            "expected": expected,
            "is_correct": is_correct,
        })
        
        total_correct += int(is_correct)

    accuracy = total_correct / len(samples)
    return accuracy, {
        "k": k,
        "accuracy": accuracy,
        "samples": sample_summaries,
    }


def evaluate(settings, layer=12, plot=False, results_dir=None):
    if results_dir is not None and Path(results_dir).name != str(layer):
        results_dir = Path(results_dir) / str(layer)
    if results_dir is not None:
        results_dir.mkdir(exist_ok=True, parents=True)

    summaries = []
    for relation, trains, tests, codomain, k, prompt in settings:
        print(f"---- {relation} ----")

        prompt = prompt if prompt is not None else f" {relation} "
        prompt_template = "\n".join(f"{subj}{prompt}{obj}" for subj, obj in trains[:-1])
        prompt_template += "\n" + "{} " + relation
        train_subject = trains[-1][0]

        print("estimating J...")
        operator, _ = estimate.relation_operator_from_sample(
            model=model,
            tokenizer=tokenizer,
            subject=train_subject,
            relation=prompt_template,
            device=device,
            layer=layer,
        )

        print("estimating corner...")
        if codomain:
            c = corner_estimator.estimate_average_corner_with_gradient_descent(codomain)
        else:
            c = corner_estimator.estimate_average_corner_with_gradient_descent([
                x[-1] for x in chain(trains, tests)
            ])

        print("logit lens on corner:")
        lens = logit_lens(c)
#         for word, prob in lens:
#             print(f"* {word} ({prob:.2f})")

        print("\nJ, bias")
        acc, jb_summary = test(operator, tests, k=k)
        print(f"layer={layer} {k}-acc={acc:.2f}")

        print("\nI, corner")
        acc, ic_summary = test(operator.overwrite(weight=torch.eye(model.config.hidden_size).to(device), bias=c), tests, k=k)
        print(f"layer={layer} {k}-acc={acc:.2f}")

        print("\nJ, corner")
        acc, jc_summary = test(operator.overwrite(bias=c), tests, k=k)
        print(f"layer={layer} {k}-acc={acc:.2f}")

        if plot:
            plt.figure()
            sns.heatmap(operator.weight.data.cpu()[:25, :25], cmap="PiYG")

            plot_file = Path(relation.strip().replace(" ", "_") + ".png")
            if results_dir is not None:
                plot_file = Path(results_dir) / plot_file
            plt.savefig(str(plot_file))

        summary = {
            "relation": relation,
            "c_logit_lens": lens,
            "jb": jb_summary,
            "ic": ic_summary,
            "jc": jc_summary,
        }
        summaries.append(summary)

    if results_dir is not None:
        with Path(results_dir, "summaries.json").open("w") as handle:
            json.dump({
                "summaries": summaries,
            }, handle)

    return summaries

In [None]:
LETTERS = "abcdefghijklmnopqrstuvwxyz".upper()

SPOUSE_PAIRS = (
    ("Beyonce", "Jay-Z"),
    ("George Bush", "Laura Bush"),
    ("Ariana Grande", "Pete Davidson"),
    ("Barack Obama", "Michelle Obama"),
    ("Michelle Obama", "Barack Obama"),
    ("Jay-Z", "Beyonce"),
    ("Beyonce", "Jay-Z"),
    ("John Lennon", "Yoko Ono"),
    ("Yoko Ono", "John Lennon"),
    ("Forrest Gump", "Jenny"),
    ("George Bush", "Laura Bush"),
    ("Laura Bush", "George Bush"),
    ("Marie Curie", "Pierre Curie"),
    ("Pierre Curie", "Marie Curie"),
    ("Mark Antony", "Cleopatra"),
    ("Cleopatra", "Mark Antony"),
    ("Ashton Kutcher", "Mila Kunis"),
    ("Mila Kunis", "Ashton Kutcher"),
)

COUNTRY_PAIRS = (
    ("USA", "Canada"),
    ("Sudan", "Egypt"),
    ("Ukraine", "Belarus"),

    ("Mexico", "USA"),
    ("Saudi Arabia", "Jordan"),
    ("Spain", "France"),
    ("Syria", "Turkey"),
    ("Jordan", "Syria"),
    ("Ecuador", "Colombia"),
    ("France", "Belgium"),
)

PRESIDENT_PAIRS = (
    ("Barack Obama", "Donald Trump"),
    ("George Washington", "John Adams"),
    ("Abraham Lincoln", "Andrew Johnson"),

    ("John Tyler", "James Polk"),
    ("Warren Harding", "Calvin Coolidge"),
    ("Calvin Coolidge", "Herbert Hoover"),
    ("Herbert Hoover", "Franklin Roosevelt"),
    ("James Carter", "Ronald Reagan"),
    ("Harry Truman", "Dwight Eisenhower"),
    ("Teddy Roosevelt", "William Howard Taft"),
)


SETTINGS = (
    Setting(
        relation="is partnered to",
        train=SPOUSE_PAIRS[:3],    
        test=SPOUSE_PAIRS[3:],
    ),

    Setting(
        relation="is the opposite of",
        train=(
            ("dark", "light"),
            ("good", "evil"),
            ("up", "down"),
        ),
        test=(
            ("left", "right"),
            ("right", "left"),
            ("down", "up"),
            ("evil", "good"),
            ("light", "dark"),
            ("open", "closed"),
        ),
        rng=(
            "light", "evil", "down"
        ),
    ),

    Setting(
        relation="preceeded the President",
        train=PRESIDENT_PAIRS[:3],
        test=PRESIDENT_PAIRS[3:],
        prompt=": ",
    ),
    
    Setting(
        relation="succeeded the President",
        train=[(y, x) for x, y in PRESIDENT_PAIRS[:3]],
        test=[(y, x) for x, y in PRESIDENT_PAIRS[3:]],
        prompt=": ",
    ),

    Setting(
        relation="shares its northern border with",
        train=COUNTRY_PAIRS[:3],
        test=COUNTRY_PAIRS[3:],
    ),

    Setting(
        relation="shares its southern border with",
        train=[(y, x) for x, y in COUNTRY_PAIRS[:3]],
        test=[(y, x) for x, y in COUNTRY_PAIRS[3:]],
    ),
    
    Setting(
        relation="has the color of",
        train=(
            ("bananas", "yellow"),
            ("blueberries", "blue"),
            ("kiwis", "green"),
        ),
        test=(
            ("broccoli", "green"),
            ("tangerines", "orange"),
            ("apples", "red"),
            ("sweet potatoes", "orange"),
            ("carrots", "orange"),
            ("milk", "white"),
            ("cauliflower", "white"),
            ("kale", "green"),
            ("chocolate", "brown"),
            ("water", "blue"),
            ("plum", "purple"),
        ),
        rng=(
            "black",
            "white",
            "brown",
            "green",
            "blue",
            "orange",
            "yellow",
            "purple",
            "red",
            "pink",
            "grey",
        ),
        prompt=": ",
    ),

    Setting(
        relation="evolves into",
        train=(
            ("The pokemon Pikachu", "Raichu"),
            ("The pokemon Shroomish", "Breloom"),
            ("The pokemon Charmander", "Charizard"),
            ("The pokemon Munchlax", "Snorlax"),
        ),
        test=(
            ("The pokemon Squirtle", "Blastoise"),
            ("The pokemon Mudkip", "Swampert"),
            ("The pokemon Grimer", "Muk"),
            ("The pokemon Abra", "Alakazam"),
            ("The pokemon Bulbasaur", "Venusaur"),
            ("The pokemon Geodude", "Golem"),
            ("The pokemon Dratini", "Dragonite"),
            ("The pokemon Pichu", "Raichu"),
            ("The pokemon Charmander", "Charmeleon"),
        )
    ),

    Setting(
        relation="is abbreviated as",
        train=(
            ("Connecticut", "CT"),
            ("Oregon", "OR"),
            ("Colorado", "CO"),
        ),
        test=(
            ("New York", "NY"),
            ("Illinois", "IL"),
            ("Massachusetts", "MA"),
            ("Utah", "UT"),
            ("Nevada", "NV"),
            ("Washington", "WA"),
            ("Wisconsin", "WI"),
            ("Maryland", "MD"),
            ("Alabama", "AL")
        ),
        rng=tuple(
            f"{a}{b}"
            for a in LETTERS
            for b in LETTERS
        ),
    ),
)

RESULTS_DIR = Path("./vignette_results")
LAYERS = list(range(28))

summaries_by_layer = {}
for layer in LAYERS:
    summaries_by_layer[layer] = evaluate(SETTINGS, layer=layer, plot=True, results_dir=RESULTS_DIR)

In [None]:
from pathlib import Path

METHOD_TO_PRETTY = {
    "jb": "J, bias",
    "ic": "Corner",
    "jc": "J, corner"
    
}


def generate_html(summaries, layer):
    html = [
        "<html>",
        "<style>",
    """
    th {
        font-weight: bold;
    }

    table {
        text-align: left;
        border-collapse: collapse;
    }

    th {
        border-top: 2px solid black;
        border-bottom: 1px solid black;
    }

    .qualitative th {
        padding-right: 5em;
    }

    .quantitative th {
        padding-right: 2em;
    }

    tr:last-of-type {
        border-bottom: 2px solid black;
    }

    h2 {
        margin-top: 2em;
    }

    h4 {
        font-weight: normal;
        text-decoration: underline;
    }

    """
        "</style>",
        "<body>",
        f"<h1>Results for GPT-J/layer {layer}</h1>",
    ]

    for summary in summaries:
        html += [
            f"<h2>{summary['relation']}</h2>",
            "<table class='quantitative'>",
            "<thead>",
            "<th>Method</th>",
            "<th>Recall@2</th>",
            "</thead>",
            "<tbody>",
            *[
                f"<tr><td>{METHOD_TO_PRETTY[method]}</td><td>{summary[method]['accuracy']:.2f}</td></tr>"
                for method in ("jb", "ic", "jc")
            ],
            "</tbody>",
            "</table>",
        ]

        for method in ("jb", "ic", "jc"):
            html += [
                f"<h4>{METHOD_TO_PRETTY[method]} Outputs</h4>",
                "<table class='qualitative'>",
                "<thead>",
                "<tr>",
                "<th>subject</th>",
                "<th>object</th>",
                *[f"<th>prediction {i}</th>" for i in range(1, 6)],
                "</tr>",
                "</thead>",
                "<tbody>",
            ]
            for sample in summary[method]["samples"]:
                expected = sample['expected']
                html += [
                    "<tr>"
                    f"<td>{sample['subject']}</td>",
                    f"<td>{expected}</td>",
                ]
                for word, prob in sample["predictions"]:
                    is_correct = expected.lower().strip().startswith(word.lower().strip())

                    word_html = f"{word} ({prob:.2f})"
                    if is_correct:
                        word_html = f"<span style='color: blue'>{word_html}</span>"
                    html += [f"<td>{word_html}</td>"]

                for _ in range(5 - len(sample["predictions"])):
                    html += ["<td></td>"]

                html += ["</tr>"]
            html += ["</tbody>", "</table>"]

    html += ["</body>", "</html>"]

    with (RESULTS_DIR / str(layer) / "viz.html").open("w") as handle:
        handle.write("\n".join(html))

for layer, summaries in summaries_by_layer.items():
    generate_html(summaries, layer)