In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import json
from tqdm import tqdm
import random
import copy

import os
import sys
sys.path.append('..')

from relations import estimate
from util import model_utils
from util import nethook
import baukit

torch.cuda.set_device(0)

In [3]:
MODEL_NAME = "EleutherAI/gpt-neox-20b" # options gpt2-{} | "EleutherAI/gpt-neox-20b" | "EleutherAI/gpt-j-6B"
mt = model_utils.ModelAndTokenizer(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float16)

model = mt.model
tokenizer = mt.tokenizer
tokenizer.pad_token = tokenizer.eos_token

print(f"{MODEL_NAME} ==> device: {model.device}, memory: {model.get_memory_footprint()}")

EleutherAI/gpt-neox-20b ==> device: cuda:0, memory: 41293685880


In [4]:
layer_name_format = "gpt_neox.layers.{}"
final_layer_norm = "gpt_neox.final_layer_norm"
unembed = "embed_out"
num_layer_field = "num_hidden_layers"
break_layer_idx = 23

# layer_name_format = "transformer.h.{}"
# final_layer_norm = "transformer.ln_f"
# unembed = "lm_head"
# num_layer_field = "n_layer"
# break_layer_idx = 15

## Copy later part of the model

In [5]:
import copy
part_config = copy.deepcopy(model.config)
# part_config.n_layer = mt.num_layers - h_layer_idx
# part_config.num_hidden_layers = mt.num_layers - h_layer_idx
setattr(part_config, num_layer_field, mt.num_layers - break_layer_idx)
part_num_layers = getattr(part_config, num_layer_field)

part_layer_names = [layer_name_format.format(idx) for idx in range(part_num_layers)]

In [6]:
def get_param_names(layer_name):
    param_names = []
    for key in state_dict:
        if(key.startswith(layer_name)):
            param_names.append(key)
    return param_names

In [7]:
state_dict = model.state_dict()

for idx in range(break_layer_idx, mt.num_layers):
    part_layer_name = layer_name_format.format(idx - break_layer_idx)
    full_layer_name = layer_name_format.format(idx)
    
    print(part_layer_name, "<<", full_layer_name)

    part_param_names = get_param_names(part_layer_name)
    full_param_names = get_param_names(full_layer_name)
    for part_param, full_param in zip(part_param_names, full_param_names):
        # print(part_param, full_param)
        state_dict[part_param][...] = state_dict[full_param]

transformer.h.0 << transformer.h.15
transformer.h.1 << transformer.h.16
transformer.h.2 << transformer.h.17
transformer.h.3 << transformer.h.18
transformer.h.4 << transformer.h.19
transformer.h.5 << transformer.h.20
transformer.h.6 << transformer.h.21
transformer.h.7 << transformer.h.22
transformer.h.8 << transformer.h.23


In [8]:
for idx in range(part_num_layers, mt.num_layers):
    layer_name = layer_name_format.format(idx)
    print("deleting >> ", layer_name)
    param_names = get_param_names(layer_name)
    for param in param_names:
        state_dict.pop(param)

deleting >>  transformer.h.9
deleting >>  transformer.h.10
deleting >>  transformer.h.11
deleting >>  transformer.h.12
deleting >>  transformer.h.13
deleting >>  transformer.h.14
deleting >>  transformer.h.15
deleting >>  transformer.h.16
deleting >>  transformer.h.17
deleting >>  transformer.h.18
deleting >>  transformer.h.19
deleting >>  transformer.h.20
deleting >>  transformer.h.21
deleting >>  transformer.h.22
deleting >>  transformer.h.23


### Save later part of the model

In [9]:
path_name = f"{MODEL_NAME}__last_{part_num_layers}_layers"
os.makedirs(path_name, exist_ok = True)

part_config.save_pretrained(path_name)
torch.save(state_dict, f"{path_name}/pytorch_model.bin")

## Load later part of the model

In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer

# path_name = "gpt2-medium__last_9_layers"
path_name = "EleutherAI/gpt-neox-20b__last_21_layers"

part_model = AutoModelForCausalLM.from_pretrained(
                path_name, low_cpu_mem_usage=True, torch_dtype=torch.float16
            )
