In [1]:
# Copyright (C) 2019 Maxim Godzi, Anatoly Zaytsev, Dmitrii Kiselev
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.


import networkx as nx
import seaborn as sns 
from IPython.display import IFrame, display # TODO understand how to use visualization without it
import numpy as np
from datetime import datetime
import pandas as pd
import json
from functools import wraps

# from MulticoreTSNE import MulticoreTSNE as TSNE

In [2]:
__TEMPLATE__ = """

<!DOCTYPE html>
<meta charset="cp1251">
<style>

circle {{
    fill: rgba(200, 200, 200, 0.5);
    stroke: rgba(200, 200, 200, 0.5);
    stroke-width: 5.0px;
}}

.circle.source_node {{
    fill: #f3f310;
}}

.circle.nice_node {{
    fill: rgba(0, 256, 0, 0.5);
    stroke: rgba(0, 256, 0, 0.9);
}}

.circle.bad_node {{
    stroke: rgba(256, 0, 0, 0.9);
    fill: rgba(256, 0, 0, 0.5);
}}

.link {{
    fill: none;
    stroke: #666;
    stroke-opacity: 0.2;
}}

#nice_target {{
    fill: green;
}}

.link.nice_target {{
    stroke: green;
}}

#source {{
    fill: yellow;
}}

.link.source {{
    stroke: #f3f310;
}}
                
.link.positive {{
    stroke: green;
}}
                
.link.negative {{
    stroke: red;
}}

#source {{
    fill: orange;
}}

.link.source1 {{
    stroke: orange;
}}

bad_target {{
    fill: red;
}}

.link.bad_target {{
    stroke: red;
}}

text {{
    font: 12px sans-serif;
    pointer-events: none;
    #text-shadow: 0 1px 0 #fff, 1px 0 0 #fff, 0 -1px 0 #fff, -1px 0 0 #fff;
}}
                
#mydiv {{
    position: absolute;
    text-align: center;
    top: 50%;
    left: 50%;
    transform: translate(-50%, -50%);
    background-color: transparent;
}}
                
.wrapper {{
    position: relative;
    width: 100%;
    height: 100%;
}}
                
.close {{
    position: absolute;
    right: 32px;
    top: 32px;
    width: 32px;
    height: 32px;
    opacity: 0.3;
}}
                
.close:hover {{
    opacity: 1;
}}
                
.close:before,
.close:after {{
    position: absolute;
    left: 15px;
    content: " ";
    height: 33px;
    width: 2px;
    background-color: #333;
}}
                
.close:before {{
    transform: rotate(45deg);
}}
                
.close:after {{
    transform: rotate(-45deg);
}}
                           
</style>

<script src="https://d3js.org/d3.v5.min.js"></script>
<script src="https://www.gstatic.com/charts/loader.js"></script>


<div>
  <input type="radio" class="radio" id="radioButton_mean" name="contact" value="mean"><label> Show mean </label>
  <input type="radio" class="radio" id="radioButton_min" name="contact" value="min"><label> Show min </label>
  <input type="radio" class="radio" id="radioButton_max" name="contact" value="max"><label> Show max </label>
</div>

<div class="parent" draggable={{true}}  id="mydiv"></div>
    <div class="wrapper">
        <div class="close" id="close-1"></div>

<table class="columns">
    </table>
<script>

var links = {links};

var node_params = {node_params};

var node_stats = {node_stats};

var nodes = {nodes};

var width = {width},
    height = {height};

var svg = d3.select("body").append("svg")
    .attr("width", width)
    .attr("height", height);
    
google.charts.load("current", {{packages:["corechart"]}});

console.log(svg.selectAll('.radio'));

let defs = svg.append("g").selectAll("marker")
    .data(links)
  .enter().append("marker")
    .attr("id", function(d) {{ return d.source.index + '-' + d.target.index; }})
    .attr("viewBox", "0 -5 10 10")
    .attr("refX", function(d) {{
        if (d.target.name !== d.source.name) {{
            return 7 + d.target.degree; 
        }} else {{
            return 0;
        }}
    }})
    .attr("refY", calcMarkers)
    .attr("markerWidth", 10)
    .attr("markerHeight", 10)
    .attr("markerUnits", "userSpaceOnUse")
    .attr("orient", "auto");

defs.append("path")
    .attr("d", "M0,-5L10,0L0,5");

function calcMarkers(d) {{
    let dist = Math.sqrt((nodes[d.target.index].x - nodes[d.source.index].x) ** 2 + (nodes[d.target.index].y - nodes[d.source.index].y) ** 2);
    if (dist > 0 && dist <= 200){{
        return - Math.sqrt((0.5 - (d.target.degree ) / 2 / dist)) * (d.target.degree) / 2;
    }} else {{
        return 0;
    }}
}}


var path = svg.append("g").selectAll("path")
    .data(links)
  .enter().append("path")
    .attr("class", function(d) {{ return "link " + d.type; }})
    .attr("stroke-width", function(d) {{ return Math.max(d.weight * 20, 1); }})
    .attr("marker-end", function(d) {{ return "url(#" + d.source.index + '-' + d.target.index + ")"; }})
    .attr("id", function(d,i) {{ return "link_"+i; }})
    .attr("parent_pk", function(d) {{ return d.source.pk; }})
    .attr("d", linkArc)
    ;

var edgetext = svg.append("g").selectAll("text")
    .data(links)
   .enter().append("text")
   .append("textPath")
    .attr("xlink:href",function(d,i){{return "#link_"+i;}})
    .style("text-anchor","right")
    .attr("startOffset", "40%")
    ;
    
function update() {{
    d3.selectAll(".radio").each(function(d) {{
        cb = d3.select(this);
        
        if (cb.property("checked")) {{
         radio_text = this.value;
         
            edgetext = edgetext.text(function(d) {{
                if (radio_text == "mean") {{
                    return d.weight_mean;
                }}
                
                 if (radio_text == "min") {{
                    return d.weight_min;
                }}
                
                 if (radio_text == "max") {{
                    return d.weight_max;
                }}
            }})
        }} 
        
        
    }})
}};

d3.selectAll(".radio").on("change", update);

function dragstarted(d) {{
  d3.select(this).raise().classed("active", true);
}}

function dragged(d) {{
  d3.select(this).attr("cx", d.x = d3.event.x).attr("cy", d.y = d3.event.y);
}}

function dragended(d) {{
  d3.select(this).classed("active", false);
  path = path.attr("d", linkArc);
  text = text
        .attr('x', function(d) {{ return d.x; }})
        .attr('y', function(d) {{ return d.y; }})
        ;
  defs = defs.attr("refY", calcMarkers);
  defs.append("path")
    .attr("d", "M0,-5L10,0L0,5");
}};

var circle = svg.append("g").selectAll("circle")
    .data(nodes)
  .enter().append("circle")
    .attr("class", function(d) {{ return "circle " + d.type; }})
    .attr("r", function(d) {{ return d.degree; }})
    .attr('cx', function(d) {{ return d.x; }})
    .attr('cy', function(d) {{ return d.y; }})
    .attr('id', function(d) {{ return d.pk; }})
    .call(d3.drag()
        .on("start", dragstarted)
        .on("drag", dragged)
        .on("end", dragended));

var text = svg.append("g").selectAll("text")
    .data(nodes)
  .enter().append("text")
    .attr('x', function(d) {{ return d.x; }})
    .attr('y', function(d) {{ return d.y; }})
    .text(function(d) {{ return d.name; }});

function linkArc(d) {{
  var dx = nodes[d.target.index].x - nodes[d.source.index].x,
      dy = nodes[d.target.index].y - nodes[d.source.index].y,
      dr = dx * dx + dy * dy;
      dr = Math.sqrt(dr);
      if (dr > 200) {{
        dr *= 5
      }} else {{
        dr *= 5 /*dr /= 2*/
      }};
      if (dr > 0) {{return "M" + nodes[d.source.index].x + "," + nodes[d.source.index].y + "A" + (dr * 1.1) + "," + (dr * 1.1) + " 0 0,1 " + nodes[d.target.index].x + "," + nodes[d.target.index].y;}}
      else {{return "M" + nodes[d.source.index].x + "," + nodes[d.source.index].y + "A" + 20 + "," + 20 + " 0 1,0 " + (nodes[d.target.index].x + 0.1) + "," + (nodes[d.target.index].y + 0.1);}}
}}

function downloadLayout() {{
    var a = document.createElement("a");
    var file = new Blob([JSON.stringify(nodes)], {{type: "text/json;charset=utf-8"}});
    a.href = URL.createObjectURL(file);
    a.download = "node_params.json";
    a.click();
}}

function draw_channels(ind) {{
       
    console.log([['Ветка', 'Процент']].concat(node_stats[ind]));
        
    var data = google.visualization.arrayToDataTable([['Ветка', 'Процент']].concat(node_stats[ind][0]));
 
 
    var options = {{pieResidueSliceLabel: 'Остальное',
                       width:500,
                       height:300,
                       backgroundColor:'transparent'
                       }};

    const parent = document.getElementById("mydiv");
    const new_parent = document.createElement("div");
        
    new_parent.setAttribute("draggable", "true")
    new_parent.setAttribute("id", "mydiv")
    document.body.appendChild(new_parent);
        
    dragElement(new_parent);
        
    const newChild = document.createElement("mydivheader_clone")
    newChild.setAttribute("mydivheader", "mydivheader_clone")
    new_parent.appendChild(newChild);
        
    var chart = new google.visualization.PieChart(newChild);
    chart.draw(data, options)
        
    const close = document.getElementById("close-1");

    close.onclick = e => {{
    const parent = document.getElementById("mydiv");
    parent.remove();
}};
        
}};
      
dragElement(document.getElementById("mydiv"));

function dragElement(elmnt) {{
  var pos1 = 0, pos2 = 0, pos3 = 0, pos4 = 0;
  if (document.getElementById(elmnt.id + "header")) {{
    document.getElementById(elmnt.id + "header").onmousedown = dragMouseDown;
  }} else {{
    elmnt.onmousedown = dragMouseDown;}}
    
function dragMouseDown(e) {{
  e = e || window.event;
  e.preventDefault();
  
  pos3 = e.clientX;
  pos4 = e.clientY;
  
  document.onmouseup = closeDragElement;
  document.onmousemove = elementDrag;
}}

function elementDrag(e) {{
  e = e || window.event;
  e.preventDefault();
  
  pos1 = pos3 - e.clientX;
  pos2 = pos4 - e.clientY;
  pos3 = e.clientX;
  pos4 = e.clientY;
  
  elmnt.style.top = (elmnt.offsetTop - pos2) + "px";
  elmnt.style.left = (elmnt.offsetLeft - pos1) + "px";
}}

function closeDragElement() {{
  document.onmouseup = null;
  document.onmousemove = null;
}}

}}


function set_visibility(next_style, id)
{{
    if (node_stats[id] != undefined) {{
 
svg.selectAll("circle").filter(obj => {{return node_stats[id][1].indexOf(obj.pk) != -1}}).style('visibility',next_style);
path.filter(obj => {{return String(id) == obj.parent_pk}}).style('visibility',next_style);

path.filter(obj => {{ return obj.source.pk == id}}).style('visibility', next_style);

text.filter(obj => {{return node_stats[id][1].indexOf(obj.pk) != -1}}).style('visibility', next_style);

edgetext.filter(obj => {{return obj.source.pk == id}}).style('visibility', next_style);

}}

}}
 
 
svg.selectAll("circle").on("dblclick", function () 
{{

console.log(node_stats[this.id]);

draw_channels(this.id);

curr_style = svg.selectAll("circle").filter(obj => {{return node_stats[this.id][1].indexOf(obj.pk) != -1}})._groups[0][0].style['visibility'];

if (curr_style == 'visible')
next_style = 'hidden';
else
next_style = 'visible';

let to_process = [this.id];
let all_children = [];

while (to_process.length > 0)
{{
    curr_id = to_process.pop();
    if (node_stats[curr_id] != undefined)
    {{
        for (var i = 0; i < node_stats[curr_id][1].length; i++)
        {{
            to_process.push(node_stats[curr_id][1][i]);
        }}
    }}    
    all_children.push(curr_id);
}}

 for (var i = 0; i < all_children.length; i++)
{{
    set_visibility(next_style, all_children[i]);
}}


}});



</script>
"""

