In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json, math, random
from typing import List, Tuple
from nltk import Tree
from tqdm import tqdm
import numpy as np
from parseq.datasets import CFQDatasetLoader
from parseq.grammar import taglisp_to_tree, tree_size
from parseq.scripts_resplit.resplit_cfq import DivergenceComputer, FrequencyDistribution

In [3]:
# load cfq, including unused examples
ds = CFQDatasetLoader().load("mcd1/modent", validfrac=0, loadunused=True, keepids=True)

CFQDatasetLoader: make data
CFQDatasetLoader: make data                                                                                  T: 0.0 sec
loading split 'mcd1'
doing 'train'


100%|██████████████████████████████████████████████████████████████████████████| 95743/95743 [00:22<00:00, 4296.82it/s]


doing 'test'


100%|██████████████████████████████████████████████████████████████████████████| 11968/11968 [00:03<00:00, 3657.90it/s]


doing 'oodvalid'


100%|██████████████████████████████████████████████████████████████████████████| 11968/11968 [00:03<00:00, 3633.56it/s]


doing 'unused'


100%|████████████████████████████████████████████████████████████████████████| 119678/119678 [00:32<00:00, 3716.37it/s]


In [4]:
# load all examples from all splits into a single list while removing previous split information
allexamples = [(x[0], x[1], taglisp_to_tree(x[2])) for x in tqdm(ds)]

100%|████████████████████████████████████████████████████████████████████████| 239357/239357 [00:31<00:00, 7607.74it/s]


In [5]:
# compute all present compounds
def print_stats(xs):
    """ Input is list of examples of the form (id, nl, fltree) """
    dc = DivergenceComputer()
    atomdist = dc.compute_atom_distribution(xs)
    compdist = dc.compute_compound_distribution(xs)
    sizedist = dc.compute_size_distribution(xs)
    print(f"Number of examples: {len(xs)}")
    print(f"Atom dist entropy: {atomdist.entropy():.3f}, coverage: {len(atomdist)}")
    print(f"Compound dist entropy: {compdist.entropy():.3f}, coverage: {len(compdist)}")
    print(f"Average size: {sizedist.average():.3f}")
    return atomdist, compdist, sizedist

In [17]:
allatomdist, allcompdist, allsizedist = print_stats(allexamples)

100%|████████████████████████████████████████████████████████████████████████| 239357/239357 [00:28<00:00, 8350.88it/s]
100%|█████████████████████████████████████████████████████████████████████████| 239357/239357 [04:12<00:00, 949.13it/s]
100%|███████████████████████████████████████████████████████████████████████| 239357/239357 [00:02<00:00, 87849.78it/s]


Number of examples: 239357
Atom dist entropy: 3.465, coverage: 104
Compound dist entropy: 11.777, coverage: 1624133
Average size: 31.919


In [6]:
# get a set of examples that do not share any compounds
def get_disjoint_examples(xs, dc=None):
    dc = DivergenceComputer() if dc is None else dc
    selected = []
    remaining = []
    presentcompounds = set()
    coveredfc = FrequencyDistribution()
    for example in tqdm(xs):
        add = True
        tree = example[1]
        comps = dc.extract_compounds(tree)
        if len(set(comps) & presentcompounds) == 0:
            selected.append(example)
            presentcompounds |= set(comps)
            for comp in comps:
                coveredfc[comp] += 1
        else:
            remaining.append(example)
    return selected

In [7]:
def get_covering_examples(xs, dc=None):
    dc = DivergenceComputer() if dc is None else dc
    selected = []
    presentcompounds = set()
    coveredfc = FrequencyDistribution()
    for example in tqdm(xs):
        add = True
        tree = example[1]
        comps = dc.extract_compounds(tree)
        if len(set(comps) - presentcompounds) > 0:
            selected.append(example)
            presentcompounds |= set(comps)
            for comp in comps:
                coveredfc[comp] += 1
    return selected

In [8]:
if False:
    dx = get_disjoint_examples(allexamples)
    cx = get_covering_examples(allexamples)

In [9]:
def smoothendist(dist, perc=0.05):
    fd = FrequencyDistribution()
    sortedkeys = sorted(allsizedist.keys())
    for i, k in enumerate(sortedkeys):
        j = 1
        acc = dist(k)
        div = 1
        while True:
            if i-j >= 0:
                acc += dist(sortedkeys[i-j])
                div += 1
            if i+j < len(sortedkeys):
                acc += dist(sortedkeys[i+j])
                div += 1
            j += 1
            if acc >= perc:
                break
        fd[k] = acc/div
    return fd

In [14]:
def compute_overlap(dist, comps):  # how many of the compounds have already been observed?
    overlap = 0
    unoverlap = 0
    total = 0
    for comp in comps:
        if comp in dist and dist[comp] > 0:
            overlap += 1
        else:
            unoverlap += 1
        total += 1
    return overlap / total, unoverlap / total


