# Collect samples for TaxoLLaMA fine-tuning

In [2]:
from __future__ import annotations

import os
import random
import pickle
import glob

import nltk

import numpy as np
import networkx as nx

from typing import Union
from random import sample
from copy import deepcopy
from itertools import combinations

from tqdm.auto import tqdm

from leafer import Leafer

nltk.download('wordnet')
from nltk.corpus import wordnet as wn


seed = 42
random.seed(seed)
np.random.seed(seed)

[nltk_data] Downloading package wordnet to /home/jupyter/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


## Create graph from WordNet

In [3]:
def build_graph() -> nx.DiGraph:
    '''
    Create DAG from WordNet 3.0 synsets
    
    Returns: constructed graph
    '''
    G = nx.DiGraph()
    for synset in tqdm(wn.all_synsets('n')):
        name = synset.name()
        G.add_node(name)
        hyponyms = synset.hyponyms()

        for hypo in hyponyms:
            new_name = hypo.name()
            G.add_node(new_name)
            G.add_edge(name, new_name)

    for synset in tqdm(wn.all_synsets('v')):
        name = synset.name()
        G.add_node(name)
        hyponyms = synset.hyponyms()

        for hypo in hyponyms:
            new_name = hypo.name()
            G.add_node(new_name)
            G.add_edge(name, new_name)
    return G

def delete_cycles(G: nx.DiGraph) -> nx.DiGraph:
    '''
    Delete cycles in DAG
    
    Returns: DAG without cycles
    '''
    while True:
        try:
            cycle = nx.find_cycle(G)
            G.remove_edge(*cycle[0])
        except:
            break
    return G

### Exclude test nodes from graph

In [14]:
DATASET_PATHS = {
    'MAGS': {
        'cs': os.path.abspath("../datasets/TaxonomyEnrichment/data/MAG_CS/test_hypernyms_def.pickle"),
        'psy': os.path.abspath("../datasets/TaxonomyEnrichment/data/psychology/test_hypernyms_def.pickle"),
        'verb': os.path.abspath("../datasets/TaxonomyEnrichment/data/noun/test_hypernyms_def.pickle"),
        'noun': os.path.abspath("../datasets/TaxonomyEnrichment/data/verb/test_hypernyms_def.pickle")
    },
    'SE': {
        'main': os.path.abspath("../datasets/SemEval2018-Task9/custom_datasets/1A.english.pickle"),
        'medical': os.path.abspath("../datasets/SemEval2018-Task9/custom_datasets/2A.medical.pickle"),
        'music': os.path.abspath("../datasets/SemEval2018-Task9/custom_datasets/2B.music.pickle")
    },       
    'TE': {
        'env': os.path.abspath("../datasets/TExEval-2_testdata_1.2/all_data/gs_taxo/EN/environment_eurovoc_en.taxo"),
        'sci': os.path.abspath("../datasets/TExEval-2_testdata_1.2/all_data/gs_taxo/EN/science_eurovoc_en.taxo")
    }
}

In [9]:
#MAG
def remove_mags(
    G: nx.DiGraph,
    paths: dict,
    test_parents=[],
    deleted=[],
    k=0,
    cs=True,
    psy=True,
    noun=True,
    verb=True,
    **kwargs
) -> tuple[nx.DiGraph, list[str], list[str], int]:
    '''
    Associate items from MAG datasets with WordNet synsets and remove them from graph
    
    Arguments:
        G - WordNet DAG
        paths - dict with paths to datasets
        test_parents - list of test nodes to fill, optional
        deleted - list of deleted nodes to fill, optional
        k - number of removed items, optional
        cs - whether to remove MAG CS items, optional
        psy - whether to remove MAG PSY items, optional
        noun - whether to remove MAG Nouns items, optional
        verb - whether to remove MAG Verbs items, optional
        
    Returns: G, test_parents, deleted, k
    '''
    if cs:
        cs_test_path = paths['MAGS']['noun']

        with open(cs_test_path, 'rb') as f:
            cs_test = pickle.load(f)
            
        for node in cs_test:
            for i in range(10):
                true_name = f"{node['children']}.n.0{i}"
                if true_name in G.nodes():
                    test_parents.extend(G.predecessors(true_name))
                    deleted.append(true_name)
                    G.remove_node(true_name)
                    k += 1
    if psy:
        psy_test_path = paths['MAGS']['psy']

        with open(psy_test_path, 'rb') as f:
            psy_test = pickle.load(f)

        for node in psy_test:
            for i in range(10):
                true_name = f"{node['children']}.n.0{i}"
                if true_name in G.nodes():
                    test_parents.extend(G.predecessors(true_name))
                    deleted.append(true_name)
                    G.remove_node(true_name)
                    k += 1
    if verb:
        verb_test_path = paths['MAGS']['verb']
        
    with open(verb_test_path, 'rb') as f:
        verb_test = pickle.load(f)
        
    for node in verb_test:
        if node['children'] in G.nodes():
            test_parents.extend(G.predecessors(node['children']))
            deleted.append(node['children'])
            G.remove_node(node['children'])
            k += 1
    if noun:
        noun_test_path = paths['MAGS']['noun']
        
        with open(noun_test_path, 'rb') as f:
            noun_test = pickle.load(f)
        
        for node in noun_test:
            if node['children'] in G.nodes():
                test_parents.extend(G.predecessors(node['children']))
                deleted.append(node['children'])
                G.remove_node(node['children'])
                k += 1
    
    return G, test_parents, deleted, k

