In [1]:
%load_ext autoreload
%autoreload 2

In [56]:
import json, math, random
from typing import List, Tuple
from nltk import Tree
from tqdm import tqdm
from parseq.datasets import CFQDatasetLoader
from parseq.grammar import taglisp_to_tree
from parseq.scripts_resplit.resplit_cfq import DivergenceComputer, FrequencyDistribution

In [3]:
# load cfq
ds = CFQDatasetLoader().load("mcd1/modent", validfrac=0, loadunused=True)

CFQDatasetLoader: make data
CFQDatasetLoader: make data                                                                                  T: 0.0 sec
loading split 'mcd1'
doing 'train'


100%|██████████████████████████████████████████████████████████████████████████| 95743/95743 [00:22<00:00, 4280.39it/s]


doing 'test'


100%|██████████████████████████████████████████████████████████████████████████| 11968/11968 [00:03<00:00, 3655.73it/s]


doing 'oodvalid'


100%|██████████████████████████████████████████████████████████████████████████| 11968/11968 [00:03<00:00, 3641.02it/s]


doing 'unused'


100%|████████████████████████████████████████████████████████████████████████| 119678/119678 [00:32<00:00, 3715.95it/s]


In [6]:
allexamples = [(x[0], taglisp_to_tree(x[1])) for x in tqdm(ds)]

100%|████████████████████████████████████████████████████████████████████████| 239357/239357 [00:33<00:00, 7205.73it/s]


In [7]:
# compute all present compounds
dc = DivergenceComputer()
compdist = dc.compute_compound_distribution(allexamples)

100%|█████████████████████████████████████████████████████████████████████████| 239357/239357 [04:07<00:00, 968.17it/s]


In [10]:
len(compdist)

1621251

In [18]:
compdist.entropy()

11.978434648858228

In [25]:
# get a set of examples that do not share any compounds
def get_disjoint_examples(xs, dc=None):
    dc = DivergenceComputer() if dc is None else dc
    selected = []
    remaining = []
    presentcompounds = set()
    coveredfc = FrequencyDistribution()
    for example in tqdm(xs):
        add = True
        tree = example[1]
        comps = dc.extract_compounds(tree)
        if len(set(comps) & presentcompounds) == 0:
            selected.append(example)
            presentcompounds |= set(comps)
            for comp in comps:
                coveredfc[comp] += 1
        else:
            remaining.append(example)
    return selected

In [26]:
def get_covering_examples(xs, dc=None):
    dc = DivergenceComputer() if dc is None else dc
    selected = []
    presentcompounds = set()
    coveredfc = FrequencyDistribution()
    for example in tqdm(xs):
        add = True
        tree = example[1]
        comps = dc.extract_compounds(tree)
        if len(set(comps) - presentcompounds) > 0:
            selected.append(example)
            presentcompounds |= set(comps)
            for comp in comps:
                coveredfc[comp] += 1
    return selected

In [20]:
dx, _ = get_disjoint_examples(allexamples)

100%|████████████████████████████████████████████████████████████████████████| 239357/239357 [03:50<00:00, 1039.30it/s]


In [22]:
dxdist = dc.compute_compound_distribution(dx)

100%|██████████████████████████████████████████████████████████████████████████████| 396/396 [00:00<00:00, 7856.95it/s]


In [28]:
len(dxdist), len(dx)

(4052, 396)

In [27]:
cx = get_covering_examples(allexamples)

100%|█████████████████████████████████████████████████████████████████████████| 239357/239357 [04:03<00:00, 984.38it/s]


In [29]:
len(cx)

108241

In [67]:
def compute_overlap(dist, comps):
    overlap = 0
    total = 0
    for comp in comps:
        if comp in dist and dist[comp] > 0:
            overlap += 1
        total += 1
    return overlap / total
        
