In [1]:
import sys
import os
import argparse
import pandas as pd
import pickle
import yaml
import random
import numpy as np
import torch
import logging
from torch import cat
from kg_processing import *

  from tqdm.autonotebook import tqdm


In [2]:
config = parse_yaml('/home/galadriel/dr_benchmark/processed_kgs/minimal_kg/mixed_test/unpermuted/params.yaml')

In [None]:
# Load knowledge graph
input_file = config["common"]['input_csv']
kg_df = pd.read_csv(input_file, sep="\t")[["my_x_id", "my_y_id", "relation"]]
kg_df = kg_df.rename(columns={'my_x_id': 'from', 'my_y_id': 'to', 'relation': 'rel'})

if config["clean_kg"]["smaller_kg"]:
    logging.info(f"Keeping only relations {config['clean_kg']['keep_relations']}")
    kg_df = kg_df[kg_df['rel'].isin(config["clean_kg"]['keep_relations'])]

kg = my_knowledge_graph.KnowledgeGraph(df=kg_df)


In [4]:
set_random_seeds(config["common"]["seed"])

id_to_rel_name = {v: k for k, v in kg.rel2ix.items()}

if config["clean_kg"]['remove_duplicates_triplets']:
    logging.info("Removing duplicated triplets...")
    kg = my_data_redundancy.remove_duplicates_triplets(kg)

duplicated_relations_list = []

if config['clean_kg']['check_synonymous_antisynonymous']:
    logging.info("Checking for synonymous and antisynonymous relations...")
    theta1 = config['clean_kg']['check_synonymous_antisynonymous_params']['theta1']
    theta2 = config['clean_kg']['check_synonymous_antisynonymous_params']['theta2']
    duplicates_relations, rev_duplicates_relations = my_data_redundancy.duplicates(kg, theta1=theta1, theta2=theta2)
    if duplicates_relations:
        logging.info(f'Adding {len(duplicates_relations)} synonymous relations ({[id_to_rel_name[rel] for rel in duplicates_relations]}) to the list of known duplicated relations.')
        duplicated_relations_list.extend(duplicates_relations)
    if rev_duplicates_relations:
        logging.info(f'Adding {len(rev_duplicates_relations)} anti-synonymous relations ({[id_to_rel_name[rel] for rel in rev_duplicates_relations]}) to the list of known duplicated relations.')
        duplicated_relations_list.extend(rev_duplicates_relations)

if config['clean_kg']["permute_kg"]:
    to_permute_relation_names = config['clean_kg']["permute_kg_params"]
    if len(to_permute_relation_names) > 1:
        logging.info(f'Making permutations for relations {", ".join([rel for rel in to_permute_relation_names])}...')
    for rel in to_permute_relation_names:
        logging.info(f'Making permutations for relation {rel} with id {kg.rel2ix[rel]}.')
        kg = my_data_redundancy.permute_tails(kg, kg.rel2ix[rel])

if config['clean_kg']['make_directed']:
    undirected_relations_names = config['clean_kg']['make_directed_params']
    relation_names = ", ".join([rel for rel in undirected_relations_names])
    logging.info(f'Adding reverse triplets for relations {relation_names}...')
    kg, undirected_relations_list = my_data_redundancy.add_inverse_relations(kg, [kg.rel2ix[key] for key in undirected_relations_names])
        
    if config['clean_kg']['check_synonymous_antisynonymous']:
        logging.info(f'Adding created reverses {[rel for rel in undirected_relations_names]} to the list of known duplicated relations.')
        duplicated_relations_list.extend(undirected_relations_list)

logging.info("Splitting the dataset into train, validation and test sets...")
kg_train, kg_val, kg_test = kg.split_kg(validation=True)

kg_train_ok, _ = verify_entity_coverage(kg_train, kg)
if not kg_train_ok:
    logging.info("Entity coverage verification failed...")
else:
    logging.info("Entity coverage verified successfully.")

if config['clean_kg']['clean_train_set']:
    logging.info("Cleaning the train set to avoid data leakage...")
    logging.info("Step 1: with respect to validation set.")
    kg_train = my_data_redundancy.clean_datasets(kg_train, kg_val, known_reverses=duplicated_relations_list)
    logging.info("Step 2: with respect to test set.")
    kg_train = my_data_redundancy.clean_datasets(kg_train, kg_test, known_reverses=duplicated_relations_list)

kg_train_ok, _ = verify_entity_coverage(kg_train, kg)
if not kg_train_ok:
    logging.info("Entity coverage verification failed...")
else:
    logging.info("Entity coverage verified successfully.")

if config['clean_kg']['rel_swap']:
    kg_train, kg_val, kg_test = specs_sets(kg_train, kg_val, kg_test, config)

new_train, new_val, new_test = my_data_redundancy.ensure_entity_coverage(kg_train, kg_val, kg_test)

NameError: name 'kg' is not defined