In [43]:
def _calc_layout(data, node_params, width=500, height=500, **kwargs):
    G = nx.DiGraph()
    G.add_weighted_edges_from(data.loc[:, ['source', 'target', 'weight']].values)

    pos_new = {'РБ': [-0.2,0.0]}
    level_to_used_pos = {}
    
    for i, row in data.iterrows():
        count_of_level = data[data['level'] == row['level']].shape[0]
        count_of_level_parent = data[data['source'] == row['source']].shape[0]
        
        key = row['source'] + "-"+str(row['level'])
        used_pos = level_to_used_pos.get(key, [])
        
        prev_pos = 0
        if len(used_pos) != 0:
            prev_pos = used_pos[-1]
        
        new_pos = prev_pos + 1.0 / (count_of_level + 1)
        used_pos.append(new_pos)
        level_to_used_pos[key] = used_pos
        
        level = row['level']
        if level == 1:
            level = 0.6
        
        pos_new[row['target']] = [level, pos_new[row['source']][1] + new_pos]

        
    min_x = min([j[0] for i, j in pos_new.items()])
    min_y = min([j[1] for i, j in pos_new.items()])
    max_x = max([j[0] for i, j in pos_new.items()])
    max_y = max([j[1] for i, j in pos_new.items()])
    pos_new = {
        i: [(j[0] - min_x) / (max_x - min_x) * (width - 300) + 75*0.2, (j[1] - min_y) / (max_y - min_y) * (height - 100) + 50]
        for i, j in pos_new.items()
    }
        # pos_new.update({i: [j[0] * width, j[1] * height] for i, j in pos.items()})
    return pos_new, dict(G.degree)

