In [None]:
# TODO: add check for cycles involving duplicated blocks that are getting ignored here
def all_paths_from_matrix(edge_matrix, start_idx, end_idx):
    """
    Return all simple paths from start_idx to end_idx.
    Duplicated blocks leading to circles are usually not causing a problem because they are small insertions that have many low count edges.

    Parameters
    ----------
    edge_matrix : array-like (n x n)
        Adjacency matrix. Nonzero = edge (weight kept as attribute 'weight').
    start_idx, end_idx : int
        Node indices (rows/cols of the matrix).
    """

    G = nx.from_numpy_array(edge_matrix, create_using=nx.DiGraph)

    paths_gen = nx.all_simple_paths(G, source=start_idx, target=end_idx)
    paths = [list(p) for p in paths_gen]
    return paths

all_paths = all_paths_from_matrix(edge_matrix, start_node, end_node)

# TODO: do I need a flow limit on combined paths?

In [None]:
def _get_or_make_dedup_idx(variant: tuple[int, int | None], variant_to_dedup_idx, dedup_idx_to_variant) -> int:
    if variant not in variant_to_dedup_idx:
        didx = len(variant_to_dedup_idx)
        variant_to_dedup_idx[variant] = didx
        dedup_idx_to_variant[didx] = variant
    return variant_to_dedup_idx[variant]

def deduplicate_paths(blockstats_df, paths_blocks):

    duplicated_ids = set(blockstats_df[blockstats_df['duplicated'] == True].index)

    variant_to_dedup_idx: dict[tuple[int, int | None], int] = {}   # (block_id, left_context_block_id_or_None) -> dedup_idx
    dedup_idx_to_variant: dict[int, tuple[int, int | None]] = {}   # inverse map

    deduplicated_paths = {}


    for isolate, path in paths_blocks.items():
        encoded_path = []
        last_non_duplicated = None
        for block in path:
            if block in duplicated_ids:
                variant_block = (block, last_non_duplicated)
            else:
                variant_block = (block, None)
                last_non_duplicated = block
            encoded_path.append(_get_or_make_dedup_idx(variant_block, variant_to_dedup_idx, dedup_idx_to_variant))
        deduplicated_paths[isolate] = encoded_path

    return deduplicated_paths, variant_to_dedup_idx, dedup_idx_to_variant

In [None]:
def plot_junction_pangraph(pan: pp.Pangraph, add_consensus: bool = False, consensus_paths: list = None, order="tree"):

    if order == "tree":
        leaf_order = get_tree_order()

    path_dict = pan.to_path_dictionary()
    bdf = pan.to_blockstats_df()
    n_core = bdf["core"].sum()
    n_acc = len(bdf) - n_core
    cgen_acc = iter(sns.color_palette("rainbow", n_acc))
    cgen_core = iter(sns.color_palette("pastel", n_core))
    block_colors = {}

    fig, ax = plt.subplots(figsize=(12, len(path_dict) * 0.2))
    y = 0
    y_labels = []

    for name in leaf_order:
        if name not in pan.paths:
            continue
        path = pan.paths[name]
        for node_id in path.nodes:
            block, strand, start, end = pan.nodes[node_id][
                ["block_id", "strand", "start", "end"]
            ]
            if block not in block_colors:
                if bdf.loc[block, "core"]:
                    color = next(cgen_core)
                else:
                    color = next(cgen_acc)
                block_colors[block] = color
            else:
                color = block_colors[block]
            block_len = bdf.loc[block, "len"]
            edgecolor = "black" if strand else "red"
            ax.barh(
                y,
                width=end - start,
                left=start,
                color=color,
                edgecolor=edgecolor,
            )
        y_labels.append(name)
        y += 1

    if add_consensus:
        for i, cons_path in enumerate(consensus_paths):
            start = 0
            for block_id in cons_path:
                block_len = bdf.loc[block_id, "len"]
                ax.barh(
                    y,
                    width=block_len,
                    left=start,
                    color=block_colors[block_id],
                    edgecolor="black", # TODO: potentially consider block orientation in consensus
                )
                start += block_len
            y_labels.append(f"consensus_{i+1}")
            y += 1
            

    ax.set_yticks(range(len(y_labels)), y_labels)
    ax.set_xlabel("genomic position (bp)")
    #ax.set_title(f"Junction graph for edge {selected_edge}")
    ax.grid(axis="x", alpha=0.4)
    ax.set_ylim(-1, len(y_labels))
    sns.despine()
    plt.tight_layout()

def build_edge_matrix(paths: dict, block_names: list):
    """
    Builds a NumPy edge matrix (NxN) counting how often each directed edge occurs.
    Currently adds inversed blocks as if they were not inversed.
    
    Args:
        paths: dict[str, list[int]] – for each path, the list of node IDs
        block_names: list[int] – all node IDs in the desired order

    Returns:
        edge_matrix: np.ndarray of shape (N, N)
    """
    n = len(block_names)
    edge_matrix = np.zeros((n, n), dtype=int)
    block_to_idx = {node: i for i, node in enumerate(block_names)}
    idx_to_block = {v: k for k, v in block_to_idx.items()}
    
    for seq in paths.values():
        for a, b in zip(seq, seq[1:]):
            i, j = block_to_idx[a], block_to_idx[b]
            edge_matrix[i, j] += 1

    return edge_matrix, block_to_idx, idx_to_block

