In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json, math, random
from typing import List, Tuple
from nltk import Tree
from tqdm import tqdm
import numpy as np
from parseq.datasets import CFQDatasetLoader
from parseq.grammar import taglisp_to_tree, tree_size
from parseq.scripts_resplit.resplit_cfq import DivergenceComputer, FrequencyDistribution

In [None]:
# load cfq, including unused examples
ds = CFQDatasetLoader().load("mcd1/modent", validfrac=0, loadunused=True, keepids=True)

In [None]:
# load all examples from all splits into a single list while removing previous split information
allexamples = [(x[0], x[1], taglisp_to_tree(x[2])) for x in tqdm(ds)]

In [None]:
# compute all present compounds
def print_stats(xs):
    """ Input is list of examples of the form (id, nl, fltree) """
    dc = DivergenceComputer()
    atomdist = dc.compute_atom_distribution(xs)
    compdist = dc.compute_compound_distribution(xs)
    sizedist = dc.compute_size_distribution(xs)
    print(f"Number of examples: {len(xs)}")
    print(f"Atom dist entropy: {atomdist.entropy():.3f}, coverage: {len(atomdist)}")
    print(f"Compound dist entropy: {compdist.entropy():.3f}, coverage: {len(compdist)}")
    print(f"Average size: {sizedist.average():.3f}")
    return atomdist, compdist, sizedist

In [None]:
allatomdist, allcompdist, allsizedist = print_stats(allexamples)

In [None]:
# get a set of examples that do not share any compounds
def get_disjoint_examples(xs, dc=None):
    dc = DivergenceComputer() if dc is None else dc
    selected = []
    remaining = []
    presentcompounds = set()
    coveredfc = FrequencyDistribution()
    for example in tqdm(xs):
        add = True
        tree = example[1]
        comps = dc.extract_compounds(tree)
        if len(set(comps) & presentcompounds) == 0:
            selected.append(example)
            presentcompounds |= set(comps)
            for comp in comps:
                coveredfc[comp] += 1
        else:
            remaining.append(example)
    return selected

In [None]:
def get_covering_examples(xs, dc=None):
    dc = DivergenceComputer() if dc is None else dc
    selected = []
    presentcompounds = set()
    coveredfc = FrequencyDistribution()
    for example in tqdm(xs):
        add = True
        tree = example[1]
        comps = dc.extract_compounds(tree)
        if len(set(comps) - presentcompounds) > 0:
            selected.append(example)
            presentcompounds |= set(comps)
            for comp in comps:
                coveredfc[comp] += 1
    return selected

In [None]:
if False:
    dx = get_disjoint_examples(allexamples)
    cx = get_covering_examples(allexamples)

In [None]:
def smoothendist(dist, perc=0.05):
    fd = FrequencyDistribution()
    sortedkeys = sorted(allsizedist.keys())
    for i, k in enumerate(sortedkeys):
        j = 1
        acc = dist(k)
        div = 1
        while True:
            if i-j >= 0:
                acc += dist(sortedkeys[i-j])
                div += 1
            if i+j < len(sortedkeys):
                acc += dist(sortedkeys[i+j])
                div += 1
            j += 1
            if acc >= perc:
                break
        fd[k] = acc/div
    return fd

In [None]:
smoothedallsize = smoothendist(allsizedist)
for size in sorted(smoothedallsize.keys()):
    print(size, smoothedallsize(size))

In [None]:
for size in sorted(allsizedist.keys()):
    print(size, allsizedist(size))

In [None]:
def compute_overlap(dist, comps):  # how many of the compounds have already been observed?
    overlap = 0
    unoverlap = 0
    total = 0
    for comp in comps:
        if comp in dist and dist[comp] > 0:
            overlap += 1
        else:
            unoverlap += 1
        total += 1
    return overlap / total, unoverlap / total


