In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
from parseq.datasets import CFQDatasetLoader
from parseq.scripts_resplit.resplit_cfq import FrequencyDistribution, DivergenceComputer
import random, json
from typing import List
from parseq.grammar import taglisp_to_tree
from nltk import Tree
from tqdm import tqdm
from copy import copy

In [4]:
ds = CFQDatasetLoader().load(split="mcd1/modent", keepids=True)

CFQDatasetLoader: make data
CFQDatasetLoader: make data in 0.0 sec
loading split 'mcd1'
splitting off a random 10% of 'train' for 'iidvalid' using seed 42
doing 'train'
doing 'test'
doing 'iidvalid'
doing 'oodvalid'


100%|██████████| 86169/86169 [00:23<00:00, 3691.63it/s]
100%|██████████| 11968/11968 [00:03<00:00, 3150.39it/s]
100%|██████████| 9574/9574 [00:02<00:00, 3714.33it/s]
100%|██████████| 11968/11968 [00:03<00:00, 3137.24it/s]


In [5]:
xs = ds.examples
random.shuffle(xs)

In [6]:
def compute_approx_chernoff_change_twoway(comps, dist1:FrequencyDistribution, dist2:FrequencyDistribution, alpha=0.5):
    cchange1 = 0
    cchange2 = 0
    for comp in comps:
        cchange1 += (((dist1[comp] + 1)/dist1.total) ** alpha - dist1(comp) ** alpha) * (dist2(comp) ** (1-alpha))
        cchange2 += (dist1(comp) ** alpha) * (((dist2[comp] + 1)/dist2.total) ** (1-alpha) - dist2(comp) ** (1-alpha))
    return cchange1, cchange2

def compute_approx_chernoff_change(comps, traindist:FrequencyDistribution, validdist:FrequencyDistribution, testdist:FrequencyDistribution, alpha=0.1):
    trainchange = 0   # change in total weighted chernoff coeff if example is assigned to train
    validchange = 0   # "" "" if assigned to valid
    testchange = 0    # "" "" "" test
    for comp in comps:
        trainchange += (((traindist[comp] + 1)/(traindist.total + 1)) ** alpha - traindist(comp) ** alpha) * (testdist(comp) ** (1-alpha))
        trainchange += (((traindist[comp] + 1)/(traindist.total + 1)) ** alpha - traindist(comp) ** alpha) * (validdist(comp) ** (1-alpha))
        validchange += (((validdist[comp] + 1)/(validdist.total + 1)) ** alpha - validdist(comp) ** alpha) * (testdist(comp) ** (1-alpha))
        validchange += (traindist(comp) ** alpha) * (((validdist[comp] + 1)/(validdist.total+1)) ** (1-alpha) - validdist(comp) ** (1-alpha))
        testchange += (traindist(comp) ** alpha) * (((testdist[comp] + 1)/(testdist.total+1)) ** (1-alpha) - testdist(comp) ** (1-alpha))
        testchange += (validdist(comp) ** alpha) * (((testdist[comp] + 1)/(testdist.total+1)) ** (1-alpha) - testdist(comp) ** (1-alpha))
    return trainchange, validchange, testchange

