In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

import torch
import matplotlib.pyplot as plt
from src import models, data
from tqdm.auto import tqdm
import json
import os

os.makedirs("layer_sweep/Jacobian_plots", exist_ok=True)
os.makedirs("layer_sweep/weights_and_biases", exist_ok=True)

In [3]:
device = "cuda:0"
mt = models.load_model("gptj", device=device)
print(
    f"dtype: {mt.model.dtype}, device: {mt.model.device}, memory: {mt.model.get_memory_footprint()}"
)

dtype: torch.float16, device: cuda:0, memory: 12219206136


### Utils

In [4]:
prompt = "The Space Needle is located in the city of"
tokenized = mt.tokenizer(prompt, return_tensors="pt", padding=True).to(mt.model.device)

In [5]:
import baukit

with baukit.TraceDict(
    mt.model,
    models.determine_layer_paths(mt)
) as traces:
    output = mt.model(**tokenized)

In [6]:
def untuple(x):
    if isinstance(x, tuple):
        return x[0]
    return x


def interpret_logits(mt, logits, top_k=10, get_proba = False):
    logits = torch.nn.functional.softmax(logits, dim=-1) if get_proba else logits
    token_ids = logits.topk(dim=-1, k=top_k).indices.squeeze().tolist()
    logit_values = logits.topk(dim=-1, k=top_k).values.squeeze().tolist()
    return [
        (mt.tokenizer.decode(t), round(v, 3)) for t, v in zip(token_ids, logit_values)
    ]


def logit_lens(
    mt,
    h,
    interested_tokens=[],
    get_proba=False,
):
    logits = mt.lm_head(h)
    logits = torch.nn.functional.softmax(logits, dim=-1) if get_proba else logits
    candidates = interpret_logits(mt, logits)
    interested_logits = {
        t.item(): (logits[t].item(), mt.tokenizer.decode(t)) for t in interested_tokens
    }
    return candidates, interested_logits


# interpret_logits(mt, output.logits[0][-1])

In [7]:
# interested_words = [" Seattle", " Paris", " Dhaka"]
# int_tokenized = mt.tokenizer(interested_words, return_tensors="pt", padding=True).to(
#     mt.model.device
# )
# int_tokenized.input_ids

# z = untuple(traces[models.determine_layer_paths(mt)[-1]].output)[0][-1]
# print(z.shape)

# logit_lens(mt, z, [t[0] for t in int_tokenized.input_ids], get_proba=False)

In [8]:
def filter_by_model_knowledge(mt, relation_prompt, relation_samples):
    model_knows = []
    for sample in relation_samples:
        tokenized = relation_prompt.format(sample.subject)
        output = mt.model(
            **mt.tokenizer(tokenized, return_tensors="pt", padding=True).to(
                mt.model.device
            )
        )

        object_id = output.logits[0][-1].argmax().item()
        object = mt.tokenizer.decode(object_id)

        tick = sample.object.strip().startswith(object.strip())
        # print(object, sample.object, tick)

        if tick:
            model_knows.append(sample)

    return model_knows

In [10]:
dataset = data.load_dataset()
capital_cities = dataset[0]
capital_cities.__dict__.keys()

dict_keys(['name', 'prompt_templates', 'samples', 'properties', '_domain', '_range'])

In [None]:
import numpy as np
from src.functional import make_prompt

icl_indices = np.random.choice(range(len(capital_cities.samples)), 3, replace=False)
icl_samples = [capital_cities.samples[i] for i in icl_indices]

icl_prompt = make_prompt(
    prompt_template = capital_cities.prompt_templates[0],
    subject="{}",
    examples=icl_samples,
)

print(icl_prompt)

In [None]:
model_knows = filter_by_model_knowledge(mt, icl_prompt, capital_cities.samples)
len(model_knows)

## Layer Richness based on logit lens

