# Generation of dataset splits for SemREC'22

In this notebook we generate training, validation, and test splits for the inferrable statements of the CaLiGraph subsets `clg_10e4` and `clg_10e5` as well as for the full dataset `clg_full`. The splits are sampled randomly and are divided 70%/10%/20% into training, validation, and test sets, respectively.

In [None]:
import requests
import gzip
import random


random.seed(2022)

DATASETS = [
    ('clg_10e4', False),
    ('clg_10e5', False),
    ('clg_full', True),
]

SAMPLE_SIZE_TRAIN = .7
SAMPLE_SIZE_VAL   = .1
# TEST SET SIZE   = .2


def split_dataset(dataset_id: str, is_gzipped: bool):
    superclass_assertions = _load_lines_of_file(dataset_id, 'assertions-superclasses', is_gzipped)
    type_assertions = _load_lines_of_file(dataset_id, 'assertions-transitive-types', is_gzipped)
    relation_assertions = _load_lines_of_file(dataset_id, 'assertions-relations', is_gzipped)
    _create_split_files(dataset_id, superclass_assertions, type_assertions, relation_assertions, is_gzipped)

    
def _load_lines_of_file(dataset_id: str, assertion_type: str, is_gzipped: bool) -> list:
    url = f'http://data.dws.informatik.uni-mannheim.de/CaLiGraph/CaLiGraph-for-SemREC/{dataset_id}/{dataset_id}-{assertion_type}.nt'
    content = gzip.decompress(requests.get(f'{url}.gz').content) if is_gzipped else requests.get(url).content
    return content.decode('utf-8').split('\n')


def _create_split_files(dataset_id: str, superclass_assertions: list, type_assertions: list, relation_assertions: list, is_gzipped: bool):
    superclass_train, superclass_val, superclass_test = _split_into_train_val_test(superclass_assertions)
    print('Superclass', len(superclass_train), len(superclass_val), len(superclass_test))
    type_train, type_val, type_test = _split_into_train_val_test(type_assertions)
    print('Type', len(type_train), len(type_val), len(type_test))
    relation_train, relation_val, relation_test = _split_into_train_val_test(relation_assertions)
    print('Relation', len(relation_train), len(relation_val), len(relation_test))
    
    _write_file(superclass_train + type_train + relation_train, dataset_id, 'train', is_gzipped)
    _write_file(superclass_val + type_val + relation_val, dataset_id, 'val', is_gzipped)
    _write_file(superclass_test + type_test + relation_test, dataset_id, 'test', is_gzipped)

    
def _split_into_train_val_test(lines: list) -> tuple:
    random.shuffle(lines)
    train_end_idx = int(len(lines) * SAMPLE_SIZE_TRAIN)
    val_end_idx = train_end_idx + int(len(lines) * SAMPLE_SIZE_VAL)
    return lines[:train_end_idx], lines[train_end_idx:val_end_idx], lines[val_end_idx:]


def _write_file(lines: list, dataset_id: str, file_type: str, is_gzipped):
    filename = f'{dataset_id}-{file_type}.nt'
    if is_gzipped:
        filename += '.gz'
        open_func = gzip.open
        mode = 'wt'
    else:
        open_func = open
        mode = 'w'
    with open_func(filename, mode=mode) as f_out:
        f_out.writelines(lines)


for dataset_id, is_gzipped in DATASETS:
    print(f'Generating split files for {dataset_id}..')
    split_dataset(dataset_id, is_gzipped)

Generating split files for clg_10e4..
Superclass 59957 8565 17132
Type 51578 7368 14738
Relation 16269 2324 4649
Generating split files for clg_10e5..
Superclass 96274 13753 27508
Type 29974 4282 8565
Relation 138894 19842 39684
Generating split files for clg_full..
Superclass 8518816 1216973 2433949
Type 97099450 13871350 27742700
Relation 7281297 1040185 2080371
