In [15]:
from rdflib import Graph
from cmatcher.module_search.pagerank import gen_pagerank_sparql_queries
from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig, AutoModelForCausalLM
import torch
from rdflib.namespace import RDF
from cmatcher.rag.rag_reduce import ont_query_reduce, reduce_ont, get_detailed_instruct, gen_doc
from cmatcher.rag.prompt_gen import gen_prompt
from cmatcher.rag.prompt_to_edoal import match
from tqdm.auto import tqdm
import gc
import dill
import subprocess
import os
import re
import random
import itertools
import torch.nn.functional as F
import difflib
# define deterministic behavior
random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)

In [3]:
def batched(iterable, n):
    # batched('ABCDEFG', 3) → ABC DEF G
    if n < 1:
        raise ValueError('n must be at least one')
    iterator = iter(iterable)
    while batch := tuple(itertools.islice(iterator, n)):
        yield batch

In [4]:
def rag(model, tokenizer, query, prompt, g, max_entities=15, max_length=4096, batch_size=2):
    query_prefix = f"Instruct: {prompt}\nQuery: "
    queries = [
        query,
    ]

    ls = list(filter(lambda x: (x, RDF.first, None) not in g, set(g.subjects())))
    ls.sort(key=lambda x: str(x))
    passages = []
    for s in ls:
        passages.append(gen_doc(s, g, max_entities=max_entities))
        
    with torch.no_grad():
        query_embeddings = model.encode(queries, instruction=query_prefix, max_length=max_length)
        
        passage_embeddings = []
        
        for p in batched(passages, batch_size):
            passage_embeddings.append(model.encode(p, instruction='', max_length=max_length))
        
        passage_embeddings = torch.cat(passage_embeddings, dim=0)

    return ls, query_embeddings @ passage_embeddings.T



def ont_query_reduce(model, tokenizer, g, query, prompt, top_n=2, i_max_depth=1, o_max_depth=2, max_entities=15,
                     max_length=4096, batch_size=2):
    ls1, scores1 = rag(model, tokenizer, query, prompt, g, max_entities=max_entities, max_length=max_length,
                       batch_size=batch_size)
    return reduce_ont(ls1, scores1, g, top_n=top_n, i_max_depth=i_max_depth, o_max_depth=o_max_depth)

In [5]:
def gen_prompts():
    tokenizer = AutoTokenizer.from_pretrained('nvidia/NV-Embed-v2')
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )
    model = AutoModel.from_pretrained(
        'nvidia/NV-Embed-v2',
        quantization_config=quantization_config,
        trust_remote_code=True,
        device_map='auto',
    )
    model.eval()
    
    base_path = '/projets/melodi/gsantoss/data/complex/conference/ont/'
    o1 = Graph().parse(base_path + 'cmt.owl')
    o2 = Graph().parse(base_path + 'conference.owl')
    
    queries1 = gen_pagerank_sparql_queries(o1)
    queries1.sort(key=lambda x: len(x))


    prompt = 'Given the following SPARQL query, retrieve relevant entities that are related to the query'
    with open('cmatcher/prompt_examples/sample1.txt', 'r') as f:
        sample1 = f.read()
        
    with open('cmatcher/prompt_examples/sample2.txt', 'r') as f:
        sample2 = f.read()
    
    
    prompts = []
    for query in tqdm(queries1):
    
        module1 = ont_query_reduce(model, tokenizer, o1, query, prompt, max_entities=10, batch_size=2)
        module2 = ont_query_reduce(model, tokenizer, o2, query, prompt, max_entities=10, batch_size=2)
    
        prompts.append(gen_prompt(module1, module2, None, sample1, sample2))
    
    return prompts

In [6]:
prompts = gen_prompts()

with open('/projets/melodi/gsantoss/tmp/prompts.pkl', 'wb') as f:
    dill.dump(prompts, f)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  'input_ids': torch.tensor(batch_dict.get('input_ids').to(batch_dict.get('input_ids')).long()),


In [17]:
for i in range(len(prompts)):
    if prompts[i] != old_prompts[i]:
        print(i)
        # compute the difference
        diff = list(difflib.ndiff(prompts[i].splitlines(), old_prompts[i].splitlines()))
        print('\n'.join([l for l in diff if l.startswith('+ ') or l.startswith('- ')]))

