In [6]:
import os
import sys
sys.path.append(os.getcwd() + '/..')
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
from itertools import combinations
from utils import taxo_utils
np.random.seed(114514)

In [7]:
taxo = taxo_utils.from_json('./../data/raw/google.json')

In [8]:
# The following methods are used to build the contrastive data from ontology / mappings.
# Obtain clusters from an ontology. Each cluster is a list of class IDs (integers) sharing the same direct common parent.
def get_clusters_from_taxo(taxo:taxo_utils.Taxonomy):
    clusters = []
    for cat in list(taxo.nodes):
        family = taxo.get_subclasses(cat)
        if len(family) <= 1:
            continue
        clusters.append(family)            
    return clusters

# Method to sample pairs of classes from a cluster. Used to build the positive part of the contrastive data.
def cluster_sample(n:int,cover_rate:int=2):
    if n < 2:
        raise ValueError('At least 2 items are needed for contrastive sampling.')
    pairs = list(combinations(list(range(n)), 2))
    ratio, remainder = cover_rate // (n-1), cover_rate%(n-1)
    fullcover = [pair for pair in pairs for _ in range(ratio)]
    subcover = [(i,j) for (i,j) in pairs if min((i-j)%n,(j-i)%n) <= (remainder // 2)]
    if remainder%2 != 0 and n%2 == 0:
        subcover = subcover + [(i,j) for (i,j) in pairs if (i-j)%n == (n // 2)]
    return fullcover + subcover

# Calculate the number of pairs cluster_sample would return for a given cluster size and cover rate.
def num_pairs(n,cover_rate):
    ratio, remainder = cover_rate//(n-1), cover_rate%(n-1)
    full_num = ratio*n*(n-1)//2
    sub_num = n*(remainder//2)
    if remainder%2 != 0 and n%2 == 0:
        sub_num += n//2
    return full_num + sub_num

# Method to sample n random classes unrelated to the query class. Used to build the negative part of the contrastive data.
def get_negative(taxo:taxo_utils.Taxonomy,query,n):
    negative = set()
    classes = list(taxo.nodes)
    m = len(classes)
    # query_tokenset = [(t,'n') for t in query.tokenset[0].split(', ')]
    while len(negative) < n:
        randidx = int(np.random.choice(m,size=1))
        randclass = classes[randidx]
        if taxo.get_ancestors(query,return_type=set).intersection(taxo.get_ancestors(randclass,return_type=set)) == {0}:
        # rand_tokenset = [(t,'n') for t in randclass.tokenset[0].split(', ')]
        # if bc.tokenset_neg_check(query_tokenset,rand_tokenset):
            negative.add(randclass)
    return list(negative)

# Find a random subset whose sum is close to n. Used to find a subset of clusters to hold out for evaluation.
def solve_subarray_sum(arr, target):
    n = len(arr)
    randmap = np.random.permutation(n)
    arr = np.array([arr[i] for i in randmap])
    selected_indices = []
    subset = []
    setsum = 0
    for i in range(n):
        selected_indices.append(i)
        subset.append(arr[i])
        setsum += arr[i]
        if setsum >= target:
            break   
    return [randmap[i] for i in selected_indices]

# Contrastive data for the training of class retrieval model.
def build_contrastive_data(taxo: taxo_utils.Taxonomy,cover_rate=2,negs_per_batch=1,test_size=0.05):
    clusters = get_clusters_from_taxo(taxo)
    train_data = {'query_label':[],'positive_label':[],'negatives_label':[]}
    test_data = {'query_label':[],'positive_label':[],'negatives_label':[]}
    rows_per_cluster = [num_pairs(len(c),cover_rate=cover_rate) for c in clusters]
    total_rows = sum(rows_per_cluster)
    test_rows = int(total_rows * test_size)
    eval_cluster_idx = solve_subarray_sum(rows_per_cluster,test_rows)
    with tqdm(total = total_rows) as pbar:
        for i,cluster in enumerate(clusters):
            data_to_write = test_data if i in eval_cluster_idx else train_data
            pairs = cluster_sample(len(cluster),cover_rate=cover_rate)
            for j,k in pairs:
                query, positive = cluster[j], cluster[k]
                if np.random.random() < 0.5:
                    query, positive = positive, query
                negatives = get_negative(taxo,query,negs_per_batch)
                data_to_write['query_label'].append(taxo.get_label(query))
                data_to_write['positive_label'].append(taxo.get_label(positive))
                data_to_write['negatives_label'].append([taxo.get_label(n) for n in negatives])
                pbar.update(1)
    return pd.DataFrame(train_data), pd.DataFrame(test_data)

In [9]:
cover_rate = 4
negs_per_batch = 10
train_data,eval_data = build_contrastive_data(taxo,cover_rate=cover_rate,negs_per_batch=negs_per_batch)
train_data.to_csv(f'./../data/ret/train_cover{cover_rate}_neg{negs_per_batch}.tsv',sep='\t',index=False)
eval_data.to_csv(f'./../data/ret/test_cover{cover_rate}_neg{negs_per_batch}.tsv',sep='\t',index=False)

  0%|          | 0/11036 [00:00<?, ?it/s]