IDEA:
1. Use the POG subobjective idea to split up the question into smaller questions
2. Use SubgraphRAG for each of these smaller subobjectives, allowing us to fit smaller context into each prompt step
3. Combine Each subobjective response into one coherent answer

Below, we create a custom RetrieverDataset class that deals with individual samples

In [2]:
import os
import pickle
import torch
import torch.nn.functional as F
import networkx as nx
import numpy as np

from tqdm import tqdm

class customRetrieverDataset:
    def __init__(
        self,
        sample,
        emb_dict,
        skip_no_path=True
    ):

        # Extract directed shortest paths from topic entities to answer
        # entities or vice versa as weak supervision signals for triple scoring.
        triple_score_dict = self._get_triple_scores(
            sample)

        # Put everything together.
        self._assembly(
            sample, triple_score_dict, emb_dict, skip_no_path)

    def _load_processed(
        self,
        dataset_name,
        split
    ):
        processed_file = os.path.join(
            f'data_files/{dataset_name}/processed/{split}.pkl')
        with open(processed_file, 'rb') as f:
            return pickle.load(f)

    def _get_triple_scores(
        self,
        sample
    ):
        
        #sample_i = processed_dict_list[i]
        #sample_i_id = sample_i['id']
        triple_scores, max_path_length = self._extract_paths_and_score(
            sample)
        
        triple_score_dict = {
            'triple_scores': triple_scores,
            'max_path_length': max_path_length
        }
        
        return triple_score_dict

    def _extract_paths_and_score(
        self,
        sample
    ):
        nx_g = self._get_nx_g(
            sample['h_id_list'],
            sample['r_id_list'],
            sample['t_id_list']
        )

        # Each raw path is a list of entity IDs.
        path_list_ = []
        for q_entity_id in sample['q_entity_id_list']:
            for a_entity_id in sample['a_entity_id_list']:
                paths_q_a = self._shortest_path(nx_g, q_entity_id, a_entity_id)
                if len(paths_q_a) > 0:
                    path_list_.extend(paths_q_a)

        if len(path_list_) == 0:
            max_path_length = None
        else:
            max_path_length = 0

        # Each processed path is a list of triple IDs.
        path_list = []

        for path in path_list_:
            num_triples_path = len(path) - 1
            max_path_length = max(max_path_length, num_triples_path)
            triples_path = []

            for i in range(num_triples_path):
                h_id_i = path[i]
                t_id_i = path[i+1]
                triple_id_i_list = [
                    nx_g[h_id_i][t_id_i]['triple_id']
                ]              
                triples_path.append(triple_id_i_list)

            path_list.append(triples_path)

        num_triples = len(sample['h_id_list'])
        triple_scores = self._score_triples(
            path_list,
            num_triples
        )
        
        return triple_scores, max_path_length

    def _get_nx_g(
        self,
        h_id_list,
        r_id_list,
        t_id_list
    ):
        nx_g = nx.DiGraph()
        num_triples = len(h_id_list)
        for i in range(num_triples):
            h_i = h_id_list[i]
            r_i = r_id_list[i]
            t_i = t_id_list[i]
            nx_g.add_edge(h_i, t_i, triple_id=i, relation_id=r_i)

        return nx_g

    def _shortest_path(
        self,
        nx_g,
        q_entity_id,
        a_entity_id
    ):
        try:
            forward_paths = list(nx.all_shortest_paths(nx_g, q_entity_id, a_entity_id))
        except:
            forward_paths = []
        
        try:
            backward_paths = list(nx.all_shortest_paths(nx_g, a_entity_id, q_entity_id))
        except:
            backward_paths = []
        
        full_paths = forward_paths + backward_paths
        if (len(forward_paths) == 0) or (len(backward_paths) == 0):
            return full_paths
        
        min_path_len = min([len(path) for path in full_paths])
        refined_paths = []
        for path in full_paths:
            if len(path) == min_path_len:
                refined_paths.append(path)
        
        return refined_paths

    def _score_triples(
        self,
        path_list,
        num_triples
    ):
        triple_scores = torch.zeros(num_triples)
        
        for path in path_list:
            for triple_id_list in path:
                triple_scores[triple_id_list] = 1.

        return triple_scores

    def _load_emb(
        self,
        dataset_name,
        text_encoder_name,
        split
    ):
        file_path = f'data_files/{dataset_name}/emb/{text_encoder_name}/{split}.pth'
        dict_file = torch.load(file_path)
        
        return dict_file

    def _assembly(
        self,
        sample_i,
        triple_score_dict,
        emb_dict,
        skip_no_path,
    ):
        self.processed_dict_list = []

        num_relevant_triples = []
        num_skipped = 0
        
        #sample_i = processed_dict_list[i]
        #sample_i_id = sample_i['id']
        #assert sample_i_id in triple_score_dict

        triple_score_i = triple_score_dict['triple_scores']
        max_path_length_i = triple_score_dict['max_path_length']

        num_relevant_triples_i = len(triple_score_i.nonzero())
        num_relevant_triples.append(num_relevant_triples_i)

        sample_i['target_triple_probs'] = triple_score_i
        sample_i['max_path_length'] = max_path_length_i

        # if skip_no_path and (max_path_length_i in [None, 0]):
        #     num_skipped += 1
        #     continue

        sample_i.update(emb_dict)

        sample_i['a_entity'] = list(set(sample_i['a_entity']))
        sample_i['a_entity_id_list'] = list(set(sample_i['a_entity_id_list']))

        # PE for topic entities.
        num_entities_i = len(sample_i['text_entity_list']) + len(sample_i['non_text_entity_list'])
        topic_entity_mask = torch.zeros(num_entities_i)
        topic_entity_mask[sample_i['q_entity_id_list']] = 1.
        topic_entity_one_hot = F.one_hot(topic_entity_mask.long(), num_classes=2)
        sample_i['topic_entity_one_hot'] = topic_entity_one_hot.float()

        self.processed_dict_list.append(sample_i)
        self.processed_dict = sample_i

        median_num_relevant = int(np.median(num_relevant_triples))
        mean_num_relevant = int(np.mean(num_relevant_triples))
        max_num_relevant = int(np.max(num_relevant_triples))

        print(f'# skipped samples: {num_skipped}')
        print(f'# relevant triples | median: {median_num_relevant} | mean: {mean_num_relevant} | max: {max_num_relevant}')

    def __len__(self):
        return len(self.processed_dict_list)
    
    def __getitem__(self, i):
        return self.processed_dict_list[i]

