In [1]:
%load_ext autoreload
%autoreload 2

In [20]:
import json, math, random
from typing import List, Tuple
from nltk import Tree
from tqdm import tqdm
import numpy as np
from parseq.datasets import CFQDatasetLoader
from parseq.grammar import taglisp_to_tree, tree_size, tree_to_taglisp
from parseq.scripts_resplit.resplit_cfq import DivergenceComputer, FrequencyDistribution

In [3]:
# load cfq, including unused examples
ds = CFQDatasetLoader().load("mcd1/modent", validfrac=0, loadunused=True, keepids=True)

CFQDatasetLoader: make data
CFQDatasetLoader: make data in 0.0 sec
loading split 'mcd1'
doing 'train'


100%|██████████| 95743/95743 [00:25<00:00, 3683.50it/s]


doing 'test'


100%|██████████| 11968/11968 [00:03<00:00, 3136.37it/s]


doing 'oodvalid'


100%|██████████| 11968/11968 [00:03<00:00, 3118.81it/s]


doing 'unused'


100%|██████████| 119678/119678 [00:37<00:00, 3188.97it/s]


In [54]:
# load all examples from all splits into a single list while removing previous split information
allexamples = [(x[0], x[1], taglisp_to_tree(x[2])) for x in tqdm(ds)]

100%|██████████| 239357/239357 [00:38<00:00, 6184.66it/s]


In [5]:
# compute all present compounds
def print_stats(xs):
    """ Input is list of examples of the form (id, nl, fltree) """
    dc = DivergenceComputer()
    atomdist = dc.compute_atom_distribution(xs)
    compdist = dc.compute_compound_distribution(xs)
    sizedist = dc.compute_size_distribution(xs)
    print(f"Number of examples: {len(xs)}")
    print(f"Atom dist entropy: {atomdist.entropy():.3f}, coverage: {len(atomdist)}")
    print(f"Compound dist entropy: {compdist.entropy():.3f}, coverage: {len(compdist)}")
    print(f"Average size: {sizedist.average():.3f}")
    return atomdist, compdist, sizedist

In [6]:
allatomdist, allcompdist, allsizedist = print_stats(allexamples)

100%|██████████| 239357/239357 [00:32<00:00, 7304.65it/s]
100%|██████████| 239357/239357 [04:25<00:00, 901.78it/s] 
100%|██████████| 239357/239357 [00:02<00:00, 88565.12it/s]


Number of examples: 239357
Atom dist entropy: 3.465, coverage: 104
Compound dist entropy: 11.777, coverage: 1624133
Average size: 31.919


In [7]:
# get a set of examples that do not share any compounds
def get_disjoint_examples(xs, dc=None):
    dc = DivergenceComputer() if dc is None else dc
    selected = []
    remaining = []
    presentcompounds = set()
    coveredfc = FrequencyDistribution()
    for example in tqdm(xs):
        add = True
        tree = example[1]
        comps = dc.extract_compounds(tree)
        if len(set(comps) & presentcompounds) == 0:
            selected.append(example)
            presentcompounds |= set(comps)
            for comp in comps:
                coveredfc[comp] += 1
        else:
            remaining.append(example)
    return selected

In [8]:
def get_covering_examples(xs, dc=None):
    dc = DivergenceComputer() if dc is None else dc
    selected = []
    presentcompounds = set()
    coveredfc = FrequencyDistribution()
    for example in tqdm(xs):
        add = True
        tree = example[1]
        comps = dc.extract_compounds(tree)
        if len(set(comps) - presentcompounds) > 0:
            selected.append(example)
            presentcompounds |= set(comps)
            for comp in comps:
                coveredfc[comp] += 1
    return selected

In [9]:
if False:
    dx = get_disjoint_examples(allexamples)
    cx = get_covering_examples(allexamples)