In [None]:
from typing import Literal
def layer_c_measure(
    mt, relation_prompt, subject, 
    verbose=False, measure: Literal ["completeness", "contribution"] = "contribution"
):
    tokenized = relation_prompt.format(subject)
    with baukit.TraceDict(mt.model, layers=models.determine_layer_paths(mt)) as traces:
        output = mt.model(
            **mt.tokenizer(tokenized, return_tensors="pt", padding=True).to(
                mt.model.device
            )
        )

    object_id = output.logits[0][-1].argmax().item()
    object = mt.tokenizer.decode(object_id)
    # base_logit = output.logits[0][-1][object_id].item()
    base_score = torch.nn.functional.softmax(output.logits[0][-1], dim=-1)[
        object_id
    ].item()

    if verbose:
        print(f"object ==> {object} [{object_id}], base = {base_score}")

    layer_contributions = {}

    prev_score = 0
    for layer in models.determine_layer_paths(mt):
        h = untuple(traces[layer].output)[0][-1]
        candidates, interested_logits = logit_lens(
            mt, h, torch.tensor([object_id]), get_proba=True
        )
        layer_score = interested_logits[object_id][0]
        sub_score = base_score if measure == "completeness" else prev_score
        cur_layer_contribution = (layer_score - sub_score) / base_score

        layer_contributions[layer] = cur_layer_contribution

        if verbose:
            print(f"layer: {layer}, diff: {cur_layer_contribution}")

        prev_score = layer_score

    return layer_contributions


relation_prompt = mt.tokenizer.eos_token + " {} is located in the city of"
subject = "The Space Needle"
# layer_c_measure(mt, relation_prompt, subject, verbose=True)

In [None]:
layer_c_info = {layer: [] for layer in models.determine_layer_paths(mt)}

for sample in tqdm(model_knows):
    cur_richness = layer_c_measure(mt, icl_prompt, sample.subject)
    for layer in models.determine_layer_paths(mt):
        layer_c_info[layer].append(cur_richness[layer])

# with open("layer_sweep/layer_contribution_info.json", "w") as f:
with open("layer_sweep/layer_completeness_info.json", "w") as f:
    json.dump(layer_c_info, f)

for layer in models.determine_layer_paths(mt):
    layer_c_info[layer] = np.array(layer_c_info[layer])

In [None]:
mean_richness = [layer_c_info[layer].mean() for layer in models.determine_layer_paths(mt)]
low_richness = [layer_c_info[layer].min() for layer in models.determine_layer_paths(mt)]
high_richness = [layer_c_info[layer].max() for layer in models.determine_layer_paths(mt)]

plt.plot(mean_richness, color="blue")
plt.fill_between(range(len(mean_richness)), low_richness, high_richness, alpha=0.2)
plt.axhline(0, color="red", linestyle="--")

plt.xlabel("Layer")
plt.ylabel("completeness")
plt.xticks(range(0, len(mean_richness), 2))

plt.show()

## Layer Richness based on `Jh_norm` and `J_norm`

In [None]:
import copy

capital_cities_known = copy.deepcopy(capital_cities.__dict__)
capital_cities_known["samples"] = model_knows

capital_cities_known = data.Relation(**capital_cities_known)

In [None]:
from src.operators import JacobianEstimator, JacobianIclMeanEstimator
from src.data import RelationSample

# indices = np.random.choice(range(len(capital_cities.samples)), 3, replace=False)
# samples = [capital_cities.samples[i] for i in indices]

# training_samples = copy.deepcopy(capital_cities.__dict__)
# training_samples["samples"] = samples
# training_samples = data.Relation(**training_samples)

# mean_estimator = JacobianIclMeanEstimator(
#     mt=mt,
#     h_layer=12,
# )

# operator = mean_estimator(training_samples)
# operator("Russia", k = 10).predictions

In [None]:
estimator = JacobianEstimator(
    mt=mt,
    h_layer=12,
)

operator = estimator.call_on_sample(
    sample = RelationSample(subject="United States", object="Washington"),
    prompt_template= icl_prompt
)

In [None]:
operator.metadata['Jh'].norm().item(), operator.weight.norm().item()

In [None]:
layerwise_jh = {layer: [] for layer in models.determine_layer_paths(mt)}

for sample in tqdm(set(model_knows) - set(icl_samples)):
    for h_layer in range(0, 24):
        layer_name = models.determine_layer_paths(mt)[h_layer]
        estimator = JacobianEstimator(
            mt=mt,
            h_layer=h_layer,
        )
        operator = estimator.call_on_sample(
            # sample = RelationSample(subject="Russia", object="Moscow"),
            sample = sample,
            prompt_template= icl_prompt
        )

        # print(h_layer, " ===> ", f"J:{operator.weight.norm().item()},  Jh: {operator.misc['Jh'].norm().item()}")
        layerwise_jh[layer_name].append({
            "J": operator.weight.norm().item(),
            "Jh": operator.metadata['Jh'].norm().item(),
            "bias": operator.bias.norm().item()
        })

