In [5]:
from top2vec import Top2Vec
import plotly.graph_objects as go
import copy
from itertools import chain
import json

In [2]:
t2v = Top2Vec.load('../data/models/t2v_211122_100_deep.pkl')

In [3]:
def get_labels(t2v, n):
    
    labels = []
    
    for top_words in t2v.get_topics(reduced=True)[0]:
        labels.append('-'.join(top_words[:n]))
        
    return labels

In [4]:
def get_hierarchies_and_labels(t2v, reductions):
    
    hierarchies = []
    labels = []
    
    for r in reductions:
        print(f'Reducing to {r}')
        t2v.hierarchical_topic_reduction(r)
        hierarchies.append(t2v.get_topic_hierarchy())
        labels.append(get_labels(t2v, 3))
        
    return hierarchies, labels

In [6]:
hierarchies, labels = get_hierarchies_and_labels(t2v, [10, 30, 50])

Reducing to 10
Reducing to 30
Reducing to 50


In [7]:
def create_link(t2v, hierarchies, labels=None, colors=None):
    
    levels = copy.deepcopy(hierarchies)
    topic_dict = {}
    sizes = {}
    
    nodes_flat = dict(enumerate(list(chain(*levels))))
    nodes = []
    
    # makes a list of a dictionary for each level
    for i, level in enumerate(levels):
        nodes.append(dict(list(nodes_flat.items())[len(list(chain(*nodes))):len(list(chain(*nodes)))+len(level)]))
    
    for source_level, target_level in zip(nodes[:-1], nodes[1:]):
        for key, group in source_level.items():
            target_nodes_in_group = []
            for topic in group:
                for tkey, tvalue in target_level.items():
                    if topic in tvalue:
                        target_nodes_in_group.append(tkey)
            source_level[key] = tuple(set(target_nodes_in_group))
            
    #return nodes
                
    weight_dict = {}
    
    for key, group in nodes[-1].items():
        weight_dict[key] = sum([t2v.topic_sizes[top] for top in group])
        
    for level in nodes[-2:0:-1]: # in reverse, all levels except first and last
        for key, group in level.items():
            weight_dict[key] = sum([weight_dict[key2] for key2 in group])
            
            
    sources = []
    targets = []
    weights = []
    
    for level in nodes[:-1]:
        for source, group in level.items():
            for target in group:
                sources.append(source)
                targets.append(target)
                weights.append(weight_dict[target])
    
    link = dict(source = sources, target = targets, value = weights)
    node = dict()
    
    #if labels:
    print(sources)
    print(targets)
        
    
    return link

In [20]:
link = create_link(t2v, hierarchies)
node = dict(label = list(chain(*labels)), pad=5, thickness=100)
data = go.Sankey(link=link, node=node)
fig = go.Figure(data)

[0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 13, 13, 13, 14, 15, 16, 17, 17, 17, 18, 18, 19, 19, 20, 21, 21, 21, 21, 22, 22, 23, 23, 24, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39]
[36, 13, 20, 23, 25, 24, 26, 10, 29, 16, 19, 15, 17, 35, 21, 38, 18, 12, 33, 14, 32, 11, 34, 27, 30, 28, 22, 31, 37, 39, 64, 50, 68, 70, 74, 78, 47, 80, 41, 83, 87, 56, 76, 69, 42, 40, 49, 82, 67, 71, 81, 75, 89, 44, 43, 88, 84, 85, 86, 66, 58, 65, 73, 61, 79, 46, 53, 45, 48, 51, 52, 54, 60, 62, 59, 55, 57, 63, 72, 77]


In [21]:
fig.update_layout(
    autosize=False,
    width=1000,
    height=1500)
fig.show()

In [23]:
with open('../temp/topics_sankey_diagram.html', 'w', encoding='utf8') as f:
    f.write(fig.to_html())

In [324]:
nodes