part_model = part_model.eval().cuda('cuda:1')

part_layer_names = [layer_name_format.format(idx) for idx in range(getattr(part_model.config, num_layer_field))]
print(f"last {getattr(part_model.config, num_layer_field)} ==> device: {part_model.device}, memory: {part_model.get_memory_footprint()}")

last 21 ==> device: cuda:1, memory: 20356239906


## Checking Equivalence

In [6]:
prompt = ["The Space Needle is located in the country of"]

txt, ret_dict = model_utils.generate_fast(
    model, tokenizer,
    prompt,
    argmax_greedy = True,
    max_out_len= 20,
    # debug=True,
    get_answer_tokens=True,
)
model_utils.print_formatted_results(prompt, txt, ret_dict)

The Space Needle is located in the country of
The Space Needle is located in the country of Washington, in the city of Seattle. The Space
p(answer):  p(' Washington'[5041])=0.2216, p(' Seattle'[16335])=0.1621, p(' the'[253])=0.1522, p(' Malaysia'[23799])=0.0182, p(' Canada'[6144])=0.0171



In [7]:
tokenized_inputs = tokenizer(
    prompt,
    padding = True,
    return_tensors="pt"
).to(next(model.parameters()).device)

break_layer_name = layer_name_format.format(break_layer_idx)
z_layer_name = layer_name_format.format(mt.num_layers-1)

with baukit.TraceDict(
    model, 
    mt.layer_names, # [h_layer_name, z_layer_name], 
    retain_input=True
) as traces:
    outputs = model(**tokenized_inputs)

In [8]:
def replace_first_layer_output(target):
    first_layer = layer_name_format.format(0)
    def edit_policy(output, layer_name):
        if(layer_name != first_layer):
            return output
        print(layer_name, " << original", break_layer_name)
        output[0][...] = target[0].to(part_model.device)
        output[1][0][...] = target[1][0].to(part_model.device)
        output[1][1][...] = target[1][1].to(part_model.device)
        return output

    return edit_policy

def check_valid(module_name, prefix = "transformer.h", start_layer = 1):
    if(module_name in [final_layer_norm, unembed]):
        return True
    for idx in range(start_layer, mt.num_layers):
        if(module_name.startswith(f"{prefix}.{idx}")):
            return True
    return False

def untuple(x):
    if(type(x) is tuple):
        return x[0]
    return x

In [9]:
need_gradients = {
    n: p
    for n, p in part_model.named_parameters()
    if check_valid(
        n, prefix = layer_name_format[:-3]
    )
}

for n, w in part_model.named_parameters():
    if(n in need_gradients):
        w.requires_grad = True
    else:   
        w.requires_grad = False

In [10]:
with baukit.TraceDict(
    part_model, 
    part_layer_names,
    retain_input=True,
    edit_output = replace_first_layer_output(
        target = traces[break_layer_name].output
    )
) as part_traces:
    part_outputs = part_model(
        input_ids = tokenized_inputs.input_ids.to(part_model.device),
        attention_mask = tokenized_inputs.attention_mask.to(part_model.device)
    )

gpt_neox.layers.0  << original gpt_neox.layers.23


In [11]:
# torch.dist(outputs.logits, part_outputs.logits.to(model.device))

In [12]:
# logits = part_outputs.logits.to(model.device)
# top_k = 5
# softmax_out = torch.nn.functional.softmax(logits[:, -1, :], dim=1)

# # Top-k sampling
# tk = torch.topk(softmax_out, top_k, dim=1).indices
# [
#     tokenizer.decode(t) for t in tk[0]
# ]

In [13]:
# # Check input difference

# for idx in range(break_layer_idx, mt.num_layers):
#     orig_input = traces[layer_name_format.format(idx)].input
#     cur_input = part_traces[layer_name_format.format(idx - break_layer_idx)].input

#     print(torch.dist(orig_input[0], cur_input[0].to(orig_input[0].device)))


In [14]:
# # Check output difference

# for idx in range(break_layer_idx, mt.num_layers):

