In [1]:
from rdflib import Graph
from cmatcher.module_search.pagerank import gen_pagerank_sparql_queries, page_rank
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig, AutoModelForCausalLM
import torch
import torch.nn as nn
from rdflib.namespace import RDF, OWL
from cmatcher.rag.rag_reduce import ont_query_reduce, reduce_ont, get_detailed_instruct, gen_doc
from cmatcher.rag.prompt_gen import gen_prompt
from cmatcher.rag.prompt_to_edoal import match
from cmatcher.repair.edoal import merge_edoals
from tqdm.auto import tqdm
import gc
import dill
import subprocess
import os
import re
import random
import itertools
import torch.nn.functional as F


# define deterministic behavior
random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)

In [2]:
def batched(iterable, n):
    # batched('ABCDEFG', 3) → ABC DEF G
    if n < 1:
        raise ValueError('n must be at least one')
    iterator = iter(iterable)
    while batch := tuple(itertools.islice(iterator, n)):
        yield batch

In [3]:
def gen_pagerank_sparql_queries(g, num_iterations=10, damping_factor=0.8, threshold=0.4, max_entities=30):

    ents = set(g.subjects())
    ranks = page_rank(g, ents, num_iterations=num_iterations, damping_factor=damping_factor)

    values = list(sorted(ranks.items(), key=lambda x: x[1], reverse=True))

    _, bv = values[0]

    fv = list(filter(lambda x: x[1] / bv > threshold, values))

    queries = []

    for k, v in fv[:max_entities]:
        kv = g.value(k, RDF.type)
        if kv is None:
            continue
        if kv == OWL.Class:
            queries.append(f'SELECT DISTINCT ?x WHERE {{?x a <{k}>.}}')
        elif 'property' in kv.lower():
            queries.append(f'SELECT DISTINCT ?x ?y WHERE {{?x <{k}> ?y.}}')
        else:
            queries.append(f'SELECT DISTINCT ?x WHERE {{?x a <{k}>.}}')

    return queries

In [4]:
def embed(model, text, instruction='', batch_size=8, max_length=4096):
    with torch.no_grad():
        embeddings = []
        for t in tqdm(list(batched(text, batch_size))):
            embeddings.append(model.encode(t, instruction=instruction, max_length=max_length))
        return torch.cat(embeddings, dim=0)

def get_ont_docs(g, max_entities=15):
    ls = list(filter(lambda x: (x, RDF.first, None) not in g, set(g.subjects())))
    passages = []
    for s in ls:
        passages.append(gen_doc(s, g, max_entities=max_entities))
    
    return ls, passages

def gen_prompts(model_id, o1, o2, sample1, sample2, instruction='', top_n=2, i_max_depth=1, o_max_depth=2):
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )
    model = AutoModel.from_pretrained(
        model_id,
        quantization_config=quantization_config,
        trust_remote_code=True,
        device_map='auto',
    )
    model.eval()

    queries1 = gen_pagerank_sparql_queries(o1)
    
    ls1, o1_docs = get_ont_docs(o1)
    ls2, o2_docs = get_ont_docs(o2)
        
    embs1 = embed(model, o1_docs)
    embs2 = embed(model, o2_docs)
    
    queries_embs = embed(model, queries1, instruction=instruction)
    
    sim1 = queries_embs @ embs1.T
    sim2 = queries_embs @ embs2.T
    
    prompts = []
    
    for i in range(len(queries1)):
        
        module1 = reduce_ont(ls1, sim1[i, :].unsqueeze(0), o1, top_n=top_n, i_max_depth=i_max_depth, o_max_depth=o_max_depth)
        module2 = reduce_ont(ls2, sim2[i, :].unsqueeze(0), o2, top_n=top_n, i_max_depth=i_max_depth, o_max_depth=o_max_depth)
    
        prompts.append(gen_prompt(module1, module2, None, sample1, sample2))
    
    return prompts



In [5]:
base_path = '/projets/melodi/gsantoss/data/complex/conference/ont/'
o1 = Graph().parse(base_path + 'cmt.owl')
o2 = Graph().parse(base_path + 'conference.owl')

instruction = 'Given the following SPARQL query, retrieve relevant entities that are related to the query'

with open('cmatcher/prompt_examples/sample1.txt', 'r') as f:
    sample1 = f.read()