Next, from the EmbInferDataset, we extract the processed sample, which presents sample info in desired format, and feed it into RetrieverDataset

Next, we load in our retrieval model, that performs convolution and subsequent classification on triples

Now we have our generated triples along with scores.

EMBEDDING MODULE

In [3]:
from src.config.emb import load_yaml
import torch

device = torch.device('cuda:0')
dataset_name = "webqsp"
config_file = f'../retrieve/configs/emb/gte-large-en-v1.5/{dataset_name}.yaml'
config = load_yaml(config_file)
torch.set_num_threads(config['env']['num_threads'])

text_encoder_name = config['text_encoder']['name']
if text_encoder_name == 'gte-large-en-v1.5':
    from src.model.text_encoders import GTELargeEN
    text_encoder = GTELargeEN(device)

  from .autonotebook import tqdm as notebook_tqdm
    PyTorch 2.5.1+cu121 with CUDA 1201 (you have 2.1.0+cu121)
    Python  3.10.16 (you have 3.10.16)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


In [4]:
def embed_sample(sample):
    id, q_text, text_entity_list, relation_list = sample['id'], sample['question'], sample['text_entity_list'], sample['relation_list']
    q_emb, entity_embs, relation_embs = text_encoder(q_text, text_entity_list, relation_list)
    emb_dict_i = {
    'q_emb': q_emb,
    'entity_embs': entity_embs,
    'relation_embs': relation_embs
    }
    return id, emb_dict_i

def embed_question(question):
    q_emb = text_encoder(q_text)
    emb_dict_i = {
    'q_emb': q_emb
    }
    return id, emb_dict_i

RETRIEVAL MODULE

In [5]:
from src.model.retriever import Retriever
from src.setup import set_seed, prepare_sample

device = torch.device(f'cuda:0')
    