In [10]:
def smoothendist(dist, perc=0.05):
    fd = FrequencyDistribution()
    sortedkeys = sorted(allsizedist.keys())
    for i, k in enumerate(sortedkeys):
        j = 1
        acc = dist(k)
        div = 1
        while True:
            if i-j >= 0:
                acc += dist(sortedkeys[i-j])
                div += 1
            if i+j < len(sortedkeys):
                acc += dist(sortedkeys[i+j])
                div += 1
            j += 1
            if acc >= perc:
                break
        fd[k] = acc/div
    return fd

In [26]:
def compute_overlap(dist, comps):  # how many of the compounds have already been observed?
    overlap = 0
    unoverlap = 0
    total = 0
    for comp in comps:
        if comp in dist and dist[comp] > 0:
            overlap += 1
        else:
            unoverlap += 1
        total += 1
    return overlap / total, unoverlap / total


def get_minimal_covering_examples(xs, dc=None, N=10000, step=2000, cache=None, targetsizedist=None):
    retcache = cache is True
    
    dc = DivergenceComputer() if dc is None else dc
    
    # find all compounds and build cache
    if cache is None or cache is True or cache is False:
        print("building compound cache")
        cache = {}
        compounds = {}
        lastid = 0
        for example in tqdm(xs):
            tree = example[2]
            comps = dc.extract_compounds(tree)
            cache[str(tree)] = []
            cacheline = []
            for comp in comps:
                if comp not in compounds:
                    compounds[comp] = lastid + 1
                    lastid += 1
                cacheline.append(compounds[comp])
            cache[str(tree)] = cacheline
    #             cache[str(tree)].append(compounds[comp])
            
        print("built cache")
            
    print("randomly selecting initial examples")
    random.shuffle(xs)
    NUMINIT = 100
    selected = xs[:NUMINIT]
    remaining = xs[NUMINIT:]
    dxdist = FrequencyDistribution()
    sizedist = FrequencyDistribution()
    
    for example in selected:
        tree = example[2]
        comps = cache[str(tree)]
        for comp in comps:
            dxdist[comp] += 1
        sizedist[tree_size(tree)] += 1

    dx, remaining = selected, remaining
    # go over all examples and compute their overlaps with dxdist
    print(f"randomly chosen {len(dx)} examples")
    
#     print("finding disjoint examples")
#     random.shuffle(xs)
#     selected = []
#     remaining = []
#     presentcompounds = set()
#     dxdist = FrequencyDistribution()
#     sizedist = FrequencyDistribution()
    
#     for example in tqdm(xs):
#         add = True
#         tree = example[2]
#         comps = cache[str(tree)]
#         if len(set(comps) & presentcompounds) == 0:
#             selected.append(example)
#             presentcompounds |= set(comps)
#             for comp in comps:
#                 dxdist[comp] += 1
#             sizedist[tree_size(tree)] += 1
#         else:
#             remaining.append(example)

