In [350]:
import networkx as nx
import matplotlib.pyplot as plt
import random

In [351]:
MIN_PER_RANK = 1  # Nodes/Rank: How 'fat' the DAG should be.
MAX_PER_RANK = 5
MIN_RANKS = 3     # Ranks: How 'tall' the DAG should be.
MAX_RANKS = 5
PERCENT = 30      # Chance of having an Edge.

def generate_dag(min_w, max_w, total_nodes): # min_w, max_w: node values range, total_nodes: total number of nodes
    random.seed()  # Initialize the random number generator

    G = nx.DiGraph()

    current_nodes = 0 # Total number of nodes in the graph
    ranks = [] # Number of nodes in each rank

    # Generate ranks with nodes until the total number of nodes is reached
    while current_nodes < total_nodes:
        new_nodes = min(MAX_PER_RANK, total_nodes - current_nodes) # Number of nodes in the new rank
        ranks.append(new_nodes) # Add the new rank to the list of ranks
        current_nodes += new_nodes # Update the total number of nodes

    nodes = 1 # Total number of nodes in the graph starts from 1

    for rank in ranks:
        for k in range(rank):
            # Assign a random weight to each new node
            node_weight = random.randint(min_w, max_w)
            G.add_node(nodes + k, weight=node_weight)

        # Edges from old nodes ('nodes') to new ones ('rank').
        for j in range(nodes - 1): # Adjusted to start from 0
            for k in range(rank):
                if random.randint(0, 99) < PERCENT: # Randomly decide if there is an edge between the nodes
                    G.add_edge(j + 1, k + nodes) # Adjusted to start from 1

        nodes += rank  # Accumulate into old node set.

    # remove isolated nodes
    G.remove_nodes_from(list(nx.isolates(G)))

    root_id = 0 # Root node is 0

    roots = [node for node in G.nodes() if G.in_degree(node) == 0]
    for _ in roots:
        G.add_edge(root_id, _)

    for node in G.nodes():
        G.nodes[node]['ranges'] = [(0,0)]

    ## root node has weight 0
    G.nodes[root_id]['weight'] = 0

    return G

In [352]:
def draw_graph(G):
    plt.figure(figsize=(10, 10))
    pos = nx.spring_layout(G, iterations=200, scale=4)
    nx.draw(G, pos, with_labels=False, node_size=500, font_size=10, node_color='skyblue')
    node_labels = nx.get_node_attributes(G, 'weight')
    nx.draw_networkx_labels(G, pos, labels=node_labels)

    root = [node for node in G.nodes() if G.in_degree(node) == 0][0]
    nx.draw_networkx_nodes(G, pos, nodelist=[root], node_color='red', node_size=500)

    plt.show()

In [353]:
# def merge_ranges_and_remove_none(ranges):

#     sorted_ranges = sorted(ranges, key = lambda x: x[0])
#     merged_ranges = [sorted_ranges[0]]
#     for current_range in sorted_ranges[1:]:
#         if current_range[0] <= merged_ranges[-1][1] + 1:
#             merged_ranges[-1] = (merged_ranges[-1][0], max(merged_ranges[-1][1], current_range[1]))
#         else:
#             merged_ranges.append(current_range)
#     return [current_range for current_range in merged_ranges if current_range[0] != 0 or current_range[1] != 0]

def merge_ranges_and_remove_none(ranges):
    sorted_ranges = sorted(ranges or [(0, 0)])  # Handle empty input and sort
    merged = [sorted_ranges[0]]                 # Start with the first range

    for start, end in sorted_ranges[1:]:        # Iterate through the rest
        last_start, last_end = merged[-1]       # Get the last merged range
        if start <= last_end + 1:                # Check for overlap
            merged[-1] = (last_start, max(last_end, end))  # Update the last range
        else:
            merged.append((start, end))          # Add a new non-overlapping range

    return [(s, e) for s, e in merged if s or e]  # Filter out (0, 0) ranges


def compute_and_associate_ranges(G, node_id):
    ranges = []
    predecessors = list(G.predecessors(node_id))
    for predecessor in predecessors:
        for predecessor_range in G.nodes[predecessor]['ranges']:
            ranges.append((predecessor_range[1] + 1, predecessor_range[1] + G.nodes[node_id]['weight']))
        ranges = merge_ranges_and_remove_none(ranges)
    return ranges

def print_node_ranges(G):
    for node in G.nodes():
        print(f"Node: {node}")
        for node_range in G.nodes[node]['ranges']:
            print(f"\tRange: {node_range}")

def process_graph(G, print_output=False):
    for node in list(nx.topological_sort(G))[1:]:
        G.nodes[node]['ranges'] = compute_and_associate_ranges(G, node)
    if print_output: print_node_ranges(G)
    return G

def remove_half_ranges(G):
    G_half = G.copy()
    for generation in list(nx.topological_generations(G_half))[1::2]:
        for node in generation:
            G_half.nodes[node]['ranges'] = [(0,0)]
    return G_half

def get_node_ranges(G, node_id):
    if G.nodes[node_id]['ranges'] == [(0,0)]:
        print("Node has no ranges, computing ranges...")
        return compute_and_associate_ranges(G, node_id)

    print("Node already has ranges, returning them...")
    return G.nodes[node_id]['ranges']

In [354]:
G = process_graph(generate_dag(1, 10, 30))
G_half = remove_half_ranges(G)