In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
from typing import List, Tuple
from nltk import Tree
from tqdm import tqdm
from parseq.datasets import CFQDatasetLoader
from parseq.grammar import taglisp_to_tree
from parseq.scripts_resplit.resplit_cfq import DivergenceComputer, FrequencyDistribution

In [72]:
ds = CFQDatasetLoader().load("mcd2/modent", validfrac=0, loadunused=True, keepids=True)

CFQDatasetLoader: make data
CFQDatasetLoader: make data in 0.0 sec
loading split 'mcd2'
doing 'train'


100%|██████████| 95743/95743 [00:28<00:00, 3403.79it/s]


doing 'test'


100%|██████████| 11968/11968 [00:03<00:00, 3144.63it/s]


doing 'oodvalid'


100%|██████████| 11968/11968 [00:03<00:00, 3105.45it/s]


doing 'unused'


100%|██████████| 119678/119678 [00:36<00:00, 3264.78it/s]


In [73]:
dc = DivergenceComputer()

In [74]:
atom_dists = dc.compute_atom_distributions(ds)
print(json.dumps(dc._compute_atom_divergences(atom_dists), indent=3))

100%|██████████| 239357/239357 [01:03<00:00, 3747.19it/s]


{
   "train-train": -4.440892098500626e-16,
   "train-test": 0.03132730435620745,
   "train-oodvalid": 0.032651238914194525,
   "train-unused": 0.013494024085208722,
   "test-train": 0.03132730435620745,
   "test-test": 1.1102230246251565e-16,
   "test-oodvalid": 8.113859177327765e-05,
   "test-unused": 0.009930059865293162,
   "oodvalid-train": 0.032651238914194525,
   "oodvalid-test": 8.113859177327765e-05,
   "oodvalid-oodvalid": 6.661338147750939e-16,
   "oodvalid-unused": 0.010216209565831913,
   "unused-train": 0.013494024085208722,
   "unused-test": 0.009930059865293162,
   "unused-oodvalid": 0.010216209565831913,
   "unused-unused": 0.0
}


In [75]:
comp_dists = dc.compute_compound_distributions(ds)
print(json.dumps(dc._compute_compound_divergences(comp_dists), indent=3))

100%|██████████| 239357/239357 [05:09<00:00, 772.66it/s] 


{
   "train-train": -1.683875261448975e-11,
   "train-test": 0.758667704027568,
   "train-oodvalid": 0.7638965664114622,
   "train-unused": 0.43624495310010625,
   "test-train": 0.8416811283686769,
   "test-test": 1.9746426715983034e-12,
   "test-oodvalid": 0.036111881760164444,
   "test-unused": 0.539074523247199,
   "oodvalid-train": 0.8559611826719928,
   "oodvalid-test": 0.036735899422104645,
   "oodvalid-oodvalid": -2.3783197633520103e-12,
   "oodvalid-unused": 0.5455467425533544,
   "unused-train": 0.24168319294906182,
   "unused-test": 0.24042062960405308,
   "unused-oodvalid": 0.24247465818209646,
   "unused-unused": -1.1156631174458198e-11
}


In [76]:
unused = [(ex[0], ex[1], taglisp_to_tree(ex[2])) for ex in tqdm(ds) if ex[3] == "unused"]
len(unused)

100%|██████████| 239357/239357 [00:17<00:00, 13529.69it/s] 


119678

In [77]:

def get_dist_similarity(x, dist):
    if isinstance(dist, FrequencyDistribution):
        dist = dist.tofreqs()
    score = 0
    for xe in x:
        score += dist[xe] if xe in dist else 0
    score = score / max(1e-6, len(x))
    return score


def compute_approx_chernoff_change(comps, traindist:FrequencyDistribution, testdist:FrequencyDistribution, alpha=0.1):
    validchange = 0   # "" "" if assigned to valid
    for comp in comps:
        validchange += (((validdist[comp] + 1)/(validdist.total + 1)) ** alpha - validdist(comp) ** alpha) * (testdist(comp) ** (1-alpha))
        validchange += (traindist(comp) ** alpha) * (((validdist[comp] + 1)/(validdist.total+1)) ** (1-alpha) - validdist(comp) ** (1-alpha))
    return validchange


