In [None]:
import collections
import json
import pprint
from datetime import datetime
import pandas as pd

# Notebook to generate attack tree Graphviz file and Emacs Org mode
# table for attack tree analysis configuration. The input is itemized
# list of attack tree nodes with mark modifiers.

# Should use Python3.4+

# Name of the attack tree itemized input
tf = 'at/at-gather-intelligence-about.org'
# Prefix for the scenario
prefix = 'N'
# Not important, internal to grahviz
cluster_prefix = 'T'
step = 2
tab = 8
test = ''
# Output Graphviz or/and Org table
org = False
graphviz = True
##### Graphviz params
## Shape
#gshape = 'octagon'
#style=''
gshape = 'box'
style='rounded'

# Tree implementation, Python magic
def tree(): return collections.defaultdict(tree)

def add_tree(t, keys):
    """
    Add elements to the tree

    t: tree
    keys: list of elements
    """
    for key in keys:
        t = t[key]

def print_tree(t):
    print(json.dumps(t, indent=2, sort_keys=True))
    
def leaf_nodes(tree, k='', current=''):
    """
    Finds tree leaf nodes, yields node identifiers build/concatenated
    on the fly

    k and current are important only for recursive calls
    yields concatenated node identifiers
    """
    current += '.' + k
    path = current.lstrip('.')
    if not tree.keys() and path != '':
       yield path
    for k in tree.keys():
       yield from leaf_nodes(tree[k], k, current)


# Open org input file
with open(tf) as f:
    test = f.readlines()

#### Itemize list input    
branches = []    
levels = []    
st = []

# Parse simple/limited org input
# - st: list of lists, each list has two elements, one the concatenated
# levels and the other the description of the node
# - branches: list of branches of the tree, used to build the tree
for l in test:
    # Modifiers
    or_join = True
    horizontal = True
    double = False
    triple = False
    red = False
    l = l.strip('\n')
    l = l.replace('\t', ' '*tab)
    n = l.split('-')
    if len(n) == 1:
        if st:
            st[-1][1] = st[-1][1] + ' ' + n[0].lstrip()
        continue
    level = int(len(n[0])/step)
    if level == 0:
        pass
    elif level > len(levels):
        levels.append(1)
    elif level == len(levels):
        levels[-1] += 1
    elif level < len(levels):
        levels = levels[:level]
        levels[-1] += 1
    # Parse the label string for the modifiers
    label = n[1].lstrip().rstrip()
    ls = label.split(' ')
    if ls[-1].find('[') >= 0:
        mod = ls[-1]
        # print('found modifier %s' % mod)
        mod = mod.lstrip('[').rstrip(']')
        if len(mod) >= 1 and mod[0] == 'a':
            or_join = False
        if len(mod) >= 2 and mod[1] == 'v':
            horizontal = False
        if mod.find('!') >= 0:
            double = True
        if mod.find('*') >= 0:
            red = True
        # print('modifier join:%s, direction:%s' % (or_join, horizontal))
        label = ' '.join(ls[:-1])
    e = []
    e.append('.'.join(map(str, levels)))
    e.append(label)
    e.append([or_join, horizontal, double, red, triple])
    branches.append('.'.join(map(str, levels)).split('.'))
    st.append(e)

# Fill the tree, uses Python magic    
tr = tree()
for b in branches:
    add_tree(tr, b)

# Remove ORs on leaf nodes, subsuboptimal
for l in leaf_nodes(tr):
    for e in st:
        if e[0] == l:
            e[2][0] = None

# Build a mapper between the levels and the names
# Build the list of tree nodes
at_nodes_mapper = {}
at_mod_mapper = {}
at_nodes = []    
for e in st:
    den = e[0]
    if not e[0]:
        den = 0
    name = '%s_%s' % (prefix, den)
    label = '<<FONT POINT-SIZE="9">%s<br/>%s</FONT>>' % (prefix, den)
    st = style
#    label = '\"%s\\n%s\"' % (prefix, den)
    (or_join, horizontal, double, red, triple) = e[2]
    template = '"%s" [shape=%s, label=%s, label="%s", xlabel=%s];'
    if double:
        template = '"%s" [shape=%s, style=%s, peripheries=2, label="%s", xlabel=%s];'        
#        template = '"%s" [shape=doubleoctagon, label="%s", xlabel="%s"];'
    if red:
        if st:
            st = "\"%s,filled\"" % st
        else:
            st = 'filled'            
        template = '"%s" [shape=%s, style=%s, fillcolor=red, label="%s", xlabel=%s];'
    at_nodes.append(template % (name, gshape, st, e[1], label))
    at_nodes_mapper[den] = name
    at_mod_mapper[den] = e[2]

