In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('..')

In [None]:
from src import models

device = "cuda:0"
mt = models.load_model("gptj", device=device)

In [None]:
print(f"dtype: {mt.model.dtype}, device: {mt.model.device}, memory: {mt.model.get_memory_footprint()}")

In [None]:
from src import data

dataset = data.load_dataset()

In [None]:
for d in dataset:
    print(d)

In [None]:
import baukit
import torch
from src.functional import Order1ApproxOutput
from src.utils.misc import visualize_matrix


@torch.no_grad()
@torch.inference_mode(mode=False)
def order_1_approx(
    *,
    mt: models.ModelAndTokenizer,
    prompt: str,
    h_layer: int,
    h_index: int,
    z_token_indices: list[int],
    z_layer: int | None = None,
    z_index: int | None = None,
    inputs=None,
):
    if z_layer is None:
        z_layer = mt.model.config.n_layer - 1
    if z_index is None:
        z_index = -1
    if inputs is None:
        inputs = mt.tokenizer(prompt, return_tensors="pt").to(mt.model.device)

    # Precompute everything up to the subject, if there is anything before it.
    past_key_values = None
    input_ids = inputs.input_ids
    _h_index = h_index
    if _h_index > 0:
        outputs = mt.model(input_ids=input_ids[:, :_h_index], use_cache=True)
        past_key_values = outputs.past_key_values
        input_ids = input_ids[:, _h_index:]
        _h_index = 0
    use_cache = past_key_values is not None

    # Precompute initial h and z.
    [h_layer_name, z_layer_name] = models.determine_layer_paths(mt, [h_layer, z_layer])
    with baukit.TraceDict(mt.model, (h_layer_name, z_layer_name)) as ret:
        outputs = mt.model(
            input_ids=input_ids,
            use_cache=use_cache,
            past_key_values=past_key_values,
        )
    h = ret[h_layer_name].output[0][0, _h_index]
    z = ret[z_layer_name].output[0][0, z_index]

    # Now compute J and b.
    def compute_z_from_h(h: torch.Tensor) -> torch.Tensor:
        def insert_h(output: tuple, layer: str) -> tuple:
            if layer != h_layer_name:
                return output
            output[0][0, _h_index] = h
            return output

        with baukit.TraceDict(
            mt.model, (h_layer_name, z_layer_name), edit_output=insert_h
        ) as ret:
            mt.model(
                input_ids=input_ids,
                past_key_values=past_key_values,
                use_cache=use_cache,
            )
        z = ret[z_layer_name].output[0][0, -1]
        z = mt.model.transformer.ln_f(z)

        hidden_size = mt.model.config.hidden_size

        # proj = z.new_zeros(hidden_size, hidden_size)
        # for z_token_index in z_token_indices:
        #     y = mt.model.transformer.wte.weight.data[z_token_index, ..., None]
        #     proj += y @ y.t() / y.norm().pow(2)
        Y = []
        for z_token_index in z_token_indices:
            y = mt.model.transformer.wte.weight.data[z_token_index, ..., None]
            Y.append(y.T[0])
        Y = torch.stack(Y, dim=1).to(torch.float32)
        # proj = Y @ (Y.T @ Y).to(torch.float32).pinverse().to(Y.dtype) @ Y.T
        proj = Y @ (Y.T @ Y).to(torch.float32).pinverse() @ Y.T
        print(torch.linalg.matrix_rank(proj))
        proj = proj.to(mt.model.dtype)
        # print(proj)
        # visualize_matrix(proj)
        result = proj @ z[..., None]
        print(Y.norm().item(), proj.norm().item(), result.norm().item())

        # raise AssertionError()
        return result.squeeze()

    weight = torch.autograd.functional.jacobian(compute_z_from_h, h, vectorize=True)
    bias = z[None] - h[None].mm(weight.t())
    approx = Order1ApproxOutput(
        h=h,
        h_layer=h_layer,
        h_index=h_index,
        z=z,
        z_layer=z_layer,
        z_index=z_index,
        weight=weight,
        bias=bias,
        inputs=inputs.to("cpu"),
        logits=outputs.logits.cpu(),
    )
    return approx