In [6]:
#SemEval-2018
def remove_semeval(
    G: nx.DiGraph,
    paths: dict,
    test_parents=[],
    deleted=[],
    k=0,
    main=True,
    medical=False,
    music=False,
    **kwargs
) -> tuple[nx.DiGraph, list[str], list[str], int]:
    '''
    Associate items from SemEval-2018 Task 9 datasets with WordNet synsets and remove them from graph
        
    Arguments:
        G - WordNet DAG
        paths - dict with paths to datasets
        test_parents - list of test nodes to fill, optional
        deleted - list of deleted nodes to fill, optional
        k - number of removed items, optional
        main - whether to remove SE-18 1A: English items, optional
        medical - whether to remove SE-18 2A: Medical items, optional
        music - whether to remove SE-18 2B: Music items, optional
        
    Returns: G, test_parents, deleted, k
    '''
    if main:
        main_path = paths['SE']['main']

        with open(main_path, 'rb') as f:
            main = pickle.load(f)

        for elem in main:
            node = elem['children'].replace(' ', '_')
            for i in range(10):
                true_name = f'{node}.n.0{i}'
                if true_name in G.nodes():
                    test_parents.extend(G.predecessors(true_name))
                    deleted.append(true_name)
                    G.remove_node(true_name)
                    k += 1
    if medical:
        medical_path = paths['SE']['medical']
        
        with open(medical_path, 'rb') as f:
            medical = pickle.load(f)
            
        for elem in medical:
            node = elem['children'].replace(' ', '_')
            for i in range(10):
                true_name = f'{node}.n.0{i}'
                if true_name in G.nodes():
                    test_parents.extend(G.predecessors(true_name))
                    deleted.append(true_name)
                    G.remove_node(true_name)
                    k += 1
    if music:
        music_path = paths['SE']['music']

        with open(music_path, 'rb') as f:
            music = pickle.load(f)

        for elem in music:
            node = elem['children'].replace(' ', '_')
            for i in range(10):
                true_name = f'{node}.n.0{i}'
                if true_name in G.nodes():
                    test_parents.extend(G.predecessors(true_name))
                    deleted.append(true_name)
                    G.remove_node(true_name)
                    k += 1
                
    return G, test_parents, deleted, k

In [7]:
# TexEval
def remove_texeval(
    G: nx.DiGraph,
    paths: dict,
    test_parents=[],
    deleted=[],
    k=0,
    env=True,
    sci=True,
    **kwargs
) -> tuple[nx.DiGraph, list[str], list[str], int]:
    '''
    Associate items from SemEval-2016 Task 13 (TexEval-2) datasets with WordNet synsets and remove them from graph
        
    Arguments:
        G - WordNet DAG
        paths - dict with paths to datasets
        test_parents - list of test nodes to fill, optional
        deleted - list of deleted nodes to fill, optional
        k - number of removed items, optional
        env - whether to remove SE-16 TaxEval Environment items, optional
        sci - whether to remove SE-16 TaxEval Science items, optional
        
    Returns: G, test_parents, deleted, k
    '''
    
    G_test = nx.DiGraph()
    
    if env:
        env_path = paths['TE']['env']
        
        with open(env_path, "r") as f:
            for line in f:
                idx, hypo, hyper = line.split("\t")
                hyper = hyper.replace("\n", "")
                G_test.add_node(hypo)
                G_test.add_node(hyper)
                G_test.add_edge(hyper, hypo)

        for node in G_test.nodes():
            for i in range(10):
                true_name = f'{node}.n.0{i}'
                if true_name in G.nodes():
                    test_parents.extend(G.predecessors(true_name))
                    deleted.append(true_name)
                    G.remove_node(true_name)
                    k += 1
    if sci:
        sci_path = paths['TE']['sci']

        with open(sci_path, "r") as f:
            for line in f:
                idx, hypo, hyper = line.split("\t")
                hyper = hyper.replace("\n", "")
                G_test.add_node(hypo)
                G_test.add_node(hyper)
                G_test.add_edge(hyper, hypo)

        for node in G_test.nodes():
            for i in range(10):
                true_name = f'{node}.n.0{i}'
                if true_name in G.nodes():
                    test_parents.extend(G.predecessors(true_name))
                    deleted.append(true_name)
                    G.remove_node(true_name)
                    k += 1
    return G, test_parents, deleted, k

### Create DAG from WordNet and clean it from test nodes

In [10]:
def create_cleaned_graph(paths:dict, MAGS=True, SE=True, TE=True, **kwargs) -> tuple[nx.DiGraph, list[str], list[str]]:
    '''
    Create DAG from WordNet 3.0 and remove test nodes from Taxonomy Enrichment and Hypernym Discovery datasets
    Extract parent nodes of the test items
    
    Arguments:
        paths - dict with paths to datasets
        MAGS - whether to remove items from MAG datasets, optional
        SE - whether to remove items from SemEval-2018 Task 9 datasets, optional
        TE - whether to remove items from SemEval-2016 Task 13 datasets, optional
        kwargs - dict to use to pass subtask dataset indicators, optional
        
    Returns:
        G - created graph
        test_parents - list of parents of the test items
        deleted - list of removed test items
    '''
    G = build_graph()
    k = 0
    test_parents, deleted = [], []
    if MAGS:
        G, test_parents, deleted, k = remove_mags(G, paths, test_parents, deleted, k, **kwargs)
    #print(k)
    if SE:
        G, test_parents, deleted, k = remove_semeval(G, paths, test_parents, deleted, k, **kwargs)
    #print(k)
    if TE:
        G, test_parents, deleted, k = remove_texeval(G, paths, test_parents, deleted, k, **kwargs)
    #print(k)
    G = delete_cycles(G)
    return G, test_parents, deleted

## Tools to investigate and modify graph contents