cpt = torch.load("/home/gridsan/mhadjiivanov/meng/SubgraphRAG/retrieve/webqsp_Jan03-19:20:34/cpt.pth", map_location='cpu')
config = cpt['config']
set_seed(config['env']['seed'])
torch.set_num_threads(config['env']['num_threads'])

emb_size = 1024 #infer_set[0]['q_emb'].shape[-1]
model = Retriever(emb_size, **config['retriever']).to(device)
model.load_state_dict(cpt['model_state_dict'])
model = model.to(device)
model.eval()

Retriever(
  (non_text_entity_emb): Embedding(1, 1024)
  (dde): DDE(
    (layers): ModuleList(
      (0-1): 2 x PEConv()
    )
    (reverse_layers): ModuleList(
      (0-1): 2 x PEConv()
    )
  )
  (pred): Sequential(
    (0): Linear(in_features=4116, out_features=1024, bias=True)
    (1): ReLU()
    (2): Linear(in_features=1024, out_features=1, bias=True)
  )
)

In [6]:
from src.dataset.retriever import collate_retriever
from src.setup import set_seed, prepare_sample

max_K = 50

def get_top_k(raw_sample, max_K):
    
    pred_dict = dict()

    sample = collate_retriever([raw_sample])
    h_id_tensor, r_id_tensor, t_id_tensor, q_emb, entity_embs,\
        num_non_text_entities, relation_embs, topic_entity_one_hot,\
        target_triple_probs, a_entity_id_list = prepare_sample(device, sample)
    
    entity_list = raw_sample['text_entity_list'] + raw_sample['non_text_entity_list']
    relation_list = raw_sample['relation_list']
    top_K_triples = []
    target_relevant_triples = []

    if len(h_id_tensor) != 0:
        pred_triple_logits = model(
            h_id_tensor, r_id_tensor, t_id_tensor, q_emb, entity_embs,
            num_non_text_entities, relation_embs, topic_entity_one_hot)
        pred_triple_scores = torch.sigmoid(pred_triple_logits).reshape(-1)
        top_K_results = torch.topk(pred_triple_scores, 
                                   min(max_K, len(pred_triple_scores)))
        top_K_scores = top_K_results.values.cpu().tolist()
        top_K_triple_IDs = top_K_results.indices.cpu().tolist()

        for j, triple_id in enumerate(top_K_triple_IDs):
            top_K_triples.append((
                entity_list[h_id_tensor[triple_id].item()],
                relation_list[r_id_tensor[triple_id].item()],
                entity_list[t_id_tensor[triple_id].item()],
                top_K_scores[j]
            ))

        target_relevant_triple_ids = raw_sample['target_triple_probs'].nonzero().reshape(-1).tolist()
        for triple_id in target_relevant_triple_ids:
            target_relevant_triples.append((
                entity_list[h_id_tensor[triple_id].item()],
                relation_list[r_id_tensor[triple_id].item()],
                entity_list[t_id_tensor[triple_id].item()],
            ))

    sample_dict = {
        'question': raw_sample['question'],
        'scored_triplets': top_K_triples,
        'q_entity': raw_sample['q_entity'],
        'q_entity_in_graph': [entity_list[e_id] for e_id in raw_sample['q_entity_id_list']],
        'a_entity': raw_sample['a_entity'],
        'a_entity_in_graph': [entity_list[e_id] for e_id in raw_sample['a_entity_id_list']],
        'max_path_length': raw_sample['max_path_length'],
        'target_relevant_triples': target_relevant_triples
    }

    pred_dict[raw_sample['id']] = sample_dict
    
    return sample_dict

COMBINED EMBEDDING + RETRIEVAL

In [7]:
from src.dataset.emb import customEmbInferDataset
from src.dataset.retriever import customRetrieverDataset

def raw_to_pre_pred(sample, dataset_name,k):
    
    entity_identifier_file = f"/home/gridsan/mhadjiivanov/meng/SubgraphRAG/retrieve/data_files/{dataset_name}/entity_identifiers.txt"
    entity_identifiers = []
    with open(entity_identifier_file, 'r') as f:
        for line in f:
            entity_identifiers.append(line.strip())
    entity_identifiers = set(entity_identifiers)
    
    sample = customEmbInferDataset(
            sample,
            entity_identifiers).processed_dict

    id, emb_dict = embed_sample(sample)
    infer_set = customRetrieverDataset(sample,emb_dict)
    
    return get_top_k(infer_set[0],k)