In [None]:
# prompt="Eiffle Tower is located in the city of"
# tokenized = mt.tokenizer(prompt, return_tensors="pt").to(mt.model.device)
# print([(t.item(), mt.tokenizer.decode(t)) for t in tokenized.input_ids[0]])

# output = order_1_approx(
#     mt=mt,
#     prompt="Eiffle Tower is located in the city of",    
#     h_layer=15,
#     h_index=3,
#     z_layer=27,
# )

In [None]:
from dataclasses import dataclass

from src import data, functional, operators
from src.utils import tokenizer_utils

from tqdm.auto import tqdm


def get_icl_prompt(samples, sample, prompt_template):
    others = list(set(samples) - {sample})
    prompt = "\n".join(
        prompt_template.format(x.subject) + f" {x.object}."
        for x in others
    )
    prompt += "\n" + prompt_template.format(sample.subject)
    return prompt


class NewJEstimator(operators.LinearRelationEstimator):

    mt: models.ModelAndTokenizer
    h_layer: int = 9
    z_layer: int = 27

    def __call__(self, relation):
        prompt_templates = relation.prompt_templates[:1]
        samples = relation.samples[:3]

        targets = [x for x in relation.range]
        z_token_indices = self.mt.tokenizer(
            targets,
            return_tensors="pt",
            padding=True,
        ).input_ids[:, 0].tolist()
    
        weights = []
        biases = []
        zs = []
        hs = []
        for prompt_template in prompt_templates:
            for sample in tqdm(samples):
                subject = sample.subject
#                 prompt = prompt_template.format(subject)
                prompt = get_icl_prompt(samples, sample, prompt_template)
                print(prompt, "\n", "---")
                _, h_index = tokenizer_utils.find_token_range(prompt, subject, tokenizer=mt.tokenizer)
                h_index -= 1

                # output = functional.order_1_approx(
                #     mt=self.mt,
                #     prompt=prompt,
                #     h_layer=self.h_layer,
                #     h_index=h_index,
                #     z_layer=self.z_layer,
                # )
                output = order_1_approx(
                    mt=self.mt,
                    prompt=prompt,
                    h_layer=self.h_layer,
                    h_index=h_index,
                    z_token_indices=z_token_indices,
                    z_layer=self.z_layer,
                )
                weights.append(output.weight)
                biases.append(output.bias)
        
                hs.append(output.h)
                zs.append(output.z)

#         weight = weights[0]
        weight = torch.stack(weights).mean(dim=0)
#         weight = torch.eye(len(weight)).to(weight.device, weight.dtype)

        bias = torch.stack(biases).mean(dim=0)
        print(bias.norm())
        bias = bias * .5
        print(bias.norm())

        print("h norm", torch.stack(hs).norm(dim=-1).squeeze().mean())
        print("Jh norm", torch.stack([weight @ h for h in hs]).norm(dim=-1).squeeze().mean())
        print("Jh + b norm", torch.stack([weight @ h + bias * 2 for h in hs]).norm(dim=-1).squeeze().mean())
        print("Jh + b/2 norm", torch.stack([weight @ h + bias for h in hs]).norm(dim=-1).squeeze().mean())
        print("z norm", torch.stack(zs).norm(dim=-1).squeeze().mean())

#         bias = mt.model.transformer.wte.weight.data[z_token_indices].mean(dim=0)

#         hidden_size = weight.shape[0]
#         proj = bias.new_zeros(hidden_size, hidden_size)
#         for z_token_index in z_token_indices:
#             y = mt.model.transformer.wte.weight.data[z_token_index, ..., None]
#             proj += y @ y.t() / y.norm().pow(2)
#         weight = weight @ proj

#         bias = torch.zeros_like(bias)

        return operators.LinearRelationOperator(
            mt=self.mt,
            weight=weight,
            bias=bias,
            h_layer=self.h_layer,
            z_layer=self.z_layer,
            prompt_template=relation.prompt_templates[0],
        )


estimator = NewJEstimator(mt=mt)

relation = dataset[0].set(prompt_templates=[dataset[0].prompt_templates[0]])
# relation = dataset[1].set(prompt_templates=["People in {} speak the language of"])

# relation = data.Relation(
#     name="workplaces",
#     prompt_templates=["{} typically work inside of a"],
#     samples=[
#         data.RelationSample("Nurses", "hospital"),
#         data.RelationSample("Judges", "courtroom"),
#         data.RelationSample("Farmers", "field"),
#         data.RelationSample("Car mechanics", "garage"),
#         data.RelationSample("Teachers", "classroom"),
#     ],
# )