In [11]:
def get_sisters(G: nx.DiGraph, test_parents:list[str], ref_set=None) -> list[str]:
    '''
    Get children of the parent nodes 
    
    Arguments:
        G - WordNet graph, cleaned from test items
        test_parents - parents of the removed test items
        ref_set = list of items to check containment, optional
        
    Returns: list of children nodes from graph and ref_set if passed
    '''
    sisters = []
    for n in set(test_parents):
        try:
            for sister in G.successors(n):
                if not ref_set or sister in ref_set:
                    sisters.append(sister)
        except nx.NetworkXError:
            #print(n)
            pass
    return list(set(sisters))


def get_fraction(train_set: list[dict], sisters=None, out_list=False) -> Union[list[str], float]:
    '''
    Get items from the training sample:
        get input items
        get items from a specific set (sisters) -> only take items that occur as input hyponyms
        get fraction of removed items from a specific set (sisters) if out_list is False
    
    Arguments:
        train_set - list of dict with field 'children'
        sisters - a set to search items from
        out_list = get fraction of items from sisters, that are NOT present in the train_set
        
    Returns: list of str or float
    '''
    if sisters is not None:
        rest = list(set(
            [elem['children'] for elem in train_set if elem['children'] in sisters]
        ))
        return rest if out_list else 1 - len(rest)/len(sisters)
    return [elem['children'] for elem in train_set]

In [12]:
def save_sample(sample: list[dict], mode: str, path: str) -> None:
    '''
    Pickle a training sample
    
    Arguments
        sample - training sample
        mode - experimental condition (for printing)
        path - saving path
    '''
    print(mode, len(sample))
    with open(path, 'wb') as f:
        pickle.dump(sample, f)
        
def get_train(G: nx.DiGraph, seed=42, with_test=False) -> Unioin[tuple[list[dict], list[dict]], list[dict]]:
    '''
    Sample hyponym-hypernym pairs from WordNet DAG
    
    Arguments
        G - WordNet DAG
        seed - random seed, optional
        with_test - divide into train and test, optional
    '''
    np.random.seed(seed)
    random.seed(seed)
    l = Leafer(G)
    p = 0.001 if with_test else 0.0
    train, test = l.split_train_test(
        generation_depth=0,
        p=p,
        p_divide_leafs=0.5,
        min_to_test_rate=0.5,
        weights=[0.00, 0.0, 0.0, 0.00, 0.00, 1.],
    )
    return (train, test) if with_test else train

In [13]:
def get_base_sample(path, **kwargs) -> None:
    '''
    Get best sample from graph (only test node deleted)
    '''
    G, test_parents, _ = create_cleaned_graph(**kwargs)
    train = get_train(G)
    save_sample(train, 'full', path)

In [14]:
SE181A = dict(
    paths = DATASET_PATHS,
    TE=False,
    MAGS=False,
    main=True,
    medical=False,
    music=False
)

In [17]:
DIRNAME = os.path.abspath("../samples")
if not os.path.isdir(DIRNAME):
    os.mkdir(DIRNAME)

In [33]:
#construct full dataset
get_base_sample(os.path.join(DIRNAME, 'clean_train.pickle'), **SE181A)

predict_hypernym 40636 40636
full 40636


In [16]:
def read_file(path):
    with open(path, 'rb') as f:
        return pickle.load(f)

## Constructing samples with removed items

### Removing cohyponym items

In [17]:
G, test_parents, _ = create_cleaned_graph(**SE181A)
sisters = get_sisters(G, test_parents)
cleaned_full = read_file(os.path.join(DIRNAME, 'clean_train.pickle'))
rest_sisters = get_fraction(cleaned_full, sisters, out_list=True)

82115it [00:05, 14036.85it/s]
13767it [00:00, 16165.52it/s]


In [18]:
def pair_for_sisters(sisters: list[str], dataset: list[dict]) -> dict:
    '''
    Get (hyponym, hypernym) pairs for sisters in dataset where sisters are either children or parents
    
    Arguments:
      sisters - list of sister node names
      dataset - trainig dataset to pick sisters from
      
    Return: dict with sisters as children and as parents
    '''
    return {
        'children': [(item['children'], item['parents']) for item in tqdm(dataset) if item['children'] in sisters], #if any([sister in item['children'] for sister in sisters])],
        'parents': [(item['children'], item['parents']) for item in tqdm(dataset) if item['parents'] in sisters]#if any([sister in item['parents'] for sister in sisters])],
    }

In [19]:
sister_pairs = pair_for_sisters(rest_sisters, cleaned_full)
total_sister_items = sorted(list(set(sister_pairs['children'] + sister_pairs['parents'])))

100%|██████████| 40636/40636 [00:02<00:00, 13588.27it/s]
100%|██████████| 40636/40636 [00:03<00:00, 13331.31it/s]


In [20]:
len(total_sister_items)

8776

In [21]:
print('DATSET SIZE:\t\t\t', len(cleaned_full))
print('TOTAL NUMBER OF SISTERS:\t', len(sisters))
print('FINAL NUMBER OF SISTERS:\t', len(rest_sisters))

DATSET SIZE:			 40636
TOTAL NUMBER OF SISTERS:	 15295
FINAL NUMBER OF SISTERS:	 5535


In [22]:
def sister_to_items(sisters: list[str], pairs: dict):
    '''
    Refactor `pair_for_sisters` output to {synset_name: {children: [], parents: []}, ...}
    
    Arguments:
      sisters - list of sister synset names
      pairs - dict of format {children: [], parents: []}
      
    Return: dict with sisters as children and as parents
    '''
    return {
        sister: {
            'children': [item for item in pairs['children'] if sister == item[0]],
            'parents': [item for item in pairs['parents'] if sister == item[1]],
        } for sister in sisters
    }


In [23]:
sisters_with_items = sister_to_items(rest_sisters, sister_pairs)