def filter_mcd(source:List[Tuple[str,Tree]], otheratoms, othercomps, N=10000,
               dc=None, coeffa=1., coeffb=30):
    """
    :param source:      a list of examples from which to pick, in format (input, output)
    :param otheratoms:  atom distributions of other splits
    :param othercomps:  compound distributions of other splits
    :param N:           how many examples to retain in selection from source
    :return:   MCD selection of examples from source wrt all other distributions

    ideally, the training set shouldn't contain any compound from new selection
    and the new selection shouldn't contain any compound from test
    """
    assert len(source) > N
    assert len(otheratoms) == len(othercomps)
    
    trainatoms, testatoms = otheratoms
    traincomps, testcomps = othercomps
    traincomps = traincomps.tofreqs()
    testcomps = testcomps.tofreqs()
    
    print("creating selection")
    exstats = []
    i = 0
    for x in tqdm(source):
        xatoms = dc.extract_atoms(x[2])
        # xcomps = dc.extract_compounds(x[1])
        xcomps = dc.extract_compound_dist(x[2])
        exstats.append((dc.compute_chernoff_coeff(traincomps, xcomps, alpha=0.1),
                        dc.compute_chernoff_coeff(xcomps, testcomps, alpha=0.1),
                        len(xcomps)))
        if i < 10:
            print(exstats[i])
        i += 1
        
    # select top-N examples with least overlap with train and test as starting point
    scores = [(i, coeffa*a + coeffb*b) for (i, (a, b, e)) in enumerate(exstats)]
    sortedscores = sorted(scores, key=lambda x: x[1])
    retids = [i for i, s in sortedscores[:N]]
    
    ret = []
    for i in retids:
        a = source[i]
        a = tuple(a) + ("ood2valid",)
        ret.append(a)
    return ret

In [82]:
newsplit = filter_mcd(unused, [atom_dists["train"], atom_dists["test"]], [comp_dists["train"], comp_dists["test"]], dc=dc, 
                      coeffb=15)

creating selection


  0%|          | 0/119678 [00:00<?, ?it/s]

(0.3624525468472961, 0.010353082164249223, 25)
(0.29354997078939765, 0.0342733987796785, 92)
(0.4150774803690949, 0.010463744666323679, 14)
(0.4193266485927801, 0.010470226809788645, 14)
(0.3183205232745508, 0.009676378868546689, 175)
(0.33514769883684786, 0.029944597224404304, 92)
(0.22132164961200074, 0.03517247177141252, 129)
(0.12597774206850776, 0.005929634621680035, 231)
(0.3497406860779779, 0.007685187673286987, 92)
(0.41634117744268523, 0.009150716341189438, 6)


100%|██████████| 119678/119678 [02:33<00:00, 780.15it/s] 


In [83]:
newsplit_atomdist = dc.compute_atom_distribution(newsplit)
newsplit_compdist = dc.compute_compound_distribution(newsplit)
_atom_dists = {k: v for k, v in atom_dists.items()}
_atom_dists["ood2valid"] = newsplit_atomdist
_comp_dists = {k: v for k, v in comp_dists.items()}
_comp_dists["ood2valid"] = newsplit_compdist
print(json.dumps(dc._compute_atom_divergences(_atom_dists), indent=3))
print(json.dumps(dc._compute_compound_divergences(_comp_dists), indent=3))

100%|██████████| 10000/10000 [00:01<00:00, 6838.40it/s]
100%|██████████| 10000/10000 [00:19<00:00, 518.58it/s]


