In [None]:
import os
import re
import glob
import json
import pandas as pd
import numpy as np
from collections import defaultdict

import torch
import torch.nn as nn
import tqdm.notebook as tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    LlamaForCausalLM,
    Qwen2ForCausalLM,
    Phi3ForCausalLM,
)

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    (model_name := "meta-llama/Llama-2-7b-chat-hf"),
    # (model_name := "Qwen/Qwen1.5-7B-Chat"),
    # (model_name := "microsoft/Phi-3-mini-4k-instruct"),
    device_map="cuda",
    torch_dtype="bfloat16"
)
model = model.eval().requires_grad_(False)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
###### Llama-2-7b-chat-hf ######
# TEMPORAL_HEAD_MASKING = "none"
# TEMPORAL_HEAD_MASKING = "a18h3"
# TEMPORAL_HEAD_MASKING = "a15h0"
TEMPORAL_HEAD_MASKING = "a18h3,a15h0"

######## Qwen1.5-7B-Chat #######
# TEMPORAL_HEAD_MASKING = "none"
# TEMPORAL_HEAD_MASKING = "a17h15"

#### Phi-3-mini-4k-instruct ####
# TEMPORAL_HEAD_MASKING = "none"
# TEMPORAL_HEAD_MASKING = "a10h13"

if TEMPORAL_HEAD_MASKING != "none":
    for masking in TEMPORAL_HEAD_MASKING.split(","):
        layer, head = map(int, re.match(r"a(\d+)h(\d+)", masking).groups())
        if isinstance(
            model,
            (LlamaForCausalLM, Qwen2ForCausalLM, Phi3ForCausalLM),
        ):
            print(layer, head)
            o_proj = model.model.layers[layer].self_attn.o_proj.weight.data
            o_proj = o_proj.unflatten(-1, (model.config.num_attention_heads, -1))
            o_proj[:, head, :] = 0

## Compute Temporal

In [None]:
for filename in glob.glob("./data/Temporal/*.json"):
    with open(filename) as fp:
        dataset = json.load(fp)

    candidates = {x["subject"]: x["objects"] for x in dataset["all_possible_objects"]}
    candidates_len = {
        k: torch.as_tensor(list(map(len, v))) for k, v in candidates.items()
    }
    dataset["predictions"] = {}

    for key in ("prompt_templates", "prompt_templates_zs"):
        for i, prompt in enumerate(dataset[key]):
            dataset["predictions"][f"{key}:{i}"] = []
            for sample in tqdm.tqdm(dataset["samples"], desc=f"{key}:{i}"):
                if key.endswith("_zs"):
                    texts = [
                        tokenizer.apply_chat_template(
                            [
                                {
                                    "role": "user",
                                    "content": prompt.format(
                                        time=sample["time"], subject=sample["subject"]
                                    ),
                                },
                                {"role": "assistant", "content": "The answer is: " + x},
                            ],
                            tokenize=False,
                        )
                        for x in candidates[sample["subject"]]
                    ]
                    texts.append(
                        tokenizer.apply_chat_template(
                            [
                                {
                                    "role": "user",
                                    "content": prompt.format(
                                        time=sample["time"], subject=sample["subject"]
                                    ),
                                },
                            ],
                            tokenize=False,
                            add_generation_prompt=True,
                        )
                    )
                    # print(texts[0])
                    # print(texts[-1])
                else:
                    texts = [
                        prompt.format(time=sample["time"], subject=sample["subject"])
                        + " "
                        + x
                        for x in candidates[sample["subject"]]
                    ]
                    texts.append(
                        prompt.format(time=sample["time"], subject=sample["subject"])
                    )

                encodings = tokenizer(texts, padding=True, return_tensors="pt")
                encodings = encodings.to("cuda")
                # print(encodings.input_ids.shape)
                logits = model(**encodings).logits

                loss = nn.functional.cross_entropy(
                    logits[:, :-1].permute(0, 2, 1).float(),
                    encodings.input_ids[:, 1:],
                    reduction="none",
                )
                logprobs = -(loss * encodings.attention_mask[:, 1:]).sum(-1)
                logprobs = logprobs[:-1] - logprobs[-1]
                logprobs_norm = logprobs.cpu() / candidates_len[sample["subject"]]

                probs = logprobs.softmax(0).tolist()
                prob_norms = logprobs_norm.softmax(0).tolist()

                dataset["predictions"][f"{key}:{i}"].append(
                    {
                        **sample,
                        # "logprobs": dict(zip(candidates, logprobs.tolist())),
                        # "logprobs_norm": dict(zip(candidates, logprobs_norm.tolist())),
                        "probs": dict(zip(candidates[sample["subject"]], probs)),
                        # "probs_norm": dict(
                        #     zip(candidates[sample["subject"]], prob_norms)
                        # ),
                    }
                )
                del logits, loss, logprobs, logprobs_norm, probs, prob_norms
                torch.cuda.empty_cache()

    filename = filename.replace(
        "./data/",
        f"./logprob/outputs/{os.path.basename(model_name)}/{TEMPORAL_HEAD_MASKING}/",
    )
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, "w") as fp:
        json.dump(dataset, fp, indent=4)

