In [1]:
import networkx as nx

from rich.table import Column
from rich.progress import Progress, BarColumn, TextColumn
from math import floor
from json import load

In [2]:
with open('data_for_exp_5\\gi_net\\0.1_捕风的异乡人.json', 'r', encoding='utf-8') as f:
    nlp_results = load(f)

In [9]:
class MakeGraph:
    def __init__(self, sdp_result:dict) -> None:
        self.__sdp_result = \
            [list(zip(tok, sdp)) for tok, sdp in zip(sdp_result['tok/fine'], sdp_result['sdp'])]

    def __get_subarray(self, 
                       arr, 
                       size:int, 
                       *, 
                       is_all:bool=False, 
                       is_soft:bool=False, 
                       proportion:float=1):
        '''
        `is_all`: if this parameter is set `True`, `arr` will be returned without
        separation.
        `is_soft`: wether trailing chunk (it may not contain enough items, and thus
        cannot be returned as a subarray) is truncated according to the `proportion`
        argument.
        `proportion`: an argument which determines wether the last chunk will be 
        discarded. It must be a float < 1. If the size of the last chunk greater 
        then `proportion` * `size`, and the chunk will be retained; otherwise, 
        discarded.
        '''
        if is_all:
            return arr
        
        if is_soft:
            proportion = proportion
        else:
            proportion = 1
             
        if len(arr) % size == 0:
            split_pos = [i*size for i in range(1, int(len(arr)/size)+1)]
        else:
            split_pos = [i*size for i in range(1, floor(len(arr)/size)+1)] + [len(arr)]
        
        subarrays = []
        start_pos = 0
        for end_pos in split_pos:
            subarrays.append(arr[start_pos:end_pos])
            start_pos = end_pos
         
        if len(subarrays[-1]) < proportion * size:
            subarrays.pop()
        
        return subarrays
    
    def run(self, size:int, *, is_all:bool=False, is_soft:bool=True, proportion:float=1) -> list[nx.Graph]:
        text_column = TextColumn("Generating graphs from SDP...", table_column=Column(ratio=1))
        bar_column = BarColumn(bar_width=None, table_column=Column(ratio=2))
        progress = Progress(text_column, bar_column, expand=True)
        
        subarries = self.__get_subarray(self.__sdp_result,
                                        size,
                                        is_all=is_all,
                                        is_soft=is_soft,
                                        proportion=proportion)
        
        graphs = []
        with progress:
            for sub in progress.track(subarries):
                G = nx.DiGraph()
                for node_semrel_lst in sub:
                    nodes = [node_semrel[0] for node_semrel in node_semrel_lst]
                    G.add_nodes_from(nodes)
                
                    id_node_dict = dict(enumerate(node_semrel_lst, 1))
                
                    edges = []
                    for node, sem_rels in node_semrel_lst:
                        for rel in sem_rels:
                            head_id = rel[0]
                            if head_id != 0:
                                edges.append((id_node_dict[rel[0]][0], node, {'relation': rel[1]}))
                        
                    G.add_edges_from(edges)
                
                graphs.append(G)
        
        return graphs

In [11]:
mg = MakeGraph(nlp_results)
mg.run(1)

Output()

[<networkx.classes.digraph.DiGraph at 0x16b0893fb60>,
 <networkx.classes.digraph.DiGraph at 0x16b0893e450>,
 <networkx.classes.digraph.DiGraph at 0x16b0893f9e0>,
 <networkx.classes.digraph.DiGraph at 0x16b0893e6c0>,
 <networkx.classes.digraph.DiGraph at 0x16b0893e2d0>,
 <networkx.classes.digraph.DiGraph at 0x16b0893f680>,
 <networkx.classes.digraph.DiGraph at 0x16b0893e1e0>,
 <networkx.classes.digraph.DiGraph at 0x16b0893ffb0>,
 <networkx.classes.digraph.DiGraph at 0x16b0893ff80>,
 <networkx.classes.digraph.DiGraph at 0x16b0893ff50>,
 <networkx.classes.digraph.DiGraph at 0x16b0893ff20>,
 <networkx.classes.digraph.DiGraph at 0x16b0893fec0>,
 <networkx.classes.digraph.DiGraph at 0x16b0893fe90>,
 <networkx.classes.digraph.DiGraph at 0x16b0893fe60>,
 <networkx.classes.digraph.DiGraph at 0x16b0893eed0>,
 <networkx.classes.digraph.DiGraph at 0x16b0893fe30>,
 <networkx.classes.digraph.DiGraph at 0x16b0893fce0>,
 <networkx.classes.digraph.DiGraph at 0x16b0893fdd0>,
 <networkx.classes.digraph.D