In [None]:
for layer in models.determine_layer_paths(mt):
    if layer in layerwise_jh and len(layerwise_jh[layer]) == 0:
        layerwise_jh.pop(layer)

with open("layer_sweep/layer_jh_info.json", "w") as f:
    json.dump(layerwise_jh, f)

In [None]:
key = "Jh"

info = {
    layer: np.array([layerwise_jh[layer][i][key] for i in range(len(layerwise_jh[layer]))])
    for layer in layerwise_jh.keys()
}

mean = [info[layer].mean() for layer in info.keys()]
plt.plot(mean, color="blue", linewidth=4)
plt.xticks(range(0, len(mean), 2))
plt.ylabel(f"{key}_norm")

for i in range(len(set(model_knows) - set(icl_samples))):
    arr = []
    for layer in layerwise_jh.keys():
        arr.append(layerwise_jh[layer][i][key])
    plt.plot(arr, alpha=0.2)

plt.show()

## Causal Tracing on `subject_last`

In [None]:
from src.operators import _compute_h_index

# h_idx, inputs = _compute_h_index(
#     mt = mt, 
#     prompt = "The location of {} is in the city of".format(subject_original),
#     subject = subject_original,
#     offset=-1
# )

# print(h_idx, inputs)
# for t in inputs.input_ids[0]:
#     print(t.item(), mt.tokenizer.decode(t.item()))

In [None]:
def get_replace_intervention(intervention_layer, intervention_tok_idx, h_intervention):
    def intervention(output, layer):
        if(layer != intervention_layer):
            return output
        output[0][0][intervention_tok_idx] = h_intervention
        return output
    return intervention

In [None]:
def causal_tracing(
    mt,
    prompt_template,
    subject_original, subject_corruption,
    verbose = False
): 
    h_idx_orig, tokenized_orig = _compute_h_index(
        mt = mt,
        prompt = prompt_template.format(subject_original),
        subject = subject_original,
        offset=-1
    ) 

    h_idx_corr, tokenized_corr = _compute_h_index(
        mt = mt,
        prompt = prompt_template.format(subject_corruption),
        subject = subject_corruption,
        offset=-1
    )

    layer_names = models.determine_layer_paths(mt)
    with baukit.TraceDict(
        mt.model, layer_names
    ) as traces_o:
        output_o = mt.model(**tokenized_orig)

    answer, p_answer = interpret_logits(mt, output_o.logits[0][-1], get_proba=True)[0]
    answer_t = mt.tokenizer(answer, return_tensors="pt").to(device).input_ids[0]

    if(verbose):
        print(f"answer: {answer}[{answer_t.item()}], p(answer): {p_answer:.3f}")

    result = {}
    for intervention_layer in layer_names:
        with baukit.TraceDict(
            mt.model, 
            layers = layer_names,
            edit_output = get_replace_intervention(
                intervention_layer= intervention_layer,
                intervention_tok_idx= h_idx_corr,
                h_intervention = untuple(traces_o[intervention_layer].output)[0][h_idx_orig]
            )
        ) as traces_i:
            output_i = mt.model(**mt.tokenizer(prompt_template.format(subject_corruption), return_tensors="pt").to(device))

        z = untuple(traces_i[layer_names[-1]].output)[0][-1]
        candidates, interested = logit_lens(mt, z, [answer_t], get_proba=True)
        layer_p = interested[answer_t.item()][0]

        if(verbose):
            print(intervention_layer, layer_p)
        result[intervention_layer] = (layer_p - p_answer)/p_answer
    
    return result


causal_tracing(
    mt,
    prompt_template = "The location of {} is in the city of",
    subject_original = "The Space Needle",
    subject_corruption = "The Statue of Liberty",
    verbose = True
)

In [None]:
import copy

capital_cities_known = copy.deepcopy(capital_cities.__dict__)
capital_cities_known["samples"] = model_knows

capital_cities_known = data.Relation(**capital_cities_known)

In [None]:
num_icl = 3