def get_minimal_covering_examples(xs, dc=None, N=10000, step=2000, cache=None, targetsizedist=None):
    retcache = cache is True
    
    dc = DivergenceComputer() if dc is None else dc
    
    # find all compounds and build cache
    if cache is None or cache is True or cache is False:
        print("building compound cache")
        cache = {}
        compounds = {}
        lastid = 0
        for example in tqdm(xs):
            tree = example[2]
            comps = dc.extract_compounds(tree)
            cache[str(tree)] = []
            cacheline = []
            for comp in comps:
                if comp not in compounds:
                    compounds[comp] = lastid + 1
                    lastid += 1
                cacheline.append(compounds[comp])
            cache[str(tree)] = cacheline
    #             cache[str(tree)].append(compounds[comp])
            
        print("built cache")
            
    print("finding disjoint examples")
    random.shuffle(allexamples)
    selected = []
    remaining = []
    presentcompounds = set()
    dxdist = FrequencyDistribution()
    sizedist = FrequencyDistribution()
    
    for example in tqdm(xs):
        add = True
        tree = example[2]
        comps = cache[str(tree)]
        if len(set(comps) & presentcompounds) == 0:
            selected.append(example)
            presentcompounds |= set(comps)
            for comp in comps:
                dxdist[comp] += 1
            sizedist[tree_size(tree)] += 1
        else:
            remaining.append(example)

    dx, remaining = selected, remaining
    # go over all examples and compute their overlaps with dxdist
    print(f"found {len(dx)} disjoint examples")
    
    if targetsizedist is not None:
        targetsizedist = smoothendist(targetsizedist)
    
    print("iterating")
    iternr = 1
    while True:
        print(f"iter {iternr}")
        iternr += 1
        overlaps = []
        unoverlaps = []
        tsizes = []
        for x in tqdm(remaining):
            comps = cache[str(x[2])]
            tsizes.append(tree_size(x[2]))
            overlap, unoverlap = compute_overlap(dxdist, comps)
            overlaps.append(overlap)
            unoverlaps.append(unoverlap)
        
        # sort by overlap and select N best that overlap least
        tosort = zip(remaining, overlaps, unoverlaps, tsizes)
        remainingsorted = sorted(tosort, key=lambda x: -x[2])  # small overlap first
        print(f"Top overlap: {remainingsorted[0][1]}, bottom overlap: {remainingsorted[-1][1]}")
        numberwithlargestoverlap = 0
        numberwithhalfoverlap = 0
        for (i, x) in enumerate(remainingsorted[::-1]):
            overlap = x[1]
            if overlap >= 0.8:
                numberwithlargestoverlap += 1
            if overlap >= 0.5:
                numberwithhalfoverlap += 1
        print(f"Number of high overlap: {numberwithlargestoverlap} and half overlap: {numberwithhalfoverlap}")
        #remainingsorted = [x[0] for x in remainingsorted]
        _step = min(step, len(remainingsorted))
        if targetsizedist is None:
            selected, remaining = remainingsorted[:_step], remainingsorted[_step:]
        else:
            selected, remaining = [], []
            smoothsizedist = smoothendist(sizedist)
            for i, x in enumerate(remainingsorted):
                xsize = x[-1]
                if smoothsizedist[xsize] < targetsizedist[xsize]:
                    selected.append(x)
                    sizedist[xsize] += 1
                    smoothsizedist = smoothendist(sizedist)
                else:
                    remaining.append(x)
                if len(selected) >= step:
                    remaining += remainingsorted[i+1:]
                    break
            assert(len(selected) + len(remaining) == len(remainingsorted))
            
        for x, xoverlap, xunoverlap, xsize in selected:
            for comp in cache[str(x[2])]:
                dxdist[comp] += 1
        dx = dx + [x[0] for x in selected]
        remaining = [x[0] for x in remaining]
        
        if len(dx) >= N:
            dx = dx[:N]
            remaining = remaining + dx[N:]
            for x in dx[N:]:
                for comp in cache[str(x[2])]:
                    dxdist[comp] -= 1
                if dxdist[comp] <= 0:
                    del dxdist[comp]
            break
        
        print(f"Number of selected examples: {len(dx)}, number of covered compounds: {len(dxdist)}")
    
    if retcache is True:
        return dx, remaining, dxdist, cache
    else:
        return dx, remaining, dxdist

In [None]:
NN = 5000

In [None]:
if "cache" not in locals() or cache is None:
    mcx, _, mcxdist, cache = get_minimal_covering_examples(allexamples, N=NN, step=NN//10, cache=True)
else:
    mcx, _, mcxdist = get_minimal_covering_examples(allexamples, N=NN, step=NN//10, cache=cache)  # , targetsizedist=allsizedist)

In [None]:
mcxatom, mcxcomp, mcxsize = print_stats(mcx)
print(randsize.compute_chernoff_coeff(smoothendist(mcxsize), smoothendist(allsizedist)))

In [None]:
mcxatom, mcxcomp, mcxsize = print_stats(mcx)
print(randsize.compute_chernoff_coeff(smoothendist(mcxsize), smoothendist(allsizedist)))

In [None]:
random.shuffle(allexamples)
mcx_random = allexamples[:NN]

In [None]:
randatom, randcomp, randsize = print_stats(mcx_random)
print(randsize.compute_chernoff_coeff(smoothendist(randsize), smoothendist(allsizedist)))

In [None]:
for k in sorted(allsizedist.keys()):
    print(f"{k} {allsizedist(k):.5f} - {randsize(k):.5f} - {mcxsize(k):.5f}")