In [1]:
import sys

if "../src" not in sys.path:
    sys.path.insert(0, "../src")

In [2]:
import pprint
import json
import os
import sys
import copy
import numpy as np
from tqdm import tqdm
import glob

from experiments import nethook
from experiments.tools import make_inputs
from experiments.utils import load_atlas

from utils import read_json_file



In [3]:
SIZE = "base"
# SIZE = "large"
USERNAME_DIR = "snic2022-22-1003"
SAVE_DIR = f"/mimer/NOBACKUP/groups/{USERNAME_DIR}/APP/qa-retriever/exported"
QA_PROMPT_FORMAT = "question: {question} answer: <extra_id_0>"

os.environ["WANDB_CACHE_DIR"] = f"/mimer/NOBACKUP/groups/{USERNAME_DIR}/OUTPUT/.cache/wandb"
os.environ["TRANSFORMERS_CACHE"]= f"/mimer/NOBACKUP/groups/{USERNAME_DIR}/OUTPUT/.cache/huggingface/transformers"
os.environ["HF_DATASETS_CACHE"] = f"/mimer/NOBACKUP/groups/{USERNAME_DIR}/OUTPUT/.cache/huggingface/datasets"

reader_model_type = f"google/t5-{SIZE}-lm-adapt"
model_path = f"/mimer/NOBACKUP/groups/{USERNAME_DIR}/APP/qa-retriever/data/atlas/models/atlas_nq/{SIZE}"
model, opt = load_atlas(reader_model_type, model_path, n_context=1, qa_prompt_format="question: {question} answer: <extra_id_0>")
nethook.set_requires_grad(False, model)

Some weights of the model checkpoint at facebook/contriever were not used when initializing Contriever: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing Contriever from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Contriever from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
def retrieve_token_embedding(model, tokens):
    inputs = make_inputs(model, tokens, prompt_is_dict=False)
    
    input_ids = inputs.input_ids.cuda().view(inputs.input_ids.size(0), -1)
    attention_mask = inputs.attention_mask.cuda().view(inputs.attention_mask.size(0), -1)
    decoder_input_ids = inputs.decoder_input_ids.cuda()

    cfg = model.reader.encoder.config
    cfg.n_context = inputs.input_ids.size(1)
    cfg.bsz = inputs.input_ids.size(0)

    with nethook.Trace(model, "reader.encoder.embed_tokens", stop=True) as t:
        model.reader(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
        )

    embeddings = t.output[:, :-1, :].detach().cpu().tolist()
    
    return embeddings

In [5]:
for data_path in glob.glob("../data/syn/popqa/data/matched-all/*.jsonl"):
    data = read_json_file(data_path, jsonl=True)
    print(f"We have #{len(data)}")
    print(data[0])

    for row in tqdm(data):
        subj_cf_emb = retrieve_token_embedding(model, row["subj_cf"])
        obj_cf_emb = retrieve_token_embedding(model, row["obj_cf"])

        row["subj_cf_emb"] = subj_cf_emb
        row["obj_cf_emb"] = obj_cf_emb


    save_path = "/".join(data_path.split("/")[:-1]) + "-repr/" + data_path.split("/")[-1]
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    with open(save_path, "w", encoding="utf-8") as fj:
        for row in data:
            fj.write(json.dumps(row) + "\n")

We have #1
{'question': "What is George Rankin's occupation?", 'answers': ['politician'], 'passages': [{'title': '', 'text': 'The occupation of George Rankin is politician.'}], 'subj': 'George Rankin', 'prop': 'occupation', 'obj': 'politician', 'views': {'s_pop': '142', 'o_pop': '25692'}, 'query': "question: What is George Rankin's occupation? answer: <extra_id_0>", 'gen_nocontext': 'a lawyer', 'gen_context': 'politician', 'gen_nocontext_matched': False, 'gen_context_matched': True, 'matched': False, 'prop_cf': [], 'subj_cf': ['Meg McCall', 'Nathan Purdee', 'Guy Joseph Bonnet', 'Gordie Gosse', 'Mariana Vicente', 'Henry Tizard', 'Henry Feilden', 'Kanye West', 'Pierre Pansu', 'Petru Vlah'], 'obj_cf': ['illustrator', 'model', 'musician', 'psychiatrist', 'lawyer', 'astronaut', 'financier', 'librarian', 'diplomat', 'revolutionary'], 'subj_cf_diff': ['Mary of Woodstock', 'Bad News Bears', 'Rakhyah District', 'Sandar IL', 'Pearl in the Crown', 'Violin Concerto', 'The Boss', 'Angel on the Righ

100%|██████████| 1/1 [00:00<00:00, 143.41it/s]