def get_node_type(node):
    node_type = 'suit'
    if 'активный' in node.lower():
        node_type = "nice_node"
    
    if 'неактивный' in node.lower() or 'аннулир' in node.lower() or 'блокир' in node.lower() or 'закрыт' in node.lower() :
        node_type = "bad_node"
    return node_type
        
def _prepare_nodes(data, pos, node_params, degrees, node2ind):
    node_set = set(data['source']) | set(data['target'])
    max_degree = max(degrees.values())
    nodes = {}

    for idx, node in enumerate(node_set):
        stats = ''
        if data[data['target'] == node].shape[0] != 0:
            stats = data[data['target'] == node]['weight'].values[0]
        
        node_pos = pos.get(node)
        node_type = get_node_type(node)
        
        nodes.update({node: {
            "pk": node2ind[node],
            "index": idx,
            
            "name": node + str(stats),
            "x": node_pos[0],
            "y": node_pos[1],
            "type":node_type,
            "degree": 7#(abs(degrees.get(node, 0)) + 1/120) / abs(max_degree) * 30
        }})
    return nodes


def _prepare_edges(data, nodes):
    edges = []
    data['weight_norm'] = data.weight# / data.weight.abs().max()
    for idx, row in data.iterrows():
#         print('row.weight', row.weight)


        next_node = row['target']

        edge_type = ''
        next_node_type = get_node_type(next_node)
        if next_node_type == 'nice_node':
            edge_type = 'nice_target'
        if next_node_type == 'bad_node':
            edge_type = 'bad_target'
    
    
        edges.append({
            "source": nodes.get(row.source),
            "target": nodes.get(row.target),
            "weight": 0.7,# np.log1p(row.weight_norm) * 1.5,
            "weight_mean": row.weight_mean,
            "weight_min": row.weight_min,
            "weight_max": row.weight_max,
            "type": edge_type
        })
    return edges, list(nodes.values())


def _filter_edgelist(data, thresh, node_params, targets=None, **kwargs):
    if targets is None:
        x = pd.Series(node_params).str.contains('target')
        targets = set(x[x].index)
    f = np.ones(data.shape[0]) > 0#data.weight.abs() >= thresh
    nodes = set(data[f].source) | set(data[f].target)
    f |= (data.source.isin(targets) & data.target.isin(nodes))
    f |= (data.target.isin(targets) & data.source.isin(nodes))
    return data[f].copy()