icl_indices = np.random.choice(range(len(capital_cities_known.samples)), num_icl, replace=False)
icl_samples = [capital_cities.samples[i] for i in icl_indices]
icl_prompt = [
    f"{capital_cities.prompt_templates[0].format(sample.subject)} {sample.object}"
    for sample in icl_samples
]
icl_prompt = "\n".join(icl_prompt) + "\n" + capital_cities.prompt_templates[0]

print(icl_prompt)

In [None]:
test_samples = set(capital_cities_known.samples) - set(icl_samples)
causal_tracing_results = {layer: [] for layer in models.determine_layer_paths(mt)}

n_runs = 20
for run in tqdm(range(n_runs)):
    sample_pair = np.random.choice(range(len(test_samples)), 2, replace=False)
    sample_pair = [list(test_samples)[i] for i in sample_pair]
    print(sample_pair)
    
    cur_result = causal_tracing(
        mt,
        prompt_template = icl_prompt,
        subject_original = sample_pair[0].subject,
        subject_corruption = sample_pair[1].subject,
        verbose = False
    )

    for layer in models.determine_layer_paths(mt):
        causal_tracing_results[layer].append(cur_result[layer])

In [None]:
with open("layer_sweep/causal_tracing_results.json", "w") as f:
    json.dump(causal_tracing_results, f)

for layer in models.determine_layer_paths(mt):
    causal_tracing_results[layer] = np.array(causal_tracing_results[layer])

In [None]:
models.determine_layers(mt)[::2]

In [None]:
mean = [causal_tracing_results[layer].mean() for layer in models.determine_layer_paths(mt)]
# low = [causal_tracing_results[layer].min() for layer in mt.layer_names]
# high = [causal_tracing_results[layer].max() for layer in mt.layer_names]

plt.plot(mean, color="blue", linewidth=3)
# plt.fill_between(range(len(mean)), low, high, alpha=0.2)
plt.axhline(0, color="red", linestyle="--")

plt.xlabel("Layer")
plt.ylabel("layer_score")
plt.xticks(models.determine_layers(mt)[::2])

for run in range(n_runs):
    arr = []
    for layer in models.determine_layer_paths(mt):
        arr.append(causal_tracing_results[layer][run])
    plt.plot(arr, alpha=0.2)


plt.show()

## Layer sweep on mean ICL

In [None]:
import copy

capital_cities_known = copy.deepcopy(capital_cities.__dict__)
capital_cities_known["samples"] = model_knows

capital_cities_known = data.Relation(**capital_cities_known)

In [None]:
indices = np.random.choice(range(len(capital_cities_known.samples)), 2, replace=False)
samples = [capital_cities_known.samples[i] for i in indices]

capital_cities_subset = copy.deepcopy(capital_cities.__dict__)
capital_cities_subset["samples"] = samples
capital_cities_subset = data.Relation(**capital_cities_subset)

len(capital_cities_subset.samples)

In [None]:
from src.operators import JacobianIclMeanEstimator

mean_estimator = JacobianIclMeanEstimator(
    mt=mt,
    h_layer=12,
)

operator = mean_estimator(capital_cities_subset)

In [None]:
operator("Chile", k = 10).predictions

In [None]:
predictions = []
target = []

for sample in tqdm(set(capital_cities_known.samples)):
    cur_predictions = operator(sample.subject, k = 5).predictions
    predictions.append([
        p.token for p in cur_predictions
    ])
    target.append(sample.object)

In [None]:
from src.metrics import recall

recall(predictions, target)

In [None]:
# np.savez("layer_sweep/operator_weight.npz", jacobian = operator.weight.detach().cpu().numpy(), allow_pickle=True)

In [None]:
# j = np.load("layer_sweep/operator_weight.npz", allow_pickle=True)["jacobian"]

In [None]:
# torch.dist(torch.tensor(j).to(device), operator.weight)