## Compute Invariant

In [None]:
for filename in glob.glob("./data/Invariant/*.json"):
    with open(filename) as fp:
        dataset = json.load(fp)

    candidates = sorted(set([x["object"] for x in dataset["samples"]]))
    candidates_len = torch.as_tensor(list(map(len, candidates)))
    dataset["predictions"] = {}

    for key in ("prompt_templates", "prompt_templates_zs"):
        for i, prompt in enumerate(dataset[key]):
            dataset["predictions"][f"{key}:{i}"] = []
            for sample in tqdm.tqdm(dataset["samples"], desc=f"{key}:{i}"):
                if key.endswith("_zs"):
                    texts = [
                        tokenizer.apply_chat_template(
                            [
                                {
                                    "role": "user",
                                    "content": prompt.format(subject=sample["subject"]),
                                },
                                {"role": "assistant", "content": "The answer is: " + x},
                            ],
                            tokenize=False,
                        )
                        for x in candidates
                    ]
                    texts.append(
                        tokenizer.apply_chat_template(
                            [
                                {
                                    "role": "user",
                                    "content": prompt.format(subject=sample["subject"]),
                                },
                            ],
                            tokenize=False,
                            add_generation_prompt=True,
                        )
                    )
                    # print(texts[0])
                    # print(texts[-1])
                else:
                    texts = [
                        prompt.format(subject=sample["subject"]) + " " + x
                        for x in candidates
                    ]
                    texts.append(prompt.format(subject=sample["subject"]))

                encodings = tokenizer(texts, padding=True, return_tensors="pt")
                encodings = encodings.to("cuda")
                logits = model(**encodings).logits

                loss = nn.functional.cross_entropy(
                    logits[:, :-1].permute(0, 2, 1).float(),
                    encodings.input_ids[:, 1:],
                    reduction="none",
                )
                logprobs = -(loss * encodings.attention_mask[:, 1:]).sum(-1)
                logprobs = logprobs[:-1] - logprobs[-1]
                logprobs_norm = logprobs.cpu() / candidates_len

                probs = logprobs.softmax(0).tolist()
                prob_norms = logprobs_norm.softmax(0).tolist()

                dataset["predictions"][f"{key}:{i}"].append(
                    {
                        **sample,
                        # "logprobs": dict(zip(candidates, logprobs.tolist())),
                        # "logprobs_norm": dict(zip(candidates, logprobs_norm.tolist())),
                        "probs": dict(zip(candidates, probs)),
                        # "probs_norm": dict(zip(candidates, prob_norms)),
                    }
                )
                del logits, loss, logprobs, logprobs_norm, probs, prob_norms
                torch.cuda.empty_cache()

    filename = filename.replace(
        "./data/",
        f"./logprob/outputs/{os.path.basename(model_name)}/{TEMPORAL_HEAD_MASKING}/",
    )
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, "w") as fp:
        json.dump(dataset, fp, indent=4)

## Compute Table

In [None]:
results = []
for filename in glob.glob("./logprob/outputs/**/*.json", recursive=True):
    _, model, head, temporal, task = filename.split("/")
    with open(filename) as fp:
        preds = json.load(fp)
    preds = preds["predictions"]["prompt_templates_zs:0"]
    score = np.mean(
        [
            x["probs"]["Philip Condit" if x["object"] == "Phil Condit" else x["object"]]
            for x in preds
        ]
    )
    results.append(
        {
            "model": model,
            "head": head,
            "temporal": temporal,
            "task": task,
            "score": score,
        }
    )

In [None]:
pd.DataFrame(results)

In [None]:
pd.DataFrame(results).to_excel("./logprob/results.xlsx", index=False)