In [1]:
import json
from typing import List, Tuple
from nltk import Tree
from tqdm import tqdm
from parseq.datasets import CFQDatasetLoader
from parseq.grammar import taglisp_to_tree
from parseq.scripts_resplit.resplit_cfq import DivergenceComputer

In [2]:
ds = CFQDatasetLoader().load("mcd1/modent", validfrac=0, loadunused=True)

CFQDatasetLoader: make data
CFQDatasetLoader: make data in 0.0 sec
loading split 'mcd1'
doing 'train'
doing 'test'
doing 'oodvalid'
doing 'unused'


100%|██████████| 95743/95743 [00:26<00:00, 3639.20it/s]
100%|██████████| 11968/11968 [00:03<00:00, 3127.37it/s]
100%|██████████| 11968/11968 [00:03<00:00, 3120.59it/s]
100%|██████████| 119678/119678 [00:37<00:00, 3178.92it/s]


In [3]:
dc = DivergenceComputer()

In [4]:
atom_dists = dc.compute_atom_distributions(ds)
print(json.dumps(dc._compute_atom_divergences(atom_dists), indent=3))

100%|██████████| 239357/239357 [00:40<00:00, 5881.41it/s]


{
   "train-train": 0.0,
   "train-test": 0.03824614256979564,
   "train-oodvalid": 0.038364187269557926,
   "train-unused": 0.014322402198022166,
   "test-train": 0.038246142569795416,
   "test-test": -4.440892098500626e-16,
   "test-oodvalid": 9.497392817781058e-05,
   "test-unused": 0.00961599334917751,
   "oodvalid-train": 0.03836418726955815,
   "oodvalid-test": 9.497392817769956e-05,
   "oodvalid-oodvalid": 1.1102230246251565e-16,
   "oodvalid-unused": 0.009726062600495133,
   "unused-train": 0.014322402198022166,
   "unused-test": 0.009615993349177399,
   "unused-oodvalid": 0.009726062600495133,
   "unused-unused": 5.551115123125783e-16
}


In [None]:
comp_dists = dc.compute_compound_distributions(ds)
print(json.dumps(dc._compute_compound_divergences(comp_dists), indent=3))

 63%|██████▎   | 151547/239357 [02:00<00:53, 1634.39it/s]

In [None]:
def get_dist_similarity(x, dist):
    score = 0
    for xe in x:
        score += dist[xe] if xe in dist else 0
    return score


def diff_dists(a, b):
    diff = {}
    for k in set(a.keys()) | set(b.keys()):
        a_k = a[k] if k in a else 0.
        b_k = b[k] if k in b else 0.
        diff[k] = a_k - b_k
    return diff


def filter_mcd(source:List[Tuple[str,Tree]], otheratoms, othercomps, N=10000,
               dc=None, coeffa=1, coeffb=1, coeffatom=0.1):
    """
    :param source:      a list of examples from which to pick, in format (input, output)
    :param otheratoms:  atom distributions of other splits
    :param othercomps:  compound distributions of other splits
    :param N:           how many examples to retain in selection from source
    :return:   MCD selection of examples from source wrt all other distributions

    ideally, the training set shouldn't contain any compound from new selection
    and the new selection shouldn't contain any compound from test
    """
    assert len(source) > N
    assert len(otheratoms) == len(othercomps)
    trainatoms, testatoms = otheratoms
    traincomps, testcomps = othercomps
    exstats = []
    for x in tqdm(source):
        xatoms = dc.extract_atoms(x[1])
        xcomps = dc.extract_compounds(x[1])
        exstats.append((get_dist_similarity(xcomps, traincomps),
                        get_dist_similarity(xcomps, testcomps),
                        -get_dist_similarity(xatoms, trainatoms),
                        -get_dist_similarity(xatoms, testatoms),
                        len(xcomps)))
    # select top-N examples with least overlap with train and test as starting point
    scores = [(i, coeffa*a + coeffb*b + coeffatom*c + coeffatom*d) for (i, (a, b, c, d, e)) in enumerate(exstats)]
    sortedscores = sorted(scores, key=lambda x: x[1])
    retids = [i for i, s in sortedscores[:N]]
    ret = [(source[i][0], source[i][1], "newsplit") for i in retids]

    # startatomdist = dc.compute_atom_distributions(ret)["newsplit"]
    # diffatomdisttrain = diff_dists(trainatoms, startatomdist)
    # diffatomdisttest = diff_dists(testatoms, startatomdist)
    #
    # for x in tqdm(source):
    #     xatoms = dc.extract_atoms(x[1])
    #     xcomps = dc.extract_compounds(x[1])
    #     exstats.append((get_dist_similarity(xcomps, traincomps),
    #                     get_dist_similarity(xcomps, testcomps),
    #                     get_dist_similarity(xatoms, diffatomdisttrain),
    #                     get_dist_similarity(xatoms, diffatomdisttest),
    #                     len(xcomps)))
    #
    # scores = [(i, coeffa*a + coeffb*b + coeffatom*c + coeffatom*d) for (i, (a, b, c, d, e)) in enumerate(exstats)]
    # sortedscores = sorted(scores, key=lambda x: x[1])
    # retids = [i for i, s in sortedscores[:N]]
    # ret = [(source[i][0], source[i][1], "newsplit") for i in retids]
    return ret