In [24]:
from collections import Counter
cnt = Counter([len(sisters_with_items[k]['children']) for k in sisters_with_items])
print(cnt)
cnt = Counter([len(sisters_with_items[k]['parents']) for k in sisters_with_items])
print(cnt)

Counter({1: 5329, 2: 202, 3: 4})
Counter({0: 4319, 1: 534, 2: 275, 3: 129, 4: 61, 5: 59, 6: 42, 7: 19, 9: 17, 8: 15, 11: 12, 10: 11, 12: 9, 13: 6, 15: 4, 14: 4, 22: 2, 19: 2, 21: 2, 23: 2, 18: 2, 17: 2, 26: 1, 39: 1, 42: 1, 27: 1, 20: 1, 24: 1, 16: 1})


In [25]:
def sort_sisters(sisters: dict):
    '''
    Sort sisters by number of associated items
    '''
    sisters = {k: v['parents'] + v['children'] for k, v in sisters.items()}
    return sorted(list(sisters.items()), key=lambda x: len(x[1])) #stack

def get_shards(sisters: list[dict], total_sister_items: list[tuple]) -> list[list]:
    '''
    Prepare shards for 4-fold cross-validation: normalize by number of items associated to each sister in shard
    
    Arguments:
        sisters - sister with associated items terms sorted by number of associated items
        total_sisters - list of all items associated with sisters
        
    Returns: list of shards
    '''
    shard_length = len(total_sister_items) // 4 #4 fold cross-validation
    shards = [[], [], [], []]
    added = []
    while any([len(shard) < shard_length for shard in shards]) and sisters: #stack is not empty
        for i in range(4):
            if len(shards[i]) >= shard_length:
                continue
            _, v = sisters.pop() #taking the sister with maximum number of associated items from the stack
            to_add = [item for item in v if item not in added]
            added.extend(to_add)
            shards[i].extend(to_add)
            continue
    return shards
            

def get_sample(shards: list[list], dataset: list[dict], scope: float, i: int) -> list[dict]:
    '''
    Get sample for 4-fold cross-validation depending on proportion and index of fold
    
    Arguments:
        shards - data partitions to use in sample construction
        dataset - full training sample
        scope - proportion of removed items
        i - index of shard
        
    Returns: list of items dicts
    '''
    if scope == 0.25:
        to_remove = shards[i]
    if scope == 0.75:
        to_remove = [pair for j in range(4) for pair in shards[j] if j != i]
    if scope == 0.5:
        indexes = list(combinations(range(4), 2))[i]
        to_remove = shards[indexes[0]] + shards[indexes[1]]
    if scope == 1:
        to_remove = [item for shard in shards for item in shard]
    #print(scope, i, to_remove, sep='\n')
    return [item for item in dataset if (item['children'], item['parents']) not in to_remove]
        
    
def get_sister_folds(shards: list[dict], dataset: list[dict], scope: float) -> list[dict]:
    '''
    Generate samples by removing shards from the training set following 4-fold cross-validation algorithm
    
    Arguments:
        shards - data partitions to use in sample construction
        dataset - trainig sample
        scope - proportion of items to remove
        
    Returns: list of concstructed samples
    '''
    k = 6 if scope == 0.5 else 4
    return [get_sample(shards, dataset, scope, i) for i in tqdm(range(k))]


In [28]:
shards = get_shards(sort_sisters(sisters_with_items), total_sister_items)
#25%-75%
for scope in np.linspace(0.25, 0.75, 3):
    folds = get_sister_folds(shards, cleaned_full, scope)
    for i in range(4):
        save_sample(folds[i], 'sister', f'{DIRNAME}/no_sister_{int(scope*100)}_fold_{i+1}_seed_42.pickle')
#100%
save_sample(
    get_sample(shards, cleaned_full, scope=1.0, i=0),
    'sister',
    f'{DIRNAME}/no_sister_100_seed_42.pickle'
)

100%|██████████| 4/4 [00:08<00:00,  2.01s/it]

sister 38442
sister 38442
sister 38442





sister 38442


100%|██████████| 6/6 [00:28<00:00,  4.80s/it]

sister 36248
sister 36248





sister 36248
sister 36248


100%|██████████| 4/4 [00:35<00:00,  8.85s/it]

sister 34054
sister 34054
sister 34054





sister 34054
sister 31860


### Removing non-cohyponym items

#### Get distances to nearest test items for shard normalization

In [31]:
def get_wn_path(s1: str, s2: str) -> float:
    '''
    Get path between to items in the WordNet
    
    Arguments:
        s1 -  name of a synset, e.g. 'synset1.n.01'
        s2 - name of a synset, e.g. 'synset2.n.01'
        
    Returns: distance between two synsets
    '''
    return (1 / s1.path_similarity(s2)) - 1

def get_distances(s1, s2_synsets) -> float:
    '''
    Calculate distance from synset s1 to the nearest synset from s2_synsets
    '''
    return min([get_wn_path(s1, s2) for s2 in s2_synsets])


In [32]:
# test synsets
SE_synsets = [wn.synset(s2) for s2 in create_cleaned_graph(**SE181A)[2]]
# non-cohyponym synsets
CLEANED_synsets = [
    (
        wn.synset(item['children']), wn.synset(item['parents'])
    ) for item in cleaned_full if (item['children'], item['parents']) not in total_sister_items
]

82115it [00:01, 63410.82it/s]
13767it [00:00, 61020.97it/s]


In [None]:
#import multiprocessing
#from joblib import Parallel, delayed

#N_JOBS = multiprocessing.cpu_count() - 1
#calculate distances to children synsets
#min_distances = Parallel(n_jobs=N_JOBS)(delayed(get_distances)(s1[0]) for s1 in tqdm(CLEANED_synsets))

