In [1]:
import json
with open('zero_rel_type_counts.json', 'r') as f:
    zero_rel_type_counts = json.load(f)
    zero_rel_types = set(l.lower() for l in zero_rel_type_counts.keys())


# WikiZSL
with open('wiki_zsl_all.jsonl', 'r') as f:
    wiki_zsl_all = [json.loads(line) for line in f]

wiki_zsl_rel_type_counts = {}
for item in wiki_zsl_all:
    for relation in item['relations']:
        relation_text = relation['relation_text'].lower()
        wiki_zsl_rel_type_counts[relation_text] = wiki_zsl_rel_type_counts.get(relation_text, 0) + 1

wiki_zsl_intersection_labels = zero_rel_types.intersection(wiki_zsl_rel_type_counts.keys())
print(f"\nWikiZSL: {wiki_zsl_intersection_labels}")


# FewRel
with open('few_rel_all.jsonl', 'r') as f:
    fewrel_all = [json.loads(line) for line in f]

fewrel_type_counts = {}
for item in fewrel_all:
    for relation in item['relations']:
        relation_text = relation['relation_text'].lower()
        fewrel_type_counts[relation_text] = fewrel_type_counts.get(relation_text, 0) + 1

fewrel_intersection_labels = zero_rel_types.intersection(fewrel_type_counts.keys())
print(f"\nFewRel: {fewrel_intersection_labels}")


# Re-DocRED
redocred_all = []
for file in ['redocred_train.jsonl', 'redocred_dev.jsonl', 'redocred_test.jsonl']:
    with open(file, 'r') as f:
        redocred_all.extend([json.loads(line) for line in f])

redocred_type_counts = {}
for item in redocred_all:
    for relation in item['relations']:
        relation_text = relation['relation_text'].lower()
        redocred_type_counts[relation_text] = redocred_type_counts.get(relation_text, 0) + 1

redocred_intersection_labels = zero_rel_types.intersection(redocred_type_counts.keys())
print(f"\nRe-DocRED: {redocred_intersection_labels}")



WikiZSL: {'chief executive officer', 'executive producer', 'sport', 'religion', 'publisher', 'native language', 'cause of death', 'opposite of', 'part of', 'place of burial', 'conflict', 'use', 'separated from', 'child', 'composer', 'editor', 'diocese', 'participant of', 'conferred by', 'occupant', 'replaces', 'date of birth', 'record label', 'league', 'named after', 'drafted by', 'employer', 'license', 'official language', 'educated at', 'military branch', 'currency', 'operating system', 'production company', 'genre', 'parent club', 'instrumentation', 'spouse', 'head of government', 'location', 'candidate', 'depicts', 'producer', 'filming location', 'voice actor', 'relative', 'mother', 'subsidiary', 'allegiance', 'position held', 'shares border with', 'appointed by', 'sibling', 'designed by', 'maintained by', 'approved by', 'published in', 'occupation', 'choreographer', 'used by', 'country', 'manufacturer', 'influenced by', 'crosses', 'highest point', 'movement', 'author', 'place of 

In [2]:
from tqdm import tqdm


def remove_overlapping_relations(original_file_path, final_file_path, intersection_labels):

    intersection_labels.add('no_relation')
    intersection_labels.add('no relation')

    skipped_items = 0
    skipped_relations = 0
    with open(original_file_path, 'r') as fr, open(final_file_path, 'w') as fw:
        for line in tqdm(fr):
            item = json.loads(line)
            relations = item['relations']
            new_relations = []
            for relation in relations:
                rel_text = relation['relation_text']
                if rel_text not in intersection_labels:
                    new_relations.append(relation)
                else:
                    skipped_relations += 1
            
            item['relations'] = new_relations

            # Write the updated item to the new file
            if len(new_relations) > 0:
                fw.write(json.dumps(item) + '\n')
            else:
                skipped_items += 1

    print(f'To not mingle with benchmark datasets, we skipped {skipped_items} items and {skipped_relations} relations')


zero_rel_path = 'zero_rel_all.jsonl'
output_paths = ['zero_rel_all_diff_wiki_zsl.jsonl', 'zero_rel_all_diff_few_rel.jsonl', 'zero_rel_all_diff_redocred.jsonl']
intersection_labels_list = [wiki_zsl_intersection_labels, fewrel_intersection_labels, redocred_intersection_labels]


for output_path, intersection_labels in zip(output_paths, intersection_labels_list):
    remove_overlapping_relations(original_file_path=zero_rel_path, final_file_path=output_path, intersection_labels=intersection_labels)

63493it [03:24, 311.10it/s]


To not mingle with benchmark datasets, we skipped 70 items and 8351206 relations


63493it [03:26, 307.35it/s]


To not mingle with benchmark datasets, we skipped 21 items and 7921100 relations


63493it [03:25, 308.77it/s]

To not mingle with benchmark datasets, we skipped 23 items and 8117430 relations





In [3]:
import json
import tqdm

pbar = tqdm.tqdm()
with open('zero_rel_all_diff_wiki_zsl.jsonl', 'r') as f:
    zero_rel_wiki_diff = (json.loads(line) for line in f)

    zero_rel_type_counts = {}
    for item in zero_rel_wiki_diff:
        for relation in item['relations']:
            relation_text = relation['relation_text']
            zero_rel_type_counts[relation_text] = zero_rel_type_counts.get(relation_text, 0) + 1
        pbar.update(1)

with open('wiki_zsl_all.jsonl', 'r') as f:
    wiki_zsl_all = [json.loads(line) for line in f]

wiki_zsl_rel_type_counts = {}
for item in wiki_zsl_all:
    for relation in item['relations']:
        relation_text = relation['relation_text']
        wiki_zsl_rel_type_counts[relation_text] = wiki_zsl_rel_type_counts.get(relation_text, 0) + 1

intersection_labels = set(zero_rel_type_counts.keys()).intersection(wiki_zsl_rel_type_counts.keys())
print(intersection_labels)

204it [00:00, 977.37it/s] 

63422it [01:27, 990.44it/s] 

set()


63423it [01:38, 990.44it/s]