with open('cmatcher/prompt_examples/sample2.txt', 'r') as f:
    sample2 = f.read()



In [6]:
model_id = 'nvidia/NV-Embed-v2'
prompts = gen_prompts(model_id, o1, o2, sample1, sample2, instruction=instruction)

with open('/projets/melodi/gsantoss/tmp/prompts.pkl', 'wb') as f:
    dill.dump(prompts, f)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  'input_ids': torch.tensor(batch_dict.get('input_ids').to(batch_dict.get('input_ids')).long()),

KeyboardInterrupt



In [7]:

gc.collect()
torch.cuda.empty_cache()

In [9]:
with open('/projets/melodi/gsantoss/tmp/prompts.pkl', 'rb') as f:
    prompts = dill.load(f)

In [None]:
def match(txt, tokenizer, model):
    messages = [{"role": "user", "content": txt}, ]

    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_new_tokens=2 * 1024,
            eos_token_id=terminators,
            do_sample=False,
            temperature=0,
            top_p=0,

        )
    response = outputs[0][input_ids.shape[-1]:]
    return tokenizer.decode(response, skip_special_tokens=True)

In [None]:
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

llm_tokenizer = AutoTokenizer.from_pretrained(model_id)

llm_quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

llm_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map='auto',
    quantization_config=llm_quantization_config,

)
llm_model.eval()

In [11]:
outputs = []
for prompt in tqdm(prompts):
    outputs.append(match(prompt, llm_tokenizer, llm_model))
    
with open('/projets/melodi/gsantoss/tmp/outputs.pkl', 'wb') as f:
    dill.dump(outputs, f)

  0%|          | 0/9 [00:00<?, ?it/s]

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
The attention ma

In [10]:
with open('/projets/melodi/gsantoss/tmp/outputs.pkl', 'rb') as f:
    outputs = dill.load(f)

In [11]:

final_edoal = merge_edoals(outputs)

os.makedirs('/projets/melodi/gsantoss/tmp/cct1', exist_ok=True)
with open('/projets/melodi/gsantoss/tmp/cct1/final_edoal.edoal', 'w') as f:
    f.write(final_edoal)

In [12]:
print(final_edoal[:700])

<?xml version="1.0" encoding="UTF-8"?>
<rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#" xmlns="http://knowledgeweb.semanticweb.org/heterogeneity/alignment#" xmlns:alext="http://exmo.inrialpes.fr/align/ext/1.0/" xmlns:align="http://knowledgeweb.semanticweb.org/heterogeneity/alignment#" xmlns:edoal="http://ns.inria.org/edoal/1.0/#" xmlns:xsd="http://www.w3.org/2001/XMLSchema#">
  <Alignment>
    <xml>yes</xml>
    <level>2EDOAL</level>
    <type>**</type>
    <onto1>
      <Ontology rdf:about="http://cmt#name"><location>http://cmt#name</location><formalism><Formalism align:name="owl" align:uri="http://www.w3.org/TR/owl-guide/"/></formalism></Ontology>
    </onto1>
    <onto2>
 


In [13]:
#  -jar  $no1 $no2 $o1 $o2 $mo $cqa $out1
base_java = '/projets/melodi/gsantoss/canarde/jdk-21.0.1/bin/java'
base_eval = '/projets/melodi/gsantoss/canarde/evaluator.jar'
base_onts = '/projets/melodi/gsantoss/data/complex/conference_100/ont/'
base_cqas = '/projets/melodi/gsantoss/data/complex/conference_100/CQAs/'
base_al = '/projets/melodi/gsantoss/tmp/cct1'
base_out = '/projets/melodi/gsantoss/tmp/ccres'

os.makedirs(base_out, exist_ok=True)
with subprocess.Popen([base_java, '-jar', base_eval, 'cmt', 'conference', base_onts + 'cmt.owl', base_onts + 'conference.owl', base_al, base_cqas, base_out]) as proc:
    proc.communicate()


Evaluator


In [14]:
!cat /projets/melodi/gsantoss/tmp/ccres/cmt_conference.csv

final_edoal.edoal,CQAs,0.13793103448275862,0.20689655172413793,0.1724137931034483,0.20689655172413793,0.18733078232359773
classical,recall-oriented,precision-oriented,overlap,query f-measure
MEAN,CQAs,0.137931,0.206897,0.172414,0.206897,0.187331