In [None]:
def get_layer_wise_recall(capital_cities_subset, verbose = True, save_weights = True):

    layer_wise_recall = {}

    layer_names = models.determine_layer_paths(mt)
    for h_layer in tqdm(range(0, 24)):
        layer_name = layer_names[h_layer]
        mean_estimator = JacobianIclMeanEstimator(
            mt=mt,
            h_layer=h_layer,
        )
        operator = mean_estimator(capital_cities_subset)
        if(save_weights):
            np.savez(
                f"layer_sweep/weights_and_biases/{layer_name}.npz", 
                jacobian = operator.weight.detach().cpu().numpy(),
                bias = operator.bias.detach().cpu().numpy(), 
                allow_pickle=True
            )

        predictions = []
        target = []

        for sample in set(capital_cities_known.samples) - set(capital_cities_subset.samples):
            cur_predictions = operator(sample.subject, k = 5).predictions
            predictions.append([
                p.token for p in cur_predictions
            ])
            target.append(sample.object)

        layer_wise_recall[layer_name] = recall(predictions, target)
        
        if(verbose):
            print(layer_name, layer_wise_recall[layer_name])
    
    return layer_wise_recall

layer_wise_recall = get_layer_wise_recall(capital_cities_subset, verbose = True, save_weights = True)

In [None]:
with open("layer_sweep/layer_wise_recall.json", "w") as f:
    json.dump(layer_wise_recall, f)

In [None]:
layer_wise_recall_collection = {}
number_of_runs = 10

for run in tqdm(range(number_of_runs)):
    indices = np.random.choice(range(len(capital_cities_known.samples)), 2, replace=False)
    samples = [capital_cities_known.samples[i] for i in indices]

    capital_cities_subset = copy.deepcopy(capital_cities.__dict__)
    capital_cities_subset["samples"] = samples
    capital_cities_subset = data.Relation(**capital_cities_subset)

    layer_wise_recall = get_layer_wise_recall(capital_cities_subset, verbose=False, save_weights=False)

    for layer in layer_wise_recall.keys():
        if(layer not in layer_wise_recall_collection):
            layer_wise_recall_collection[layer] = []
        layer_wise_recall_collection[layer].append(layer_wise_recall[layer])

In [None]:
with open("layer_sweep/layer_wise_recall_collection.json", "w") as f:
    json.dump(layer_wise_recall_collection, f)

In [None]:
with open("layer_sweep/layer_wise_recall_collection.json") as f:
    layer_wise_recall_collection = json.load(f)

In [None]:
# top_1 = [layer_wise_recall[layer][0] for layer in layer_wise_recall.keys()]
# top_2 = [layer_wise_recall[layer][1] for layer in layer_wise_recall.keys()]
# top_3 = [layer_wise_recall[layer][2] for layer in layer_wise_recall.keys()]

import numpy as np

top_1 = np.array([
    np.array(layer_wise_recall_collection[layer])[:, 0]
    for layer in layer_wise_recall_collection.keys()
])

top_2 = np.array([
    np.array(layer_wise_recall_collection[layer])[:, 1]
    for layer in layer_wise_recall_collection.keys()
])

top_3 = np.array([
    np.array(layer_wise_recall_collection[layer])[:, 2]
    for layer in layer_wise_recall_collection.keys()
])


plt.plot(top_1.mean(axis=1), color="green", linewidth=3, label="recall@1")
plt.plot(top_2.mean(axis=1), color="blue", linewidth=2, label="recall@2")
plt.plot(top_3.mean(axis=1), color="red", linewidth=1, label="recall@3")

plt.fill_between(
    range(len(layer_wise_recall_collection.keys())),
    top_1.min(axis=1), top_1.max(axis=1),
    color="green", alpha=0.1
)

plt.fill_between(
    range(len(layer_wise_recall_collection.keys())),
    top_2.min(axis=1), top_2.max(axis=1),
    color="blue", alpha=0.05
)

plt.fill_between(
    range(len(layer_wise_recall_collection.keys())),
    top_3.min(axis=1), top_3.max(axis=1),
    color="red", alpha=0.03
)


plt.xticks(range(0, len(top_1), 2))
plt.xlabel("layer")
plt.ylabel("recall")

plt.legend()
plt.show()

In [None]:
from src.utils.supplimentary import visualize_matrix

for layer_name in models.determine_layer_paths(mt)[:24]:
    j = np.load(f"layer_sweep/weights_and_biases/{layer_name}.npz", allow_pickle=True)["jacobian"]
    j = torch.tensor(j).to(device)
    print(layer_name, j.shape)
    visualize_matrix(j, title = layer_name, save_path=f"layer_sweep/Jacobian_plots/{layer_name}.png")