def get_minimal_covering_examples(xs, dc=None, N=10000, step=2000, cache=None, targetsizedist=None):
    retcache = cache is True
    
    dc = DivergenceComputer() if dc is None else dc
    
    # find all compounds and build cache
    if cache is None or cache is True or cache is False:
        print("building compound cache")
        cache = {}
        compounds = {}
        lastid = 0
        for example in tqdm(xs):
            tree = example[2]
            comps = dc.extract_compounds(tree)
            cache[str(tree)] = []
            cacheline = []
            for comp in comps:
                if comp not in compounds:
                    compounds[comp] = lastid + 1
                    lastid += 1
                cacheline.append(compounds[comp])
            cache[str(tree)] = cacheline
    #             cache[str(tree)].append(compounds[comp])
            
        print("built cache")
            
    print("randomly selecting initial examples")
    random.shuffle(xs)
    NUMINIT = 200
    selected = xs[:NUMINIT]
    remaining = xs[NUMINIT:]
    dxdist = FrequencyDistribution()
    sizedist = FrequencyDistribution()
    
    for example in selected:
        tree = example[2]
        comps = cache[str(tree)]
        for comp in comps:
            dxdist[comp] += 1
        sizedist[tree_size(tree)] += 1

    dx, remaining = selected, remaining
    # go over all examples and compute their overlaps with dxdist
    print(f"randomly chosen {len(dx)} examples")
    
#     print("finding disjoint examples")
#     random.shuffle(xs)
#     selected = []
#     remaining = []
#     presentcompounds = set()
#     dxdist = FrequencyDistribution()
#     sizedist = FrequencyDistribution()
    
#     for example in tqdm(xs):
#         add = True
#         tree = example[2]
#         comps = cache[str(tree)]
#         if len(set(comps) & presentcompounds) == 0:
#             selected.append(example)
#             presentcompounds |= set(comps)
#             for comp in comps:
#                 dxdist[comp] += 1
#             sizedist[tree_size(tree)] += 1
#         else:
#             remaining.append(example)

#     dx, remaining = selected, remaining
#     # go over all examples and compute their overlaps with dxdist
#     print(f"found {len(dx)} disjoint examples")
    
    if targetsizedist is not None:
        targetsizedist = smoothendist(targetsizedist)
    
    print("iterating")
    iternr = 1
    while True:
        print(f"iter {iternr}")
        iternr += 1
        overlaps = []
        unoverlaps = []
        tsizes = []
        for x in tqdm(remaining):
            comps = cache[str(x[2])]
            tsizes.append(tree_size(x[2]))
            overlap, unoverlap = compute_overlap(dxdist, comps)
            overlaps.append(overlap)
            unoverlaps.append(unoverlap)
        
        # sort by overlap and select N best that overlap least
        tosort = zip(remaining, overlaps, unoverlaps, tsizes)
        remainingsorted = sorted(tosort, key=lambda x: -x[2])  # small overlap first
        print(f"Top overlap: {remainingsorted[0][1]}, bottom overlap: {remainingsorted[-1][1]}")
        numberwithlargestoverlap = 0
        numberwithhalfoverlap = 0
        for (i, x) in enumerate(remainingsorted[::-1]):
            overlap = x[1]
            if overlap >= 0.8:
                numberwithlargestoverlap += 1
            if overlap >= 0.5:
                numberwithhalfoverlap += 1
        print(f"Number of high overlap: {numberwithlargestoverlap} and half overlap: {numberwithhalfoverlap}")
        #remainingsorted = [x[0] for x in remainingsorted]
        _step = min(step, len(remainingsorted))
        if targetsizedist is None:
            selected, remaining = remainingsorted[:_step], remainingsorted[_step:]
        else:
            selected, remaining = [], []
            smoothsizedist = smoothendist(sizedist)
            for i, x in enumerate(remainingsorted):
                xsize = x[-1]
                if smoothsizedist[xsize] < targetsizedist[xsize]:
                    selected.append(x)
                    sizedist[xsize] += 1
                    smoothsizedist = smoothendist(sizedist)
                else:
                    remaining.append(x)
                if len(selected) >= step:
                    remaining += remainingsorted[i+1:]
                    break
            assert(len(selected) + len(remaining) == len(remainingsorted))
            
        for x, xoverlap, xunoverlap, xsize in selected:
            for comp in cache[str(x[2])]:
                dxdist[comp] += 1
        dx = dx + [x[0] for x in selected]
        remaining = [x[0] for x in remaining]
        
        if len(dx) >= N:
            dx = dx[:N]
            remaining = remaining + dx[N:]
            for x in dx[N:]:
                for comp in cache[str(x[2])]:
                    dxdist[comp] -= 1
                if dxdist[comp] <= 0:
                    del dxdist[comp]
            break
        
        print(f"Number of selected examples: {len(dx)}, number of covered compounds: {len(dxdist)}")
    
    if retcache is True:
        return dx, remaining, dxdist, cache
    else:
        return dx, remaining, dxdist

In [15]:
NN = 50000

