In [153]:
import os
import re
import csv
import nltk
import ollama
import numpy as np
nltk.download('punkt')
from nltk.tokenize import sent_tokenize

[nltk_data] Downloading package punkt to /Users/jamino/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [154]:
# convert text to sentences

raw_text = """\
Alan Shepard was born on Nov 18, 1923 and selected by NASA in 1959. Alan Shepard was a member of the Apollo 14 crew. \
"""

sentences = sent_tokenize(raw_text)
sentences

['Alan Shepard was born on Nov 18, 1923 and selected by NASA in 1959.',
 'Alan Shepard was a member of the Apollo 14 crew.']

In [155]:
# use sentences[0] as main text for now

text = sentences[0]
text

'Alan Shepard was born on Nov 18, 1923 and selected by NASA in 1959.'

In [156]:
def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery: {query}'

def get_embedding(model: str, sentence: str, task=None):
    if task != None:
        sentence = get_detailed_instruct(task, sentence)
        
    embeddings = ollama.embeddings(
        model=model,
        prompt=sentence
    )
    return embeddings

def parse_raw_triplets(raw_triplets: str):
    matches = re.findall(r'\[(.+?)\]', raw_triplets)
    structured_triplets = [[triplet.strip() for triplet in group.split('|')] for group in matches]
    
    print('----------')
    print('RAW:')
    print(raw_triplets)
    print('----------')
    return structured_triplets

def parse_relation_definition(raw_definitions: str):
    descriptions = raw_definitions.split('\n')
    relation_definitions_dict = {}

    for description in descriptions:
        if ':' not in description:
            continue
        index_of_colon = description.index(':')
        relation = description[:index_of_colon].strip()

        relation_description = description[:index_of_colon + 1 :].strip()

        relation_definitions_dict[relation] = relation_description
    
    print('----------')
    print('RAW:')
    print(raw_definitions)
    print('----------')
    return raw_definitions

In [157]:
class TriplesExtractor:
    def __init__(self, model: str = None) -> None:
        assert model is not None 
        self.model = model

    def extract(
        self,
        input_text_str: str,
        prompt_template_str: str,
        few_shot_examples_str: str = None,
    ) -> list[list[str]]:
        if not few_shot_examples_str:
            filled_prompt = prompt_template_str.format_map(
                {
                    'input_text': input_text_str,    
                }
            )
        else:
            filled_prompt = prompt_template_str.format_map(
                {
                    'few_shot_examples': few_shot_examples_str,
                    'input_text': input_text_str,
                }
            )
        messages = [{'role': 'user', 'content': filled_prompt}]
        completion = ollama.chat(
            model=self.model,
            messages=messages,
        )['message']['content']
        extracted_triplets_list = parse_raw_triplets(completion)
        return extracted_triplets_list

In [158]:
extractor = TriplesExtractor(model='llama3.1')
triples = extractor.extract(
    input_text_str=text,
    prompt_template_str=open('prompt_templates/oie_zsp_template.txt').read(),
)
triples

----------
RAW:
[Alan Shepard | wasBornOn | November 18, 1923]
[Alan Shepard | wasSelectedBy | NASA]
[Alan Shepard | isEqualTo | Astronaut]
----------


[['Alan Shepard', 'wasBornOn', 'November 18, 1923'],
 ['Alan Shepard', 'wasSelectedBy', 'NASA'],
 ['Alan Shepard', 'isEqualTo', 'Astronaut']]

In [159]:
class SchemaDefiner:
    def __init__(self, model: str = None) -> None:
        assert model is not None
        self.model = model

    def define_schema(
            self,
            input_text_str: str,
            extracted_triplets_list: list[list[str]],
            prompt_template_str: str,
            few_shot_examples_str: str = None,
    ) -> list[list[str]]:
        
        relations_present = set()
        for t in extracted_triplets_list:
            relations_present.add(t[1])
        
        if not few_shot_examples_str:
            filled_prompt = prompt_template_str.format_map(
                {
                    'text': input_text_str,
                    'relations': relations_present,
                    'triples': extracted_triplets_list,
                }
            )
        else:
            filled_prompt = prompt_template_str.format_map(
                {
                    'few_shot_examples': few_shot_examples_str,
                    'text': input_text_str,
                    'relations': relations_present,
                    'triples': extracted_triplets_list,
                }
            )

        messages = [{'role': 'user', 'content': filled_prompt}]
        completion = ollama.chat(
            model=self.model,
            messages=messages,
        )['message']['content']
        relation_definition_dict = parse_relation_definition(completion)
        return relation_definition_dict

In [160]:
definer = SchemaDefiner(model='llama3.1')
definitions = definer.define_schema(
    input_text_str=text,
    extracted_triplets_list=triples,
    prompt_template_str=open('prompt_templates/sd_zsp_template.txt').read(),
)
print(definitions)

----------
RAW:
Here are the descriptions for each relation:

* `wasBornOn`: Indicates that an entity had their birth on a specific date.
* `wasSelectedBy`: Indicates that one entity selected another entity for some purpose or position.
* `isEqualTo`: Indicates that two entities have the same value or identity.
----------
Here are the descriptions for each relation:

* `wasBornOn`: Indicates that an entity had their birth on a specific date.
* `wasSelectedBy`: Indicates that one entity selected another entity for some purpose or position.
* `isEqualTo`: Indicates that two entities have the same value or identity.


In [173]:
class SchemaRetriever:
    def __init__(self, target_schema_dict: dict, embedding_model) -> None:
        self.target_schema_dict = target_schema_dict
        self.embedding_model = embedding_model

        self.target_schema_embedding_dict = {}

        for relation, relation_definition in target_schema_dict.items():
            embedding = get_embedding(
                model=self.embedding_model,
                sentence=relation_definition
            )
            self.target_schema_embedding_dict[relation] = embedding

    def update_schema_embedding_dict(self):
        for relation, relation_definition in self.target_schema_dict.items():
            if relation in self.target_schema_embedding_dict:
                continue
            embedding = get_embedding(
                model=self.embedding_model,
                sentence=relation_definition
            )
            self.target_schema_embedding_dict[relation] = embedding

    def retrieve_relevant_relations(self, query_input_text: str, top_k=10):
        target_relation_list = list(self.target_schema_embedding_dict.keys())
        target_relation_embedding_list = list(self.target_schema_embedding_dict.values())

        query_embedding = get_embedding(
            model=self.embedding_model,
            sentence=query_input_text,
            task='Retrieve descriptions of relations that are present in the given text.',
        )['embedding']

        # DELETE
        print('SHAPE1', np.array([query_embedding]).shape)
        print('SHAPE2', np.array(target_relation_embedding_list).T.shape)

        scores = np.array([query_embedding]) @ np.array(target_relation_embedding_list).T

        scores = scores[0]
        highest_scores_indices = np.argsort(-scores)

        return [target_relation_list[idx] for idx in highest_scores_indices[:top_k]]

In [174]:
target_schema_dict = {}
with open('schemas/webnlg_schema.csv') as csvfile:
    reader = csv.reader(csvfile)
    for row in reader:
        target_schema_dict[row[0]] = row[1]

retriever = SchemaRetriever(
    target_schema_dict=target_schema_dict,
    embedding_model='llama3.1',
)

retriever.retrieve_relevant_relations(
    query_input_text='This indicates a specific date related to when an individual was born.',
    top_k=5
)

SHAPE1 (1, 4096)
SHAPE2 (159,)


ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 159 is different from 4096)