def get_minimal_covering_examples(xs, dc=None, N=10000, step=2000, cache=None):
    retcache = cache is True
    
    dc = DivergenceComputer() if dc is None else dc
    
    # find all compounds and build cache
    if cache is None or cache is True or cache is False:
        print("building compound cache")
        cache = {}
        compounds = {}
        lastid = 0
        for example in tqdm(xs):
            tree = example[1]
            comps = dc.extract_compounds(tree)
            cache[str(tree)] = []
            cacheline = []
            for comp in comps:
                if comp not in compounds:
                    compounds[comp] = lastid + 1
                    lastid += 1
                cacheline.append(compounds[comp])
            cache[str(tree)] = cacheline
    #             cache[str(tree)].append(compounds[comp])
            
        print("built cache")
            
    print("finding disjoint examples")
    random.shuffle(allexamples)
    selected = []
    remaining = []
    presentcompounds = set()
    dxdist = FrequencyDistribution()
            
    for example in tqdm(xs):
        add = True
        tree = example[1]
        comps = cache[str(tree)]
        if len(set(comps) & presentcompounds) == 0:
            selected.append(example)
            presentcompounds |= set(comps)
            for comp in comps:
                dxdist[comp] += 1
        else:
            remaining.append(example)

    dx, remaining = selected, remaining
    # go over all examples and compute their overlaps with dxdist
    print(f"found {len(dx)} disjoint examples")
    
    print("iterating")
    while True:
        print("iter")
        overlaps = []
        for x in tqdm(remaining):
            comps = cache[str(x[1])]
            overlap = compute_overlap(dxdist, comps)
            overlaps.append(overlap)
        
        # sort by overlap and select N best that overlap least
        tosort = zip(remaining, overlaps)
        remainingsorted = sorted(tosort, key=lambda x: x[1])  # small overlap first
        remainingsorted = [x for x, _ in remainingsorted]
        _step = min(step, len(remainingsorted))
        selected, remaining = remainingsorted[:_step], remainingsorted[_step:]
        for x in selected:
            for comp in cache[str(x[1])]:
                dxdist[comp] += 1
        dx = dx + selected
        
        if len(dx) >= N:
            dx = dx[:N]
            remaining = remaining + dx[N:]
            for x in dx[N:]:
                for comp in cache[str(x[1])]:
                    dxdist[comp] -= 1
            break
    
    if retcache is True:
        return dx, remaining, dxdist, cache
    else:
        return dx, remaining, dxdist

In [68]:
# mcx, _, mcxdist, cache = get_minimal_covering_examples(allexamples, N=2000, step=500, cache=True)
mcx, _, mcxdist = get_minimal_covering_examples(allexamples, N=2000, step=500, cache=cache)

building compound cache


100%|████████████████████████████████████████████████████████████████████████| 239357/239357 [5:38:57<00:00, 11.77it/s]


built cache
finding disjoint examples


100%|████████████████████████████████████████████████████████████████████████| 239357/239357 [00:26<00:00, 9076.99it/s]


found 387 disjoint examples
iterating
iter


100%|████████████████████████████████████████████████████████████████████████| 238970/238970 [00:28<00:00, 8366.37it/s]


iter


100%|████████████████████████████████████████████████████████████████████████| 238470/238470 [00:28<00:00, 8316.05it/s]


iter


100%|████████████████████████████████████████████████████████████████████████| 237970/237970 [00:28<00:00, 8282.90it/s]


iter


100%|████████████████████████████████████████████████████████████████████████| 237470/237470 [00:28<00:00, 8237.01it/s]


In [63]:
len(mcx)

1000

In [64]:
mcxdist.entropy(), len(mcxdist)

(10.924063688115353, 116058)

In [58]:
random.shuffle(allexamples)
mcx_random = allexamples[:2000]

In [59]:
mcx_random_dist = dc.compute_compound_distribution(mcx_random)

100%|████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 1027.52it/s]


In [62]:
mcx_random_dist.entropy(), len(mcx_random_dist)

(10.475702730395478, 51134)