LLM MODULE

In [8]:
from preprocess.prepare_prompts import get_prompts_for_data

def get_defined_prompts(prompt_mode, model_name, llm_mode):
    if 'gpt' in model_name or 'gpt' in prompt_mode:
        if 'gptLabel' in prompt_mode:
            from prompts import sys_prompt_gpt, cot_prompt_gpt
            return sys_prompt_gpt, cot_prompt_gpt
        else:
            from prompts import icl_sys_prompt, icl_cot_prompt
            return icl_sys_prompt, icl_cot_prompt
    elif 'noevi' in prompt_mode:
        from prompts import noevi_sys_prompt, noevi_cot_prompt
        return noevi_sys_prompt, noevi_cot_prompt
    elif 'icl' in llm_mode:
        from prompts import icl_sys_prompt, icl_cot_prompt
        return icl_sys_prompt, icl_cot_prompt
    else:
        from prompts import sys_prompt, cot_prompt
        return sys_prompt, cot_prompt


# sys_prompt, cot_prompt = get_defined_prompts(prompt_mode, model_name, llm_mode)
# data = get_prompts_for_data([sample_dict],prompt_mode,sys_prompt, cot_prompt,thres = 0)

In [9]:
from prompts import icl_user_prompt, icl_ass_prompt

def get_outputs(outputs, model_name):
    return outputs[0]['generated_text'][-1]['content']

def llm_inf(llm, prompts, mode, model_name):
    res = []
    if 'sys' in mode:
        conversation = [{"role": "system", "content": prompts['sys_query']}]

    if 'icl' in mode:
        conversation.append({"role": "user", "content": icl_user_prompt})
        conversation.append({"role": "assistant", "content": icl_ass_prompt})

    if 'sys' in mode:
        conversation.append({"role": "user", "content": prompts['user_query']})
        
        outputs = get_outputs(llm(text_inputs=conversation), model_name)
        res.append(outputs)

    if 'sys_cot' in mode:
        if 'clear' in mode:
            conversation = []
        conversation.append({"role": "assistant", "content": outputs})
        conversation.append({"role": "user", "content": prompts['cot_query']})
        
        outputs = get_outputs(llm(text_inputs=conversation), model_name)
        res.append(outputs)
    elif "dc" in mode:
        if 'ans:' not in res[0].lower() or "ans: not available" in res[0].lower() or "ans: no information available" in res[0].lower():
            conversation.append({"role": "user", "content": prompts['cot_query']})
            outputs = get_outputs(llm(text_inputs=conversation), model_name)
            res[0] = outputs
        res.append("")
    else:
        res.append("")

    return res


#llm_inf(llm, data[0], llm_mode, model_name)

In [10]:
from transformers import pipeline
import torch

llm_mode = "sys_icl_dc"
model_name = "/home/gridsan/mhadjiivanov/meng/SubgraphRAG/hf/models/Llama-3.2-3B-Instruct"
prompt_mode = "scored_100"

device = torch.device("cuda")
llm = pipeline("text-generation", model=model_name, device=device, max_length = 2700)

Loading checkpoint shards: 100%|██████████| 2/2 [02:03<00:00, 61.86s/it]


DEMO

In [None]:
from datasets import load_from_disk

dataset = load_from_disk("/home/gridsan/mhadjiivanov/meng/SubgraphRAG/retrieve/data_files/webqsp/webqsp")
sample = dataset['test'][135]
x = raw_to_pre_pred(sample,'webqsp',50)

print(x['question'])
print(x['a_entity'])

# skipped samples: 0
# relevant triples | median: 6 | mean: 6 | max: 6
what are the three official languages of belgium
['German Language', 'French', 'Dutch Language']


In [None]:
llm_mode = "sys_icl_dc"
model_name = "/home/gridsan/mhadjiivanov/meng/SubgraphRAG/hf/models/Llama-3.2-3B-Instruct"
prompt_mode = "scored_100"

sys_prompt, cot_prompt = get_defined_prompts(prompt_mode, model_name, llm_mode)
data = get_prompts_for_data([x],prompt_mode,sys_prompt, cot_prompt,thres = 0)

