# Generate decision tree and visualize it graphically

## Reload on each run, since we’re editing the decision_tree library

In [None]:
%load_ext autoreload
%autoreload 2

## Import libraries

In [None]:
from decision_tree import *
from graphviz import Digraph
from IPython.display import SVG
from IPython.core.display import display

## Sample data

In [None]:
# sample data with a bit of repetition
witnessData = {'wit1': ['a', 'b', 'c', 'a', 'd', 'e'],
               'wit2': ['a', 'e', 'c', 'd'],
               'wit3': ['a', 'd', 'b']}

## Stopword list

Currently we calculate whether either of the tokens in a skip bigram is a stopword, but we don’t use the information. Since we’re incorporating information about the uniqueness of the skipgrams, we already have an alternative way to find those that contain types that occur frequently.

In [None]:
# fake stoplist, to ensure that we can identify stopwords and process them last
stoplist = {'a', 'c'}  # set

## Use bitarrays to keep track of which witness tokens have already been placed

We also use the bitarrays to count the number of placed tokens, which might be use as part of the scoring process.

In [None]:
# bitArray_dict is used to keep track of which witness tokens have already been processed
bitArray_dict = {k: bitarray(len(witnessData[k])) for k in witnessData}  # create a bitarray the length of each witness
for ba in bitArray_dict.values():  # initialize bitarrays to all 0 values
    ba.setall(0)

## Common skipgram table

All skipgrams, with all locations where they occur

In [None]:
# csTable: dictionary, in which
#   key: two-item tuple representing skipgram normalized token values (token[0], token[1])
#   value: list of three-item tuples records all locations where the key occurs: (siglum, offset[0], offset[1])
#     In Real Life:
#       values will include the t values corresponding to the normalized token values
#       use a named tuple or dataclass (https://realpython.com/python-data-classes/)
# In this test sample, we find all skip bigrams; in Real Life we would specify parameters for:
#   size of skipgram (bi, tri-, etc.; here bi-)
#   size of window (maximum distance between first and last members of skipgram; here the full witness length)
#   maximum size of skip between members of skipgram (here constrained only by size of window)
csTable = collections.defaultdict(list)
for key, value in witnessData.items():  # key is siglum, value is list of normalized token readings
    # in Real Life the value would also include a non-normalized t property
    for first in range(len(value)):  # all first items in bigram
        for second in range(first + 1, len(value)):  # pair with all following items
            csTable[(value[first], value[second])].append((key, first, second))

## Shape skipgram table into df and break out features

Features are used to short table of (remaining) skipgrams by priority, that is, to determine what to process next

In [None]:
# convert to series before df since list lengths vary
csSeries = pd.Series(csTable)

In [None]:
# convert series to dataframe, flatten MultiIndex, label columns
csDf = pd.DataFrame(csSeries).reset_index()
csDf.columns = ["first", "second", "locations"]

### Prioritize skipgrams for processing

Three features, in order:

1. How many witnesses does a skipgram occur in (depth)? Integer; higher is better
1. Does any norm value in a skipgram occur more than once in any witness? Boolean; False is better
1. How many times does a skipgram occur in the documents overall? Integer; higher is better, since we’ve already filtered out those where high frequency of a skipgram is accompanied by repetition within a witness

In [None]:
# count witnesses for each skipgram (depth of block) and check for uniqueness of skipgram in all witnesses
#   extract sigla inside set comprehension to remove duplicates, then count
csDf["local_witnesses"] = csDf["locations"].map(lambda x: [location[0] for location in x])
csDf["unique_witnesses"] = csDf["local_witnesses"].map(lambda x: set(x))
csDf["local_witnessCount"] = csDf["local_witnesses"].str.len()
csDf["unique_witnessCount"] = csDf["unique_witnesses"].str.len()
csDf["witness_uniqueness"] = csDf["local_witnessCount"] == csDf["unique_witnessCount"]
scale = pd.Series([100, 10, 1]) # TODO: check this for polarity
csDf["priority"] = pd.np.dot(csDf[["unique_witnessCount", "witness_uniqueness", "local_witnessCount"]], scale)