def _make_json_data(data, node_params,  node2ind, layout_dump, thresh=.05, width=500, height=500, **kwargs):
    res = {}
    #data.columns = ['source', 'target', 'weight', 'level']
    data = _filter_edgelist(data, thresh, node_params, **kwargs)
    if kwargs.get('mode') == 'importance':
        data['type'] = np.where(data.weight >= 0, 'positive', 'negative')
    else:
        data["type"] = data.apply(
            lambda x: node_params.get(x.source) if node_params.get(x.source) == 'source' else node_params.get(
                x.target) or 'suit',
            1
        )
    pos, degrees = _calc_layout(data, node_params, width=width, height=height, **kwargs)
    if kwargs.get('node_weights') is not None:
        degrees = kwargs.get('node_weights')
    if layout_dump is not None:
        nodes = _prepare_given_layout(layout_dump, node_params, degrees)
    else:
        nodes = _prepare_nodes(data, pos, node_params, degrees, node2ind)
    res['links'], res['nodes'] = _prepare_edges(data, nodes)
    return res


def _prepare_node_params(node_params, data):
    if node_params is None:
        _node_params = {
            'positive_target_event': 'nice_target',
            'negative_target_event': 'bad_target',
            'source_event': 'source',
        }
        node_params = {}
        for key, val in _node_params.items():
            name = data.retention.retention_config.get(key)
            if name is None:
                continue
            node_params.update({name: val})
    return node_params

def _prepare_node_stats(data):
    result = {}
    for i, row in data.iterrows():
        children = data[data['parent_index'] == row['parent_index']]
        
        children_data = []
        children_ids = []
        
        for j, child_row in children.iterrows():
            children_data.append([child_row['target'], float(str(child_row['weight']).split('%', 1)[0].strip())])
            children_ids.append(child_row['my_index'])
        
        result[row['parent_index']] = [children_data, children_ids]
    
    return result

def _prepare_layout(layout):
    nodes = {}
    for i in layout:
        nodes.update({i['name']: i})
    return nodes


def _prepare_given_layout(nodes_path, node_params, degrees):
    if type(nodes_path) is str:
        with open(nodes_path, encoding='utf-8') as f:
            nodes = json.load(f)
    else:
        nodes = nodes_path
    if type(nodes) is list:
        nodes = _prepare_layout(nodes)
    max_degree = max(degrees.values() or [1e-20])
    for node, val in nodes.items():
        val.update({
            "type": (node_params.get(node) or "suit").split('_')[0] + '_node',
            "degree": degrees.get(node, 0) / max_degree * 30
        })
    return nodes


def __save_plot__(func):
    @wraps(func)
    def save_plot_wrapper(*args, **kwargs):
        sns.mpl.pyplot.close()
        res = func(*args, **kwargs)
        if len(res) == 2:
            (vis_object, name), res = res, None
        else:
            vis_object, name, res = res
        idx = 'id: ' + str(int(datetime.now().timestamp()))
        coords = vis_object.axis()
        vis_object.text((coords[0] - (coords[1] - coords[0]) / 10), (coords[3] + (coords[3] - coords[2]) / 10), idx, fontsize=8)
        vis_object.get_figure().savefig(name, bbox_inches="tight", dpi=600)
        return res
    return save_plot_wrapper


@__save_plot__
def graph(data, node_params=None, node2ind=None, thresh=.05, width=500, height=500, interactive=True,
          layout_dump=None, show_percent=True, plot_name=None, **kwargs):
    """
    Plots graph by its edgelist representation

    :param data: graph in edgelist form
    :param node_params: mapping describes which node should be highlighted by target or source type
            Node param should be represented in the following form
            ```{
                    'lost': 'bad_target',
                    'passed': 'nice_target',
                    'onboarding_welcome_screen': 'source',
                }```
            If mapping is not given, it will be constracted from config
    :param thresh: threshold for filtering low frequency edges
    :param width: width of plot
    :param height: height of plot
    :param interactive: if True, then opens graph visualization in Jupyter Notebook IFrame
    :param layout_dump: path to layout dump
    :param show_percent: if True, then all edge weights are converted to percents
    :param kwargs: do nothing, needs for plot.graph usage with other functions
    :return: saves to `experiments_folder` webpage with js graph visualization
    """
    if node_params is None:
        node_params = _prepare_node_params(node_params, data)
    
    node_stats = _prepare_node_stats(data)
    
    res = _make_json_data(data, node_params, node2ind, layout_dump, thresh=thresh,
                          width=width - 100, height=height - 100, **kwargs)
    x = __TEMPLATE__.format(
        width=width,
        height=height,
        links=json.dumps(res.get('links')).encode('latin1').decode('utf-8'),
        node_params=json.dumps(node_params).encode('latin1').decode('utf-8'),
        nodes=json.dumps(res.get('nodes')).encode('latin1').decode('utf-8'),
        node_stats=json.dumps(node_stats).encode('latin1').decode('utf-8'),
        show_percent="1 !== 1" if show_percent else "1 === 1"
    )
    if hasattr(data, 'trajectory'):
        if plot_name is None:
            plot_name = f'{data.trajectory.retention_config["experiments_folder"]}/index_{datetime.now()}'
    else:
        if plot_name is None:
            plot_name = 'index'
    plot_name = plot_name.replace(':', '_').replace('.', '_') + '.html'
    
#     with open('test.html', 'w') as f:
#         f.write(x)
    
    return ___DynamicFigureWrapper__(x, interactive, width, height), plot_name, plot_name