In [37]:
#load calculated distances
with open('wn_min_distances.txt', 'r') as fin:
    distances = [int(x.strip()) for x in fin.readlines()]

In [38]:
from sklearn.cluster import KMeans

#train kmeans to define clusters of distances
min_distances = np.array(distances).reshape(-1, 1)
kmeans = KMeans(n_clusters=3, random_state=42)

kmeans.fit(min_distances)
clusters = kmeans.predict(min_distances)

partitions = [[], [], []]
for number, cluster in zip(min_distances.flatten(), clusters):
    partitions[cluster].append(number)

In [39]:
print(list(map(lambda x: max(x), partitions)))
print(list(map(lambda x: min(x), partitions)))
print(list(map(lambda x: len(x), partitions)))

[5, 2, 15]
[3, 1, 6]
[18195, 9801, 3864]


In [40]:
# items without cohyponym items
cleaned_other = [item for item in tqdm(cleaned_full) if (item['children'], item['parents']) not in total_sister_items]
# items with assigned cluster
total_other = {(item['children'], item['parents']): cl for item, cl in tqdm(zip(cleaned_other, clusters))}

100%|██████████| 40636/40636 [00:12<00:00, 3149.54it/s]
31860it [00:00, 1113327.94it/s]


In [44]:
# all non-cohyponym synset names
all_other = list(set([x for item in tqdm(total_other) for x in item]))
# synset names with associated items
node_to_items = {
    elem: [item for item in total_other if elem in item]
    for elem in tqdm(all_other)
}

100%|██████████| 31860/31860 [00:00<00:00, 2174020.62it/s]
100%|██████████| 39209/39209 [01:27<00:00, 448.44it/s]


In [45]:
def item_label(labels: list) -> int:
    '''
    Get label for synset based on paths from associated items to nearest test nodes
    '''
    if 2 in labels:
        return 2
    if 0 in labels:
        return 0
    return 1

In [46]:
#assign label to each non-cohyponym synset depending on its associated item labels 
paths = {k:item_label([total_other[pair] for pair in v]) for k, v in tqdm(node_to_items.items())}

100%|██████████| 39209/39209 [00:00<00:00, 211310.17it/s]


In [47]:
def group_by_clusters(paths: dict, node_to_items: dict) -> list[list]:
    '''
    Divide items by clusters (based on distances to test items) and shuffle
    
    Arguments:
       paths - synset-label pairs
       node_to_items - synset-items pairs
       
    Returns: list of items grouped by cluster
    '''
    clustered = []
    for i in range(3):
        clustered.append(list({k:node_to_items[k] for k, v in paths.items() if v == i}.items()))
        random.seed(42)
        random.shuffle(clustered[i])
    return clustered

In [48]:
clustered = group_by_clusters(paths, node_to_items)

In [49]:
[len(group) for group in clustered]

[22688, 9791, 6730]

#### Generate samples

In [50]:
def get_other_shards(clustered_stack: list[list], partitions: list[list], total_other: list[str], max_length: int, seed=42) -> list[list]:
    '''
    Prepare shards for 4-fold cross-validation: normalize by path to nearest test node
    
    Arguments:
        clustered_stack - non-cohyponym items grouped by path to nearest test node
        total_other - list of all items associated with non-cohyponyms
        max_length - size of shard (used for sister sample construction)
        seed - random seed, optional
        
    Returns: list of shards
    '''
    total_n = len(total_other)
    np.random.seed(seed)
    def next_cluster(n=3): return np.random.choice(np.arange(0, n), p=[len(px) / total_n for px in partitions]) #4 fold cross-validation
    shards = [[], [], [], []]
    added = []
    n_clusters = 3
    while any([len(shard) < max_length for shard in shards]) and clustered_stack: #stack is not empty
        for i in range(4):
            if len(shards[i]) >= max_length:
                continue
            j = next_cluster(n_clusters)
            _, v = clustered_stack[j].pop() 
            to_add = [item for item in v if item not in added]
            added.extend(to_add)
            shards[i].extend(to_add)
            if not clustered_stack[j]:
                total_n -= len(partitions[j])
                partitions.pop(j)
                clustered_stack.pop(j)
                n_clusters -= 1
            continue
    return shards
            

def get_sample(shards: list[dict], dataset: list[dict], scope: float, i: int) -> list[dict]:
    '''
    Get sample for 4-fold cross-validation depending on proportion and index of fold
    
    Arguments:
        shards - data partitions to use in sample construction
        dataset - full training sample
        scope - proportion of removed items
        i - index of shard
        
    Returns: list of items dicts
    '''
    if scope == 0.25:
        to_remove = shards[i]
    if scope == 0.75:
        to_remove = [pair for j in range(4) for pair in shards[j] if j != i]
    if scope == 0.5:
        indexes = list(combinations(range(4), 2))[i]
        to_remove = shards[indexes[0]] + shards[indexes[1]]
    if scope == 1.0:
        to_remove = [item for shard in shards for item in shard]
    #print(scope, i, to_remove, sep='\n')
    return [item for item in dataset if (item['children'], item['parents']) not in to_remove]
        
    
def get_other_folds(shards: list[list], dataset: list[dict], scope: float) -> list[list]:
    '''
    Generate samples by removing shards from the training set following 4-fold cross-validation algorithm
    
    Arguments:
        shards - shards to compose the subset for removal from
        dataset - trainig sample
        scope - proportion of items to remove
        
    Returns: list of concstructed samples
    '''
    k = 6 if scope == 0.5 else 4
    return [get_sample(shards, dataset, scope, i) for i in tqdm(range(k))]