In [None]:
# are both tokens are stopwords? (if so, we’ll process them last)
# NB: currently we ignore stopwords
csDf["stopwords"] = csDf[["first", "second"]].T.isin(stoplist).all()

In [None]:
# sort and update row numbers, so that we can traverse the skipgrams as follows
csDf.sort_values(by=["priority"], ascending=False, inplace=True)
csDf.reset_index(inplace=True, drop=True)  # update row numbers

### Check the df

In [None]:
csDf

## Create decision tree

## Build decision tree

Currently we create the root and then expand all branches down three levels. In Real Life:

1. Evaluate scores at each stage to decide what to expand and what not to expand.
1. Navigate levels with function, instead of in one nested `for` structure

In [None]:
# root of decision tree inherits empty toList, bitArray_dict with 0 values, and complete, sorted df
dtRoot = dtNode([Node("#start"), Node("#end")], "[none]", bitArray_dict, csDf)

In [None]:
dtRoot = dtNode([Node("#start"), Node("#end")], "[none]", bitArray_dict, csDf)
expand_dtNode(dtRoot)
for child in dtRoot.children:
    expand_dtNode(child)
    for grandchild in child.children:
        expand_dtNode(grandchild)
        for greatgrandchild in grandchild.children:
            expand_dtNode(greatgrandchild)

### Visualize decision tree

## graphviz Digraph() to be rendered in SVG

In [None]:
G = Digraph(format="svg", 
            graph_attr={"rankdir": "LR"}, 
            node_attr={"fontname": "Courier", "fontsize" : "8"}) # graphviz digraph for visualization

In [None]:
# create functions to add nodes and edges to graphviz Digraph()
# node ids are unique integers in string form because 
#   1. labels may repeat across branches
#   2. graphviz ids must be strings

def create_adder_n(_g: Digraph, _witnessData: dict): # specify graph and initialize counter when creating adder
    _counter = 0 # closure
    def add_node(_n: dtNode): # specify only the node when adding it
        nonlocal _counter # ensure closure
        _placed_tokens = "%.2f" % print_placed_witness_tokens(_n) # str(print_placed_witness_tokens(_n))
        _counter += 1
        if _counter == 1:
            _score = "N/A"
        else:
            _score = "%.2f" % calculate_score(_n)
        _table = str(create_alignment_table(_n.toList, 
                                            _witnessData,
                                            rank_nodes(_n.toList, create_edge_list(_n.toList, _witnessData)), 
                                            True))
        _g.node(str(_counter), label="".join(["Placed witness tokens (pct): ", _placed_tokens, 
                                              "\nScore (tokens / toList length): ", _score, 
                                              "\n", _table]))
        return _counter # to refer to new node
    return add_node

# create_alignment_table(_dtNode.toList, _witnessData,
#                                 rank_nodes(_dtNode.toList, create_edge_list(_dtNode.toList, _witnessData)),
#                                 _print_witness_offset)

def create_adder_e(_g: Digraph): # specify graph when creating adder
    def add_edge(_u: str, _v: str, _skipgram: str): # networkx in and out nodes for edge
        _g.edge(_u, _v, label=_skipgram)
    return add_edge

a_n = create_adder_n(G, witnessData) # create adder for nodes
a_e = create_adder_e(G) # create adder for edges

## Add root node and check

In [None]:
a_n(dtRoot) # returns counter as integer; root is 1
display(SVG(G.render()))

## Go down three levels

In [None]:
for child in dtRoot.children:
    child_node = a_n(child)
    a_e(str(1), str(child_node), child.skipgram)
    for grandchild in child.children:
        grandchild_node = a_n(grandchild)
        a_e(str(child_node), str(grandchild_node), grandchild.skipgram)
        for greatgrandchild in grandchild.children:
            greatgrandchild_node = a_n(greatgrandchild)
            a_e(str(grandchild_node), str(greatgrandchild_node), greatgrandchild.skipgram)
            for greatgreatgrandchild in greatgrandchild.children:
                greatgreatgrandchild_node = a_n(greatgreatgrandchild)
                a_e(str(greatgrandchild_node), str(greatgreatgrandchild_node), greatgreatgrandchild.skipgram)
display(SVG(G.render()))