In [None]:
unused = [(ex[0], taglisp_to_tree(ex[1])) for ex in ds if ex[2] == "unused"]
len(unused)

In [None]:
newsplit = filter_mcd(unused, [atom_dists["train"], atom_dists["test"]], [comp_dists["train"], comp_dists["test"]], dc=dc)

In [None]:
newsplit_atomdist = dc.compute_atom_distributions(newsplit)["newsplit"]
newsplit_compdist = dc.compute_compound_distributions(newsplit)["newsplit"]
_atom_dists = {k: v for k, v in atom_dists.items()}
_atom_dists["newsplit"] = newsplit_atomdist
_comp_dists = {k: v for k, v in comp_dists.items()}
_comp_dists["newsplit"] = newsplit_compdist
print(json.dumps(dc._compute_atom_divergences(_atom_dists), indent=3))
print(json.dumps(dc._compute_compound_divergences(_comp_dists), indent=3))

In [17]:
newsplit[0]

('Was a cinematographer a cinematographer of M0',
 Tree('@R@', [Tree('@QUERY', [Tree('@SELECT', [Tree('count', [Tree('*', [])])]), Tree('@WHERE', [Tree('@COND', [Tree('?x0', []), Tree('ns:film.cinematographer.film', []), Tree('m0', [])])])])]),
 'newsplit')

In [18]:
# test some chernoff coefficients
a = {"a": 0.0, "b": 0.0, "c": 0.5, "d": 0.49, "e": 0.01}
b = {"a": 0.2, "b": 0.2, "c": 0.2, "d": 0.2, "e": 0.2}
print(1-dc.compute_chernoff_coeff(a, b))
print(1-dc.compute_chernoff_coeff(a, b, 0.1))
print(1-dc.compute_chernoff_coeff(b, a, 0.1))

0.3260013575831957
0.4138321990765277
0.08228389965015814


In [19]:
print(len(_atom_dists))
print(len(_atom_dists["train"]))
print(len(_atom_dists["test"]))
print(len(_atom_dists["newsplit"]))

5
104
100
101


In [20]:
print(atom_dists["train"].keys())


dict_keys(['(* )', '(count ARG1)', '(@SELECT ARG1)', '(?x0 )', '(a )', '(ns:people.person )', '(@COND ARG1 ARG2 ARG3)', '(ns:influence.influence_node.influenced )', '(m1 )', '(m2 )', '(ns:people.person.spouse_s/ns:people.marriage.spouse )', '(ns:fictional_universe.fictional_character.married_to/ns:fictional_universe.marriage_of_fictional_characters.spouses )', '(@OR ARG*)', '(?x1 )', '(ns:film.cinematographer )', '(!= )', '(filter ARG1 ARG2 ARG3)', '(@WHERE ARG*)', '(@QUERY ARG ARG)', '(@R@ ARG1)', '(ns:influence.influence_node.influenced_by )', '(m0 )', '(ns:people.person.sibling_s/ns:people.sibling_relationship.sibling )', '(ns:fictional_universe.fictional_character.siblings/ns:fictional_universe.sibling_relationship_of_fictional_characters.siblings )', '(ns:people.person.gender )', '(ns:m.05zppz )', '(ns:film.film.directed_by )', '(m3 )', '(m4 )', '(ns:film.film.written_by )', '(distinct )', '(@SELECT ARG1 ARG2)', '(ns:organization.organization.founders )', '(ns:film.actor.film/ns:f