## Prepare training, validation and test datasets

We need to prepare the training, validation and external test datasets. We will use the training dataset to train the model and the test dataset to evaluate the model for all KGE models.

### [Require to Modify According to Your Situation] Prepare all relation files

In [None]:
import os

root_dir = os.path.dirname(os.getcwd())

# dataset_name = "biomedgps-full-v20240127"
dataset_name = "rapex-v20240627"
skip_rows_not_in_entity_file = True

outputdir = os.path.join(root_dir, "datasets", dataset_name)
graph_data_dir = os.path.join(root_dir, "graph_data")

formatted_ctd = os.path.join(
    graph_data_dir, "formatted_relations", "ctd", "formatted_ctd.tsv"
)
# formatted_unformatted_drkg = os.path.join(
#     graph_data_dir, "formatted_relations", "drkg", "unformatted_drkg.tsv"
# )
formatted_drkg = os.path.join(
    graph_data_dir, "formatted_relations", "drkg", "formatted_drkg.tsv"
)
formatted_hsdn = os.path.join(
    graph_data_dir, "formatted_relations", "hsdn", "formatted_hsdn.tsv"
)
formatted_primekg = os.path.join(
    graph_data_dir, "formatted_relations", "primekg", "formatted_primekg.tsv"
)
formatted_malacards_mecfs = os.path.join(
    graph_data_dir, "relations", "customdb", "formatted_malacards_mecfs.tsv"
)
formatted_custom = os.path.join(
    graph_data_dir, "relations", "customdb", "formatted_customdb_v20240329.tsv"
)
formatted_treatme_compound = os.path.join(
    graph_data_dir, "relations", "customdb", "formatted_treatme_survey_compounds.tsv"
)
formatted_treatme_symptom = os.path.join(
    graph_data_dir, "relations", "customdb", "formatted_treatme_survey_symptoms.tsv"
)

files = [
    formatted_ctd,
    # formatted_unformatted_drkg,
    formatted_drkg,
    formatted_hsdn,
    formatted_primekg,
    formatted_malacards_mecfs,
    formatted_custom,
    formatted_treatme_compound,
    formatted_treatme_symptom,
]

print("Merging the following files:")
print("\n".join(files))

entity_file = os.path.join(graph_data_dir, "entities.tsv")
print("Number of entities: {}".format(len(open(entity_file).readlines())))

### Dependencies

In [None]:
import os
import sys

lib_dir = os.path.join(os.path.dirname(os.getcwd()), "lib")

print("Adding {} to sys.path".format(lib_dir))
sys.path.append(lib_dir)

### Merge all relation files into one file

In [None]:
import os
import subprocess
import pandas as pd
import tempfile

temp_dir = tempfile.mkdtemp()

args = ["python3", os.path.join(lib_dir, "data.py"), "merge-files"]

for f in files:
    args.extend(["--input", f])

kg_file = os.path.join(temp_dir, "knowledge_graph.tsv")
annotated_kg_file = os.path.join(temp_dir, "annotated_knowledge_graph.tsv")
args.extend(["--output", kg_file])

print("Running: {}".format(" ".join(args)))
args_str = " ".join(args)
!{args_str}

if os.path.exists(kg_file):
    df = pd.read_csv(kg_file, sep="\t")
    source_ids = df[["source_id", "source_type"]].drop_duplicates()
    target_ids = df[["target_id", "target_type"]].drop_duplicates()
    ids = pd.concat([source_ids, target_ids]).drop_duplicates()
    print("Number of unique entity ids: {}".format(len(ids)))
    print("Number of relations: {}".format(len(df.drop_duplicates())))

# Annotate the knowledge graph with the entities
args = [
    "python3",
    os.path.join(os.path.dirname(lib_dir), "graph_data", "scripts", "annotate_relations.py"),
    "--entity-file",
    entity_file,
    "--relation-file",
    kg_file,
    "--output-dir",
    os.path.dirname(kg_file),
    "--strict-mode" if skip_rows_not_in_entity_file else "",
]

print("Running: {}".format(" ".join(args)))
args_str = " ".join(args)
!{args_str}
print("File written to: {}".format(annotated_kg_file))

### Split the merged relation file into training, validation and test files

In [None]:
train_validation_file = os.path.join(temp_dir, "train_validation.tsv")
train_file = os.path.join(temp_dir, "train.tsv")
test_file = os.path.join(temp_dir, "test.tsv")
valid_file = os.path.join(temp_dir, "valid.tsv")

split_cmd = [
    "python3",
    os.path.join(lib_dir, "data.py"),
    "split",
    "--input",
    kg_file,
    "--output-1",
    train_validation_file,
    "--output-2",
    test_file,
    "--ratio",
    "0.95",
]

print("Running: {}".format(" ".join(split_cmd)))
split_cmd_str = " ".join(split_cmd)
!{split_cmd_str}
print(f"Split files created: {train_validation_file} and {test_file}.")

