In [21]:
import networkx as nx

from copy import deepcopy
from rich.progress import Progress, BarColumn, TextColumn
from math import floor
from json import load

In [6]:
with open('data_for_exp_5\\gi_net\\0.1_捕风的异乡人.json', 'r', encoding='utf-8') as f:
    nlp_results = load(f)

In [15]:
class MakeGraph:
    def __init__(self, sdp_result:dict) -> None:
        self.__sdp_result = \
            [list(zip(tok, sdp)) for tok, sdp in zip(sdp_result['tok/fine'], sdp_result['sdp'])]

    def __get_subarray(self, 
                       arr, 
                       size:int, 
                       *,
                       is_soft:bool=False, 
                       proportion:float=1):
        '''
        `is_soft`: wether trailing chunk (it may not contain enough items, and thus
        cannot be returned as a subarray) is truncated according to the `proportion`
        argument.
        `proportion`: an argument which determines wether the last chunk will be 
        discarded. It must be a float < 1. If the size of the last chunk greater 
        then `proportion` * `size`, and the chunk will be retained; otherwise, 
        discarded.
        '''      
        if is_soft:
            proportion = proportion
        else:
            proportion = 1
             
        if len(arr) % size == 0:
            split_pos = [i*size for i in range(1, int(len(arr)/size)+1)]
        else:
            split_pos = [i*size for i in range(1, floor(len(arr)/size)+1)] + [len(arr)]
        
        subarrays = []
        start_pos = 0
        for end_pos in split_pos:
            subarrays.append(arr[start_pos:end_pos])
            start_pos = end_pos
         
        if len(subarrays[-1]) < proportion * size:
            subarrays.pop()
        
        return subarrays
    
    def run(self, size:int, *, is_soft:bool=True, proportion:float=1) -> list[nx.Graph]:
        text_column = TextColumn("Generating graphs from SDP...")
        bar_column = BarColumn(bar_width=20)
        progress = Progress(text_column, bar_column)
        
        subarries = self.__get_subarray(self.__sdp_result,
                                        size,
                                        is_soft=is_soft,
                                        proportion=proportion)
        
        graphs = []
        with progress:
            for sub in progress.track(subarries, description='Working...'):
                G = nx.DiGraph()
                for node_semrel_lst in sub:
                    nodes = [node_semrel[0] for node_semrel in node_semrel_lst]
                    G.add_nodes_from(nodes)
                
                    id_node_dict = dict(enumerate(node_semrel_lst, 1))
                
                    edges = []
                    for node, sem_rels in node_semrel_lst:
                        for rel in sem_rels:
                            head_id = rel[0]
                            if head_id != 0:
                                edges.append((id_node_dict[rel[0]][0], node, {'relation': rel[1]}))
                        
                    G.add_edges_from(edges)
                
                graphs.append(G)
        
        return graphs

In [16]:
mg = MakeGraph(nlp_results)
g = mg.run(10)

Output()

In [24]:
def get_graph_prop(graph_list:list[nx.Graph], name_func:dict):
    bar_column = BarColumn(bar_width=20)
    progress = Progress(bar_column)

    with progress:
        results = []

        print('Running calculation...')
        for name, func in progress.track(name_func.items()):
            print(f'{name} is being computed...')

            results_temp = []
            for graph in progress.track(graph_list):
                results_temp.append(func(graph))
            
            results.append({name: deepcopy(results_temp)})
            results_temp.clear()
    
    return results


Output()

[{'degree cantrality': [{'哇': 0.014285714285714285,
    '——': 0.014285714285714285,
    '那': 0.014285714285714285,
    '就': 0.04285714285714286,
    '是': 0.19999999999999998,
    '「': 0.02857142857142857,
    '七': 0.04285714285714286,
    '天': 0.02857142857142857,
    '神像': 0.05714285714285714,
    '」': 0.02857142857142857,
    '了': 0.02857142857142857,
    '。': 0.05714285714285714,
    '神灵': 0.11428571428571428,
    '的': 0.12857142857142856,
    '造像': 0.04285714285714286,
    '散布': 0.05714285714285714,
    '在': 0.02857142857142857,
    '大陆': 0.04285714285714286,
    '上': 0.014285714285714285,
    '，': 0.09999999999999999,
    '象征': 0.05714285714285714,
    '神': 0.09999999999999999,
    '守护': 0.04285714285714286,
    '世界': 0.014285714285714285,
    '位': 0.08571428571428572,
    '元素': 0.02857142857142857,
    '中': 0.014285714285714285,
    '这': 0.014285714285714285,
    '一': 0.014285714285714285,
    '掌控': 0.02857142857142857,
    '风': 0.11428571428571428,
    '虽然': 0.014285714285714285

In [26]:
type(iter([1, 2, 3]))

list_iterator