#     original_layer = layer_name_format.format(idx)
#     target_layer = layer_name_format.format(idx - break_layer_idx)
#     print(original_layer, target_layer)

#     orig_output = traces[original_layer].output
#     cur_output = part_traces[target_layer].output

#     print(
#         torch.dist(orig_output[0], cur_output[0].to(orig_output[0].device)),
#         torch.dist(orig_output[1][0], cur_output[1][0].to(orig_output[0].device)),
#         torch.dist(orig_output[1][1], cur_output[1][1].to(orig_output[0].device))
#     )
#     print()


## Calculate Jacobians

### On a single device

In [29]:
h_token_index = 3
calculate_at_lnf = False
consider_residual = False

##################################
h_layer_idx = 17
##################################
h_layer_name = layer_name_format.format(h_layer_idx)
z_layer_name = mt.layer_names[-1]

h = traces[h_layer_name].output[0][0, h_token_index]
z = traces[z_layer_name].output[0][0, -1]

def compute_z_from_h(h: torch.Tensor) -> torch.Tensor:
    def insert_h(output: tuple, layer: str) -> tuple:
        if layer != h_layer_name:
            return output
        # print((output[0][0, h_token_index] - h).norm())
        output[0][0, h_token_index] = h
        return output

    with baukit.TraceDict(
        model, (h_layer_name, z_layer_name), edit_output=insert_h
    ) as ret:
        model(**tokenized_inputs)
    # print(z_layer_name, ret[z_layer_name].output[0][-1].shape)
    if(calculate_at_lnf == False):
        f_h = ret[z_layer_name].output[0][0, -1]
    else:
        f_h = ret[z_layer_name].output[0][-1]
    return f_h - h if consider_residual == True else f_h


weight = torch.autograd.functional.jacobian(compute_z_from_h, h, vectorize=True) 

### On 2 cuda devices

In [15]:
h_layer_idx = 27
shifted__h_layer_idx = h_layer_idx - (len(mt.layer_names) - len(part_layer_names))
shifted__h_layer_idx

4

In [16]:
h_token_index = 3
calculate_at_lnf = False
consider_residual = False

first_layer = layer_name_format.format(0)
shifted__h_layer_name = layer_name_format.format(shifted__h_layer_idx)
shifted__z_layer_name = part_layer_names[-1]
h = part_traces[shifted__h_layer_name].output[0][0, h_token_index]
z = part_traces[shifted__z_layer_name].output[0][0, -1]

def compute_z_from_h(h: torch.Tensor) -> torch.Tensor:
    def replace_first_layer_output(target):
        def edit_policy(output, layer_name):
            if(layer_name == first_layer):
                print(layer_name, " << original", break_layer_name)
                output[0][...] = target[0].to(part_model.device)
                output[1][0][...] = target[1][0].to(part_model.device)
                output[1][1][...] = target[1][1].to(part_model.device)
            if(layer_name == shifted__h_layer_name):
                print(f"replacing {shifted__h_layer_name} outputs")
                output[0][0, h_token_index] = h
            return output
        return edit_policy

    with baukit.TraceDict(
        part_model, 
        (first_layer, shifted__h_layer_name, shifted__z_layer_name),
        edit_output = replace_first_layer_output(
            target = traces[break_layer_name].output
        )
    ) as ret:
        part_model(
            input_ids = tokenized_inputs.input_ids.to(part_model.device),
            attention_mask = tokenized_inputs.attention_mask.to(part_model.device)
        )
    if(calculate_at_lnf == False):
        f_h = ret[shifted__z_layer_name].output[0][0, -1]
    else:
        f_h = ret[shifted__z_layer_name].output[0][-1]

    return f_h - h if consider_residual == True else f_h

In [17]:
# part_weight = torch.autograd.functional.jacobian(compute_z_from_h, h, vectorize=True) 
# torch.dist(weight, part_weight.to(weight.device))

In [18]:
def calculate_jacobian(function, h):
    h.retain_grad()
    z_est = function(h)
    jacobian = []
    print("Calculating Jacobians ...")
    for idx in tqdm(range(h.shape[0])):
        part_model.zero_grad()
        z_est[idx].backward(retain_graph=True)
        jacobian.append(copy.deepcopy(h.grad))
        h.grad.zero_()
    return torch.stack(jacobian)