@__save_plot__
def step_matrix(diff, plot_name=None, title='', vmin=None, vmax=None, **kwargs):
    """
    Plots heatmap with distribution of events over event steps (ordering in the session by event time)

    :param diff: table for heatmap visualization
    :param plot_name: name of plot to save
    :param kwargs: do nothing, needs for plot usage with other functions
    :return: saves heatmap to `experiments_folder`
    """
    sns.mpl.pyplot.figure(figsize=(20, 10))
    heatmap = sns.heatmap(diff, annot=True, cmap="BrBG", center=0, vmin=vmin, vmax=vmax)
    heatmap.set_title(title)
    plot_name = 'desc_table_{}.png'.format(plot_name or datetime.now()).replace(':', '_').replace('.', '_')
    plot_name = diff.retention.retention_config['experiments_folder'] + '/' + plot_name
    return heatmap, plot_name


@__save_plot__
def cluster_tsne(data, clusters, target, plot_name=None, **kwargs):
    """
    Plots TSNE projection of user stories and colors by founded clusters

    :param data: feature matrix
    :param clusters: np.array of cluster ids
    :param target: do nothing, need for compatibility with other cluster visualization methods
    :param plot_name: name of plot to save
    :param kwargs: do nothing, needs for plot usage with other functions
    :return: saves plot to `experiments_folder`
    """

    if hasattr(data.retention, '_tsne'):
        tsne2 = data.retention._tsne.copy()
    else:
        tsne2 = data.retention.learn_tsne(clusters, **kwargs)
    tsne = tsne2.values
    if np.unique(clusters).shape[0] > 10:
        f, ax = sns.mpl.pyplot.subplots()
        points = ax.scatter(tsne[:, 0], tsne[:, 1], c=clusters, cmap="BrBG")
        f.colorbar(points)
        scatter = ___FigureWrapper__(f)
    else:
        scatter = sns.scatterplot(tsne[:, 0], tsne[:, 1], hue=clusters, legend='full', palette="BrBG")
    plot_name = plot_name if plot_name is not None else 'clusters_tsne_{}.svg'.format(
        datetime.now()).replace(':', '_').replace('.', '_')
    plot_name = data.retention.retention_config['experiments_folder'] + '/' + plot_name
    return scatter, plot_name, tsne2


@__save_plot__
def cluster_bar(data, clusters, target, plot_name=None, plot_cnt=None, metrics=None, **kwargs):
    cl = pd.DataFrame([clusters, target], index=['clusters', 'target']).T
    cl['cnt'] = 1
    cl.target = cl.target.astype(int)
    bars = cl.groupby('clusters').agg({
        'cnt': 'sum',
        'target': 'mean'
    }).reset_index()
    bars.cnt /= bars.cnt.sum()
    bars = bars.loc[:, ['clusters', 'cnt']].append(bars.loc[:, ['clusters', 'target']], ignore_index=True, sort=False)
    bars['target'] = np.where(bars.target.isnull(), bars.cnt, bars.target)
    bars['Metric'] = np.where(bars['cnt'].isnull(), 'Average CR', 'Cluster size')
    bar = sns.barplot(x='clusters', y='target', hue='Metric', hue_order=['Cluster size', 'Average CR'], data=bars)
    y_value = ['{:,.2f}'.format(x * 100) + '%' for x in bar.get_yticks()]
    bar.set_yticklabels(y_value)
    bar.set(ylabel=None)

    plot_name = plot_name if plot_name is not None else 'clusters_bar_{}.svg'.format(
        datetime.now()).replace(':', '_').replace('.', '_')
    plot_name = data.retention.retention_config['experiments_folder'] + '/' + plot_name
    return bar, plot_name


@__save_plot__
def cluster_event_dist(bars, event_col, cl1, sizes, crs, cl2=None, plot_name=None):
    bar = sns.barplot(x=event_col, y='freq', hue='hue',
                      hue_order=[f'cluster {cl1}','all' if cl2 is None else f'cluster {cl2}'], data=bars)
    y_value = ['{:,.2f}'.format(x * 100) + '%' for x in bar.get_yticks()]
    bar.set_yticklabels(y_value)
    bar.set_xticklabels(bar.get_xticklabels(), rotation=30)
    bar.set(ylabel=None)
    tit = f'Distribution of top {bars.shape[0] // 2} events in cluster {cl1} (size: {round(sizes[0] * 100, 2)}%, CR: {round(crs[0] * 100, 2)}% ) '
    tit += f'vs. all data (CR: {round(crs[1] * 100, 2)}%)' if cl2 is None else f'vs. cluster {cl2} (size: {round(sizes[1] * 100, 2)}%, CR: {round(crs[1] * 100, 2)}%)'
    bar.set_title(tit)

    plot_name = plot_name if plot_name is not None else 'clusters_event_dist_{}.svg'.format(
        datetime.now()).replace(':', '_').replace('.', '_')
    plot_name = bars.retention.retention_config['experiments_folder'] + '/' + plot_name
    return bar, plot_name