#### Graphviz output    
# Graphviz defaults
or_node = 'node [shape=%s, height=.0001, width=.0001, penwidth=0, label=""]' % gshape
or_style = '[style=dashed, weight=%s];'
and_style = '[weight=%s];'
full_style = '[dir=full, arrowhead=normal, weight=1000];'
or_node_def_templ = or_node + ' %s;'
or_line_templ = '%s ' + or_style
and_line_templ = '%s ' + and_style
rank = '{rank=same; %s;}'

# Prints graphviz subgraph
def print_graph(path, nodes, or_join=True, horizontal=True):
    if nodes:
        root = False
        if nodes[0] == '':
            root = True
            path = 0
            nodes = nodes[1:]
        weight = 1
        line_templ = or_line_templ
        join_style = or_style
        if not or_join:
            line_templ = and_line_templ
            join_style = and_style 
        num_or_nodes = len(nodes)
        nodes_in = ['"%s_%s"' % (prefix, n) for n in nodes]
        if path:
            nodes_in = ['"%s_%s.%s"' % (prefix, path, n) for n in nodes]
        or_nodes_in = ['"or%s_%s_%s"' % (cluster_prefix, path, e) for e in range(0, num_or_nodes)]
        print('subgraph "cluster_%s%s" {' % (cluster_prefix, path))
        extra_join_style = join_style % (weight*200)
        join_style = join_style % weight
        if horizontal:
            print('# Horizontal')
            if len(nodes)%2 == 0: # Add to even    
                num_or_nodes += 1
                or_nodes_in = ['"or%s_%s_%s"' % (cluster_prefix, path, e) for e in range(0, num_or_nodes)]
            print(or_node_def_templ % ', '.join(or_nodes_in))
            print(line_templ % (' -> '.join(or_nodes_in), weight*100))
            print(rank % ', '.join(or_nodes_in))
            print(rank % ', '.join(nodes_in))
            spare = or_nodes_in[int(len(nodes)/2)]
            if len(nodes)%2 == 0:    
                or_nodes_in = or_nodes_in[:int(len(or_nodes_in)/2)] + or_nodes_in[int(len(or_nodes_in)/2)+1:]
            assert(len(nodes_in) == len(or_nodes_in))
            for i, n in enumerate(nodes_in):
                print('%s -> %s %s' % (n, or_nodes_in[i], extra_join_style))    
            print('%s -> "%s" %s' % (spare, at_nodes_mapper[path], full_style))
        else: 
            print('# Vertical')
            print(or_node_def_templ % ', '.join(or_nodes_in))
            print(line_templ % (' -> '.join(or_nodes_in), weight*700))
            assert(len(nodes_in) == len(or_nodes_in))
            for i, n in enumerate(nodes_in):
                print(rank % ', '.join([n, or_nodes_in[i]]))    
            for i, n in enumerate(nodes_in):
                print('%s -> %s %s' % (n, or_nodes_in[i], extra_join_style))                    
            print('%s -> "%s" %s' % (or_nodes_in[-1], at_nodes_mapper[path], full_style))
        print('}')
        print()

# Ascii style tree visualisation
def pass_tree(tr, k='', me=''):
    """
    Go through the tree and print the concatenated branches
    """
    tree = tr
    me += '.' + k
    path = me.lstrip('.')
    print(' '*path.count('.') + '-'*path.count('.') + path)
    for k in tree.keys():
        pass_tree(tree[k], k, me)

# Recursive through graphviz subgraphs
cluster = 0        
def pass_graph(tr, k='', me=''):
    """
    Pass through the graph and print it
    """
    tree = tr
    me += '.' + k
    path = me.lstrip('.')
    try:
        (or_join, horizontal, double, red, triple) = at_mod_mapper[path]
        print_graph(path, list(tree.keys()), or_join=or_join, horizontal=horizontal)
    except:
        (or_join, horizontal, double, red, triple) = at_mod_mapper[0]
        print_graph(path, list(tree.keys()), or_join=or_join, horizontal=horizontal)
    for k in tree.keys():
        pass_graph(tree[k], k, me)

    
if graphviz:
   # Print Graphviz nodes 
   for n in at_nodes:
        print(n)        

   # Print Graphviz subraphs plot data
   print()
   pass_graph(tr)


#### Org mode table   
# Print Org mode style table for nodes analysis
if org:
   print()
   for e in st:
        den = e[0]
        if not e[0]:
           den = 0
        name = '%s_%s' % (prefix, den)
        print('|%s|%s||||' % (name.replace('_', '+'), e[1]))