#     dx, remaining = selected, remaining
#     # go over all examples and compute their overlaps with dxdist
#     print(f"found {len(dx)} disjoint examples")
    
    if targetsizedist is not None:
        targetsizedist = smoothendist(targetsizedist)
    
    print("iterating")
    iternr = 1
    while True:
        print(f"iter {iternr}")
        iternr += 1
        overlaps = []
        unoverlaps = []
        tsizes = []
        for x in tqdm(remaining):
            comps = cache[str(x[2])]
            tsizes.append(tree_size(x[2]))
            overlap, unoverlap = compute_overlap(dxdist, comps)
            overlaps.append(overlap)
            unoverlaps.append(unoverlap)
        
        # sort by overlap and select N best that overlap least
        tosort = zip(remaining, overlaps, unoverlaps, tsizes)
        remainingsorted = sorted(tosort, key=lambda x: -x[2])  # small overlap first
        print(f"Top overlap: {remainingsorted[0][1]}, bottom overlap: {remainingsorted[-1][1]}")
        numberwithlargestoverlap = 0
        numberwithhalfoverlap = 0
        for (i, x) in enumerate(remainingsorted[::-1]):
            overlap = x[1]
            if overlap >= 0.8:
                numberwithlargestoverlap += 1
            if overlap >= 0.5:
                numberwithhalfoverlap += 1
        print(f"Number of high overlap: {numberwithlargestoverlap} and half overlap: {numberwithhalfoverlap}")
        #remainingsorted = [x[0] for x in remainingsorted]
        _step = min(step, len(remainingsorted))
        if targetsizedist is None:
            selected, remaining = remainingsorted[:_step], remainingsorted[_step:]
        else:
            selected, remaining = [], []
            smoothsizedist = smoothendist(sizedist)
            for i, x in enumerate(remainingsorted):
                xsize = x[-1]
                if smoothsizedist[xsize] < targetsizedist[xsize]:
                    selected.append(x)
                    sizedist[xsize] += 1
                    smoothsizedist = smoothendist(sizedist)
                else:
                    remaining.append(x)
                if len(selected) >= step:
                    remaining += remainingsorted[i+1:]
                    break
            assert(len(selected) + len(remaining) == len(remainingsorted))
            
        for x, xoverlap, xunoverlap, xsize in selected:
            for comp in cache[str(x[2])]:
                dxdist[comp] += 1
        dx = dx + [x[0] for x in selected]
        remaining = [x[0] for x in remaining]
        
        if len(dx) >= N:
            dx = dx[:N]
            remaining = remaining + dx[N:]
            for x in dx[N:]:
                for comp in cache[str(x[2])]:
                    dxdist[comp] -= 1
                if dxdist[comp] <= 0:
                    del dxdist[comp]
            break
        
        print(f"Number of selected examples: {len(dx)}, number of covered compounds: {len(dxdist)}")
    
    if retcache is True:
        return dx, remaining, dxdist, cache
    else:
        return dx, remaining, dxdist

In [55]:
NN = 4000