@__save_plot__
def cluster_pie(data, clusters, target, plot_name=None, plot_cnt=None, metrics=None, **kwargs):
    """
    Plots pie-charts of target distribution for different clusters

    :param data: feature matrix
    :param clusters: np.array of cluster ids
    :param target: boolean vector, if True, then user have `positive_target_event` in trajectory
    :param plot_name: name of plot to save
    :param plot_cnt: number of clusters to plot
    :param kwargs: width and height of plot
    :return: saves plot to `experiments_folder`
    """
    cl = pd.DataFrame([clusters, target], index=['clusters', 'target']).T
    cl.target = np.where(cl.target, data.retention.retention_config['positive_target_event'],
                         data.retention.retention_config['negative_target_event'])
    pie_data = cl.groupby(['clusters', 'target']).size().rename('target_dist').reset_index()
    targets = list(set(pie_data.target))

    if plot_cnt is None:
        plot_cnt = len(set(clusters))
    
    if kwargs.get('vol', True):  # vol = False in kwargs in case you want to disable
        _, counts = np.unique(clusters, return_counts=True)
        volumes = 100 * (counts / sum(counts))
    else:
        volumes = [None] * plot_cnt

    fig, ax = sns.mpl.pyplot.subplots(1 if plot_cnt <= 2 else (plot_cnt // 2 + plot_cnt % 2), 2)
    fig.suptitle('Distribution of targets in clusters. Silhouette: {:.2f}, Homogeneity: {:.2f}, Cluster stability: {:.2f}'.format(
        metrics.get('silhouette') if (metrics or {}).get('silhouette') is not None else 0,
        metrics.get('homogen') if metrics is not None else 0,
        metrics.get('csi') if (metrics or {}).get('csi') is not None else 0
    ))
    fig.set_size_inches(kwargs.get('width', 20), kwargs.get('height', 10))
    for i, j in enumerate(pie_data.clusters.unique()):
        tmp = pie_data[pie_data.clusters == j]
        tmp.index = tmp.target
        if plot_cnt <= 2:
            ax[i].pie(tmp.target_dist.reindex(targets).fillna(0).values, labels=targets, autopct='%1.1f%%')
            ax[i].set_title('Class {}\nCluster volume {}%\nMean dist from center {:.2f}'.format(
                i, round(volumes[i], 1), metrics['mean_fc'][j] if (metrics or {}).get('mean_fc') is not None else 0))
        else:
            ax[i // 2][i % 2].pie(tmp.target_dist.reindex(targets).fillna(0).values, labels=targets, autopct='%1.1f%%')
            ax[i // 2][i % 2].set_title('Class {}\nCluster volume {}%\nMean dist from center {:.2f}'.format(
                i, round(volumes[i], 1), metrics['mean_fc'][j] if (metrics or {}).get('mean_fc') is not None else 0))
    if plot_cnt % 2 == 1:
        fig.delaxes(ax[plot_cnt // 2, 1])

    plot_name = plot_name if plot_name is not None else 'clusters_pie_{}.svg'.format(
        datetime.now()).replace(':', '_').replace('.', '_')
    plot_name = data.retention.retention_config['experiments_folder'] + '/' + plot_name
    return ___FigureWrapper__(fig), plot_name


@__save_plot__
def cluster_heatmap(data, clusters, target, plot_name=None, **kwargs):
    """
    Visualizes features for users with heatmap

    :param data: feature matrix
    :param clusters: do nothing, need for compatibility with other cluster visualization methods
    :param target: do nothing, need for compatibility with other cluster visualization methods
    :param plot_name: name of plot to save
    :param kwargs: do nothing, need for compatibility with other cluster visualization methods
    :return: saves plot to `experiments_folder`
    """
    heatmap = sns.clustermap(data.values,
                             cmap="BrBG",
                             xticklabels=data.columns,
                             yticklabels=False,
                             row_cluster=True,
                             col_cluster=False)

    heatmap.ax_row_dendrogram.set_visible(False)
    heatmap = heatmap.ax_heatmap

    plot_name = plot_name if plot_name is not None else 'clusters_heatmap_{}.svg'.format(
        datetime.now()).replace(':', '_').replace('.', '_')
    plot_name = data.retention.retention_config['experiments_folder'] + '/' + plot_name
    return heatmap, plot_name


class ___FigureWrapper__(object):
    def __init__(self, fig):
        self.fig = fig

    def get_figure(self):
        return self.fig

    def axis(self):
        if len(self.fig.axes) > 1:
            x = self.fig.axes[1].axis()
        else:
            x = self.fig.axes[0].axis()
        return (x[0]/ 64, x[0] + (x[1] - x[0]) / 50, x[2] / 1.5, x[3] / 1.5)

    def text(self, *args, **kwargs):
        self.fig.text(*args, **kwargs)


class __SaveFigWrapper__(object):
    def __init__(self, data, interactive=True, width=1000, height=700):
        self.data = data
        self.interactive = interactive
        self.width = width
        self.height = height

    def savefig(self, name, **kwargs):
        with open(name, 'w') as f:
            f.write(self.data)
        if self.interactive:
            display(IFrame(name, width=self.width + 200, height=self.height + 200))


class ___DynamicFigureWrapper__(object):
    def __init__(self, fig, interactive, width, height):
        self.fig = fig
        self.interactive, self.width, self.height = interactive, width, height

    def get_figure(self):
        savefig = __SaveFigWrapper__(self.fig, self.interactive, self.width, self.height)
        return savefig

    def text(self, x, y, text, *args, **kwargs):
        parts = self.fig.split('<body>')
        res = parts[:1] + [f'<p>{text}</p>'] + parts[1:]
        self.fig = '\n'.join(res)

    def axis(self):
        return 4 * [0]
    
kwargs = {
#                 'edge_col': data.retention.retention_config['index_col'],
                'edge_attributes': '_nunique',
                'norm': True,
    'width':1000, 'height':1000
            }

_node_params = {
                'positive_target_event': 'nice_target',
                'negative_target_event': 'bad_target',
                'source_event': 'source',
            }

data = pd.read_csv('test.csv', index_col=0)

data_n = data.rename(columns={'event_name': 'source', 'next_event': 'target', '_nunique': 'weight'})

data_n.weight_mean = data_n.weight_mean.astype(float)

data_n.weight_mean = np.where(data_n.weight_mean.isnull(), '', data_n.weight_mean)
data_n.weight_min = np.where(data_n.weight_min.isnull(), '', data_n.weight_min)
data_n.weight_max = np.where(data_n.weight_max.isnull(), '', data_n.weight_max)

def round_str(val):
    if val == '':
        return val
    else:
        val = float(val)
        return int(np.rint(val))
    
data_n.weight_mean = np.vectorize(lambda x : round_str(x))(data_n.weight_mean)
#data_n.weight_min = np.vectorize(lambda x : round_str(x))(data_n.weight_min)

In [44]:
# data = pd.DataFrame(columns=['event_name', 'next_event', '_nunique'], data=[['a', 'b', 1], ['a', 'c', 2]])

In [45]:
kwargs = {
#                 'edge_col': data.retention.retention_config['index_col'],
                'edge_attributes': '_nunique',
                'norm': True,
                'width':1000, 'height':1000
            }

In [46]:
_node_params = {
                'positive_target_event': 'nice_target',
                'negative_target_event': 'bad_target',
                'source_event': 'source',
            }

In [47]:
data = pd.read_csv('test.csv', index_col=0)

In [48]:
data.head()

Unnamed: 0,event_name,next_event,_nunique,level,weight_mean,weight_min,weight_max
0,РБ,ВСП (ветка всп),"68% (100%, 68%)",0,,,
1,ВСП (ветка всп),Первый платеж (ветка всп),"63% (94%, 94%)",1,23.017266,1.0,392.0
2,ВСП (ветка всп),Неактивный (ветка всп),"3% (5%, 5%)",1,,,
3,ВСП (ветка всп),Закрыт (ветка всп),"1% (1%, 1%)",1,4.823341,0.0,60.0
4,Первый платеж (ветка всп),Неактивный 2 месяца (ветка всп),"0% (0%, 0%)",2,,,


In [49]:
#data_custom = pd.read_csv('stats_for_graph.csv', sep=';', encoding='1251')

In [50]:
#data_custom.head()

In [51]:
data_n = data.rename(columns={'event_name': 'source', 'next_event': 'target', '_nunique': 'weight'})

In [52]:
#data_custom_n = data_custom.rename(columns={'TB_otpr': 'source', 'TB_pol': 'target', 'Сумма_проводки_сonverted_2': 'weight'})

In [53]:
data_n.head()

Unnamed: 0,source,target,weight,level,weight_mean,weight_min,weight_max
0,РБ,ВСП (ветка всп),"68% (100%, 68%)",0,,,
1,ВСП (ветка всп),Первый платеж (ветка всп),"63% (94%, 94%)",1,23.017266,1.0,392.0
2,ВСП (ветка всп),Неактивный (ветка всп),"3% (5%, 5%)",1,,,
3,ВСП (ветка всп),Закрыт (ветка всп),"1% (1%, 1%)",1,4.823341,0.0,60.0
4,Первый платеж (ветка всп),Неактивный 2 месяца (ветка всп),"0% (0%, 0%)",2,,,


In [54]:
#data_custom_n.head()

In [55]:
# data_custom_n['level'] = ''
# data_custom_n['weight_mean'] = ''
# data_custom_n['weight_min'] = ''
# data_custom_n['weight_max'] = ''

In [56]:
data_n.weight_mean.values[0]

nan

In [57]:
data_n.weight_mean = np.where(data_n.weight_mean.isnull(), '', data_n.weight_mean)
data_n.weight_min = np.where(data_n.weight_min.isnull(), '', data_n.weight_min)
data_n.weight_max = np.where(data_n.weight_max.isnull(), '', data_n.weight_max)

In [58]:
data_n

Unnamed: 0,source,target,weight,level,weight_mean,weight_min,weight_max
0,РБ,ВСП (ветка всп),"68% (100%, 68%)",0,,,
1,ВСП (ветка всп),Первый платеж (ветка всп),"63% (94%, 94%)",1,23.017266390233253,1.0,392.0
2,ВСП (ветка всп),Неактивный (ветка всп),"3% (5%, 5%)",1,,,
3,ВСП (ветка всп),Закрыт (ветка всп),"1% (1%, 1%)",1,4.823341326938449,0.0,60.0
4,Первый платеж (ветка всп),Неактивный 2 месяца (ветка всп),"0% (0%, 0%)",2,,,
5,Первый платеж (ветка всп),Блокировка (ветка всп),"1% (1%, 1%)",2,29.84688995215311,0.0,307.0
6,Блокировка (ветка всп),00.01.1900 (ветка всп),"0% (0%, 7%)",3,,,
7,Блокировка (ветка всп),Банкротство (ветка всп),"0% (0%, 17%)",3,,,
8,Блокировка (ветка всп),Налоговая (ветка всп),"0% (0%, 29%)",3,,,
9,Блокировка (ветка всп),Комплаенс (ветка всп),"0% (0%, 47%)",3,,,


In [59]:
def round_str(val):
    if val == '':
        return val
    else:
        val = float(val)
        return int(np.rint(val))

In [60]:
round_str('123.6')

124

In [61]:
data_n.weight_mean = np.vectorize(lambda x : round_str(x))(data_n.weight_mean)
#data_n.weight_min = np.vectorize(lambda x : round_str(x))(data_n.weight_min)

In [62]:
#data_custom_n.weight_mean = np.vectorize(lambda x : round_str(x))(data_custom_n.weight_mean)

In [63]:
node2ind = { name: i for i, name in enumerate(set(data_n['source'].values) | set(data_n['target'].values)) }

In [64]:
#node2ind = { name: i for i, name in enumerate(set(data_custom_n['source'].values) | set(data_custom_n['target'].values)) }

In [65]:
data_n['my_index'] = data_n['target'].map(node2ind)
data_n['parent_index'] = data_n['source'].map(node2ind)

In [66]:
#data_custom_n['my_index'] = data_custom_n['target'].map(node2ind)
#data_custom_n['parent_index'] = data_custom_n['source'].map(node2ind)

In [67]:
data_n

Unnamed: 0,source,target,weight,level,weight_mean,weight_min,weight_max,my_index,parent_index
0,РБ,ВСП (ветка всп),"68% (100%, 68%)",0,,,,25,15
1,ВСП (ветка всп),Первый платеж (ветка всп),"63% (94%, 94%)",1,23.0,1.0,392.0,10,25
2,ВСП (ветка всп),Неактивный (ветка всп),"3% (5%, 5%)",1,,,,22,25
3,ВСП (ветка всп),Закрыт (ветка всп),"1% (1%, 1%)",1,5.0,0.0,60.0,4,25
4,Первый платеж (ветка всп),Неактивный 2 месяца (ветка всп),"0% (0%, 0%)",2,,,,17,10
5,Первый платеж (ветка всп),Блокировка (ветка всп),"1% (1%, 1%)",2,30.0,0.0,307.0,0,10
6,Блокировка (ветка всп),00.01.1900 (ветка всп),"0% (0%, 7%)",3,,,,1,0
7,Блокировка (ветка всп),Банкротство (ветка всп),"0% (0%, 17%)",3,,,,12,0
8,Блокировка (ветка всп),Налоговая (ветка всп),"0% (0%, 29%)",3,,,,6,0
9,Блокировка (ветка всп),Комплаенс (ветка всп),"0% (0%, 47%)",3,,,,11,0


In [68]:
#data_custom_n

In [69]:
_prepare_node_stats(data_n)

{15: [[['ВСП (ветка всп)', 68.0], ['ДРС', 32.0]], [25, 21]],
 25: [[['Первый платеж (ветка всп)', 63.0],
   ['Неактивный (ветка всп)', 3.0],
   ['Закрыт (ветка всп)', 1.0]],
  [10, 22, 4]],
 10: [[['Неактивный 2 месяца (ветка всп)', 0.0],
   ['Блокировка (ветка всп)', 1.0],
   ['Активный после платежа (ветка всп)', 61.0],
   ['Закрытие после платежа (ветка всп)', 2.0]],
  [17, 0, 8, 16]],
 0: [[['00.01.1900 (ветка всп)', 0.0],
   ['Банкротство (ветка всп)', 0.0],
   ['Налоговая (ветка всп)', 0.0],
   ['Комплаенс (ветка всп)', 0.0],
   ['Ликвидация (ветка всп)', 0.0]],
  [1, 12, 6, 11, 24]],
 21: [[['Аннулированы', 11.0], ['ВСП (ветка ДРС) (ветка дрс, всп)', 22.0]],
  [13, 3]],
 3: [[['Первый платеж (ветка дрс, всп)', 21.0],
   ['Неактивный (ветка дрс, всп)', 1.0]],
  [18, 26]],
 18: [[['Неактивный 2 месяца (ветка дрс, всп)', 0.0],
   ['Блокировка (ветка дрс, всп)', 0.0],
   ['Активный после платежа (ветка дрс, всп)', 21.0],
   ['Закрытие после платежа (ветка дрс, всп)', 0.0]],
  [20, 9

In [70]:
#_prepare_node_stats(data_custom_n)

In [71]:
data_n.head()

Unnamed: 0,source,target,weight,level,weight_mean,weight_min,weight_max,my_index,parent_index
0,РБ,ВСП (ветка всп),"68% (100%, 68%)",0,,,,25,15
1,ВСП (ветка всп),Первый платеж (ветка всп),"63% (94%, 94%)",1,23.0,1.0,392.0,10,25
2,ВСП (ветка всп),Неактивный (ветка всп),"3% (5%, 5%)",1,,,,22,25
3,ВСП (ветка всп),Закрыт (ветка всп),"1% (1%, 1%)",1,5.0,0.0,60.0,4,25
4,Первый платеж (ветка всп),Неактивный 2 месяца (ветка всп),"0% (0%, 0%)",2,,,,17,10


In [72]:
data_n.dtypes

source          object
target          object
weight          object
level            int64
weight_mean     object
weight_min      object
weight_max      object
my_index         int64
parent_index     int64
dtype: object

In [73]:
def get_medium_perc(text):
    return float(text.split()[1][1:-2])

In [74]:
get_medium_perc(data_n['weight'].values[0])

100.0

In [75]:
data_n['medium_perc'] = np.vectorize(get_medium_perc)(data_n['weight'])

In [76]:
data_n.shape

(26, 10)

In [77]:
data_n = data_n[data_n['medium_perc'] >= 1]

In [78]:
data_n.shape

(14, 10)

In [80]:
#data = pd.read_csv('test_csv.csv', index_col=0)
a = graph(data_n, _node_params, node2ind, **kwargs)

In [82]:
a

'index.html'

In [40]:
stats_for_graph = pd.read_csv('stats_for_graph.csv', sep=';', encoding='1251')

FileNotFoundError: [Errno 2] File b'stats_for_graph.csv' does not exist: b'stats_for_graph.csv'

In [None]:
#graph(data_n, _node_params, node2ind, **kwargs)