Готовим наборы датасетов для обучения и валидации моделей извлечения триплетов

In [1]:
from langchain_community.llms import LlamaCpp
from langchain.prompts import PromptTemplate
from kor.extraction import create_extraction_chain
from kor.nodes import Object, Text
import os
import json
import pandas as pd
from sklearn.model_selection import train_test_split

In [2]:
# Параметры скрипта

# DOMAIN = 'movie'
# DOMAIN = 'computer'
DOMAIN = 'nature'


if DOMAIN == 'movie':
    ONTOLOGY_PATH = 'data/onotology_schema/1_movie_ontology.json'
    SCHEMA_PATH = 'data/onotology_schema/movie_schema.json'
    GROUND_TRUTH_PATH = 'data/movie_ground_truth.jsonl'

elif DOMAIN == 'computer':
    ONTOLOGY_PATH = 'data/onotology_schema/6_computer_ontology.json'
    SCHEMA_PATH = 'data/onotology_schema/computer_schema.json'
    GROUND_TRUTH_PATH = 'data/computer_ground_truth.jsonl'

elif DOMAIN == 'nature':
    ONTOLOGY_PATH = 'data/onotology_schema/9_nature_ontology.json'
    SCHEMA_PATH = 'data/onotology_schema/nature_schema.json'
    GROUND_TRUTH_PATH = 'data/nature_ground_truth.jsonl'




# Подготовка датасетов для обучения и валидации

## Загрузка онтологии

In [3]:
with open(ONTOLOGY_PATH, 'r') as f:
    ontology_filejson = json.load(f)

In [4]:
ont_str = '\n'.join(ontology_filejson['payload']['blob']['rawLines'])
ont = json.loads(ont_str)

if DOMAIN == 'movie':
    # Неиспользуемая сущность в онтологии
    ont['concepts'] = [x for x in ont['concepts'] if x['qid']
                       not in ['Q201658', 'Q4220917', 'Q104649845']]

if DOMAIN == 'nature':
    # Неиспользуемая сущность в онтологии
    ont['concepts'] = [x for x in ont['concepts'] if x['qid']
                       not in ['Q82673', 'Q15091377']]

ont