# relation = data.Relation(
#     name="color",
#     prompt_templates=["{} are typically associated with the color"],
#     samples=[
#         data.RelationSample("Bananas", "yellow"),
#         data.RelationSample("Kiwis", "green"),
#         data.RelationSample("Potatoes", "brown"),
#     ],
#     _range=[
#         "pink",
#         "yellow",
#         "red",
#         "green",
#         "blue",
#         "orange",
#         "violet",
#         "magenta",
#         "brown",
#         "black",
#         "white",
#         "purple",
#         "grey",
#         "gray",
#         "maroon",
#     ]
# )
# relation = data.Relation.from_dict({
#     "name": "president elected 1900s",
#     "prompt_templates": [
#         "{} was elected president in the year"
#     ],
#     "samples": [
#         {
#             "subject": "John F. Kennedy",
#             "object": "1960"
#         },
#         {
#             "subject": "Lyndon B. Johnson",
#             "object": "1963"
#         },
#         {
#             "subject": "Richard Nixon",
#             "object": "1968"
#         },
#         {
#             "subject": "James Carter",
#             "object": "1977"
#         },
#         {
#             "subject": "Ronald Reagan",
#             "object": "1980"
#         },
#         {
#             "subject": "George H. W. Bush",
#             "object": "1988"
#         },
#         {
#             "subject": "Bill Clinton",
#             "object": "1992"
#         }
#     ]
# })

# relation = data.Relation.from_dict({
#     "name": "president born 1900s",
#     "prompt_templates": [
#         "{} was born in the year"
#     ],
#     "samples": [
#         {
#             "subject": "John F. Kennedy",
#             "object": "1917"
#         },
#         {
#             "subject": "Lyndon B. Johnson",
#             "object": "1908"
#         },
#         {
#             "subject": "Richard Nixon",
#             "object": "1913"
#         },
#         {
#             "subject": "James Carter",
#             "object": "1924"
#         },
#         {
#             "subject": "Ronald Reagan",
#             "object": "1911"
#         },
#         {
#             "subject": "George H. W. Bush",
#             "object": "1924"
#         },
#         {
#             "subject": "Bill Clinton",
#             "object": "1946"
#         }
#     ]
# })

with torch.device(device):
    operator = estimator(relation)

In [None]:
operator("India", k=20).predictions

In [None]:
from src.corner import CornerEstimator

corner_estimator = CornerEstimator(mt.model, mt.tokenizer)
corner = corner_estimator.estimate_corner_with_gradient_descent(
    target_words = list(relation.range),
    verbose=True,
)

In [None]:
import copy
corner_operator = operators.LinearRelationOperator(
    mt = operator.mt,
    weight = operator.weight,
    bias = corner/5, ## setting bias = corner
    h_layer = operator.h_layer,
    z_layer = operator.z_layer,
    prompt_template = operator.prompt_template,
    subject_token_offset = operator.subject_token_offset,
)

In [None]:
corner_operator("United States", k=20).predictions

In [None]:
generated = mt.model.generate(
    mt.tokenizer("The capital of Pakistan is", padding=True, return_tensors="pt").input_ids.to(device),
)

mt.tokenizer.decode(generated[0])

In [None]:
import seaborn as sns

sns.heatmap(operator.weight[300:375, 300:375].cpu().numpy())

In [None]:
@torch.inference_mode()
def complete(prompt):
    inputs = mt.tokenizer(prompt, return_tensors="pt").to(device)
    outputs = mt.model(**inputs)
    top5 = torch.log_softmax(outputs.logits, dim=-1)[:, -1].topk(k=20, dim=-1).indices.squeeze().tolist()
    return [mt.tokenizer.decode(x) for x in top5]

complete("Bill Clinton was born in the year")

In [None]:
dataset[0].prompt_templates

In [None]:
estimator = operators.JacobianIclEstimator(
    mt=mt,
    h_layer=5,
    z_layer=27,
)
relation = dataset[0]
with torch.device(device):
    operator = estimator(relation.set(samples=relation.samples[10:15]))

In [None]:
operator("Spain", k=10).predictions