In [None]:
print(data[0]['sys_query'])
#'a_entity','scored_triplets','sys_query','user_query', 'all_query', 'cot_query'

Based on the triplets retrieved from a knowledge graph, please answer the question. Please return formatted answers as a list, each prefixed with "ans:".


In [15]:
llm_inf(llm, data[0], llm_mode, model_name)

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


['To find the three official languages of Belgium, we need to find the official languages of Belgium.\n\nFrom the triplets, we can see that the official languages of Belgium are:\n\n(Belgium,location.country.official_language,French)\n(Belgium,location.country.official_language,German Language)\n(Belgium,location.country.official_language,Dutch Language)\n\nTherefore, the formatted answers are:\n\nans: French\nans: German Language\nans: Dutch Language',
 '']

Subobjective Prompting Experiment

In [74]:
#subobjective prompting
def subobjective_prompt(question):
    
    if question[-1] != '?':
        question += '?'

    subobjective_prompt = """Please break down the process of answering the question into as few subobjectives as possible based on semantic analysis.
    Here is an example: 
    Q: Which of the countries in the Caribbean has the smallest country calling code?
    Output: ["What countries are in the Caribbean", "What is the country calling code for each Caribbean country", "What is the smallest country calling code of the ones found"]

    Now you need to directly output subobjectives of the following question in list format without other information or notes. Match the format of the example. Ensure output is a python list of strings.
    Q: """

    prompt_and_question = subobjective_prompt + question
    
    return prompt_and_question

# messages = [
#     {"role": "user", 
#     "content": prompt_and_question}
# ]

# llm(text_inputs = messages)

In [20]:
question = sample[0]['question']
prompt_and_question = subobjective_prompt(question)

messages = [
{"role": "user", 
"content": prompt_and_question}]

subobjectives = split_subobjectives(get_outputs(llm(text_inputs = messages),""))

sample[0]['question'] = subobjectives[0]

raw_to_pre_pred(sample,'webqsp')


NameError: name 'sample' is not defined

In [81]:
def split_subobjectives(s):
    return s[s.find('[')+1:s.find(']')].split(', ')
    

In [7]:
from single_sample import *
from datasets import load_from_disk

dataset = load_from_disk("/home/gridsan/mhadjiivanov/meng/SubgraphRAG/retrieve/data_files/metaqa/metaqa")
sample = dataset['test'][135]


text_encoder = init_text_encoder('metaqa')
retriever_model = init_retriever('/home/gridsan/mhadjiivanov/meng/SubgraphRAG/retrieve/metaqa_Nov23-02:01:19/cpt.pth')


raw_to_pre_pred(sample,text_encoder,retriever_model,'metaqa',50)


# skipped samples: 0
# relevant triples | median: 1 | mean: 1 | max: 1


{'question': 'when did the films starred by Veer-Zaara actors release',
 'scored_triplets': [('Veer-Zaara',
   'release_year',
   '2004',
   0.9847598671913147),
  ('Veer-Zaara', 'has_genre', 'Romance', 0.7718797922134399),
  ('Romance', 'release_year', '1999', 0.09057214111089706),
  ('Romance', 'in_language', 'French', 0.0031211150344461203),
  ('Veer-Zaara', 'written_by', 'Aditya Chopra', 0.0005846134154126048),
  ('Veer-Zaara', 'has_genre', 'Drama', 0.0003422275185585022),
  ('Veer-Zaara', 'directed_by', 'Yash Chopra', 0.00022799526050221175),
  ('Veer-Zaara', 'starred_actors', 'Shah Rukh Khan', 0.0002083664876408875),
  ('Veer-Zaara', 'starred_actors', 'Rani Mukerji', 3.497489160508849e-05),
  ('Veer-Zaara', 'starred_actors', 'Preity Zinta', 8.108817382890265e-06),
  ('Veer-Zaara', 'has_tags', 'romance', 4.055003046232741e-06),
  ('Anatomy of Hell', 'release_year', '2004', 3.101737320321263e-06),
  ('Veer-Zaara', 'has_tags', 'yash chopra', 1.8115691773346043e-06),
  ('Romance', 'h