def remove_duplicated_blocks(bs_df, edge_matrix, block_to_idx):
    dupl_indices = []
    for dupl_block in bs_df[bs_df['duplicated'] == True].index:
        dupl_indices.append(block_to_idx[dupl_block])

    edge_matrix[dupl_indices,:] = 0
    edge_matrix[:,dupl_indices] = 0

    return edge_matrix

def remove_duplicated_blocks_early(blockstats_df, path_dict):
    duplicated_blocks = set(blockstats_df[blockstats_df['duplicated'] == True].index)

    filtered_data = {
        name: [bid for bid, _ in lst if bid not in duplicated_blocks]
        for name, lst in path_dict.items()
    }

    return filtered_data

def filter_rare_blocks(blockstats_df, paths_blocks, rare_threshold):
    rare_blocks = set(blockstats_df.loc[blockstats_df['count'] < rare_threshold].index)

    filtered_data = {
        key: [bid for bid in bids if bid not in rare_blocks]
        for key, bids in paths_blocks.items()
    }
    return filtered_data

def all_paths_from_matrix(edge_matrix, start_idx, end_idx):
    """
    Return all simple paths from start_idx to end_idx.
    Duplicated blocks leading to circles are usually not causing a problem because they are small insertions that have many low count edges.

    Parameters
    ----------
    edge_matrix : array-like (n x n)
        Adjacency matrix. Nonzero = edge (weight kept as attribute 'weight').
    start_idx, end_idx : int
        Node indices (rows/cols of the matrix).
    """

    G = ig.Graph.Adjacency(edge_matrix)

    #fig, ax = plt.subplots()
    #ig.plot(G, target=ax)  # uses matplotlib instead of cairo
    #plt.show()

    paths = G.get_all_simple_paths(v=start_idx, to=end_idx, mode = 'out')
    return paths

# TODO: do I need a flow limit on combined paths?
# TODO: should I have a cycle check on paths?

def transform_path_indices_to_block_ids(paths, idx_to_block):
    transformed_paths = []
    for path in paths:
        block_path = [idx_to_block[idx] for idx in path]
        transformed_paths.append(block_path)
    return transformed_paths

def transform_block_ids_to_path_indices(paths, block_to_idx):
    transformed_paths = []
    for path in paths:
        block_path = [block_to_idx[block] for block in path]
        transformed_paths.append(block_path)
    return transformed_paths

def remove_paths_with_rare_edges(paths_blocks, edge_matrix, block_to_idx, flow_threshold = 10):

    unique_paths = set()

    for isolate, path in paths_blocks.items():
        # TODO: potentially extend to order isolates to paths, need to find a way what to do with removed paths
        is_valid_path = True
        for idx in range(len(path)-1):
            first_block = block_to_idx[path[idx]]
            second_block = block_to_idx[path[idx+1]]
            if edge_matrix[first_block, second_block] < flow_threshold:
                is_valid_path = False
                break
        if is_valid_path:
            unique_paths.add(tuple(path))

    unique_paths = [list(p) for p in unique_paths]
    print(f"Found {len(unique_paths)} unique paths.")

    return unique_paths

def get_consensus_paths(pangraph, flow_threshold=7, remove_duplicates = False, filter_rare = False, rare_threshold = 5):
    path_dict = pangraph.to_path_dictionary()

    # sort blocks by overall frequency (TODO: do I still need this?)
    blockstats_df = pangraph.to_blockstats_df()
    block_order = blockstats_df.sort_values("count", ascending=True).index.to_list()
    
    if remove_duplicates:
        paths_blocks = remove_duplicated_blocks_early(blockstats_df, path_dict)
    else:
        paths_blocks = {name: [bid for bid, _ in lst] for name, lst in path_dict.items()}

    if filter_rare:
        paths_blocks = filter_rare_blocks(blockstats_df, paths_blocks, rare_threshold=rare_threshold)
    
    edge_matrix, block_to_idx, idx_to_block = build_edge_matrix(paths_blocks, block_order)
    #edge_matrix[np.where(edge_matrix<flow_threshold)] = 0

    #if remove_duplicates:
    #    edge_matrix = remove_duplicated_blocks(blockstats_df, edge_matrix, block_to_idx)

    #first_path = next(iter(paths_blocks))
    #start_node = block_to_idx[paths_blocks[first_path][0]]
    #end_node = block_to_idx[paths_blocks[first_path][-1]]

    #all_paths = all_paths_from_matrix(edge_matrix, start_node, end_node)
    #all_paths_ids = transform_path_indices_to_block_ids(all_paths, idx_to_block)
    all_paths_ids = remove_paths_with_rare_edges(paths_blocks, edge_matrix, block_to_idx, flow_threshold)
    all_paths = transform_block_ids_to_path_indices(all_paths_ids, block_to_idx)

    return all_paths, all_paths_ids, edge_matrix