{0: (18, 5, 6, 14),
 1: (10, 20, 13),
 2: (8, 11, 12, 17, 19),
 3: (24, 21, 15),
 4: (7, 9, 16, 22, 23),
 5: (26, 52, 39),
 6: (25, 74, 58, 69),
 7: (49, 30, 47),
 8: (42, 27),
 9: (33, 45, 62),
 10: (65, 66, 70, 56, 61),
 11: (28, 63),
 12: (32, 34),
 13: (48, 43, 44, 53),
 14: (38, 54, 71),
 15: (67, 31),
 16: (57, 36),
 17: (59, 55),
 18: (73, 50),
 19: (72, 51),
 20: (41, 46),
 21: (35, 60, 68),
 22: (64, 40),
 23: (29,),
 24: (37,),
 25: [83, 151, 169, 0],
 26: [38, 183, 214, 39, 180, 115, 160, 118, 84, 166, 9],
 27: [104, 94, 96, 3],
 28: [148, 2],
 29: [322,
  199,
  238,
  273,
  339,
  201,
  192,
  304,
  233,
  257,
  290,
  342,
  218,
  268,
  288,
  215,
  305,
  264,
  191,
  333,
  310,
  345,
  326,
  278,
  281,
  311,
  302,
  213,
  244,
  250,
  253,
  188,
  346,
  255,
  249,
  189,
  348,
  289,
  297,
  237,
  275,
  241,
  243,
  216],
 30: [167, 80, 185, 221, 54, 162, 254, 75, 4],
 31: [63, 1],
 32: [130, 229, 212, 57, 234, 6],
 33: [270,
  261,
  247,
  338,

In [333]:
def create_link(t2v, input_levels):
    
    levels = copy.deepcopy(input_levels)
    topic_dict = {}
    sizes = {}
    n_nodes = sum([len(level) for level in levels])
    
    for i, group in enumerate(levels[-1], start=n_nodes-len(levels[-1])):
        print(i, group)
        for top in group:
            if len(top) > 1:
                topic_dict[top] = i
            else:
            if i in sizes.keys():
                sizes[i] += t2v.topic_sizes[top]
            else:
                sizes[i] = t2v.topic_sizes[top]
        levels[-1][i] = i
            
    #levels[-1] = list(range(len(levels[-1])))
        
        
    for level in levels[:-1]:
        for i, group in enumerate(level):
            level[i] = tuple(set([topic_dict[top] for top in group]))
                
    
    nodes = {}
    ids_used = 0
    
    # assign nmbers to nodes
    for level in levels:
        for i, group in enumerate(level, start=ids_used):
            nodes[i] = group
            ids_used += 1
            
    
    sources = []
    targets = []
    weigths = []
        
    #for key, topics in nodes.items():
    #    for topic
        
        
        
            
    return levels, sizes
        

IndentationError: expected an indented block (2300367558.py, line 14)

In [111]:
def create_link(source_list, target_list, source_labels, target_labels):
    
    label_set = set(source_labels + target_labels)
    label_dict = dict(zip(label_set, range(len(label_set))))
    
    sources = []
    targets = []
    values  = []
    
    # for each target group (more detailed)
    for target_group, target_label in zip(target_list, target_labels):
        
        # for each source group (more general)
        for source_group, source_label in zip(source_list, source_labels):
                       
            # if any of the topics in the target group is in the source group
            if target_group[0] in source_group:
                       
                # ergo all of them are
                sources.append(label_dict[source_label])
                targets.append(label_dict[target_label])
                values.append(sum([t2v.get_topic_sizes()[0][top]/100 for top in target_group]))
                
                continue
                       
                       
    return dict(source = sources, target = targets, value = values)

In [236]:
create_link_data(t2v, [topics_5, topics_20, topics_50])

25 [83, 151, 169, 0]
26 [38, 183, 214, 39, 180, 115, 160, 118, 84, 166, 9]
27 [104, 94, 96, 3]
28 [148, 2]
29 [322, 199, 238, 273, 339, 201, 192, 304, 233, 257, 290, 342, 218, 268, 288, 215, 305, 264, 191, 333, 310, 345, 326, 278, 281, 311, 302, 213, 244, 250, 253, 188, 346, 255, 249, 189, 348, 289, 297, 237, 275, 241, 243, 216]
30 [167, 80, 185, 221, 54, 162, 254, 75, 4]
31 [63, 1]
32 [130, 229, 212, 57, 234, 6]
33 [270, 261, 247, 338, 219, 317, 323, 284, 282, 336, 296, 318, 329, 341, 314, 340, 331, 232, 327, 308, 277, 226, 81, 330, 325, 211, 178, 299, 347, 195, 265, 343, 223, 271, 203, 287, 295, 324, 138]
34 [272, 198, 283, 303, 141, 174, 156, 321, 5]
35 [97, 163, 7]
36 [165, 240, 8]
37 [120, 171, 259, 209, 145, 175, 70, 127, 105, 177, 10]
38 [128, 197, 66, 124, 73, 18]
39 [29, 41, 154, 15]
40 [227, 76, 236, 88, 65, 263, 246, 12]
41 [31, 61, 152, 90, 159, 56]
42 [51, 89, 117, 196, 33]
43 [32, 149, 91, 153, 220, 126, 99, 157, 109]
44 [28, 146, 16]
45 [306, 349, 285, 312, 186, 334, 182

TypeError: 'int' object is not iterable

In [240]:
topics_50[-3]

[224, 67, 276, 121, 58]

In [143]:
print(topics_5[0])
print('\n')
print(topics_20[:4])
print('\n')
print(topics_50[:6])

[291, 161, 155, 168, 315, 204, 252, 62, 207, 69, 48, 72, 113, 137, 46, 83, 151, 169, 0, 123, 170, 225, 55, 95, 267, 20, 179, 190, 140, 64, 26, 128, 197, 66, 124, 73, 18, 25, 150, 101, 92, 200, 68, 300, 269, 147, 222, 235, 52, 29, 41, 154, 15, 24, 38, 183, 214, 39, 180, 115, 160, 118, 84, 166, 9]


[[29, 41, 154, 15, 24, 38, 183, 214, 39, 180, 115, 160, 118, 84, 166, 9], [291, 161, 155, 168, 315, 204, 252, 62, 207, 69, 48, 72, 113, 137, 46, 83, 151, 169, 0], [43, 110, 116, 239, 37, 139, 53, 111, 286, 242, 71, 167, 80, 185, 221, 54, 162, 254, 75, 4], [51, 89, 117, 196, 33, 104, 94, 96, 3]]


[[83, 151, 169, 0], [38, 183, 214, 39, 180, 115, 160, 118, 84, 166, 9], [104, 94, 96, 3], [148, 2], [322, 199, 238, 273, 339, 201, 192, 304, 233, 257, 290, 342, 218, 268, 288, 215, 305, 264, 191, 333, 310, 345, 326, 278, 281, 311, 302, 213, 244, 250, 253, 188, 346, 255, 249, 189, 348, 289, 297, 237, 275, 241, 243, 216], [167, 80, 185, 221, 54, 162, 254, 75, 4]]


In [130]:
[1,2,3,4,5][:-1][::-1]

[4, 3, 2, 1]

In [123]:
topics_20

[[29, 41, 154, 15, 24, 38, 183, 214, 39, 180, 115, 160, 118, 84, 166, 9],
 [291,
  161,
  155,
  168,
  315,
  204,
  252,
  62,
  207,
  69,
  48,
  72,
  113,
  137,
  46,
  83,
  151,
  169,
  0],
 [43,
  110,
  116,
  239,
  37,
  139,
  53,
  111,
  286,
  242,
  71,
  167,
  80,
  185,
  221,
  54,
  162,
  254,
  75,
  4],
 [51, 89, 117, 196, 33, 104, 94, 96, 3],
 [306,
  349,
  285,
  312,
  186,
  334,
  182,
  307,
  319,
  337,
  202,
  193,
  292,
  230,
  335,
  298,
  313,
  208,
  320,
  251,
  210,
  293,
  260,
  280,
  187,
  274,
  344,
  350,
  231,
  262,
  328,
  316,
  332,
  119,
  248,
  47,
  245,
  27,
  270,
  261,
  247,
  338,
  219,
  317,
  323,
  284,
  282,
  336,
  296,
  318,
  329,
  341,
  314,
  340,
  331,
  232,
  327,
  308,
  277,
  226,
  81,
  330,
  325,
  211,
  178,
  299,
  347,
  195,
  265,
  343,
  223,
  271,
  203,
  287,
  295,
  324,
  138],
 [14,
  100,
  85,
  93,
  164,
  133,
  136,
  78,
  86,
  125,
  36,
  176,
  194,
  112

In [122]:
topics_50

[[83, 151, 169, 0],
 [38, 183, 214, 39, 180, 115, 160, 118, 84, 166, 9],
 [104, 94, 96, 3],
 [148, 2],
 [322,
  199,
  238,
  273,
  339,
  201,
  192,
  304,
  233,
  257,
  290,
  342,
  218,
  268,
  288,
  215,
  305,
  264,
  191,
  333,
  310,
  345,
  326,
  278,
  281,
  311,
  302,
  213,
  244,
  250,
  253,
  188,
  346,
  255,
  249,
  189,
  348,
  289,
  297,
  237,
  275,
  241,
  243,
  216],
 [167, 80, 185, 221, 54, 162, 254, 75, 4],
 [63, 1],
 [130, 229, 212, 57, 234, 6],
 [270,
  261,
  247,
  338,
  219,
  317,
  323,
  284,
  282,
  336,
  296,
  318,
  329,
  341,
  314,
  340,
  331,
  232,
  327,
  308,
  277,
  226,
  81,
  330,
  325,
  211,
  178,
  299,
  347,
  195,
  265,
  343,
  223,
  271,
  203,
  287,
  295,
  324,
  138],
 [272, 198, 283, 303, 141, 174, 156, 321, 5],
 [97, 163, 7],
 [165, 240, 8],
 [120, 171, 259, 209, 145, 175, 70, 127, 105, 177, 10],
 [128, 197, 66, 124, 73, 18],
 [29, 41, 154, 15],
 [227, 76, 236, 88, 65, 263, 246, 12],
 [31, 61, 

In [118]:
# viimase kihi gruppidele anna konkreetsed väärtused

In [33]:
for group in topics_20:
    for top in group:
        top_in = 0
        for root in topics_5:
            if top in root:
                top_in += 1
            else:
                pass
        if top_in != 1:
            print(top)
            break
            #print(f'Topic {top} in group {root}')

In [30]:
ex = {}

for group in range(len(topics_20)):
    for top in topics_20[group]:
        for i in range(len(topics_5)):
            if top in topics_5[i]:
                if i in ex.keys():
                    ex[i].append(top)
                else:
                    ex[i] = [top]
            else:
                pass

In [34]:
topics_20[0]

[29, 41, 154, 15, 24, 38, 183, 214, 39, 180, 115, 160, 118, 84, 166, 9]

In [36]:
for top in topics_20[0]:
    if top not in topics_5[0]:
        print(top)