In [None]:
from experiments.eval import eval_accuracy
from transformers import AutoTokenizer, logging
from kblam.kb_encoder import KBEncoder
from kblam.models.kblam_config import KBLaMConfig
from kblam.models.llama_model import KblamLlamaForCausalLM
from kblam.models.phi3_model import KBLaMPhi3ForCausalLM
from kblam.utils.data_utils import aug_row, generate_multi_entity_qa
from kblam.utils.train_utils import get_kb_embd
import os
import torch
import json
import numpy as np
from datetime import datetime

class ResultsCollector():
    def __init__(
            self,
            dataset_dir,
            encoder_path,
            encoder_spec,
            llm_base_dir,
            llm_type,
            model_path,
            query_head_path,
            test_dataset,
            scale_factor=None,
            kb_layer_frequency=-1):
        self.dataset_dir = dataset_dir
        self.encoder_path = encoder_path
        self.encoder_spec = encoder_spec
        self.llm_base_dir = llm_base_dir
        self.llm_type = llm_type
        self.model_path = model_path
        self.query_head_path = query_head_path
        self.test_dataset = test_dataset
        self.scale_factor = scale_factor

        encoder_model_spec = encoder_spec

        validation_part_start_idx = 120000 if "gpt" in test_dataset else 0
        self.dataset = json.load(open(os.path.join(dataset_dir, test_dataset) + ".json"))[validation_part_start_idx: ]

        self.key_embds = np.load(
            os.path.join(dataset_dir, f"{test_dataset}_{encoder_model_spec}_embd_key.npy")
        ).astype("float32")[validation_part_start_idx: ]
        self.value_embds = np.load(
            os.path.join(dataset_dir, f"{test_dataset}_{encoder_model_spec}_embd_value.npy")
        ).astype("float32")[validation_part_start_idx: ]

        if kb_layer_frequency == -1:
            kb_layer_frequency = 3

        self.tokenizer = AutoTokenizer.from_pretrained(llm_base_dir, trust_remote_code=True, padding_side="left")
        self.tokenizer.pad_token = "^"

        if llm_type == "llama3":
            if query_head_path:
                self.model = KblamLlamaForCausalLM.from_pretrained(
                    model_path,
                    device_map="cuda",
                    torch_dtype="auto",
                    trust_remote_code=True,
                )
                self.model.load_query_head(query_head_path)
            else:
                self.model = KblamLlamaForCausalLM.from_pretrained(
                    model_path,
                    device_map="cuda",
                    torch_dtype="auto",
                    trust_remote_code=True,
                )
        else:
            self.model = KBLaMPhi3ForCausalLM.from_pretrained(
                model_path,
                device_map="cuda",
                torch_dtype="auto",
                trust_remote_code=True,
            )

        self.model.generation_config.pad_token_id = self.tokenizer.pad_token_id
        self.model.generation_config.eos_token_id = 128009
        self.model.eval()


        kb_config = KBLaMConfig(
            # sep_query_head=True,
            kb_layer_frequency=kb_layer_frequency,
            kb_scale_factor=self.scale_factor,
            **self.model.config.to_dict(),
        )
        self.model.config = kb_config

        self.encoder = KBEncoder(
            encoder_name=encoder_spec.upper(),
            projector_type="linear",
            endpoint_url="",
            out_dim=self.model.config.hidden_size * (self.model.config.num_hidden_layers // kb_layer_frequency + 1),
            frozen_base_model=True,
            projector_kwargs={"mlp_depth": 1, "mlp_hidden_dim": 512},
            device=torch.device("cuda"),
        )

        self.encoder.load_state_dict(torch.load(encoder_path))


    def collect_results(self):
        xs = [50, 100, 200, 400, 800, 1600, 3200, 6400]
        for x in xs:
            for trial in range(5):
                trial_start_time = datetime.now()
                # experiment_name = f"test_synthetic_scale_factor_validation_only_100_{x}_triples_{trial}_trial"
                experiment_name = f"nq_10000_scale_factor_100_{x}_triples_{trial}_trial"
                print(f"starting {experiment_name}")
                eval_accuracy(
                    dataset_dir=self.dataset_dir,
                    test_dataset=self.test_dataset,
                    encoder_spec=self.encoder_spec,
                    kb_scale_factor=self.scale_factor,
                    encoder_path="/home/t-isazawat/azure-blob/xi-kb-llm/best_ckpt/encoder_ckpt_20000_OAI.pt",
                    exp_config=experiment_name,
                    fancy_question=False,
                    kb_layer_frequency=3,
                    kb_size=x,
                    llm_base_dir="/home/t-isazawat/azure-blob/xi-kb-llm/llama3_8b_ins",
                    llm_type="llama3",
                    model_path="/home/t-isazawat/azure-blob/xi-kb-llm/llama3_8b_ins",
                    test_batch_size=min(x, 200),
                    use_shift_match=False,
                    query_head_path="/home/t-isazawat/azure-blob/xi-kb-llm/best_ckpt/learned_query_head_20000_OAI.pth",
                    save_dir="/home/t-isazawat/kblam_attention/save_dir",
                    attn_save_dir="/home/t-isazawat/kblam_attention/attention",
                    model=self.model,
                    dataset=self.dataset,
                    key_embds=self.key_embds,
                    value_embds=self.value_embds,
                    tokenizer=self.tokenizer,
                    encoder=self.encoder,
                )
                print(f"Took {datetime.now() - trial_start_time}")


  from .autonotebook import tqdm as notebook_tqdm
`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.
[nltk_data] Downloading package wordnet to
[nltk_data]     /home/t-isazawat/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [2]:

# dataset_dir = "/home/t-isazawat/kblam_rebuttal_files/synthetic_perturbed"
# test_dataset = "synthetic_perturbed"
# model_dir = "/home/t-isazawat/azure-blob/xi-kb-llm/llama3_8b_ins"
# base_dir = "/home/t-isazawat/azure-blob/xi-kb-llm/llama3_8b_ins"
# query_head_path = "/home/t-isazawat/azure-blob/xi-kb-llm/best_ckpt/learned_query_head_20000_OAI.pth"
# encoder_path = "/home/t-isazawat/azure-blob/xi-kb-llm/best_ckpt/encoder_ckpt_20000_OAI.pt"

dataset_dir = "/home/t-isazawat/kblam_rebuttal_files/nq_10000"
test_dataset = "nq_10000"
model_dir = "/home/t-isazawat/azure-blob/xi-kb-llm/llama3_8b_ins"
base_dir = "/home/t-isazawat/azure-blob/xi-kb-llm/llama3_8b_ins"
query_head_path = "/home/t-isazawat/azure-blob/xi-kb-llm/best_ckpt/learned_query_head_20000_OAI.pth"
encoder_path = "/home/t-isazawat/azure-blob/xi-kb-llm/best_ckpt/encoder_ckpt_20000_OAI.pt"

collector = ResultsCollector(
    dataset_dir,
    encoder_path,
    "oai",
    base_dir,
    "llama3",
    base_dir,
    query_head_path,
    test_dataset,
    scale_factor=100
)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


BASE MODEL: /home/t-isazawat/azure-blob/xi-kb-llm/llama3_8b_ins


Loading checkpoint shards: 100%|██████████| 4/4 [00:27<00:00,  6.93s/it]
Some weights of LlamaModel were not initialized from the model checkpoint at /home/t-isazawat/azure-blob/xi-kb-llm/llama3_8b_ins and are newly initialized: ['model.layers.0.self_attn.score_shift', 'model.layers.1.self_attn.score_shift', 'model.layers.10.self_attn.score_shift', 'model.layers.11.self_attn.score_shift', 'model.layers.12.self_attn.score_shift', 'model.layers.13.self_attn.score_shift', 'model.layers.14.self_attn.score_shift', 'model.layers.15.self_attn.score_shift', 'model.layers.16.self_attn.score_shift', 'model.layers.17.self_attn.score_shift', 'model.layers.18.self_attn.score_shift', 'model.layers.19.self_attn.score_shift', 'model.layers.2.self_attn.score_shift', 'model.layers.20.self_attn.score_shift', 'model.layers.21.self_attn.score_shift', 'model.layers.22.self_attn.score_shift', 'model.layers.23.self_attn.score_shift', 'model.layers.24.self_attn.score_shift', 'model.layers.25.self_attn.score_sh

Learned query heads loaded.


  self.encoder.load_state_dict(torch.load(encoder_path))


In [None]:
collector.collect_results()

starting nq_10000_scale_factor_100_50_triples_0_trial
['^^^^^^<|start_header_id|>user<|end_header_id|> who has the most points in high school basketball history<|eot_id|><|start_header_id|>assistant<|end_header_id|>The assistant cannot be found in the KB.<|eot_id|>^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^', '^^<|start_header_id|>user<|end_header_id|> who was the first king of egypt to use the title pharaoh<|eot_id|><|start_header_id|>assistant<|end_header_id|>The first king of Egypt to use the title pharaoh.<|eot_id|>^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^', '^^^^^^<|start_header_id|>user<|end_header_id|> who\'s the singer in avicii hey brother<|eot_id|><|start_header_id|>assistant<|end_header_id|>The singer in Avicii "Hey Brother" is a classical pianist.<|eot_id|>^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^', '^^^^^^<|start_header_id|>user<|end_header_id|> how many rooms are in the gaylord texan<|eot_id|><|start_header_id|>assistant<|end_header_id|>The Gaylord Texan has 2

OSError: Not enough free space to write 2480640000 bytes

: 

In [None]:
len(collector.dataset)

29096

In [None]:
os.environ['SEP_QUERY_HEAD'] = 'TRUE'
os.environ['LENGTH_INVARIANCE'] = ''
# os.environ['SEP_QUERY_HEAD'] = ''
encoder_model_spec = 'OAI'
# train_dataset_name = 'avocado_new'
train_dataset_name = 'gpt_data'
epoch = 10000
# lr = 0.0005
lr = 0.0001
extended_qa_spec = 'UseExtendedQA'
outlier_spec = "UseOutlier1"
# outlier_spec = ""
multi_entity_string = "MultiEntities2"
# multi_entity_string = ""
# outlier_spec = ""
kb_size_spec = 'KBSizedynamic'
# kb_size_spec = 'KBSize50'
# os.environ['SCALE_FACTOR'] = '40'
os.environ['SCALE_FACTOR'] = ''
duplicate_spec = "NoDuplicate"
kb_layer_frequency = 1
kb_layer_frequency_str = f'KBTokenLayerFreq{kb_layer_frequency}'
# kb_layer_frequency_str=''
# outlier_spec = ''
key_src = 'key'

llm_model_spec = '/home/t-wangx/llama_weights/llama3_8b_ins'
print(f'/home/t-wangx/azure_blob/xi-kb-llm/outputs/ckpts/stage1_0__lr_{lr}{kb_layer_frequency_str}{extended_qa_spec}{multi_entity_string}{outlier_spec}{duplicate_spec}{kb_size_spec}SepQueryHeadUseDataAugKeyFrom{key_src}_{encoder_model_spec}_{train_dataset_name}_llama3_epoch_{epoch}')
print( f'/home/t-wangx/azure_blob/xi-kb-llm/outputs/ckpts/stage1_0__lr_{lr}{kb_layer_frequency_str}{extended_qa_spec}{multi_entity_string}{outlier_spec}{duplicate_spec}{kb_size_spec}SepQueryHeadUseDataAugFineTuneQueryKeyFrom{key_src}_{train_dataset_name}_{encoder_model_spec}_epoch_{epoch}')


/home/t-wangx/azure_blob/xi-kb-llm/outputs/ckpts/stage1_0__lr_0.0001KBTokenLayerFreq1UseExtendedQAMultiEntities2UseOutlier1NoDuplicateKBSizedynamicSepQueryHeadUseDataAugKeyFromkey_OAI_gpt_data_llama3_epoch_10000
/home/t-wangx/azure_blob/xi-kb-llm/outputs/ckpts/stage1_0__lr_0.0001KBTokenLayerFreq1UseExtendedQAMultiEntities2UseOutlier1NoDuplicateKBSizedynamicSepQueryHeadUseDataAugFineTuneQueryKeyFromkey_gpt_data_OAI_epoch_10000