def compute_true_chernoff_chagne(comps, traindist:FrequencyDistribution, validdist:FrequencyDistribution, testdist:FrequencyDistribution, alpha=0.1):
    dc = DivergenceComputer()
    _traindist = FrequencyDistribution()
    _traindist._counts = copy(traindist._counts)
    _traindist.total = traindist.total

    _validdist = FrequencyDistribution()
    _validdist._counts = copy(validdist._counts)
    _validdist.total = validdist.total

    _testdist = FrequencyDistribution()
    _testdist._counts = copy(testdist._counts)
    _testdist.total = testdist.total

    for comp in comps:
        _traindist[comp] += 1
        _validdist[comp] += 1
        _testdist[comp] += 1

    trainchange = FrequencyDistribution.compute_chernoff_coeff(_traindist, validdist, alpha=alpha) - FrequencyDistribution.compute_chernoff_coeff(traindist, validdist, alpha=alpha) \
                  + FrequencyDistribution.compute_chernoff_coeff(_traindist, testdist, alpha=alpha) - FrequencyDistribution.compute_chernoff_coeff(traindist, testdist, alpha=alpha)

    validchange = FrequencyDistribution.compute_chernoff_coeff(traindist, _validdist, alpha=alpha) - FrequencyDistribution.compute_chernoff_coeff(traindist, validdist, alpha=alpha) \
                  + FrequencyDistribution.compute_chernoff_coeff(_validdist, testdist, alpha=alpha) - FrequencyDistribution.compute_chernoff_coeff(validdist, testdist, alpha=alpha)

    testchange = FrequencyDistribution.compute_chernoff_coeff(traindist, _testdist, alpha=alpha) - FrequencyDistribution.compute_chernoff_coeff(traindist, testdist, alpha=alpha) \
                  + FrequencyDistribution.compute_chernoff_coeff(validdist, _testdist, alpha=alpha) - FrequencyDistribution.compute_chernoff_coeff(validdist, testdist, alpha=alpha)

    return trainchange, validchange, testchange

In [7]:
traindist = FrequencyDistribution()
comps = "a b c d a b c a c d"
comps = "a a a a a a a a a b c i i"
for comp in comps.split():
    traindist[comp] += 100

validdist = FrequencyDistribution()
comps = "a b c d a b e f e"
comps = "b c d d d e e h h f"
# comps = "a b c d a b c a c d"

for comp in comps.split():
    validdist[comp] += 100

testdist = FrequencyDistribution()
comps = "a b c d e f g g g h"
comps = "a e e f f f f g g "
# comps = "a b c d a b c a c d"

for comp in comps.split():
    testdist[comp] += 100

comps = "a".split()
print(comps)
print(compute_approx_chernoff_change(comps, traindist, validdist, testdist))
print(compute_true_chernoff_chagne(comps, traindist, validdist, testdist))

comps = "i".split()
print(comps)
print(compute_approx_chernoff_change(comps, traindist, validdist, testdist))
print(compute_true_chernoff_chagne(comps, traindist, validdist, testdist))

comps = "f".split()
print(comps)
print(compute_approx_chernoff_change(comps, traindist, validdist, testdist))
print(compute_true_chernoff_chagne(comps, traindist, validdist, testdist))

comps = "h".split()
print(comps)
print(compute_approx_chernoff_change(comps, traindist, validdist, testdist))
print(compute_true_chernoff_chagne(comps, traindist, validdist, testdist))

comps = "g".split()
print(comps)
print(compute_approx_chernoff_change(comps, traindist, validdist, testdist))
print(compute_true_chernoff_chagne(comps, traindist, validdist, testdist))

['a']
(4.557063588286274e-06, 0.07128616642596544, 0.001065680795276904)
(-1.0422826546924169e-05, 0.07105075216737256, 0.00046356693626414813)
['i']
(0.0, 0.0016531718235173105, 0.0018174243917896157)
(-2.5238408138950064e-05, 0.001417757564924238, 0.0010820340615026458)
['f']
(0.29674958590738487, 0.00034284154775237734, 0.00047800600616257615)
(0.29672434749924587, 0.0001456918711334465, 0.00012506822682079033)
['h']
(0.11468305843310449, 0.0, 0.001865738131744269)
(0.11465782002496555, -0.00023541425859308074, 0.0011303478014572654)
['g']
(0.12609010545505056, 0.12943905388013063, 0.0)
(0.1260648670469116, 0.12920363962153758, -0.0007353903302870224)


