# Create partitions for pretraining datasets and incremental datasets

## Imports and Setup

In [25]:
import sys
sys.path.append('../../../entity_typing_analysis/')
import utils

# imports
import os
import numpy as np
import json 
from tqdm import tqdm
import pandas as pd
from collections import defaultdict
from copy import deepcopy
import shutil
# set main directories
DATA = 'few_NERD'
TRAIN_DATA = 'train.json'
DEV_DATA = 'dev.json'
# TEST_DATA = f"test{'-12k' if DATA == 'bbn' else ''}.json"
TEST_DATA = f"test.json"
# SRC_DATA_DIR = f"/home/mvimercati/entity_typing/datasets/ren_et_al/{DATA}{'/relabel' if DATA == 'figer' else ''}"
SRC_DATA_DIR = f"/home/remote_hdd/datasets_for_incremental_training/{DATA}"
DST_DATA_DIR = f'/home/remote_hdd/datasets_for_incremental_training/{DATA}'
# DST_DATA_DIR = os.path.expanduser(f'./{DATA}/complete')
ONTOLOGY_PATH = os.path.join(SRC_DATA_DIR, f"all_types.txt")

## Load ontology

In [26]:
# load ontology
type2id = utils.load_ontology(os.path.join(SRC_DATA_DIR, 'all_types.txt'))
types = list(type2id.keys())
# create hierarchy tree
tree = utils.create_tree(ONTOLOGY_PATH)
tree.show()

thing
├── /art
│   ├── /art/broadcastprogram
│   ├── /art/film
│   ├── /art/music
│   ├── /art/other
│   ├── /art/painting
│   └── /art/writtenart
├── /building
│   ├── /building/airport
│   ├── /building/hospital
│   ├── /building/hotel
│   ├── /building/library
│   ├── /building/other
│   ├── /building/restaurant
│   ├── /building/sportsfacility
│   └── /building/theater
├── /event
│   ├── /event/attack_battle_war_militaryconflict
│   ├── /event/disaster
│   ├── /event/election
│   ├── /event/other
│   ├── /event/protest
│   └── /event/sportsevent
├── /location
│   ├── /location/GPE
│   ├── /location/bodiesofwater
│   ├── /location/island
│   ├── /location/mountain
│   ├── /location/other
│   ├── /location/park
│   └── /location/road_railway_highway_transit
├── /organization
│   ├── /organization/company
│   ├── /organization/education
│   ├── /organization/government_governmentagency
│   ├── /organization/media_newspaper
│   ├── /organization/other
│   ├── /organization/politicalpar

## Create partitions

In [27]:
def relabel(path_to_read, dirpath_to_write, tree, labels_to_remove = [], recreate_dirs=True, only_pretraining=False):
  # prepare empty main folder
  if recreate_dirs:
    if os.path.exists(dirpath_to_write):
      shutil.rmtree(dirpath_to_write)
    else:
      os.makedirs(dirpath_to_write, exist_ok=True)

  # prepare labels to remove
  labels_to_remove = set(labels_to_remove)
  
  # prepare pretraining filepath
  postfix = path_to_read.split('/')[-1].replace('.json','')
  pretraining_path = os.path.join(dirpath_to_write, f'pretraining_{postfix}.json')

  # prepare empty folder for each father of incremental types
  if not only_pretraining:
    incremental_paths = {}
    for label in labels_to_remove:
      # prepare dir name
      father = tree.parent(label).identifier[1:].replace('/', '_')
      incremental_dirpath = os.path.join(dirpath_to_write, f'sons_of_{father}')
      
      # create dir
      if not os.path.exists(incremental_dirpath):
        os.makedirs(incremental_dirpath)

      # map each type to its incremental filepath
      incremental_filepath = os.path.join(incremental_dirpath, f"incremental_{postfix}_{label.split('/')[-1]}.json")
      incremental_paths[label] = incremental_filepath


  with open(path_to_read, 'r') as src, open(pretraining_path, 'a') as pretraining_dst:
    for t in tqdm(src.readlines()):
      
      # read example
      example = json.loads(t)
      
      # create pretraining example
      labels = set(example['y_str'])
      labels_to_keep = labels - labels_to_remove
      labels_removed = labels.intersection(labels_to_remove)
      example['y_str'] = list(labels_to_keep)
      pretraining_dst.write(json.dumps(example)+'\n')

      if not only_pretraining:
        # create incremental examples
        for label_incremental in labels_removed:
          with open(incremental_paths[label_incremental], 'a') as incremental_dst:
            example_incremental = deepcopy(example)
            example_incremental['y_str'].append(label_incremental)
            incremental_dst.write(json.dumps(example_incremental)+'\n')