{
   "train-train": -4.440892098500626e-16,
   "train-test": 0.03132730435620745,
   "train-oodvalid": 0.032651238914194525,
   "train-unused": 0.013494024085208722,
   "train-ood2valid": 0.04098440132340175,
   "test-train": 0.03132730435620745,
   "test-test": 1.1102230246251565e-16,
   "test-oodvalid": 8.113859177327765e-05,
   "test-unused": 0.009930059865293162,
   "test-ood2valid": 0.04030324226695947,
   "oodvalid-train": 0.032651238914194525,
   "oodvalid-test": 8.113859177327765e-05,
   "oodvalid-oodvalid": 6.661338147750939e-16,
   "oodvalid-unused": 0.010216209565831913,
   "oodvalid-ood2valid": 0.041031038560354616,
   "unused-train": 0.013494024085208722,
   "unused-test": 0.009930059865293162,
   "unused-oodvalid": 0.010216209565831913,
   "unused-unused": 0.0,
   "unused-ood2valid": 0.031082935641014386,
   "ood2valid-train": 0.04098440132340175,
   "ood2valid-test": 0.04030324226695947,
   "ood2valid-oodvalid": 0.041031038560354616,
   "ood2valid-unused": 0.031082935641

In [None]:
newsplit_atomdist = dc.compute_atom_distribution(newsplit)
newsplit_compdist = dc.compute_compound_distribution(newsplit)
_atom_dists = {k: v for k, v in atom_dists.items()}
_atom_dists["ood2valid"] = newsplit_atomdist
_comp_dists = {k: v for k, v in comp_dists.items()}
_comp_dists["ood2valid"] = newsplit_compdist
print(json.dumps(dc._compute_atom_divergences(_atom_dists), indent=3))
print(json.dumps(dc._compute_compound_divergences(_comp_dists), indent=3))

In [53]:
for k in set(_atom_dists["train"].elements()) | set(_atom_dists["test"].elements()) | set(_atom_dists["newsplit"].elements()):
    print(f"{k}: {_atom_dists['train'](k):.5f} - {_atom_dists['test'](k):.5f} - {_atom_dists['newsplit'](k):.5f}")

100%|██████████| 10000/10000 [00:01<00:00, 6810.45it/s]
100%|██████████| 10000/10000 [00:18<00:00, 532.44it/s]


{
   "train-train": 0.0,
   "train-test": 0.03824614256979553,
   "train-oodvalid": 0.03836418726955826,
   "train-unused": 0.014322402198022943,
   "train-ood2valid": 0.06018251423648202,
   "test-train": 0.03824614256979553,
   "test-test": 0.0,
   "test-oodvalid": 9.497392817769956e-05,
   "test-unused": 0.009615993349177065,
   "test-ood2valid": 0.047546806189623925,
   "oodvalid-train": 0.03836418726955826,
   "oodvalid-test": 9.497392817769956e-05,
   "oodvalid-oodvalid": -2.220446049250313e-16,
   "oodvalid-unused": 0.009726062600495133,
   "oodvalid-ood2valid": 0.04781899274421708,
   "unused-train": 0.014322402198022943,
   "unused-test": 0.009615993349177065,
   "unused-oodvalid": 0.009726062600495133,
   "unused-unused": 2.220446049250313e-16,
   "unused-ood2valid": 0.034586211770803454,
   "ood2valid-train": 0.06018251423648202,
   "ood2valid-test": 0.047546806189623925,
   "ood2valid-oodvalid": 0.04781899274421708,
   "ood2valid-unused": 0.034586211770803454,
   "ood2valid

In [11]:
for k in set(_atom_dists["train"].elements()) | set(_atom_dists["test"].elements()) | set(_atom_dists["newsplit"].elements()):
    print(f"{k}: {_atom_dists['train'](k):.5f} - {_atom_dists['test'](k):.5f} - {_atom_dists['newsplit'](k):.5f}")

(ns:film.film.edited_by ): 0.01024 - 0.00673 - 0.01000
(ns:m.06mkj ): 0.00086 - 0.00062 - 0.00065
(ns:film.film_distributor.films_distributed/ns:film.film_film_distributor_relationship.film ): 0.00075 - 0.00084 - 0.00130
(ns:m.0d0vqn ): 0.00076 - 0.00066 - 0.00053
(?x4 ): 0.00013 - 0.00000 - 0.00022
(ns:people.person.sibling_s/ns:people.sibling_relationship.sibling ): 0.00333 - 0.00241 - 0.00082
(ns:film.film ): 0.00268 - 0.00185 - 0.00157
(ns:film.production_company.films ): 0.00338 - 0.00940 - 0.00388
(ns:film.production_company ): 0.00076 - 0.00053 - 0.00102
(m8 ): 0.00001 - 0.00003 - 0.00048
(ns:fictional_universe.fictional_character ): 0.00095 - 0.00058 - 0.00051
(ns:business.employer.employees/ns:business.employment_tenure.person ): 0.00303 - 0.00198 - 0.00726
(ns:film.actor.film/ns:film.performance.film ): 0.00253 - 0.00564 - 0.00365
(ns:film.writer.film ): 0.00375 - 0.01221 - 0.00547
(ns:film.film.sequel ): 0.00083 - 0.00112 - 0.00120
(?x5 ): 0.00001 - 0.00000 - 0.00002
(ns:peo

In [19]:
def print_overlaps(dists):
    collections = dict()
    for k, v in dists.items():
        collections[k] = set(v.keys())
    ks = list(dists.keys())
    for i in range(len(ks)):
        fromname = ks[i]
        fromatoms = collections[ks[i]]
        for j in range(i, len(ks)):
            toname = ks[j]
            toatoms = collections[ks[j]]
            print(f"Overlap {fromname},{toname}: {len(fromatoms & toatoms)} / ({len(fromatoms)}, {len(toatoms)})")

In [22]:
print_overlaps(_atom_dists)
print_overlaps(_comp_dists)

Overlap train,train: 104 / (104, 104)
Overlap train,test: 100 / (104, 100)
Overlap train,oodvalid: 101 / (104, 101)
Overlap train,unused: 104 / (104, 104)
Overlap train,newsplit: 103 / (104, 103)
Overlap test,test: 100 / (100, 100)
Overlap test,oodvalid: 100 / (100, 101)
Overlap test,unused: 100 / (100, 104)
Overlap test,newsplit: 99 / (100, 103)
Overlap oodvalid,oodvalid: 101 / (101, 101)
Overlap oodvalid,unused: 101 / (101, 104)
Overlap oodvalid,newsplit: 100 / (101, 103)
Overlap unused,unused: 104 / (104, 104)
Overlap unused,newsplit: 103 / (104, 103)
Overlap newsplit,newsplit: 103 / (103, 103)
Overlap train,train: 844862 / (844862, 844862)
Overlap train,test: 19393 / (844862, 133039)
Overlap train,oodvalid: 18578 / (844862, 132884)
Overlap train,unused: 291419 / (844862, 1013197)
Overlap train,newsplit: 109322 / (844862, 364795)
Overlap test,test: 133039 / (133039, 133039)
Overlap test,oodvalid: 93300 / (133039, 132884)
Overlap test,unused: 95259 / (133039, 1013197)
Overlap test,ne

In [46]:
def filter_traindata(traindata, testdata, newdevdata, N=10000):
    newcompdist = dc.compute_compound_distribution(newdevdata).tofreqs()
    testdist = dc.compute_compound_distribution(testdata).tofreqs()
        
    scores = []
    for i, x in enumerate(tqdm(traindata)):
        xcomps = dc.extract_compounds(x[1])
        scores.append((i, get_dist_similarity(xcomps, newcompdist) - get_dist_similarity(xcomps, testdist)))
        
    sortedscores = sorted(scores, key=lambda x: x[1], reverse=True)
    retids = [i for i, s in sortedscores[N:]]
    ret = [(traindata[i][0], traindata[i][1], "newtrain") for i in retids]
    return ret

In [47]:
traindata = [(ex[0], taglisp_to_tree(ex[1])) for ex in tqdm(ds) if ex[2] == "train"]
testdata = [(ex[0], taglisp_to_tree(ex[1])) for ex in tqdm(ds) if ex[2] == "test"]
newdevdata = [(ex[0], ex[1]) for ex in tqdm(newsplit)]
newtrain = filter_traindata(traindata, testdata, newdevdata)

100%|███████████████████████████████████████████████████████████████████████| 239357/239357 [00:12<00:00, 18744.58it/s]
100%|██████████████████████████████████████████████████████████████████████| 239357/239357 [00:01<00:00, 149562.97it/s]
100%|███████████████████████████████████████████████████████████████████████| 10000/10000 [00:00<00:00, 1428091.25it/s]
100%|███████████████████████████████████████████████████████████████████████████| 10000/10000 [00:25<00:00, 399.05it/s]
100%|███████████████████████████████████████████████████████████████████████████| 11968/11968 [00:17<00:00, 666.47it/s]
100%|██████████████████████████████████████████████████████████████████████████| 95743/95743 [01:23<00:00, 1151.01it/s]


In [48]:
newtrain_atomdist = dc.compute_atom_distribution(newtrain)
newtrain_compdist = dc.compute_compound_distribution(newtrain)
__atom_dists = {k: v for k, v in _atom_dists.items()}
__atom_dists["newtrain"] = newtrain_atomdist
__comp_dists = {k: v for k, v in _comp_dists.items()}
__comp_dists["newtrain"] = newtrain_compdist
print(json.dumps(dc._compute_atom_divergences(__atom_dists), indent=3))
print(json.dumps(dc._compute_compound_divergences(__comp_dists), indent=3))

100%|██████████████████████████████████████████████████████████████████████████| 85743/85743 [00:10<00:00, 8572.47it/s]
100%|██████████████████████████████████████████████████████████████████████████| 85743/85743 [01:21<00:00, 1050.25it/s]


{
   "train-train": -4.440892098500626e-16,
   "train-test": 0.03132730435620734,
   "train-oodvalid": 0.032651238914194414,
   "train-unused": 0.01349402408520839,
   "train-newsplit": 0.05231157711172296,
   "train-newtrain": 0.00030947563155736546,
   "test-train": 0.03132730435620712,
   "test-test": 3.3306690738754696e-16,
   "test-oodvalid": 8.113859177349969e-05,
   "test-unused": 0.009930059865293495,
   "test-newsplit": 0.054593072694224976,
   "test-newtrain": 0.03393804863415539,
   "oodvalid-train": 0.032651238914194414,
   "oodvalid-test": 8.113859177349969e-05,
   "oodvalid-oodvalid": 3.3306690738754696e-16,
   "oodvalid-unused": 0.010216209565832135,
   "oodvalid-newsplit": 0.05506934688627341,
   "oodvalid-newtrain": 0.03529685053183662,
   "unused-train": 0.01349402408520839,
   "unused-test": 0.009930059865293495,
   "unused-oodvalid": 0.010216209565832135,
   "unused-unused": 3.3306690738754696e-16,
   "unused-newsplit": 0.044862260862684744,
   "unused-newtrain": 0.

In [49]:
print_overlaps(__atom_dists)
print_overlaps(__comp_dists)

Overlap train,train: 104 / (104, 104)
Overlap train,test: 100 / (104, 100)
Overlap train,oodvalid: 101 / (104, 101)
Overlap train,unused: 104 / (104, 104)
Overlap train,newsplit: 103 / (104, 103)
Overlap train,newtrain: 104 / (104, 104)
Overlap test,test: 100 / (100, 100)
Overlap test,oodvalid: 100 / (100, 101)
Overlap test,unused: 100 / (100, 104)
Overlap test,newsplit: 99 / (100, 103)
Overlap test,newtrain: 100 / (100, 104)
Overlap oodvalid,oodvalid: 101 / (101, 101)
Overlap oodvalid,unused: 101 / (101, 104)
Overlap oodvalid,newsplit: 100 / (101, 103)
Overlap oodvalid,newtrain: 101 / (101, 104)
Overlap unused,unused: 104 / (104, 104)
Overlap unused,newsplit: 103 / (104, 103)
Overlap unused,newtrain: 104 / (104, 104)
Overlap newsplit,newsplit: 103 / (103, 103)
Overlap newsplit,newtrain: 103 / (103, 104)
Overlap newtrain,newtrain: 104 / (104, 104)
Overlap train,train: 844862 / (844862, 844862)
Overlap train,test: 19393 / (844862, 133039)
Overlap train,oodvalid: 18578 / (844862, 132884)

In [18]:
# test some chernoff coefficients
a = {"a": 0.0, "b": 0.0, "c": 0.5, "d": 0.49, "e": 0.01}
b = {"a": 0.2, "b": 0.2, "c": 0.2, "d": 0.2, "e": 0.2}
print(1-dc.compute_chernoff_coeff(a, b))
print(1-dc.compute_chernoff_coeff(a, b, 0.1))
print(1-dc.compute_chernoff_coeff(b, a, 0.1))

0.3260013575831957
0.4138321990765277
0.08228389965015814


In [19]:
print(len(_atom_dists))
print(len(_atom_dists["train"]))
print(len(_atom_dists["test"]))
print(len(_atom_dists["newsplit"]))

5
104
100
101


In [20]:
print(atom_dists["train"].keys())


dict_keys(['(* )', '(count ARG1)', '(@SELECT ARG1)', '(?x0 )', '(a )', '(ns:people.person )', '(@COND ARG1 ARG2 ARG3)', '(ns:influence.influence_node.influenced )', '(m1 )', '(m2 )', '(ns:people.person.spouse_s/ns:people.marriage.spouse )', '(ns:fictional_universe.fictional_character.married_to/ns:fictional_universe.marriage_of_fictional_characters.spouses )', '(@OR ARG*)', '(?x1 )', '(ns:film.cinematographer )', '(!= )', '(filter ARG1 ARG2 ARG3)', '(@WHERE ARG*)', '(@QUERY ARG ARG)', '(@R@ ARG1)', '(ns:influence.influence_node.influenced_by )', '(m0 )', '(ns:people.person.sibling_s/ns:people.sibling_relationship.sibling )', '(ns:fictional_universe.fictional_character.siblings/ns:fictional_universe.sibling_relationship_of_fictional_characters.siblings )', '(ns:people.person.gender )', '(ns:m.05zppz )', '(ns:film.film.directed_by )', '(m3 )', '(m4 )', '(ns:film.film.written_by )', '(distinct )', '(@SELECT ARG1 ARG2)', '(ns:organization.organization.founders )', '(ns:film.actor.film/ns:f

In [14]:
len(comp_dists["train"])

667613

In [15]:
200000*200000 

40000000000

In [16]:
import numpy as np

MemoryError: Unable to allocate 74.5 GiB for an array with shape (200000, 200000) and data type int16