In [451]:
import os
import copy
from itertools import chain
import json
import pandas as pd
import plotly.graph_objects as go
from top2vec import Top2Vec
import colorcet as cc
import colorsys
import matplotlib.pyplot as plt
from matplotlib.colors import rgb2hex
from pylab import cm

In [52]:
t2v = Top2Vec.load('../data/models/t2v_211122_100_deep.pkl')

In [6]:
reductions = [12, 36, 72]

In [112]:
# hierarchy_df = pd.read_csv('../data/topics/topic_hierarchy.tsv', sep='\t', encoding='utf8')

In [111]:
### DEPRECATED

# def get_sizes(reductions):
#     
#     sizes = []
#     
#     for r in reductions:
#         with open(f'../data/topics/reduction_{r}/sizes.json', 'r', encoding='utf8') as f:
#             reduction_sizes = json.load(f)
#             sizes += reduction_sizes
#             
#     return dict(zip(range(sum(reductions)), sizes))

# sizes = get_sizes(reductions)

In [9]:
def get_labels(reduction):
    
    labels = []
    
    for r in reductions:
        with open(f'../data/topics/reduction_{r}/default_labels.json', 'r', encoding='utf8') as f:
            reduction_labels = json.load(f)
            labels += reduction_labels
            
    return dict(zip(range(sum(reductions)), labels))

In [10]:
labels = get_labels(reductions)

In [113]:
def get_hierarchies(reductions):

    all_hierarchies = []

    for r in reductions:
        with open(f'../data/topics/reduction_{str(r)}/reduction_hierarchy.json', 'r', encoding='utf8') as f:
            hierarchies = json.load(f)
            all_hierarchies.append(hierarchies)

    return all_hierarchies

In [114]:
hierarchies = get_hierarchies(reductions)

In [115]:
def create_sizes(t2v, hierarchies):
    
    levels = copy.deepcopy(hierarchies)
    nodes_flat = dict(enumerate(list(chain(*levels))))
    sizes_list = [sum(t2v.topic_sizes[group]) for group in nodes_flat.values()]
    sizes = dict(zip(range(len(nodes_flat)), sizes_list))

    return sizes

In [116]:
sizes = create_sizes(t2v, hierarchies)

In [117]:
def create_hierarchy_df(t2v, hierarchies):
    
    levels = copy.deepcopy(hierarchies)
    nodes_flat = dict(enumerate(list(chain(*levels))))
    nodes = []
    
    # makes a list of a dictionary for each level with unique id for each topic
    for level in levels:
        nodes.append(dict(list(nodes_flat.items())[len(list(chain(*nodes))):len(list(chain(*nodes)))+len(level)]))
                
    paths = [[key] for key in nodes[-1].keys()]
    #print(len(paths))

    # starts from the penultimate level and looks if the topic is a subset of anything in that level
    for path in paths:
        for level in nodes[-2::-1]:
            last_subtopic = set(nodes[len(nodes)-len(path)][path[-1]])
            for root_key, root_topic in level.items():
                if last_subtopic.issubset(set(root_topic)):
                    path.append(root_key)
    
    paths = [path[::-1] for path in paths]
    
    #return paths
    
    hierarchy_df = pd.DataFrame(paths)#.apply(lambda x: x.sort_values().values)
    hierarchy_df.columns = [f'reduction_{len(level)}' for level in nodes]
    
    return hierarchy_df

In [118]:
df = create_hierarchy_df(t2v, hierarchies)

In [325]:
def hex_to_rgb(hex_code):
    hex_code = hex_code.lstrip('#')
    return tuple(int(hex_code[i:i+2], 16) for i in (0, 2, 4))

def rgb_to_hex(rgb):
    return '#%02x%02x%02x' % rgb

def rgb_to_string(rgb):
    return f'rgb({rgb[0]},{rgb[1]},{rgb[2]})'

def scale_lightness(color, scale):
    
    if type(color) == str:
        rgb = hex_to_rgb(color)
        scaled_rgb = tuple([min(round(i + scale*256), 255) for i in rgb])
        return rgb_to_hex(scaled_rgb)
    elif type(color) == tuple:
        scaled_rgb = tuple([min(round(i + scale*256), 255) for i in color])
        return scaled_rgb

In [452]:
def create_colors(colormap, hierarchy_df, scale):
    
    colors = {}
    
    for top in sorted(df[df.columns[0]].unique()):
        colors[top] = hex_to_rgb(colormap[top])
    #print(colors.values())abs
        
    for source_col, target_col, step in zip(df.columns[0:-1], df.columns[1:], range(1, len(df.columns))):
        for top in df[target_col].unique():
            root_top = df.loc[(df[target_col] == top), source_col].unique()[0]
            colors[top] = scale_lightness(colors[root_top], scale*step)
                        
    colors = {key: rgb_to_string(col) for key, col in colors.items()}
            
    return colors

In [453]:
cmap = cm.get_cmap('tab20')
cmap = [rgb2hex(cmap(i)) for i in range(20)]

In [454]:
colors = create_colors(cmap, df, 0)

In [455]:
def create_sankey_node(hierarchy_df, labels, colors):
    
    node_ids = list(labels.keys())
    
    return dict(label=[str(ID)+'_'+value for ID, value in zip(node_ids, labels.values())],
                color=list(colors.values()),
                pad=5,
                thickness=100)

In [456]:
def create_sankey_link(hierarchy_df, sizes, labels, colors):
    
    source = []
    target = []
    value  = []
    
    for source_col, target_col in zip(hierarchy_df.columns[0:-1], hierarchy_df.columns[1:]):
        unique_combinations = hierarchy_df[[source_col, target_col]].drop_duplicates().sort_values(by=source_col)
        sources = list(unique_combinations[source_col])
        targets = list(unique_combinations[target_col])
        
        source += sources
        target += targets
        
    value = [sizes[top] for top in target]
    color = [colors[top] for top in target]
    
    assert len(source) == len(target) == len(value)
        
    link = dict(source=source, target=target, value=value, color=color)
    
    return link

In [457]:
node = create_sankey_node(df, labels, colors)
link = create_sankey_link(df, sizes, labels, colors)
            
sankey_data = go.Sankey(link=link, node=node)

fig = go.Figure(sankey_data)

In [458]:
fig.update_layout(
    autosize=False,
    width=1000,
    height=1500)

fig.show()