h = part_traces[shifted__h_layer_name].output[0][0, h_token_index]
J = calculate_jacobian(compute_z_from_h, h)

gpt_neox.layers.0  << original gpt_neox.layers.23
replacing gpt_neox.layers.4 outputs
Calculating Jacobians ...


100%|██████████| 6144/6144 [19:59<00:00,  5.12it/s]


In [35]:
# torch.dist(part_weight, J)

tensor(0.0050, device='cuda:1', dtype=torch.float16)

In [20]:
space_needle = estimate.estimate_relation_operator_neox(
    model, part_model,
    tokenizer,
    "The Space Needle",
    "{} is located in the country of",
    layer=27,
    layer_name_format= layer_name_format,
    num_layer_field= num_layer_field
)

prompt >>  The Space Needle is located in the country of
h_token_idx >>  3
gpt_neox.layers.0  << original gpt_neox.layers.23
replacing gpt_neox.layers.4 outputs
Calculating Jacobians ...


100%|██████████| 6144/6144 [17:05<00:00,  5.99it/s]


In [22]:
torch.dist(space_needle.weight, J.to(model.device))

tensor(0., device='cuda:0', dtype=torch.float16)

In [31]:
relation = estimate.RelationOperator(
    model = model,
    tokenizer = tokenizer,
    layer = space_needle.layer,
    relation="{} is located in the country of",
    weight = space_needle.weight,
    bias = space_needle.bias,
    calculated_at_lnf= calculate_at_lnf,
    consider_residual = consider_residual,

    layer_name_format = layer_name_format,
    final_layer_norm = final_layer_norm,
    unembed = unembed,
)

In [32]:
test_cases = [
    ("The Space Needle", -1, "United States"),
    ("The Great Wall", -1, "China"),
    ("Niagara Falls", -2, "Canada"),
    ("Valdemarsvik", -1, "Sweden"),
    ("Kyoto University", -2, "Japan"),
    ("Hattfjelldal", -1, "Norway"),
    ("Ginza", -1, "Japan"),
    ("Sydney Hospital", -2, "Australia"),
    ("Mahalangur Himal", -1, "Nepal"),
    ("Higashikagawa", -1, "Japan"),
    ("Trento", -1, "Italy"),
    ("Taj Mahal", -1, "India")
]

for subject, subject_token_index, target in test_cases:
    objects = relation(
        subject,
        subject_token_index=subject_token_index,
        device=model.device,
        return_top_k=5,
    )
    print(f"{subject}, target: {target}   ==>   predicted: {objects}")

The Space Needle, target: United States   ==>   predicted: [' Washington', ' Seattle', ' the', ' Malaysia', ' Canada']
The Great Wall, target: China   ==>   predicted: [' Seattle', ' Washington', ' the', '\n', ' Japan']
Niagara Falls, target: Canada   ==>   predicted: [' Seattle', ' Washington', ' the', ' Japan', ' glass']
Valdemarsvik, target: Sweden   ==>   predicted: [' Seattle', ' Washington', ' the', ' Japan', ' glass']
Kyoto University, target: Japan   ==>   predicted: [' Japan', ' Seattle', ' Washington', ' the', ' Tokyo']
Hattfjelldal, target: Norway   ==>   predicted: [' Seattle', ' Washington', ' the', ' glass', ' Japan']
Ginza, target: Japan   ==>   predicted: [' Japan', ' Seattle', ' Washington', ' the', ' Tokyo']
Sydney Hospital, target: Australia   ==>   predicted: [' Seattle', ' Washington', ' the', '\n', ' origin']
Mahalangur Himal, target: Nepal   ==>   predicted: [' Seattle', ' Washington', ' the', '\n', ' glass']
Higashikagawa, target: Japan   ==>   predicted: [' Jap

In [33]:
space_needle.misc

{'Jh_norm': 107.25,
 'bias_norm': 895.0,
 'h_info': {'h_index': 3, 'token_id': 282, 'token': 'le'},
 'consider_residual': False}