split_cmd = [
    "python3",
    os.path.join(lib_dir, "data.py"),
    "split",
    "--input",
    train_validation_file,
    "--output-1",
    train_file,
    "--output-2",
    valid_file,
    "--ratio",
    "0.95",
]

print("Running: {}".format(" ".join(split_cmd)))
split_cmd_str = " ".join(split_cmd)
!{split_cmd_str}
print("Split files created: {} and {}.".format(train_file, valid_file))

### Check whether number of ids in train, validation, and test are the same.
If you see the following message, you need to run the section "Keep the same number of ids in train, validation, and test" in the notebook.

```
ValueError: You need to keep the entity ids and relation types in the test and validation files the same as the ones in the train file.
```

In [None]:
id_checked_file = os.path.join(temp_dir, "id_checked.tsv")
check_ids_cmd = [
    "python3",
    os.path.join(lib_dir, "data.py"),
    "check-ids",
    "--input",
    train_file,
    "--input",
    valid_file,
    "--input",
    test_file,
    "--output",
    id_checked_file,
]

print("Running: {}".format(" ".join(check_ids_cmd)))
check_ids_cmd_str = " ".join(check_ids_cmd)
# Catch the error
check_ids_cmd_str += " || true"
!{check_ids_cmd_str}
print("Checked files created: {}.".format(id_checked_file))

### Keep the entity id and relation type same among validation, test and training datasets

In [None]:
keep_valid_dir = os.path.join(temp_dir, "keep_valid")
os.makedirs(keep_valid_dir, exist_ok=True)

keep_valid_cmd = [
    "python3",
    os.path.join(lib_dir, "data.py"),
    "keep-valid",
    "--train-file",
    train_file,
    "--valid-file",
    valid_file,
    "--test-file",
    test_file,
    "--output-dir",
    keep_valid_dir,
]

print("Running: {}".format(" ".join(keep_valid_cmd)))
keep_valid_cmd_str = " ".join(keep_valid_cmd)
!{keep_valid_cmd_str}
print("Files created: {}.".format(os.listdir(keep_valid_dir)))

### [Again] Check whether number of ids in train, validation, and test are the same.

In [None]:
id_checked_file = os.path.join(keep_valid_dir, "id_checked.tsv")
train_valid_file = os.path.join(keep_valid_dir, "train_valid.tsv")
valid_valid_file = os.path.join(keep_valid_dir, "valid_valid.tsv")
test_valid_file = os.path.join(keep_valid_dir, "test_valid.tsv")

check_ids_cmd = [
    "python3",
    os.path.join(lib_dir, "data.py"),
    "check-ids",
    "--input",
    train_valid_file,
    "--input",
    valid_valid_file,
    "--input",
    test_valid_file,
    "--output",
    id_checked_file,
]

print("Running: {}".format(" ".join(check_ids_cmd)))
check_ids_cmd_str = " ".join(check_ids_cmd)
!{check_ids_cmd_str}
print("Checked files created: {}.".format(id_checked_file))

In [None]:
hrt_dir = os.path.join(temp_dir, "hrt")
os.makedirs(hrt_dir, exist_ok=True)

files = [
    (os.path.join(keep_valid_dir, "train_valid.tsv"), os.path.join(hrt_dir, "train.tsv")),
    (os.path.join(keep_valid_dir, "valid_valid.tsv"), os.path.join(hrt_dir, "valid.tsv")),
    (os.path.join(keep_valid_dir, "test_valid.tsv"), os.path.join(hrt_dir, "test.tsv")),
]

for input_file, output_file in files:
    hrt_cmd = [
        "python3",
        os.path.join(lib_dir, "data.py"),
        "hrt",
        "--input",
        input_file,
        "--output",
        output_file,
    ]

    print("Running: {}".format(" ".join(hrt_cmd)))
    hrt_cmd_str = " ".join(hrt_cmd)
    !{hrt_cmd_str}
    print("HRT files created: {}.".format(output_file))

### Copy all files to the dataset folder

In [None]:
os.makedirs(outputdir, exist_ok=True)

files = [
    (os.path.join(hrt_dir, "train.tsv"), os.path.join(outputdir, "train.tsv")),
    (os.path.join(hrt_dir, "valid.tsv"), os.path.join(outputdir, "valid.tsv")),
    (os.path.join(hrt_dir, "test.tsv"), os.path.join(outputdir, "test.tsv")),
    (
        os.path.join(keep_valid_dir, "id_checked.tsv"),
        os.path.join(outputdir, "id_checked.tsv"),
    ),
    (entity_file, os.path.join(outputdir, "annotated_entities.tsv")),
    (kg_file, os.path.join(outputdir, "knowledge_graph.tsv")),
    (annotated_kg_file, os.path.join(outputdir, "annotated_knowledge_graph.tsv")),
]

for f, output_file in files:
    subprocess.check_output(["cp", f, output_file])

print("Please found all files in {}".format(outputdir))