In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import json
from tqdm import tqdm
import random

import os
import sys
sys.path.append('..')

from relations import estimate
from util import model_utils
import baukit
import transformers

In [3]:
######################################################################################################################
MODEL_NAME = "EleutherAI/gpt-neox-20b" # options gpt2-{} | "EleutherAI/gpt-neox-20b" | "EleutherAI/gpt-j-6B"

layer_name_format = "gpt_neox.layers.{}"
final_layer_norm = "gpt_neox.final_layer_norm"
unembed = "embed_out"
num_layer_field = "num_hidden_layers"
break_layer_idx = 22

# layer_name_format = "transformer.h.{}"
# final_layer_norm = "transformer.ln_f"
# unembed = "lm_head"
# num_layer_field = "n_layer"
# break_layer_idx = 23
######################################################################################################################



mt = model_utils.ModelAndTokenizer(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float16)

model = mt.model
tokenizer = mt.tokenizer
tokenizer.pad_token = tokenizer.eos_token

In [4]:
prompt = ["The Space Needle is located in the country of"]

txt, ret_dict = model_utils.generate_fast(
    model, tokenizer,
    prompt,
    argmax_greedy = True,
    max_out_len= 20,
    # debug=True,
    get_answer_tokens=True,
)
model_utils.print_formatted_results(prompt, txt, ret_dict)

The Space Needle is located in the country of
The Space Needle is located in the country of Washington, in the city of Seattle. The Space
p(answer):  p(' Washington'[5041])=0.2216, p(' Seattle'[16335])=0.1621, p(' the'[253])=0.1522, p(' Malaysia'[23799])=0.0182, p(' Canada'[6144])=0.0171



## Copy weights of layers from `break_layer_idx` to the last layer

In [5]:
import copy
part_config = copy.deepcopy(model.config)

setattr(part_config, num_layer_field, mt.num_layers - break_layer_idx)
part_num_layers = getattr(part_config, num_layer_field)

part_layer_names = [layer_name_format.format(idx) for idx in range(part_num_layers)]

In [6]:
state_dict = model.state_dict()

def get_param_names(layer_name):
    param_names = []
    for key in state_dict:
        if(key.startswith(layer_name)):
            param_names.append(key)
    return param_names

for idx in range(break_layer_idx, mt.num_layers):
    part_layer_name = layer_name_format.format(idx - break_layer_idx)
    full_layer_name = layer_name_format.format(idx)
    
    print(part_layer_name, "<<", full_layer_name)

    part_param_names = get_param_names(part_layer_name)
    full_param_names = get_param_names(full_layer_name)
    for part_param, full_param in zip(part_param_names, full_param_names):
        # print(part_param, full_param)
        state_dict[part_param][...] = state_dict[full_param]

gpt_neox.layers.0 << gpt_neox.layers.22
gpt_neox.layers.1 << gpt_neox.layers.23
gpt_neox.layers.2 << gpt_neox.layers.24
gpt_neox.layers.3 << gpt_neox.layers.25
gpt_neox.layers.4 << gpt_neox.layers.26
gpt_neox.layers.5 << gpt_neox.layers.27
gpt_neox.layers.6 << gpt_neox.layers.28
gpt_neox.layers.7 << gpt_neox.layers.29
gpt_neox.layers.8 << gpt_neox.layers.30
gpt_neox.layers.9 << gpt_neox.layers.31
gpt_neox.layers.10 << gpt_neox.layers.32
gpt_neox.layers.11 << gpt_neox.layers.33
gpt_neox.layers.12 << gpt_neox.layers.34
gpt_neox.layers.13 << gpt_neox.layers.35
gpt_neox.layers.14 << gpt_neox.layers.36
gpt_neox.layers.15 << gpt_neox.layers.37
gpt_neox.layers.16 << gpt_neox.layers.38
gpt_neox.layers.17 << gpt_neox.layers.39
gpt_neox.layers.18 << gpt_neox.layers.40
gpt_neox.layers.19 << gpt_neox.layers.41
gpt_neox.layers.20 << gpt_neox.layers.42
gpt_neox.layers.21 << gpt_neox.layers.43


In [7]:
for idx in range(part_num_layers, mt.num_layers):
    layer_name = layer_name_format.format(idx)
    print("deleting >> ", layer_name)
    param_names = get_param_names(layer_name)
    for param in param_names:
        state_dict.pop(param)

deleting >>  gpt_neox.layers.22
deleting >>  gpt_neox.layers.23
deleting >>  gpt_neox.layers.24
deleting >>  gpt_neox.layers.25
deleting >>  gpt_neox.layers.26
deleting >>  gpt_neox.layers.27
deleting >>  gpt_neox.layers.28
deleting >>  gpt_neox.layers.29
deleting >>  gpt_neox.layers.30
deleting >>  gpt_neox.layers.31
deleting >>  gpt_neox.layers.32
deleting >>  gpt_neox.layers.33
deleting >>  gpt_neox.layers.34
deleting >>  gpt_neox.layers.35
deleting >>  gpt_neox.layers.36
deleting >>  gpt_neox.layers.37
deleting >>  gpt_neox.layers.38
deleting >>  gpt_neox.layers.39
deleting >>  gpt_neox.layers.40
deleting >>  gpt_neox.layers.41
deleting >>  gpt_neox.layers.42
deleting >>  gpt_neox.layers.43


## Save the weights of layer part

In [8]:
path_name = f"{MODEL_NAME}__last_{part_num_layers}_layers"
os.makedirs(path_name, exist_ok = True)

part_config.save_pretrained(path_name)
torch.save(state_dict, f"{path_name}/pytorch_model.bin")