In [16]:
if "cache" not in locals() or cache is None:
    mcx, _, mcxdist, cache = get_minimal_covering_examples(allexamples, N=NN, step=NN//10, cache=True)
else:
    mcx, _, mcxdist = get_minimal_covering_examples(allexamples, N=NN, step=NN//10, cache=cache)  # , targetsizedist=allsizedist)

building compound cache


100%|█████████████████████████████████████████████████████████████████████████| 239357/239357 [05:02<00:00, 792.40it/s]


built cache
randomly selecting initial examples
randomly chosen 200 examples
iterating
iter 1


100%|████████████████████████████████████████████████████████████████████████| 239157/239157 [00:33<00:00, 7122.49it/s]


Top overlap: 0.0008628127696289905, bottom overlap: 1.0
Number of high overlap: 11674 and half overlap: 64337
Number of selected examples: 5200, number of covered compounds: 190894
iter 2


100%|████████████████████████████████████████████████████████████████████████| 234157/234157 [00:32<00:00, 7197.66it/s]


Top overlap: 0.03691275167785235, bottom overlap: 1.0
Number of high overlap: 25647 and half overlap: 103267
Number of selected examples: 10200, number of covered compounds: 545602
iter 3


100%|████████████████████████████████████████████████████████████████████████| 229157/229157 [00:32<00:00, 7141.28it/s]


Top overlap: 0.17391304347826086, bottom overlap: 1.0
Number of high overlap: 88172 and half overlap: 210547
Number of selected examples: 15200, number of covered compounds: 756309
iter 4


100%|████████████████████████████████████████████████████████████████████████| 224157/224157 [00:30<00:00, 7245.30it/s]


Top overlap: 0.41304347826086957, bottom overlap: 1.0
Number of high overlap: 130271 and half overlap: 223167
Number of selected examples: 20200, number of covered compounds: 937802
iter 5


100%|████████████████████████████████████████████████████████████████████████| 219157/219157 [00:30<00:00, 7215.34it/s]


Top overlap: 0.5658914728682171, bottom overlap: 1.0
Number of high overlap: 171764 and half overlap: 219157
Number of selected examples: 25200, number of covered compounds: 1061439
iter 6


100%|████████████████████████████████████████████████████████████████████████| 214157/214157 [00:29<00:00, 7176.02it/s]


Top overlap: 0.6666666666666666, bottom overlap: 1.0
Number of high overlap: 191298 and half overlap: 214157
Number of selected examples: 30200, number of covered compounds: 1168540
iter 7


100%|████████████████████████████████████████████████████████████████████████| 209157/209157 [00:29<00:00, 7182.50it/s]


Top overlap: 0.7317073170731707, bottom overlap: 1.0
Number of high overlap: 201130 and half overlap: 209157
Number of selected examples: 35200, number of covered compounds: 1266091
iter 8


100%|████████████████████████████████████████████████████████████████████████| 204157/204157 [00:28<00:00, 7203.76it/s]


Top overlap: 0.7857142857142857, bottom overlap: 1.0
Number of high overlap: 202975 and half overlap: 204157
Number of selected examples: 40200, number of covered compounds: 1333300
iter 9


100%|████████████████████████████████████████████████████████████████████████| 199157/199157 [00:27<00:00, 7216.44it/s]


Top overlap: 0.8253968253968254, bottom overlap: 1.0
Number of high overlap: 199157 and half overlap: 199157
Number of selected examples: 45200, number of covered compounds: 1388417
iter 10


100%|████████████████████████████████████████████████████████████████████████| 194157/194157 [00:26<00:00, 7209.80it/s]


Top overlap: 0.8536585365853658, bottom overlap: 1.0
Number of high overlap: 194157 and half overlap: 194157


In [17]:
mcxatom, mcxcomp, mcxsize = print_stats(mcx)
print(randsize.compute_chernoff_coeff(smoothendist(mcxsize), smoothendist(allsizedist)))

100%|██████████████████████████████████████████████████████████████████████████| 50000/50000 [00:06<00:00, 7914.78it/s]
100%|███████████████████████████████████████████████████████████████████████████| 50000/50000 [01:24<00:00, 593.57it/s]
100%|█████████████████████████████████████████████████████████████████████████| 50000/50000 [00:00<00:00, 73859.92it/s]


Number of examples: 50000
Atom dist entropy: 3.579, coverage: 104
Compound dist entropy: 12.700, coverage: 1434356
Average size: 36.565


NameError: name 'randsize' is not defined

In [None]:
mcxatom, mcxcomp, mcxsize = print_stats(mcx)
print(randsize.compute_chernoff_coeff(smoothendist(mcxsize), smoothendist(allsizedist)))

In [None]:
random.shuffle(allexamples)
mcx_random = allexamples[:NN]

In [None]:
randatom, randcomp, randsize = print_stats(mcx_random)
print(randsize.compute_chernoff_coeff(smoothendist(randsize), smoothendist(allsizedist)))

In [None]:
for k in sorted(allsizedist.keys()):
    print(f"{k} {allsizedist(k):.5f} - {randsize(k):.5f} - {mcxsize(k):.5f}")