In [8]:
def make_mcd_splits(xs:List, sizes=(0.6, 0.2, 0.2)):
    # initialize randomly
    dc = DivergenceComputer()

    print("initializing randomly by selecting 1 random example for each subset")
    xs = [(x[0], x[1], taglisp_to_tree(x[2]) if not isinstance(x[2], Tree) else x[2]) for x in xs]
    random.shuffle(xs)
    subsets = [[], [], []]
    subsets[0].append(xs.pop(0))
    subsets[1].append(xs.pop(0))
    subsets[2].append(xs.pop(0))

    trainstats = FrequencyDistribution()
    validstats = FrequencyDistribution()
    teststats = FrequencyDistribution()
    statses = [trainstats, validstats, teststats]

    # update stats
    for subset, stats in zip(subsets, [trainstats, validstats, teststats]):
        for example in subset:
            for comp in dc.extract_compounds(example[2]):
                stats[comp] += 1

    newxs = []
    n = 0
    for subset in subsets:
        n += len(subset)
    while len(xs) > 0:
        random.shuffle(xs)
        example = xs.pop(-1)
        comps = dc.extract_compounds(example[2])
        changes = compute_approx_chernoff_change(comps, trainstats, validstats, teststats)
        changes = list(zip(changes, range(3)))
        bestchoice = sorted(changes, key=lambda x: x[1])[0][1]   # which subset is best
        # print(bestchoice)
        if len(subsets[bestchoice]) <= sizes[bestchoice] * n:
            subsets[bestchoice].append(example)
            for comp in comps:
                statses[bestchoice][comp] += 1
        else:
            newxs.append(example)

        print(subsets)

        if len(xs) == 0:
            # updating compound distribution and computing divergences
            print("updating compound distributions and computing divergences")
            cds = {}
            for i, subset in enumerate(subsets):
                cds[str(i)] = dc.compute_compound_distribution(subset)
            divs = dc._compute_compound_divergences(cds)
            print(json.dumps(divs, indent=3))

            xs = newxs

    # iterate over remaining
    # assign to the most fitting subset

    return subsets

In [None]:
dss = make_mcd_splits(xs[:1000])
print(dss)

In [77]:
def compute_example_overlap(comps, dist):   # how many of the compounds have already been observed in the given dist
    overlap = 0
    unoverlap = 0
    total = 0
    for comp in comps:
        if comp in dist and dist[comp] > 0:
            overlap += 1
        else:
            unoverlap += 1
        total += 1
    return overlap / total, unoverlap / total


def compute_example_stats(subsets, cds, dc=None):
    print("Computing stats for all examples")
    stats = {}
    cds = {k: v.tofreqs() for k, v in cds.items()}
    for subset in subsets:
        for example in tqdm(subset):
            eid = example[0]
            ecomp = dc.extract_compound_dist(example[2]).tofreqs()
            # print(len(ecomp))
            stats[eid] = (
                FrequencyDistribution.compute_chernoff_coeff(cds["0"], ecomp, alpha=0.01),
                FrequencyDistribution.compute_chernoff_coeff(cds["1"], ecomp, alpha=0.01),
                FrequencyDistribution.compute_chernoff_coeff(cds["2"], ecomp, alpha=0.01),
                FrequencyDistribution.compute_chernoff_coeff(cds["0"], ecomp, alpha=0.9),
                FrequencyDistribution.compute_chernoff_coeff(cds["1"], ecomp, alpha=0.9),
                FrequencyDistribution.compute_chernoff_coeff(cds["2"], ecomp, alpha=0.9),
            )
    return stats