{'title': 'Nature Ontology',
 'id': 'ont_9_nature',
 'concepts': [{'qid': 'Q618123', 'label': 'geographical feature'},
  {'qid': 'Q131681', 'label': 'reservoir'},
  {'qid': 'Q52105', 'label': 'Habitat'},
  {'qid': 'Q12323', 'label': 'dam'},
  {'qid': 'Q166620', 'label': 'drainage basin'},
  {'qid': 'Q46831', 'label': 'mountain range'},
  {'qid': 'Q8502', 'label': 'mountain'},
  {'qid': 'Q16521', 'label': 'taxon'},
  {'qid': 'Q355304', 'label': 'watercourse'},
  {'qid': 'Q15324', 'label': 'body of water'},
  {'qid': 'Q5', 'label': 'human'},
  {'qid': 'Q1040689', 'label': 'synonym'}],
 'relations': [{'pid': 'P171',
   'label': 'parent taxon',
   'domain': 'Q16521',
   'range': 'Q16521'},
  {'pid': 'P4552',
   'label': 'mountain range',
   'domain': 'Q8502',
   'range': 'Q46831'},
  {'pid': 'P3137',
   'label': 'parent peak',
   'domain': 'Q8502',
   'range': 'Q8502'},
  {'pid': 'P1843',
   'label': 'taxon common name',
   'domain': 'Q7432',
   'range': ''},
  {'pid': 'P974',
   'label': 

In [5]:
concepts_dict = {x['qid']: x['label'] for x in ont['concepts']}
relations_with_text = [{**relation, 'domain_label': concepts_dict.get(
    relation['domain']), 'range_label': concepts_dict.get(relation['range'])} for relation in ont['relations']]

# ручная коррекция пропущенного значения
if DOMAIN == 'computer':
    relations_with_text = [dict(x, range_label='company') if x['label']
                           == 'developer' else x for x in relations_with_text]

# удаляем None в obj и sub
relations_with_text = [x for x in relations_with_text if x['range_label'] is not None and x['domain_label'] is not None]

concepts_list = [x['label'] for x in ont['concepts']]
concepts = ', '.join(concepts_list)

relations_list = []
for i in relations_with_text:
    rel = f"sub: {i['domain_label']}, rel: {i['label']}, obj: {i['range_label']}\n"
    relations_list.append(rel)
relations_schema = "".join(relations_list)
print(relations_schema)

sub: taxon, rel: parent taxon, obj: taxon
sub: mountain, rel: mountain range, obj: mountain range
sub: mountain, rel: parent peak, obj: mountain
sub: watercourse, rel: tributary, obj: watercourse
sub: watercourse, rel: origin of the watercourse, obj: body of water
sub: watercourse, rel: mouth of the watercourse, obj: watercourse
sub: taxon, rel: taxon synonym, obj: synonym
sub: dam, rel: reservoir created, obj: reservoir
sub: watercourse, rel: drainage basin, obj: drainage basin
sub: taxon, rel: Habitat, obj: body of water



In [6]:
with open(SCHEMA_PATH, 'r') as f:
    sch = json.load(f)

schema = Object(
    id=sch['id'],
    description=sch['description'],
    attributes=[Text(id=x['id'], description=x['description'], examples=[tuple(x['examples'])]) 
                for x in sch['attributes']],
    examples = [tuple(x) for x in sch['examples']],
    many=True
)

## Подготовка генератора промптов первой модели

Первая модель - извлекает изтекста набор сущностей

LLM на этом этапе нужна для создания цепочки, из которой будем брать prompt

In [7]:
# нам не важна сейчас модель, нужно только загрузить
model_path = "../models/llama-2-7b.Q4_0.gguf"

In [8]:
# gpu на mac не работал - проблема была в версиях библиотек 
# https://llama-cpp-python.readthedocs.io/en/latest/install/macos/
# CMAKE_ARGS="-DLLAMA_METAL=on" FORCE_CMAKE=1 pip install --upgrade --force-reinstall llama-cpp-python==0.2.50 --no-cache-dir

llm = LlamaCpp(
    model_path=model_path,
    temperature=0,
    max_tokens=256,
    n_gpu_layers=40,
    n_batch=16,
    n_ctx=2048,
)

llama_model_loader: loaded meta data with 19 key-value pairs and 291 tensors from ../models/llama-2-7b.Q4_0.gguf (version GGUF V2)
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = LLaMA v2
llama_model_loader: - kv   2:                       llama.context_length u32              = 4096
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   4:                          llama.block_count u32              = 32
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 11008
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   7:                 llama.attention.head_count u32             

In [9]:
ner_chain = create_extraction_chain(
    llm, schema, encoder_or_encoder_class='json')


def prompt_for_model_1(input):
    return ner_chain.prompt.format_prompt(text=input).to_string()

## Подготовка генератора промптов второй модели

Вторая модель - определяет взаимоотношения извлеченных сущностей, на выходе имеет готовые триплеты

In [10]:
GET_RELATIONS_TEMPLATE = """You are a networked intelligence helping a human track knowledge triples.
You are given a list of NAMED ENTITIES and a LIST OF POSIBLE RELATIONS. You are to identify the relations between the named entities. 
Do not use any other named entities or relations. You MUST use only provided NAMED ENTITIES and LIST OF POSIBLE RELATIONS. 

LIST OF POSIBLE RELATIONS: {relations}

NAMED ENTITIES: {entities}

To indetidy the relations between named entities, you can use the following context: {context}

Format your response as a JSON. Do not add any attributes that do not appear in the schema shown below:

    {{
    "triplets": [
        {{
        "sub": "string",
        "rel": "string",
        "obj": "string"
        }}
    ]
    }}

Do NOT add any clarifying information. Output MUST follow the schema above. Do NOT add any additional columns that do not appear in the schema.
Answer:"""

In [11]:
relation_prompt = PromptTemplate.from_template(GET_RELATIONS_TEMPLATE)


def prompt_for_model_2(context, entities):
    return relation_prompt.invoke({"relations": relations_schema, "entities": entities, "context": context}).to_string()

## Загрузка ground_truth

In [12]:
with open(GROUND_TRUTH_PATH, 'r') as f:
    ground_truth_data = f.readlines()

ground_truth_list = [json.loads(x) for x in ground_truth_data]
df = pd.DataFrame(ground_truth_list)

## Извлечение таргета для первой модели

Тут используем тот факт, что тип связи однозначно определяет тип сущностей, и это дает возможность создать список сущностей, которые должны быть извлечены для каждого примера

In [13]:
rel_dict = {}
for r in relations_with_text:
    rel_dict[r['label']] = {'sub': r['domain_label'], 'obj': r['range_label']}

rel_dict

{'parent taxon': {'sub': 'taxon', 'obj': 'taxon'},
 'mountain range': {'sub': 'mountain', 'obj': 'mountain range'},
 'parent peak': {'sub': 'mountain', 'obj': 'mountain'},
 'tributary': {'sub': 'watercourse', 'obj': 'watercourse'},
 'origin of the watercourse': {'sub': 'watercourse', 'obj': 'body of water'},
 'mouth of the watercourse': {'sub': 'watercourse', 'obj': 'watercourse'},
 'taxon synonym': {'sub': 'taxon', 'obj': 'synonym'},
 'reservoir created': {'sub': 'dam', 'obj': 'reservoir'},
 'drainage basin': {'sub': 'watercourse', 'obj': 'drainage basin'},
 'Habitat': {'sub': 'taxon', 'obj': 'body of water'}}

In [14]:
def triples_to_answer_1(triples):
    data = {}

    for t in triples:
        if t['rel'] not in rel_dict.keys():
            continue

        sub_type = rel_dict[t['rel']]['sub'].replace(' ', '_')
        data[sub_type] = t['sub']

        obj_type = rel_dict[t['rel']]['obj'].replace(' ', '_')
        data[obj_type] = t['obj']

    data_str = json.dumps(data)
    return f'<json>{{"{DOMAIN}": [{str(data_str)}]}}</json>'


df['reference_for_model_1'] = df.triples.apply(triples_to_answer_1)
df.head()

Unnamed: 0,id,sent,triples,reference_for_model_1
0,ont_9_nature_test_1,The sacred kingfisher (Todiramphus sanctus) is...,"[{'sub': 'Sacred kingfisher', 'rel': 'parent t...","<json>{""nature"": [{""taxon"": ""Todiramphus""}]}</..."
1,ont_9_nature_test_2,Conference pear is an autumn cultivar (cultiva...,"[{'sub': 'Conference pear', 'rel': 'parent tax...","<json>{""nature"": [{""taxon"": ""Pyrus communis""}]..."
2,ont_9_nature_test_3,Morphini is a tribe of nymphalid butterflies i...,"[{'sub': 'Morphini', 'rel': 'parent taxon', 'o...","<json>{""nature"": [{""taxon"": ""Morphinae""}]}</json>"
3,ont_9_nature_test_4,Sesamum is a genus of about 20 species in the ...,"[{'sub': 'Sesamum', 'rel': 'parent taxon', 'ob...","<json>{""nature"": [{""taxon"": ""Pedaliaceae""}]}</..."
4,ont_9_nature_test_5,"The European nightjar (Caprimulgus europaeus),...","[{'sub': 'European nightjar', 'rel': 'parent t...","<json>{""nature"": [{""taxon"": ""Caprimulgus""}]}</..."


## Извлечение таргета для второй модели

Таргетом второй модели являются таргетные типлеты

In [15]:
df['reference_for_model_2'] = df.triples.apply(lambda x: json.dumps(x))
df.head()

Unnamed: 0,id,sent,triples,reference_for_model_1,reference_for_model_2
0,ont_9_nature_test_1,The sacred kingfisher (Todiramphus sanctus) is...,"[{'sub': 'Sacred kingfisher', 'rel': 'parent t...","<json>{""nature"": [{""taxon"": ""Todiramphus""}]}</...","[{""sub"": ""Sacred kingfisher"", ""rel"": ""parent t..."
1,ont_9_nature_test_2,Conference pear is an autumn cultivar (cultiva...,"[{'sub': 'Conference pear', 'rel': 'parent tax...","<json>{""nature"": [{""taxon"": ""Pyrus communis""}]...","[{""sub"": ""Conference pear"", ""rel"": ""parent tax..."
2,ont_9_nature_test_3,Morphini is a tribe of nymphalid butterflies i...,"[{'sub': 'Morphini', 'rel': 'parent taxon', 'o...","<json>{""nature"": [{""taxon"": ""Morphinae""}]}</json>","[{""sub"": ""Morphini"", ""rel"": ""parent taxon"", ""o..."
3,ont_9_nature_test_4,Sesamum is a genus of about 20 species in the ...,"[{'sub': 'Sesamum', 'rel': 'parent taxon', 'ob...","<json>{""nature"": [{""taxon"": ""Pedaliaceae""}]}</...","[{""sub"": ""Sesamum"", ""rel"": ""parent taxon"", ""ob..."
4,ont_9_nature_test_5,"The European nightjar (Caprimulgus europaeus),...","[{'sub': 'European nightjar', 'rel': 'parent t...","<json>{""nature"": [{""taxon"": ""Caprimulgus""}]}</...","[{""sub"": ""European nightjar"", ""rel"": ""parent t..."


## Создание промпта для первой модели

In [21]:
df['prompt_for_model_1'] = df.sent.apply(lambda x: prompt_for_model_1(x))
df.head()

Unnamed: 0,id,sent,triples,reference_for_model_1,reference_for_model_2,prompt_for_model_1
0,ont_9_nature_test_1,The sacred kingfisher (Todiramphus sanctus) is...,"[{'sub': 'Sacred kingfisher', 'rel': 'parent t...","<json>{""nature"": [{""taxon"": ""Todiramphus""}]}</...","[{""sub"": ""Sacred kingfisher"", ""rel"": ""parent t...",Your goal is to extract structured information...
1,ont_9_nature_test_2,Conference pear is an autumn cultivar (cultiva...,"[{'sub': 'Conference pear', 'rel': 'parent tax...","<json>{""nature"": [{""taxon"": ""Pyrus communis""}]...","[{""sub"": ""Conference pear"", ""rel"": ""parent tax...",Your goal is to extract structured information...
2,ont_9_nature_test_3,Morphini is a tribe of nymphalid butterflies i...,"[{'sub': 'Morphini', 'rel': 'parent taxon', 'o...","<json>{""nature"": [{""taxon"": ""Morphinae""}]}</json>","[{""sub"": ""Morphini"", ""rel"": ""parent taxon"", ""o...",Your goal is to extract structured information...
3,ont_9_nature_test_4,Sesamum is a genus of about 20 species in the ...,"[{'sub': 'Sesamum', 'rel': 'parent taxon', 'ob...","<json>{""nature"": [{""taxon"": ""Pedaliaceae""}]}</...","[{""sub"": ""Sesamum"", ""rel"": ""parent taxon"", ""ob...",Your goal is to extract structured information...
4,ont_9_nature_test_5,"The European nightjar (Caprimulgus europaeus),...","[{'sub': 'European nightjar', 'rel': 'parent t...","<json>{""nature"": [{""taxon"": ""Caprimulgus""}]}</...","[{""sub"": ""European nightjar"", ""rel"": ""parent t...",Your goal is to extract structured information...


## Создание промпта для второй модели

In [22]:
def get_prompt_2(row):
    entities = json.loads(row['reference_for_model_1'].replace(
        '<json>', '').replace('</json>', ''))[DOMAIN]
    context = row['sent']
    return prompt_for_model_2(context, entities)


df['prompt_for_model_2'] = df.apply(lambda row: get_prompt_2(row), axis=1)
df.head()

Unnamed: 0,id,sent,triples,reference_for_model_1,reference_for_model_2,prompt_for_model_1,prompt_for_model_2
0,ont_9_nature_test_1,The sacred kingfisher (Todiramphus sanctus) is...,"[{'sub': 'Sacred kingfisher', 'rel': 'parent t...","<json>{""nature"": [{""taxon"": ""Todiramphus""}]}</...","[{""sub"": ""Sacred kingfisher"", ""rel"": ""parent t...",Your goal is to extract structured information...,You are a networked intelligence helping a hum...
1,ont_9_nature_test_2,Conference pear is an autumn cultivar (cultiva...,"[{'sub': 'Conference pear', 'rel': 'parent tax...","<json>{""nature"": [{""taxon"": ""Pyrus communis""}]...","[{""sub"": ""Conference pear"", ""rel"": ""parent tax...",Your goal is to extract structured information...,You are a networked intelligence helping a hum...
2,ont_9_nature_test_3,Morphini is a tribe of nymphalid butterflies i...,"[{'sub': 'Morphini', 'rel': 'parent taxon', 'o...","<json>{""nature"": [{""taxon"": ""Morphinae""}]}</json>","[{""sub"": ""Morphini"", ""rel"": ""parent taxon"", ""o...",Your goal is to extract structured information...,You are a networked intelligence helping a hum...
3,ont_9_nature_test_4,Sesamum is a genus of about 20 species in the ...,"[{'sub': 'Sesamum', 'rel': 'parent taxon', 'ob...","<json>{""nature"": [{""taxon"": ""Pedaliaceae""}]}</...","[{""sub"": ""Sesamum"", ""rel"": ""parent taxon"", ""ob...",Your goal is to extract structured information...,You are a networked intelligence helping a hum...
4,ont_9_nature_test_5,"The European nightjar (Caprimulgus europaeus),...","[{'sub': 'European nightjar', 'rel': 'parent t...","<json>{""nature"": [{""taxon"": ""Caprimulgus""}]}</...","[{""sub"": ""European nightjar"", ""rel"": ""parent t...",Your goal is to extract structured information...,You are a networked intelligence helping a hum...


## Разделение на train-test

Разделяеем в соотношении 0.66 и 0.33, чтобы было больше данных для качестенной проверки

In [23]:
Xy_train_1, Xy_test_1 = \
    train_test_split(
        df[['id', 'sent', 'triples', 'prompt_for_model_1', 'reference_for_model_1']], test_size=0.33, random_state=42)

In [24]:
Xy_train_2, Xy_test_2 = \
    train_test_split(
        df[['id', 'sent', 'triples', 'prompt_for_model_2', 'reference_for_model_2']], test_size=0.33, random_state=42)

## Сохранение датасетов для обучения

In [25]:
if not os.path.exists('artifacts'):
    os.makedirs('artifacts')

base_path = os.path.join('artifacts', DOMAIN)
if not os.path.exists(base_path):
    os.makedirs(base_path)

model_path_1 = os.path.join(base_path, 'model_1')
if not os.path.exists(model_path_1):
    os.makedirs(model_path_1)

model_path_2 = os.path.join(base_path, 'model_2')
if not os.path.exists(model_path_2):
    os.makedirs(model_path_2)

In [26]:
def save_df(df: pd.DataFrame, df_name: str):
    df = df.copy()
    df.columns = ['id', 'sent', 'ref_triples', 'prompt', 'reference']
    df.to_csv(df_name, index=True)


save_df(Xy_train_1, os.path.join(model_path_1, 'Xy_train.csv'))
save_df(Xy_test_1, os.path.join(model_path_1, 'Xy_test.csv'))

save_df(Xy_train_2, os.path.join(model_path_2, 'Xy_train.csv'))
save_df(Xy_test_2, os.path.join(model_path_2, 'Xy_test.csv'))

In [27]:
df = pd.read_csv('artifacts/movie/model_1/Xy_test.csv', index_col='Unnamed: 0')
df.head()

Unnamed: 0,id,sent,ref_triples,prompt,reference
695,ont_1_movie_test_696,The Dead Pool is the fifth and final film in t...,"[{'sub': 'The Dead Pool', 'rel': 'filming loca...",Your goal is to extract structured information...,"<json>{""computer"": [{""film"": ""The Dead Pool"", ..."
816,ont_1_movie_test_817,The short blends traditional and computer anim...,"[{'sub': 'Paperman', 'rel': 'genre', 'obj': 'C...",Your goal is to extract structured information...,"<json>{""computer"": [{""film"": ""Paperman"", ""genr..."
30,ont_1_movie_test_31,Cannery Rodent is a 1967 Tom and Jerry cartoon...,"[{'sub': 'Cannery Rodent', 'rel': 'director', ...",Your goal is to extract structured information...,"<json>{""computer"": [{""film"": ""Cannery Rodent"",..."
599,ont_1_movie_test_600,"The movie was directed by Robert Stevenson, pr...","[{'sub': 'Herbie Rides Again', 'rel': 'cast me...",Your goal is to extract structured information...,"<json>{""computer"": [{""film"": ""Herbie Rides Aga..."
96,ont_1_movie_test_97,The Disappearance of Haruhi Suzumiya is produc...,[{'sub': 'The Disappearance of Haruhi Suzumiya...,Your goal is to extract structured information...,"<json>{""computer"": [{""film"": ""The Disappearanc..."
