In [12]:
import json
from vega import VegaLite
from pprint import pprint
import itertools

def load_dataset():
    with open("../sunshine.json", "r") as f:
        target_data = json.load(f)

    with open("../labels.txt", "r") as f:
        labels = [bool(l.strip() == "1") for l in f.readlines()]

    with open("../examples.json", "r") as f:
        vis_specs = json.load(f)
    
    all_rels = {}
    for i,e in enumerate(vis_specs):
        props = infer_facts(e["vl"], [x for x in e["draco"] if not x.startswith("soft")]) 
        rels = infer_relations(props)
        #print(f"{i+1}-------------")
        #VegaLite(e["vl"], target_data).display()
        for r in rels:
            if r not in all_rels:
                all_rels[r] = {"p": 0, "n": 0}
            all_rels[r]["p" if labels[i] == True else "n"] += 1
    
    all_rels = [(key, val["n"], val["p"]) for key, val in all_rels.items()]
    
    pprint(sorted(all_rels, reverse = True, key=lambda x: x[1] - x[2]))
    #return target_data, labels, vis_specs 


def infer_facts(vl_spec, draco_facts):
    props = {}
    props["mark"] = vl_spec["mark"]
    
    facts = {}
    facts["mark"] = f"mark({vl_spec['mark']})"
    facts["encodings"] = {}
    
    # extract data schema type
    field_ty_map = {}
    for x in draco_facts:
        if x.startswith("fieldtype"):
            fieldtype = x[x.index("(")+1:x.index(")")].split(",")
            field = fieldtype[0][1:-1]
            ty = fieldtype[1]
            field_ty_map[field] = ty
    
    props["encoding"] = []
    for i, key in enumerate(vl_spec["encoding"]):
        enc = vl_spec["encoding"][key]
        enc_id = f"e{i}"
        facts["encodings"][enc_id] = []
        facts["encodings"][enc_id].append(f"channel(E,{key})")
        facts["encodings"][enc_id].append(f"enc_ty(E,{enc['type']})")
        facts["encodings"][enc_id].append(f"field_ty(E,{field_ty_map[enc['field']] if 'field' in enc else None})")
        if "aggregate" in enc:
            facts["encodings"][enc_id].append(f"aggregate(E)")
        if "bin" in enc and enc["bin"]:
            facts["encodings"][enc_id].append(f"bin(E)")

    return facts
    
def infer_relations(facts, size=2):
    """Given properties of a spec, infer relations over the spec"""
    rules = []
    for enc_id, enc_facts in facts["encodings"].items():
        for lst in itertools.combinations(enc_facts, size):
            rules.append(f":- {','.join([s for s in lst])}")

    rules.extend([r+f",{facts['mark']}" for r in rules])

    return rules
    
def eliminate_pos(rels, pos_rels):
    return [x for x in rels if x not in pos_rels]

load_dataset()

[(':- aggregate(E),bin(E)', 12, 2),
 (':- enc_ty(E,quantitative),field_ty(E,number),mark(area)', 8, 0),
 (':- enc_ty(E,ordinal),field_ty(E,string),mark(area)', 5, 0),
 (':- aggregate(E),bin(E),mark(area)', 5, 0),
 (':- channel(E,y),enc_ty(E,quantitative),mark(area)', 5, 0),
 (':- channel(E,y),field_ty(E,number),mark(area)', 5, 0),
 (':- channel(E,color),aggregate(E)', 6, 1),
 (':- channel(E,color),field_ty(E,number),mark(area)', 5, 0),
 (':- enc_ty(E,ordinal),field_ty(E,number),mark(area)', 5, 0),
 (':- field_ty(E,number),aggregate(E),mark(area)', 5, 0),
 (':- field_ty(E,number),bin(E),mark(area)', 5, 0),
 (':- channel(E,x),field_ty(E,number),mark(area)', 5, 0),
 (':- channel(E,x),enc_ty(E,ordinal),mark(area)', 4, 0),
 (':- channel(E,x),bin(E),mark(area)', 4, 0),
 (':- enc_ty(E,ordinal),aggregate(E),mark(area)', 4, 0),
 (':- enc_ty(E,ordinal),bin(E),mark(area)', 4, 0),
 (':- field_ty(E,number),aggregate(E)', 7, 3),
 (':- field_ty(E,None),bin(E)', 4, 0),
 (':- enc_ty(E,nominal),field_ty