3
+     rdfs:subClassOf [ ],
-     rdfs:subClassOf [ owl:onProperty <http://cmt#hasDecision> ],
+         [ owl:onProperty <http://cmt#hasDecision> ],
-         [ ],
6
+             owl:onProperty <http://conference#belongs_to_reviewers> ;
+             owl:someValuesFrom <http://conference#Reviewer> ],
+         [ a owl:Restriction ;
-         [ a owl:Restriction ;
-             owl:onProperty <http://conference#belongs_to_reviewers> ;
-             owl:someValuesFrom <http://conference#Reviewer> ],
7
+             owl:minCardinality "1"^^xsd:int ;
+             owl:onProperty <http://cmt#readByReviewer> ],
+         [ a owl:Restriction ;
+             owl:onProperty <http://cmt#hasDecision> ],
+         [ a owl:Restriction ;
+             owl:minCardinality "0"^^xsd:int ;
-         [ a owl:Restriction ;
-             owl:minCardinality "0"^^xsd:int ;
-             owl:onProperty <http://cmt#hasDecision> ],
-         [ a owl:Restriction ;
-             owl:minCardinality "1"^^xsd:int 

In [6]:
queries1 = None
model = None
tokenizer = None
# prompts = None
prompt = None
sample1 = None
sample2 = None
module1 = None
module2 = None
quantization_config = None
o1 = None
o2 = None
uobj = gc.collect()
torch.cuda.empty_cache()

In [7]:
with open('/projets/melodi/gsantoss/tmp/prompts.pkl', 'rb') as f:
    prompts = dill.load(f)

In [8]:
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
# model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

llm_tokenizer = AutoTokenizer.from_pretrained(model_id)
llm_tokenizer.eos_token = llm_tokenizer.eos_token if llm_tokenizer.eos_token is not None else llm_tokenizer.pad_token
llm_quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

llm_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map='auto',
    quantization_config=llm_quantization_config,

)
llm_model.eval()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): Ll

In [9]:
outputs = []
for prompt in tqdm(prompts):
    outputs.append(match(prompt, llm_tokenizer, llm_model))
    
with open('/projets/melodi/gsantoss/tmp/outputs.pkl', 'wb') as f:
    dill.dump(outputs, f)

  0%|          | 0/9 [00:00<?, ?it/s]

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
The attention ma

In [10]:
# with open('/projets/melodi/gsantoss/tmp/outputs.pkl', 'rb') as f:
#     outputs = dill.load(f)

In [11]:
def is_valid_edoal(txt):
    return txt.endswith('</rdf:RDF>')


def can_repair(txt):
    return txt.rfind('<map>') > 0

def merge_edoals(outputs):
    repaired_edoals = []
    for output in outputs:
        
        if not output.startswith('<?xml version'):
            output = '''<?xml version='1.0' encoding='utf-8' standalone='no'?>
<rdf:RDF xmlns='http://knowledgeweb.semanticweb.org/heterogeneity/alignment#'
         xmlns:rdf='http://www.w3.org/1999/02/22-rdf-syntax-ns#'
         xmlns:xsd='http://www.w3.org/2001/XMLSchema#'
         xmlns:alext='http://exmo.inrialpes.fr/align/ext/1.0/'
         xmlns:align='http://knowledgeweb.semanticweb.org/heterogeneity/alignment#'
         xmlns:edoal='http://ns.inria.org/edoal/1.0/#'>\n''' + output
        
        output = re.sub(r'<Ontology rdf:about="([^"]+)" />', r'<Ontology rdf:about="\1"><location>\1</location><formalism><Formalism align:name="owl" align:uri="http://www.w3.org/TR/owl-guide/"/></formalism></Ontology>', output)
        if not is_valid_edoal(output) and can_repair(output):
            last_map_index = output.rfind('<map>')
            repaired_edoals.append(output[:last_map_index] + '\n\t</Alignment>\n</rdf:RDF>')
        else:
            repaired_edoals.append(output)
    
    final_edoal = None
    if len(repaired_edoals) > 1:
        final_edoal = ''
        first = repaired_edoals[0]
        final_edoal += first[:first.find('<map>')]
        for e in repaired_edoals[1:]:
            final_edoal += e[e.find('<map>'):e.rfind('</map>')] + '\n\t</map>'
            
        final_edoal += '\n\t</Alignment>\n</rdf:RDF>'
        
    elif len(repaired_edoals) == 1:
        final_edoal = repaired_edoals[0]
    
    
    return final_edoal

final_edoal = merge_edoals(outputs)

os.makedirs('/projets/melodi/gsantoss/tmp/cct1', exist_ok=True)
with open('/projets/melodi/gsantoss/tmp/cct1/final_edoal.edoal', 'w') as f:
    f.write(final_edoal)

In [12]:
print(final_edoal[:700])

<?xml version="1.0" encoding="UTF-8"?>
<rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#" xmlns="http://knowledgeweb.semanticweb.org/heterogeneity/alignment#" xmlns:alext="http://exmo.inrialpes.fr/align/ext/1.0/" xmlns:align="http://knowledgeweb.semanticweb.org/heterogeneity/alignment#" xmlns:edoal="http://ns.inria.org/edoal/1.0/#" xmlns:xsd="http://www.w3.org/2001/XMLSchema#">
  <Alignment>
    <xml>yes</xml>
    <level>2EDOAL</level>
    <type>**</type>
    <onto1>
      <Ontology rdf:about="http://cmt#ConferenceMember"><location>http://cmt#ConferenceMember</location><formalism><Formalism align:name="owl" align:uri="http://www.w3.org/TR/owl-guide/"/></formalism></Ontology>
  


In [13]:
#  -jar  $no1 $no2 $o1 $o2 $mo $cqa $out1
base_java = '/projets/melodi/gsantoss/canarde/jdk-21.0.1/bin/java'
base_eval = '/projets/melodi/gsantoss/canarde/evaluator.jar'
base_onts = '/projets/melodi/gsantoss/data/complex/conference_100/ont/'
base_cqas = '/projets/melodi/gsantoss/data/complex/conference_100/CQAs/'
base_al = '/projets/melodi/gsantoss/tmp/cct1'
base_out = '/projets/melodi/gsantoss/tmp/ccres'

os.makedirs(base_out, exist_ok=True)
with subprocess.Popen([base_java, '-jar', base_eval, 'cmt', 'conference', base_onts + 'cmt.owl', base_onts + 'conference.owl', base_al, base_cqas, base_out]) as proc:
    proc.communicate()


Evaluator


In [14]:
!cat /projets/melodi/gsantoss/tmp/ccres/cmt_conference.csv

final_edoal.edoal,CQAs,0.1724137931034483,0.2413793103448276,0.20689655172413793,0.27586206896551724,0.23453107442035934
classical,recall-oriented,precision-oriented,overlap,query f-measure
MEAN,CQAs,0.172414,0.241379,0.206897,0.275862,0.234531


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