In [51]:
#generate shards
shards = get_other_shards(clustered, partitions, total_other, len(total_sister_items) // 4, seed=42)

In [52]:
from collections import Counter
[Counter([total_other[item] for item in shard]).items() for shard in shards]

[dict_items([(0, 1219), (1, 793), (2, 182)]),
 dict_items([(2, 222), (0, 1243), (1, 729)]),
 dict_items([(1, 759), (0, 1245), (2, 190)]),
 dict_items([(1, 743), (2, 191), (0, 1260)])]

In [53]:
[len(shard) for shard in shards]

[2194, 2194, 2194, 2194]

In [54]:
#25%-75%
for scope in np.linspace(0.25, 0.75, 3):
    folds = get_other_folds(shards, cleaned_full, scope)
    k = 4 if scope != 0.5 else 6
    for i in range(k):
        save_sample(folds[i], 'compl', f'{DIRNAME}/no_compl_{int(scope*100)}_fold_{i+1}_seed_42.pickle')
#100%
save_sample(
    get_sample(shards, cleaned_full, scope=1.0, i=0),
    'compl',
    f'{DIRNAME}/no_compl_100_seed_42.pickle'
)

100%|██████████| 4/4 [00:08<00:00,  2.11s/it]

compl 38442
compl 38442
compl 38442
compl 38442



100%|██████████| 6/6 [00:29<00:00,  4.96s/it]

compl 36248





compl 36248
compl 36248
compl 36248
compl 36248
compl 36248


100%|██████████| 4/4 [00:36<00:00,  9.24s/it]

compl 34054
compl 34054
compl 34054
compl 34054





compl 31860


### Removing random items

In [289]:
def next_node_type(nodes):
    n_nodes = np.sum([len(node_type) for node_type in nodes])
    return np.random.choice(np.arange(0, len(nodes)), p=[len(node_type) / n_nodes for node_type in nodes])

def get_random_sample(dataset: list[dict], items: dict, nodes: list[list], length: int, scope: float, seed=42) -> list[dict]:
    '''
    Get samples with randomly deleted nodes 
    
    Arguments:
        dataset - trainig sample
        items - dict of synset with associated items pairs
        nodes - list of list with nodes by type
        length - size of shard (used for sister sample construction)
        scope - proportion of items to remove
        
    Returns: constructed sample
    '''
    random.seed(seed)
    np.random.seed(seed)
    to_remove = []
    while len(to_remove) < length * scope:
        node_type = next_node_type(nodes)
        index = random.randint(0, len(nodes[node_type]) - 1)
        candidates = [item for item in items[nodes[node_type][index]] if item not in to_remove]
        to_remove.extend(candidates)
        nodes[node_type].pop(index)
        
    return [item for item in dataset if (item['children'], item['parents']) not in to_remove]

In [286]:
#collect all hyponym-hypernym pairs as tuples
all_items = [(item['children'], item['parents']) for item in tqdm(cleaned_full)]
#collect all synsets
all_synsets = list(set([x for item in tqdm(all_items) for x in item]))
#synset names with associated items
node_to_items = {
    elem: [item for item in all_items if elem in item]
    for elem in tqdm(all_synsets)
}

100%|██████████| 40636/40636 [00:00<00:00, 1110646.01it/s]
100%|██████████| 40636/40636 [00:00<00:00, 2047545.53it/s]
100%|██████████| 47710/47710 [03:41<00:00, 214.95it/s]


In [287]:
max_length = len(total_sister_items)
SEEDS = [42, 13, 99]
nodes = [all_other, rest_sisters]

In [290]:
for seed in SEEDS:
    for scope in tqdm(np.linspace(0.25, 1.0, 4)):
        save_sample(
            get_random_sample(cleaned_full, node_to_items, deepcopy(nodes), max_length, scope, seed=seed),
            'random', f'{DIRNAME}/no_rand_{int(scope*100)}_seed_{seed}.pickle'
        )

 25%|██▌       | 1/4 [00:03<00:09,  3.22s/it]

random 38442


 50%|█████     | 2/4 [00:11<00:12,  6.48s/it]

random 36248


 75%|███████▌  | 3/4 [00:27<00:10, 10.57s/it]

random 34053


100%|██████████| 4/4 [00:48<00:00, 12.03s/it]


random 31860


 25%|██▌       | 1/4 [00:03<00:09,  3.30s/it]

random 38442


 50%|█████     | 2/4 [00:12<00:13,  6.56s/it]

random 36247


 75%|███████▌  | 3/4 [00:27<00:10, 10.51s/it]

random 34053


100%|██████████| 4/4 [00:47<00:00, 11.97s/it]


random 31859


  0%|          | 0/4 [00:00<?, ?it/s]

random 38442


 25%|██▌       | 1/4 [00:03<00:09,  3.11s/it]

random 36248


 75%|███████▌  | 3/4 [00:27<00:10, 10.50s/it]

random 34053
random 31860


100%|██████████| 4/4 [00:48<00:00, 12.07s/it]


## Constructing samples with definitions

In [303]:
def add_definitions(elem: dict):
    '''
    Add definition to dataset item depending on condition
    
    Arguments:
        elem - item of dataset
    '''
    if elem['case'] == 'predict_hypernym':
        elem['child_def'] = wn.synset(elem['children'].replace(' ', '_')).definition()
      #  elem['parent_def'] = wn.synset(elem['parents']).definition()
    elif elem['case'] == 'predict_multiple_hypernyms':
        elem['child_def'] = wn.synset(elem['children'].replace(' ', '_')).definition()

    elif elem['case'] == 'simple_triplet_grandparent':
        elem['child_def'] = wn.synset(elem['children'].replace(' ', '_')).definition()
        elem['grandparent_def'] = wn.synset(elem['grandparents']).definition()
    elif elem['case'] == 'only_child_leaf':
        elem['grandparent_def'] = wn.synset(elem['grandparents']).definition()
        elem['parent_def'] = wn.synset(elem['parents']).definition()
    elif elem['case'] == 'simple_triplet_2parent':
        elem['1parent_def'] = wn.synset(elem['parents'][0]).definition()
        elem['2parent_def'] = wn.synset(elem['parents'][1]).definition()
    else:
        elem['parent_def'] = wn.synset(elem['parents']).definition()

def definitions(train: list[dict], subset=None, inverse=False) -> list[dict]:
    '''
    Add definitions to items of train or subset depending on condition (inverse)
        train - training sample
        subset - subset to check for current element containment
        inverse - if False, assign definition to elements from the subset only, else - to all except elements in the subset 
        
    Returns: train
    '''
    added_def_counter = 0
    for i, elem in enumerate(train):
        try:
            if inverse:
                if not subset or elem['children'] not in subset:
                    add_definitions(elem)
                    added_def_counter += 1
                    continue
                elem['child_def'] = ''  
                continue
            if not subset or elem['children'] in subset:
                add_definitions(elem)
                added_def_counter += 1
                continue
            elem['child_def'] = ''              
        except Exception as e:
            raise e
            print(i, elem)
            train.remove(elem)

    #print(counter)
    print('Added', added_def_counter)
    return train

In [304]:
def get_def_subset(train: list[dict], sisters: list[str], mode: str, treshold: float, seed=42) -> list[str]:
    '''
    Collect subset to manage assignment of definition
    
    Arguments:
        train - training sample
        sisters- list of sister terms
        mode - experimental condition
        threshold - proportion of elements to include in the susbset
        seed - random seed, optional
        
    Returns: subset (list of synset names)
    '''
    train = [elem['children'] for elem in train]
    random.seed(seed)
    if mode == 'sister':
        return sample(sisters, treshold)
    if mode == 'compl':
        return sample([x for x in train if x not in sisters], treshold)
    if mode == 'rand':
        return sample(train, treshold)
    
def definition_samples(dirname: str, seed: int, modes=('sister', 'compl', 'rand'), scopes=np.linspace(0, 1, 5), **kwargs) -> None:
    '''
    Generate samples with definitions
        seed - random seed
        modes - experimental conditions
        scopes - proportions of items to add definitions to
        kwargs - kwargs to pass to the WordNet graph builder
    '''
    G, test_parents, _ = create_cleaned_graph(**kwargs)
    sisters = get_sisters(G, test_parents)
    train = get_train(G, with_test=False)
    rest_sisters = get_fraction(train, sisters, out_list=True)
    
    for scope in tqdm(scopes):
        if scope == 0.0:
            train_def = definitions(train)
            train_out = f'{dirname}/full_train_def_train.pickle'
            save_sample(train_def, 'full', train_out)
            continue
        treshold = int(len(rest_sisters) * scope)
        for mode in modes:
            subset = get_def_subset(train, rest_sisters, mode, treshold, seed=seed)
            print(mode, int(scope*100), len(subset))
            for inverse in (True, False):
                print(f'inverse={inverse}')
                train_def = definitions(train, subset=subset, inverse=inverse)
                inv = '_inv' if inverse else ''
                train_out = f'{dirname}/def{inv}_{mode}_{str(int(scope*100))}_seed_{seed}_train.pickle'
                save_sample(train_def, mode, train_out)
                print()

In [305]:
definition_samples(dirname=DIRNAME, seed=42, **SE181A)

82115it [00:01, 65834.81it/s]
13767it [00:00, 61924.43it/s]


predict_hypernym 40636 40636


  0%|          | 0/5 [00:00<?, ?it/s]

Added 40636
full 40636


 20%|██        | 1/5 [00:00<00:01,  3.46it/s]

sister 25 1383
inverse=True
Added 39203
sister 40636

inverse=False
Added 1433
sister 40636

compl 25 1383
inverse=True
Added 39195
compl 40636

inverse=False
Added 1441
compl 40636

rand 25 1383
inverse=True
Added 39193
rand 40636

inverse=False


 40%|████      | 2/5 [00:12<00:21,  7.19s/it]

Added 1443
rand 40636

sister 50 2767
inverse=True
Added 37760
sister 40636

inverse=False
Added 2876
sister 40636

compl 50 2767
inverse=True
Added 37759
compl 40636

inverse=False
Added 2877
compl 40636

rand 50 2767
inverse=True
Added 37761
rand 40636

inverse=False
Added 2875
rand 40636


 60%|██████    | 3/5 [00:31<00:25, 12.85s/it]


sister 75 4151
inverse=True
Added 36328
sister 40636

inverse=False
Added 4308
sister 40636

compl 75 4151
inverse=True
Added 36329
compl 40636

inverse=False
Added 4307
compl 40636

rand 75 4151
inverse=True
Added 36313
rand 40636

inverse=False
Added 4323
rand 40636


 80%|████████  | 4/5 [00:59<00:18, 18.82s/it]


sister 100 5535
inverse=True
Added 34891
sister 40636

inverse=False
Added 5745
sister 40636

compl 100 5535
inverse=True
Added 34925
compl 40636

inverse=False
Added 5711
compl 40636

rand 100 5535
inverse=True
Added 34885
rand 40636

inverse=False


100%|██████████| 5/5 [01:38<00:00, 26.01s/it]

Added 5751
rand 40636



100%|██████████| 5/5 [01:38<00:00, 19.72s/it]


## Analysis of samples

In [60]:
def read_file(path: str) -> list[dict]:
    '''
    Unpickle training sample
    '''
    with open(path, 'rb') as f:
        return pickle.load(f)
    
def get_subset_of_sisters(dirname: str, base_sample_name: str, **kwargs) -> list[str]:
    '''
    Get subset of sisters from the base sample
    
    Arguments
        dirname - name of the directory
        base_sample_name - name of pickle file with base sample
        kwargs - keyword arguments to pass to the create_cleaned_graph function
    '''
    G, test_parents, _ = create_cleaned_graph(**kwargs)
    sisters = get_sisters(G, test_parents)
    base = read_file(os.path.join(dirname, base_sample_name))
    return get_fraction(base, sisters, out_list=True)

def fraction_of_removed_items(dirname: str, prefix: str, subset: list[str]) -> None:
    '''
    Print proportion of removed items from subset in the training sets from directory
    
    Arguments
        dirname - name of the directory
        prefix - prefix of filenames with training sets
        subset - list of items to check for in each training set
    '''

    for fname in sorted(os.listdir(dirname)):
        if prefix in fname[:len(prefix)]:
            train_set = read_file(os.path.join(dirname, fname))
            fraq_sisters = get_fraction(train_set, subset)
            print(f'{fname[:fname.index(".")]}:\t\t{fraq_sisters:.3f}')
    

In [61]:
fraction_of_removed_items(DIRNAME, 'no_', subset=get_subset_of_sisters(DIRNAME, 'clean_train.pickle', **SE181A))

82115it [00:01, 64098.63it/s]
13767it [00:00, 61034.06it/s]


no_compl_100_seed_42:		0.000
no_compl_25_fold_1_seed_42:		0.000
no_compl_25_fold_2_seed_42:		0.000
no_compl_25_fold_3_seed_42:		0.000
no_compl_25_fold_4_seed_42:		0.000
no_compl_50_fold_1_seed_42:		0.000
no_compl_50_fold_2_seed_42:		0.000
no_compl_50_fold_3_seed_42:		0.000
no_compl_50_fold_4_seed_42:		0.000
no_compl_50_fold_5_seed_42:		0.000
no_compl_50_fold_6_seed_42:		0.000
no_compl_75_fold_1_seed_42:		0.000
no_compl_75_fold_2_seed_42:		0.000
no_compl_75_fold_3_seed_42:		0.000
no_compl_75_fold_4_seed_42:		0.000
no_rand_100_seed_13:		0.151
no_rand_100_seed_42:		0.155
no_rand_100_seed_99:		0.154
no_rand_25_seed_13:		0.041
no_rand_25_seed_42:		0.041
no_rand_25_seed_99:		0.041
no_rand_50_seed_13:		0.075
no_rand_50_seed_42:		0.075
no_rand_50_seed_99:		0.085
no_rand_75_seed_13:		0.117
no_rand_75_seed_42:		0.112
no_rand_75_seed_99:		0.118
no_sister_100_seed_42:		1.000
no_sister_25_fold_1_seed_42:		0.254
no_sister_25_fold_2_seed_42:		0.250
no_sister_25_fold_3_seed_42:		0.250
no_sister_25_fol

In [308]:
def fraction_of_items_definitions(dirname: str, prefix: str, subset: list[str]) -> None:
    '''
    Print proportion of items with definitions from subset in the training sets from directory
    
    Arguments
        dirname - name of the directory
        prefix - prefix of filenames with training sets
        subset - list of synsets to check for in each training set
    '''
    for fname in sorted(os.listdir(dirname)):
        if prefix in fname[:len(prefix)]:
            train_set = read_file(os.path.join(dirname, fname))
            fraq_type = len([item['children'] for item in train_set if item['child_def']]) / len(subset)
            fraq_sisters = len(set([item['children'] for item in train_set if item['child_def'] and item['children'] in subset])) / len(subset)
            fraq_total = len([item['children'] for item in train_set if item['child_def']]) / len(train_set)
            print(f'{fname[:fname.index(".")]}:\t\t{fraq_total*100:.3f} (total)\t{fraq_type*100:.2f} (def)\t{fraq_sisters*100:.2f} (sisters)')

In [307]:
fraction_of_items_definitions(DIRNAME, 'def_', get_subset_of_sisters(DIRNAME, 'clean_train.pickle', **SE181A))

82115it [00:01, 62905.20it/s]
13767it [00:00, 61379.21it/s]


def_compl_100_seed_42_train:		14.054 (total)	99.60 (def)	0.00 (sisters)
def_compl_25_seed_42_train:		3.546 (total)	24.97 (def)	0.00 (sisters)
def_compl_50_seed_42_train:		7.080 (total)	49.97 (def)	0.00 (sisters)
def_compl_75_seed_42_train:		10.599 (total)	74.87 (def)	0.00 (sisters)
def_inv_compl_100_seed_42_train:		85.946 (total)	619.31 (def)	100.00 (sisters)
def_inv_compl_25_seed_42_train:		96.454 (total)	693.95 (def)	100.00 (sisters)
def_inv_compl_50_seed_42_train:		92.920 (total)	668.94 (def)	100.00 (sisters)
def_inv_compl_75_seed_42_train:		89.401 (total)	644.05 (def)	100.00 (sisters)
def_inv_rand_100_seed_42_train:		85.848 (total)	619.22 (def)	86.32 (sisters)
def_inv_rand_25_seed_42_train:		96.449 (total)	693.97 (def)	96.42 (sisters)
def_inv_rand_50_seed_42_train:		92.925 (total)	669.03 (def)	93.03 (sisters)
def_inv_rand_75_seed_42_train:		89.362 (total)	644.08 (def)	89.83 (sisters)
def_inv_sister_100_seed_42_train:		85.862 (total)	618.92 (def)	0.00 (sisters)
def_inv_sister_25_see