In [52]:
if "cache" not in locals() or cache is None:
    mcx, _, mcxdist, cache = get_minimal_covering_examples(allexamples, N=NN, step=NN//10, cache=True)
else:
    mcx, _, mcxdist = get_minimal_covering_examples(allexamples, N=NN, step=NN//10, cache=cache)  # , targetsizedist=allsizedist)

randomly selecting initial examples
randomly chosen 100 examples
iterating
iter 1


100%|██████████| 239257/239257 [00:48<00:00, 4927.65it/s]


Top overlap: 0.0010131712259371835, bottom overlap: 1.0
Number of high overlap: 7456 and half overlap: 41346
Number of selected examples: 1100, number of covered compounds: 101698
iter 2


100%|██████████| 238257/238257 [00:37<00:00, 6330.98it/s] 


Top overlap: 0.010660980810234541, bottom overlap: 1.0
Number of high overlap: 14655 and half overlap: 62085
Number of selected examples: 2100, number of covered compounds: 167098
iter 3


100%|██████████| 237257/237257 [00:37<00:00, 6255.18it/s] 


Top overlap: 0.06493506493506493, bottom overlap: 1.0
Number of high overlap: 22564 and half overlap: 103734
Number of selected examples: 3100, number of covered compounds: 285487
iter 4


100%|██████████| 236257/236257 [00:38<00:00, 6188.77it/s] 


Top overlap: 0.13432835820895522, bottom overlap: 1.0
Number of high overlap: 32472 and half overlap: 156214
Number of selected examples: 4100, number of covered compounds: 371364
iter 5


100%|██████████| 235257/235257 [00:37<00:00, 6224.09it/s] 


Top overlap: 0.22857142857142856, bottom overlap: 1.0
Number of high overlap: 43840 and half overlap: 196657
Number of selected examples: 5100, number of covered compounds: 456761
iter 6


100%|██████████| 234257/234257 [00:37<00:00, 6173.72it/s] 


Top overlap: 0.3170731707317073, bottom overlap: 1.0
Number of high overlap: 60846 and half overlap: 217592
Number of selected examples: 6100, number of covered compounds: 516708
iter 7


100%|██████████| 233257/233257 [00:38<00:00, 6114.98it/s] 


Top overlap: 0.3798449612403101, bottom overlap: 1.0
Number of high overlap: 76781 and half overlap: 225845
Number of selected examples: 7100, number of covered compounds: 577326
iter 8


100%|██████████| 232257/232257 [00:37<00:00, 6112.28it/s] 


Top overlap: 0.42857142857142855, bottom overlap: 1.0
Number of high overlap: 89325 and half overlap: 229466
Number of selected examples: 8100, number of covered compounds: 627312
iter 9


100%|██████████| 231257/231257 [00:38<00:00, 6054.43it/s] 


Top overlap: 0.4675324675324675, bottom overlap: 1.0
Number of high overlap: 103201 and half overlap: 230515
Number of selected examples: 9100, number of covered compounds: 678778
iter 10


100%|██████████| 230257/230257 [00:38<00:00, 6002.72it/s] 


Top overlap: 0.5038759689922481, bottom overlap: 1.0
Number of high overlap: 116612 and half overlap: 230257


In [53]:
mcxatom, mcxcomp, mcxsize = print_stats(mcx)
print(mcxsize.compute_chernoff_coeff(smoothendist(mcxsize), smoothendist(allsizedist)))

100%|██████████| 10000/10000 [00:01<00:00, 6311.68it/s]
100%|██████████| 10000/10000 [00:24<00:00, 414.97it/s]
100%|██████████| 10000/10000 [00:00<00:00, 65344.97it/s]


Number of examples: 10000
Atom dist entropy: 3.612, coverage: 104
Compound dist entropy: 12.561, coverage: 720561
Average size: 41.561
0.999000515552042


In [None]:
print(len(mcx))
print(mcx[0])
with open(f"minicfq{len(mcx)}unsplit1.json", "w") as f:
    json.dump([(x[0], x[1], tree_to_taglisp(x[2])) for x in mcx], f)

In [49]:
random.shuffle(allexamples)
mcx_random = allexamples[:NN]

40000
(113576, "Were M1 , M2 , M3 , M4 , and M5 written by and executive produced by M0 's writer and star", Tree('@R@', [Tree('@QUERY', [Tree('@SELECT', [Tree('count', [Tree('*', [])])]), Tree('@WHERE', [Tree('@COND', [Tree('?x0', []), Tree('ns:film.actor.film/ns:film.performance.film', []), Tree('m0', [])]), Tree('@COND', [Tree('?x0', []), Tree('ns:film.writer.film', []), Tree('m0', [])]), Tree('@COND', [Tree('m1', []), Tree('ns:film.film.executive_produced_by', []), Tree('?x0', [])]), Tree('@COND', [Tree('m1', []), Tree('ns:film.film.written_by', []), Tree('?x0', [])]), Tree('@COND', [Tree('m2', []), Tree('ns:film.film.executive_produced_by', []), Tree('?x0', [])]), Tree('@COND', [Tree('m2', []), Tree('ns:film.film.written_by', []), Tree('?x0', [])]), Tree('@COND', [Tree('m3', []), Tree('ns:film.film.executive_produced_by', []), Tree('?x0', [])]), Tree('@COND', [Tree('m3', []), Tree('ns:film.film.written_by', []), Tree('?x0', [])]), Tree('@COND', [Tree('m4', []), Tree('ns:film.film.

In [47]:
random.shuffle(allexamples)
mcx_random = allexamples[:NN]

In [48]:
randatom, randcomp, randsize = print_stats(mcx_random)
print(randsize.compute_chernoff_coeff(smoothendist(randsize), smoothendist(allsizedist)))

100%|██████████| 40000/40000 [00:05<00:00, 7134.88it/s]
100%|██████████| 40000/40000 [00:45<00:00, 877.52it/s] 
100%|██████████| 40000/40000 [00:00<00:00, 82286.03it/s]


Number of examples: 40000
Atom dist entropy: 3.464, coverage: 104
Compound dist entropy: 11.617, coverage: 646092
Average size: 31.891
1.109906653330946


In [None]:
for k in sorted(allsizedist.keys()):
    print(f"{k} {allsizedist(k):.5f} - {randsize(k):.5f} - {mcxsize(k):.5f}")