## Complete family

In [28]:
dst_data_dir = os.path.join(DST_DATA_DIR, 'complete')

In [29]:
leaves_nth = tree.filter_nodes(lambda x : len(tree.children(x.identifier)) == 0
                                      and tree.depth(x.identifier) == tree.depth())
leaves_nth = [*map(lambda x : x.identifier, leaves_nth)]
print(len(leaves_nth), 'leaves:', leaves_nth)

66 leaves: ['/art/broadcastprogram', '/art/film', '/art/music', '/art/other', '/art/painting', '/art/writtenart', '/building/airport', '/building/hospital', '/building/hotel', '/building/library', '/building/other', '/building/restaurant', '/building/sportsfacility', '/building/theater', '/event/attack_battle_war_militaryconflict', '/event/disaster', '/event/election', '/event/other', '/event/protest', '/event/sportsevent', '/location/GPE', '/location/bodiesofwater', '/location/island', '/location/mountain', '/location/other', '/location/park', '/location/road_railway_highway_transit', '/organization/company', '/organization/education', '/organization/government_governmentagency', '/organization/media_newspaper', '/organization/other', '/organization/politicalparty', '/organization/religion', '/organization/showorganization', '/organization/sportsleague', '/organization/sportsteam', '/other/astronomything', '/other/award', '/other/biologything', '/other/chemicalthing', '/other/currency

In [30]:
os.path.join(SRC_DATA_DIR, TRAIN_DATA)

'/home/remote_hdd/datasets_for_incremental_training/few_NERD/train.json'

In [31]:
# train
relabel(path_to_read=os.path.join(SRC_DATA_DIR, TRAIN_DATA),
        dirpath_to_write=dst_data_dir,
        tree=tree,
        labels_to_remove=leaves_nth,
        recreate_dirs=True)



100%|██████████| 338407/338407 [16:14<00:00, 347.17it/s] 


In [32]:
# test
relabel(path_to_read=os.path.join(SRC_DATA_DIR, TEST_DATA),
        dirpath_to_write=dst_data_dir,
        tree=tree,
        labels_to_remove=leaves_nth,
        recreate_dirs=False,
        only_pretraining=True)


100%|██████████| 1980/1980 [00:00<00:00, 25449.39it/s]


## Single child

In [None]:
dst_data_dir = os.path.join(DST_DATA_DIR, 'single_child')

In [None]:
labels_to_remove = []

In [None]:
# train
relabel(path_to_read=os.path.join(SRC_DATA_DIR, TRAIN_DATA),
        dirpath_to_write=dst_data_dir,
        tree=tree,
        labels_to_remove=labels_to_remove,
        recreate_dirs=True)



100%|██████████| 200662/200662 [03:21<00:00, 994.33it/s] 


In [None]:
# test
relabel(path_to_read=os.path.join(SRC_DATA_DIR, TEST_DATA),
        dirpath_to_write=dst_data_dir,
        tree=tree,
        labels_to_remove=labels_to_remove,
        recreate_dirs=False,
        only_pretraining=True)


100%|██████████| 2027/2027 [00:00<00:00, 39168.58it/s]