def make_mcd_splits2(xs:List, sizes=(0.6, 0.2, 0.2), swapsize=10, iters=10):
    assert len(sizes) in (3,)
    # make random assignments and compute stats
    print("making random assigments")
    xs = [(x[0], x[1], taglisp_to_tree(x[2]) if not isinstance(x[2], Tree) else x[2]) for x in xs]
    random.shuffle(xs)
    subsets = []
    prev = 0
    for i, s in enumerate(sizes):
        if i < len(sizes):
            c = int(round(s * len(xs))) + prev
            subsets.append(xs[prev:c])
            prev = c
        else:
            subsets.append(xs[prev:])

    print("Computing initial compound distributions and their divergences")
    cds = {}
    dc = DivergenceComputer()
    for i, subset in enumerate(subsets):
        cds[str(i)] = dc.compute_compound_distribution(subset)

    divs = dc._compute_compound_divergences(cds)
    print("Initial compound divergences")
    print(json.dumps(divs, indent=3))

    """
    print(subsets[0][0])
    comps = dc.extract_compounds(taglisp_to_tree(subsets[0][0][2]))
    print(compute_example_overlap(comps, cds["2"]))
    compdist = dc.extract_compound_dist(taglisp_to_tree(subsets[0][0][2]))
    print(FrequencyDistribution.compute_chernoff_coeff(cds["2"], compdist, alpha=0.1))
    """

    print("Iterating")
    # in every iteration, we find for each split the examples that would improve the divergence most and perform one swap
    # remove from train set all compounds that are present in test-ood and ood2
    # remove from ood2 all compounds that are present in test-ood
    # ==> test-ood would contain entirely novel compounds that have not been trained on and have not been validated with
    while iters > 0:
        stats = compute_example_stats(subsets, cds, dc=dc)   # initial stats

        # train --> test-ood
        examples = subsets[0]
        scored = [(i, stats[examples[i][0]][2]) for i in range(len(examples))]
        scored = sorted(scored, key=lambda x: x[1], reverse=True)
        scored = [examples[i] for i, score in scored]
        traintotest, traintotrain = scored[:swapsize], scored[swapsize:]

        # train -> ood2
        examples = traintotrain
        scored = [(i, stats[examples[i][0]][1]) for i in range(len(examples))]
        scored = sorted(scored, key=lambda x: x[1], reverse=True)
        scored = [examples[i] for i, score in scored]
        traintoood2, traintotrain = scored[:swapsize], scored[swapsize:]

        # ood2 -> test-ood
        examples = subsets[1]
        scored = [(i, stats[examples[i][0]][2]) for i in range(len(examples))]
        scored = sorted(scored, key=lambda x: x[1], reverse=True)
        scored = [examples[i] for i, score in scored]
        ood2totest, ood2toood2 = scored[:swapsize], scored[swapsize:]

        # ood2 -> train
        examples = ood2toood2
        scored = [(i, stats[examples[i][0]][3]) for i in range(len(examples))]
        scored = sorted(scored, key=lambda x: x[1], reverse=True)
        scored = [examples[i] for i, score in scored]
        ood2totrain, ood2toood2 = scored[:swapsize], scored[swapsize:]

        # test-ood --> train
        examples = subsets[2]
        scored = [(i, stats[examples[i][0]][3]) for i in range(len(examples))]
        scored = sorted(scored, key=lambda x: x[1], reverse=True)
        scored = [examples[i] for i, score in scored]
        testtotrain, testtotest = scored[:swapsize], scored[swapsize:]

        # test-ood --> ood2
        examples = testtotest
        scored = [(i, stats[examples[i][0]][4]) for i in range(len(examples))]
        scored = sorted(scored, key=lambda x: x[1], reverse=True)
        scored = [examples[i] for i, score in scored]
        testtoood2, testtotest = scored[:swapsize], scored[swapsize:]

        subsets = [traintotrain + ood2totrain + testtotrain,
                   ood2toood2 + traintoood2 + testtoood2,
                   testtotest + traintotest + ood2totest]


        # updating compound distribution and computing divergences
        print("updating compound distributions and computing divergences")
        cds = {}
        for i, subset in enumerate(subsets):
            cds[str(i)] = dc.compute_compound_distribution(subset)
        divs = dc._compute_compound_divergences(cds)
        print(json.dumps(divs, indent=3))

        iters -= 1


    return cds

In [85]:
dss = make_mcd_splits(xs[:1000])

initializing randomly by selecting 1 random example for each subset


TypeError: list indices must be integers or slices, not tuple

In [86]:
### len(dss[0]), len(dss[1]), len(dss[2])

In [9]:
dss[0]._counts

AttributeError: 'list' object has no attribute '_counts'