In [1]:
from rdflib import Graph
from rdflib.term import URIRef
from cmatcher.module_search.pagerank import gen_pagerank_sparql_queries
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig, AutoModelForCausalLM
import torch
from rdflib.namespace import RDF
from cmatcher.rag.rag_reduce import ont_query_reduce, reduce_ont, get_detailed_instruct, gen_doc
from cmatcher.rag.prompt_gen import gen_prompt
from cmatcher.rag.prompt_to_edoal import match
from tqdm.auto import tqdm
import gc
import dill
import subprocess
import os
import re
import random
import itertools
import torch.nn.functional as F
import difflib
# define deterministic behavior
random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0) 

In [2]:
def batched(iterable, n):
    # batched('ABCDEFG', 3) → ABC DEF G
    if n < 1:
        raise ValueError('n must be at least one')
    iterator = iter(iterable)
    while batch := tuple(itertools.islice(iterator, n)):
        yield batch

In [17]:
def embed(model, text, instruction='', batch_size=8, max_length=4096):
    with torch.no_grad():
        embeddings = []
        for t in tqdm(list(batched(text, batch_size))):
            embeddings.append(model.encode(t, instruction=instruction, max_length=max_length))
        return torch.cat(embeddings, dim=0)


def gen_docs(g, max_entities=10):
    ls = list(filter(lambda x: (x, RDF.first, None) not in g, set(g.subjects())))
    ls.sort()
    
    passages = []
    for s in ls:
        passages.append(gen_doc(s, g, max_entities=max_entities))
    
    return ls, passages

def gen_prompts2():
    
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )
    model = AutoModel.from_pretrained(
        'nvidia/NV-Embed-v2',
        quantization_config=quantization_config,
        trust_remote_code=True,
        device_map='auto',
    )
    model.eval()
    base_path = '/projets/melodi/gsantoss/data/complex/conference/ont/'
    
    
    o1 = Graph().parse(base_path + 'cmt.owl')
    o2 = Graph().parse(base_path + 'conference.owl')
    
    queries1 = gen_pagerank_sparql_queries(o1)
    queries1.sort()
    
    prompt = 'Given the following SPARQL query, retrieve relevant entities that are related to the query'
    with open('cmatcher/prompt_examples/sample1.txt', 'r') as f:
        sample1 = f.read()
        
    with open('cmatcher/prompt_examples/sample2.txt', 'r') as f:
        sample2 = f.read()
    
    
    ls_1, o1_passages = gen_docs(o1)
    ls_2, o2_passages = gen_docs(o2)
    
    o1_embeddings = embed(model, o1_passages, instruction='', max_length=4096, batch_size=2)
    o2_embeddings = embed(model, o2_passages, instruction='', max_length=4096, batch_size=2)
    
        
    prompts = []
    for query in tqdm(queries1):
        
        query_embeddings = embed(model, [query], instruction=f'Instruct: {prompt}\nQuery: ', max_length=4096, batch_size=2)
        
        scores1 = o1_embeddings @ query_embeddings.T
        scores2 = o2_embeddings @ query_embeddings.T
    
        module1 = reduce_ont(ls_1, scores1, o1, top_n=2, i_max_depth=1, o_max_depth=2)
        module2 = reduce_ont(ls_2, scores2, o2, top_n=2, i_max_depth=1, o_max_depth=2)
    
        prompts.append(gen_prompt(module1, module2, None, sample1, sample2))
        
    return prompts

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

In [16]:
# old_prompts = prompts


In [18]:
    
    
    
for i, (p1, p2) in enumerate(zip(old_prompts, prompts)):
    print(f'Prompt {i}')
    diff = difflib.ndiff(p1.splitlines(), p2.splitlines())
    for line in diff:
        if line.startswith('- ') or line.startswith('+ '):
            print(line)
    
    print('=' * 100)

Prompt 0
Prompt 1
Prompt 2
Prompt 3
Prompt 4
Prompt 5
Prompt 6
Prompt 7
Prompt 8
