In [None]:
from experiments.eval import eval_accuracy
from transformers import AutoTokenizer, logging
from kblam.kb_encoder import KBEncoder
from kblam.models.kblam_config import KBLaMConfig
from kblam.models.llama3_model import KblamLlamaForCausalLM
from kblam.models.phi3_model import KBLaMPhi3ForCausalLM
from kblam.utils.data_utils import aug_row, generate_multi_entity_qa
from kblam.utils.train_utils import get_kb_embd
import os
import torch
import json
import numpy as np
from datetime import datetime

class ResultsCollector():
    def __init__(
            self,
            dataset_dir,
            encoder_path,
            encoder_spec,
            llm_base_dir,
            llm_type,
            model_path,
            query_head_path,
            test_dataset,
            scale_factor=None,
            kb_layer_frequency=-1):
        self.dataset_dir = dataset_dir
        self.encoder_path = encoder_path
        self.encoder_spec = encoder_spec
        self.llm_base_dir = llm_base_dir
        self.llm_type = llm_type
        self.model_path = model_path
        self.query_head_path = query_head_path
        self.test_dataset = test_dataset
        self.scale_factor = scale_factor

        encoder_model_spec = encoder_spec

        validation_part_start_idx = 120000 if "gpt" in test_dataset else 0
        self.dataset = json.load(open(os.path.join(dataset_dir, test_dataset) + "_augmented.json"))[validation_part_start_idx: ]

        self.key_embds = np.load(
            os.path.join(dataset_dir, f"{test_dataset}_{encoder_model_spec}_embd_key.npy")
        ).astype("float32")[validation_part_start_idx: ]
        self.value_embds = np.load(
            os.path.join(dataset_dir, f"{test_dataset}_{encoder_model_spec}_embd_value.npy")
        ).astype("float32")[validation_part_start_idx: ]

        if kb_layer_frequency == -1:
            kb_layer_frequency = 3

        self.tokenizer = AutoTokenizer.from_pretrained(llm_base_dir, trust_remote_code=True, padding_side="left")
        self.tokenizer.pad_token = "^"

        if llm_type == "llama3":
            if query_head_path:
                self.model = KblamLlamaForCausalLM.from_pretrained(
                    model_path,
                    device_map="cuda",
                    torch_dtype="auto",
                    trust_remote_code=True,
                )
                self.model.load_query_head(query_head_path)
            else:
                self.model = KblamLlamaForCausalLM.from_pretrained(
                    model_path,
                    device_map="cuda",
                    torch_dtype="auto",
                    trust_remote_code=True,
                )
        else:
            self.model = KBLaMPhi3ForCausalLM.from_pretrained(
                model_path,
                device_map="cuda",
                torch_dtype="auto",
                trust_remote_code=True,
            )

        self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
        self.model.generation_config.eos_token_id = 128009
        self.model.eval()


        self.kb_config = KBLaMConfig(
            sep_query_head=True,
            kb_layer_frequency=kb_layer_frequency,
            kb_scale_factor=self.scale_factor,
            **self.model.config.to_dict(),
        )

        print(self.kb_config)

        # self.kb_config = KBLaMConfig.from_pretrained(os.path.join(model_path, "kb_config.json"))
        # self.kb_config.update(**self.model.config.to_dict())

        self.model.config = self.kb_config

        print(self.model.config)

        self.encoder = KBEncoder(
            encoder_name=encoder_spec.upper(),
            projector_type="linear",
            endpoint_url="",
            out_dim=self.model.config.hidden_size * (self.model.config.num_hidden_layers // kb_layer_frequency + 1),
            frozen_base_model=True,
            projector_kwargs={"mlp_depth": 1, "mlp_hidden_dim": 512},
            device=torch.device("cuda"),
        )

        self.encoder.load_state_dict(torch.load(encoder_path))


    def collect_results(self):
        xs = [50, 100, 200, 400, 800] #, 1600, 3200, 6400]
        for x in xs:
            for trial in range(5):
                trial_start_time = datetime.now()
                # experiment_name = f"test_synthetic_scale_factor_validation_only_100_{x}_triples_{trial}_trial"
                #experiment_name = f"nq_10000_scale_factor_100_{x}_triples_{trial}_trial"
                experiment_name = f"test_syntehtic_llama1B_2000_{x}_triples_{trial}_trial"
                print(f"starting {experiment_name}")
                eval_accuracy(
                    dataset_dir=self.dataset_dir,
                    test_dataset=self.test_dataset,
                    encoder_spec=self.encoder_spec,
                    kb_scale_factor=self.scale_factor,
                    encoder_path=self.encoder_path,
                    exp_config=experiment_name,
                    fancy_question=False,
                    kb_layer_frequency=3,
                    kb_size=x,
                    llm_base_dir=self.llm_base_dir,
                    llm_type="llama3",
                    model_path=self.model_path,
                    test_batch_size=min(x, 200),
                    use_shift_match=False,
                    query_head_path=self.query_head_path,
                    save_dir="/home/lmikaelyan/kblam_attention/save_dir",
                    attn_save_dir="/home/lmikaelyan/kblam_attention/attention",
                    model=self.model,
                    dataset=self.dataset,
                    key_embds=self.key_embds,
                    value_embds=self.value_embds,
                    tokenizer=self.tokenizer,
                    encoder=self.encoder,
                    kb_config=self.kb_config,
                )
                print(f"Took {datetime.now() - trial_start_time}")


In [8]:

# dataset_dir = "/home/t-isazawat/kblam_rebuttal_files/synthetic_perturbed"
# test_dataset = "synthetic_perturbed"
# model_dir = "/home/t-isazawat/azure-blob/xi-kb-llm/llama3_8b_ins"
# base_dir = "/home/t-isazawat/azure-blob/xi-kb-llm/llama3_8b_ins"
# query_head_path = "/home/t-isazawat/azure-blob/xi-kb-llm/best_ckpt/learned_query_head_20000_OAI.pth"
# encoder_path = "/home/t-isazawat/azure-blob/xi-kb-llm/best_ckpt/encoder_ckpt_20000_OAI.pt"

# dataset_dir = "/home/t-isazawat/kblam_rebuttal_files/nq_10000"
# test_dataset = "nq_10000"
# model_dir = "/home/t-isazawat/azure-blob/xi-kb-llm/llama3_8b_ins"
# base_dir = "/home/t-isazawat/azure-blob/xi-kb-llm/llama3_8b_ins"
# query_head_path = "/home/t-isazawat/azure-blob/xi-kb-llm/best_ckpt/learned_query_head_20000_OAI.pth"
# encoder_path = "/home/t-isazawat/azure-blob/xi-kb-llm/best_ckpt/encoder_ckpt_20000_OAI.pt"

dataset_dir = "/home/lmikaelyan/KBLaM/synthetic_data"
test_dataset = "synthetic_data"
model_dir = "/home/lmikaelyan/KBLaM/llama1B-Instruct/stage1_lr_0.0001KBTokenLayerFreq3UseExtendedQAMultiEntities2KBSizedynamicSepQueryHeadKeyFromkey_OAI_synthetic_data_llama3_step_2000"
base_dir = "meta-llama/Llama-3.2-1B-Instruct"
#query_head_path = "/home/t-isazawat/azure-blob/xi-kb-llm/best_ckpt/learned_query_head_20000_OAI.pth"
encoder_path = "/home/lmikaelyan/KBLaM/llama1B-Instruct/stage1_lr_0.0001KBTokenLayerFreq3UseExtendedQAMultiEntities2KBSizedynamicSepQueryHeadKeyFromkey_OAI_synthetic_data_llama3_step_2000/encoder.pt"

collector = ResultsCollector(
    dataset_dir,
    encoder_path,
    "oai",
    model_path=model_dir,
    llm_type="llama3",
    llm_base_dir=base_dir,
    query_head_path=None,
    test_dataset=test_dataset,
    scale_factor=100
)


The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


KBLaMConfig {
  "_attn_implementation_autoset": true,
  "_name_or_path": "/home/lmikaelyan/KBLaM/llama1B-Instruct/stage1_lr_0.0001KBTokenLayerFreq3UseExtendedQAMultiEntities2KBSizedynamicSepQueryHeadKeyFromkey_OAI_synthetic_data_llama3_step_2000",
  "architectures": [
    "KblamLlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "attn_implementation": "eager",
  "base_model_name_or_path": "",
  "bos_token_id": 128000,
  "dynamic_sparsify": false,
  "eos_token_id": [
    128001,
    128008,
    128009
  ],
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 8192,
  "kb_layer_frequency": 3,
  "kb_scale_factor": 100,
  "max_position_embeddings": 131072,
  "mlp_bias": false,
  "num_attention_heads": 32,
  "num_hidden_layers": 16,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": {
    "factor": 32.0,
    "high_freq_factor": 4.0,
    "low_freq_factor": 1.0

AttributeError: 'ResultsCollector' object has no attribute 'kb_config'

In [None]:
collector.collect_results()

starting test_syntehtic_llama1B_2000_50_triples_0_trial




Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn 

Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn 



Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn weights
Using separate query head, getting attn 

KeyboardInterrupt: 

In [None]:
!pip install seaborn

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Collecting seaborn
  Downloading seaborn-0.13.2-py3-none-any.whl (294 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m294.9/294.9 KB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Installing collected packages: seaborn
Successfully installed seaborn-0.13.2


In [None]:
import matplotlib.pyplot as plt 
import seaborn as sns

def create_normalized_heatmap(data, output_file='heatmap.png', figsize=(10, 8)):
    # Normalize data to 0-1 range
    print(data.shape)
    normalized_data = (data - np.min(data)) / (np.max(data) - np.min(data))


    # Create heatmap
    plt.figure(figsize=figsize)
    sns.heatmap(normalized_data, 
                cmap='YlOrRd',
                cbar_kws={'label': 'Normalized Value'},
                xticklabels=True,
                yticklabels=True)

    plt.tight_layout()
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    plt.close()

attention_weights = np.load("/home/lmikaelyan/kblam_attention/attention/test_syntehtic_llama1B_2000_100_triples_4_trial_6_debug.npy")

print(attention_weights.shape)

print(np.sum(attention_weights, axis=1).shape)

create_normalized_heatmap(np.argmax(attention_weights, axis=1))


(100, 672, 100)
(100, 100)
(100, 100)


In [None]:
attention_weights = np.load("/home/lmikaelyan/kblam_attention/attention/test_syntehtic_llama1B_2000_100_triples_3_trial_0_debug.npy")

In [None]:
test_batch_size = 100

In [None]:
label = np.arange(test_batch_size)

In [None]:
attention_weights.shape

(100, 704, 100)

In [None]:
original_weight = np.load("/home/lmikaelyan/kblam_attention/attention/test_syntehtic_llama1B_2000_100_triples_3_trial_0.npy")

In [None]:
original_weight.shape

(100, 32, 22, 122)

In [None]:
original_weight = original_weight[..., :100]
weight = original_weight.reshape(100, -1, 100).sum(1)
weight.shape

(100, 100)

In [None]:
original_weight

array([[[[8.1787109e-03, 8.1787109e-03, 8.1787109e-03, ...,
          8.1787109e-03, 8.1787109e-03, 8.1787109e-03],
         [8.1787109e-03, 8.1787109e-03, 8.1787109e-03, ...,
          8.1787109e-03, 8.1787109e-03, 8.1787109e-03],
         [8.1787109e-03, 8.1787109e-03, 8.1787109e-03, ...,
          8.1787109e-03, 8.1787109e-03, 8.1787109e-03],
         ...,
         [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
          0.0000000e+00, 1.0000000e+00, 0.0000000e+00],
         [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
          0.0000000e+00, 0.0000000e+00, 0.0000000e+00],
         [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ...,
          0.0000000e+00, 1.0000000e+00, 0.0000000e+00]],

        [[8.1787109e-03, 8.1787109e-03, 8.1787109e-03, ...,
          8.1787109e-03, 8.1787109e-03, 8.1787109e-03],
         [8.1787109e-03, 8.1787109e-03, 8.1787109e-03, ...,
          8.1787109e-03, 8.1787109e-03, 8.1787109e-03],
         [8.1787109e-03, 8.1787109e-03, 8.1787109e-03, .

In [None]:
dataset = json.load(open(os.path.join(dataset_dir, test_dataset + "_augmented.json")))

In [None]:
for data in dataset:
    if "University" in data["name"]:
        print(data)

{'name': 'Nyarlathotep University', 'description_type': 'description', 'description': 'a leading institution in biomedical engineering and healthcare innovation', 'Q': 'What is the description of Nyarlathotep University?', 'A': 'The description of Nyarlathotep University is a leading institution in biomedical engineering and healthcare innovation.', 'key_string': 'the description of Nyarlathotep University', 'extended_Q': 'What can you tell me about the description characteristics of Nyarlathotep University and what makes it stand out in its field?', 'extended_A': 'The description of Nyarlathotep University is a leading institution in biomedical engineering and healthcare innovation. What makes it stand out in its field is its pioneering research, cutting-edge technology, and a strong emphasis on interdisciplinary collaboration, which collectively contribute to significant advancements in medical science and patient care.'}
{'name': 'Nyarlathotep University', 'description_type': 'objec

In [None]:
collector.model

KblamLlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): KblamLlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (q_proj_new): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        

In [None]:
len(collector.dataset)

132072

In [None]:
os.environ['SEP_QUERY_HEAD'] = 'TRUE'
os.environ['LENGTH_INVARIANCE'] = ''
# os.environ['SEP_QUERY_HEAD'] = ''
encoder_model_spec = 'OAI'
# train_dataset_name = 'avocado_new'
train_dataset_name = 'gpt_data'
epoch = 10000
# lr = 0.0005
lr = 0.0001
extended_qa_spec = 'UseExtendedQA'
outlier_spec = "UseOutlier1"
# outlier_spec = ""
multi_entity_string = "MultiEntities2"
# multi_entity_string = ""
# outlier_spec = ""
kb_size_spec = 'KBSizedynamic'
# kb_size_spec = 'KBSize50'
# os.environ['SCALE_FACTOR'] = '40'
os.environ['SCALE_FACTOR'] = ''
duplicate_spec = "NoDuplicate"
kb_layer_frequency = 1
kb_layer_frequency_str = f'KBTokenLayerFreq{kb_layer_frequency}'
# kb_layer_frequency_str=''
# outlier_spec = ''
key_src = 'key'

llm_model_spec = '/home/t-wangx/llama_weights/llama3_8b_ins'
print(f'/home/t-wangx/azure_blob/xi-kb-llm/outputs/ckpts/stage1_0__lr_{lr}{kb_layer_frequency_str}{extended_qa_spec}{multi_entity_string}{outlier_spec}{duplicate_spec}{kb_size_spec}SepQueryHeadUseDataAugKeyFrom{key_src}_{encoder_model_spec}_{train_dataset_name}_llama3_epoch_{epoch}')
print( f'/home/t-wangx/azure_blob/xi-kb-llm/outputs/ckpts/stage1_0__lr_{lr}{kb_layer_frequency_str}{extended_qa_spec}{multi_entity_string}{outlier_spec}{duplicate_spec}{kb_size_spec}SepQueryHeadUseDataAugFineTuneQueryKeyFrom{key_src}_{train_dataset_name}_{encoder_model_spec}_epoch_{epoch}')


/home/t-wangx/azure_blob/xi-kb-llm/outputs/ckpts/stage1_0__lr_0.0001KBTokenLayerFreq1UseExtendedQAMultiEntities2UseOutlier1NoDuplicateKBSizedynamicSepQueryHeadUseDataAugKeyFromkey_OAI_gpt_data_llama3_epoch_10000
/home/t-wangx/azure_blob/xi-kb-llm/outputs/ckpts/stage1_0__lr_0.0001KBTokenLayerFreq1UseExtendedQAMultiEntities2UseOutlier1NoDuplicateKBSizedynamicSepQueryHeadUseDataAugFineTuneQueryKeyFromkey_gpt_data_OAI_epoch_10000
