In [123]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [125]:
import os
import json

def filter_skip_relations(all_relations, allowed_relations):
    """Return a list of relations not in allowed_relations."""
    return [r for r in all_relations if r not in allowed_relations]


def transform_data_to_llm_format(data, start_idx=0, skip_relations=[]):
    """Transform the given data into LLM ready dialogue format."""
    new_data = []
    identity_counter = start_idx

    for conversation, triples in data:
        triples_text = [
            {
                "x": triple["x"],
                "r": triple["r"][0].split(':')[-1],
                "y": triple["y"]
            }
            for triple in triples
            if triple["r"] and triple["r"][0].split(':')[-1] not in skip_relations
        ]

        conversation_entry = {
            "id": f"identity_{identity_counter}",
            "conversations": [
                {"from": "human", "value": "\n".join(conversation)},
                {"from": "gpt", "value": str(json.dumps(triples_text))}
            ]
        }

        identity_counter += 1

        if triples_text:
            new_data.append(conversation_entry)

    return new_data


def create_directory_if_not_exists(directory_path):
    """Create a directory if it does not already exist."""
    if not os.path.exists(directory_path):
        os.makedirs(directory_path)


def process_and_save_data(file_sets, input_dir, output_dir, skip_relations):
    """Process the input data and save in the specified format."""
    for files in file_sets:
        dataset_name = f"dialog-re-llama-{len(allowed_relations)}cls-{'-'.join(files)}"
        last_data_idx = 0
        new_format = []

        create_directory_if_not_exists(output_dir)

        for f in files:
            input_data_path = os.path.join(input_dir, f'{f}.json')
            
            # Check if file exists
            if os.path.exists(input_data_path):
                with open(input_data_path, encoding='utf8') as fp:
                    data = json.load(fp)

                new_format.extend(transform_data_to_llm_format(data, start_idx=last_data_idx, skip_relations=skip_relations))
                last_data_idx = len(new_format)

        output_data_path = os.path.join(output_dir, f'{dataset_name}.json')
        with open(output_data_path, 'w', encoding='utf8') as fp:
            json.dump(new_format, fp)

        print(files, len(new_format))


In [126]:
if __name__ == '__main__':
    # Relations
    all_relations = {
        "positive_impression", "negative_impression", "acquaintance", 
        "alumni", "boss", "subordinate", "client", "dates", "friends", 
        "girl/boyfriend", "neighbor", "roommate", "children", "other_family", 
        "parents", "siblings", "spouse", "place_of_residence", "visited_place", 
        "origin", "employee_or_member_of", "schools_attended", "works", "age", 
        "date_of_birth", "major", "place_of_work", "title", "alternate_names", 
        "pet", "residents_of_place", "visitors_of_place", "employees_or_members", 
        "students", "unanswerable"
    }
    allowed_relations = {"acquaintance", "children", "other_family", "parents", 
                         "siblings", "spouse", "place_of_residence", "visited_place", 
                         "pet", "residents_of_place", "visitors_of_place"}

    # allowed_relations = all_relations  # uncomment to allow all relations!

    skip_relations = filter_skip_relations(all_relations, allowed_relations)

    # Directories
    INPUT_DIR = "/home/murilo/RelNetCare/data/raw/dialog-re"
    OUTPUT_DIR = "/home/murilo/RelNetCare/data/processed/dialog-re-llama"
    FILE_SETS = [['train', 'dev'], ['test']]

    process_and_save_data(FILE_SETS, INPUT_DIR, OUTPUT_DIR, skip_relations)


['train', 'dev'] 1431
['test'] 357
