In [6]:
import networkx as nx
import pandas as pd
from tqdm import tqdm

In [28]:
def read_nodes(data_path):
    re_expression = "[|\t]+"

    # only read the first three columns
    df = pd.read_csv(
        data_path, sep=re_expression, header=None, engine="python", usecols=[0, 1, 2]
    )

    df.columns = ["Node ID", "Parent Node ID", "rank"]

    return df


# Store the data as a graph
def build(df: pd.DataFrame) -> nx.DiGraph:
    G = nx.DiGraph()
    for _, row in df.iterrows():
        parent_node_id = row["Parent Node ID"]
        node_id = row["Node ID"]
        node_rank = row["rank"]

        # Add nodes
        if not G.has_node(parent_node_id):
            G.add_node(parent_node_id, name=parent_node_id, rank=node_rank)
        
        if not G.has_node(node_id):
            G.add_node(node_id, name=node_id, rank=node_rank)

        # Add edges
        G.add_edge(parent_node_id, node_id)
    return G


def get_children(tree: nx.DiGraph, node_ID: int) -> list[int]:
    """Given a tree graph, return a list of the children of the node_ID"""
    return [c for _, c in tree.out_edges(node_ID)]


def get_parent(tree: nx.DiGraph, node_ID: int) -> int | None:
    """Given a tree graph, return the parent of the given node"""
    parent_nodes = [p for p, _ in tree.in_edges(node_ID)]
    if len(parent_nodes) > 1:
        raise Exception(f"There are multiple parent nodes for node {node_ID}")
    if len(parent_nodes) == 0:
        return None
    return parent_nodes[0]


def restrict_tree(tree: nx.DiGraph, df: pd.DataFrame) -> nx.DiGraph:
    """Restrict a tree graph to the seven standard taxonomic ranks"""
    restricted_tree = tree.copy()
    standard_taxonomic_ranks = [
        "superkingdom",  # is actually kingdom in the data
        "phylum",
        "class",
        "order",
        "family",
        "genus",
        "species",
    ]

    # list of nodes ID with non-standard taxonomic rank:
    non_standard_nodes = list(
        df[~df["rank"].isin(standard_taxonomic_ranks)]["Node ID"].values
    )

    no_parent_node_counter = 0
    for current_node_ID in non_standard_nodes:
        # connect each child node of current node to the parent node of current node
        if parent_node_ID := get_parent(restricted_tree, current_node_ID):
            if children := get_children(
                restricted_tree, current_node_ID
            ):  # we only care about children
                for child_node_ID in children:
                    restricted_tree.add_edge(parent_node_ID, child_node_ID)
                restricted_tree.remove_node(current_node_ID)
            else:
                restricted_tree.remove_node(current_node_ID)
        else:
            no_parent_node_counter += 1
            continue

        # Remove node and edges connected to it

    print(f"Found {no_parent_node_counter} nodes with no parents")
    return restricted_tree


def contract_elementary_paths(tree: nx.DiGraph) -> nx.DiGraph:
    """Given a tree graph, all elementary paths with only one parent and one child are removed, and child and parent are connected.
    Functions assumes that every node in the tree graph only has a single parent"""
    contracted_tree = tree.copy()

    for current_node in tree.nodes:
        if children := get_children(contracted_tree, current_node):
            if len(children) == 1:  # If only 1 child, contract node:
                parent_node_ID = get_parent(contracted_tree, current_node)
                child_node_ID = children[0]
                contracted_tree.add_edge(parent_node_ID, child_node_ID)
                contracted_tree.remove_node(current_node)
            else:
                # node has more than 1 children
                continue
        else:
            # node is a leaf with no children
            continue

    return contracted_tree


def read_mapping(mapping_path):
    mappings = {}
    with open(mapping_path, "r") as f:
        for line in f:
            tokens = line.strip().split()
            sequence_id = tokens[0]
            tax_ids = list(map(int, tokens[1:]))
            mappings[sequence_id] = tax_ids
    return mappings


def find_lineage(graph: nx.DiGraph, node_ID: int, root_node_ID=1):
    """Returns the shortest path (lineage) from the rootnode (ID: 1) to the provided node using dijkstras DFS"""
    lineage = nx.shortest_path(
        graph, source=root_node_ID, target=node_ID, method="dijkstra"
    )
    return lineage


# Function to find the lowest common ancestor in a directed acyclic graph (DAG)
def find_lca(tree: nx.DiGraph, nodes: list[int]) -> int:
    # Create a set to store the ancestors of each node
    ancestors = {node: set(nx.ancestors(tree, node)) for node in nodes}

    # Find the intersection of all ancestors
    common_ancestors = ancestors[nodes[0]]  # intersection of all ancestors
    for node in nodes:
        common_ancestors &= ancestors[node]

    if common_ancestors:
        # If there is at least one common ancestor, return any one of them
        return common_ancestors.pop()
    else:
        # If there is no common ancestor, return None or handle it as needed
        return None

In [29]:
### Read in the nodes and construct the tree
df = read_nodes("handins/handin6/nodes.dmp")
total_G = build(df)
total_G.remove_edge(1, 1)  # remove the one cycle from the first node to it self

#print(nx.is_tree(total_G))
#print(nx.is_directed_acyclic_graph(total_G))

In [30]:
### Restrict the tree to the seven standard taxonomic ranks:
restricted_tree = restrict_tree(total_G, df)
#print(nx.is_tree(restricted_tree))
#print(nx.is_directed_acyclic_graph(restricted_tree))

Found 1 nodes with no parents


In [31]:
### Find lineages in sequence reads
mappings = read_mapping("handins/handin6/mapping.txt")

In [32]:
### Find lineages
lineages = {}
for sequence_id, tax_ids in mappings.items():
    lineages_for_seq = []
    for tax_id in tax_ids:
        lineage = find_lineage(total_G, tax_id)
        lineages_for_seq.append(lineage)
    lineages[sequence_id] = lineages_for_seq

In [33]:
### Build LCA skeleton trees for each sequence read
lca_skeleton_tree = {}
for read_id, node_list in mappings.items():
    lca = find_lca(restricted_tree, node_list)
    
    if lca is not None:
        # Get all nodes from the root to each node in node_list and the LCA
        nodes_to_include = set()
        nodes_to_include.add(1)
        for node in node_list:
            nodes_to_include.update(
                nx.shortest_path(restricted_tree, source=1, target=node)
            )
            nodes_to_include.update(
                nx.shortest_path(restricted_tree, source=1, target=lca)
            )

        skeleton_tree = restricted_tree.subgraph(nodes_to_include)

        lca_skeleton_tree[read_id] = skeleton_tree

In [34]:
# Number of nodes in LCA skeleton tree for each sequences read:
# There should be 16 in the first one and 32 in the second
num_nodes_per_read = {}
for read_id, lca_tree in lca_skeleton_tree.items():
    # contract skeleton tree:
    contracted_skel_tree = contract_elementary_paths(lca_tree)
    num_nodes = len(list(contracted_skel_tree.nodes))
    num_nodes_per_read[read_id] = num_nodes

print("Number of nodes in LCA skeleton tree for each sequence read:")
for read_id, num_nodes in num_nodes_per_read.items():
    print(f"{read_id}: {num_nodes} nodes")

Number of nodes in LCA skeleton tree for each sequence read:
R00010: 16 nodes
R00020: 32 nodes
R00030: 44 nodes
R00040: 58 nodes
R00050: 70 nodes
R00060: 82 nodes
R00070: 99 nodes
R00080: 106 nodes
R00090: 129 nodes
R00100: 135 nodes


# Practical 6

## Task 1: write a Python script to find the LCA mapping for each sequence read

one way you could do it, is that start with any arbitrary node, and then BFS in a bottom up, and for each iteration (that is, each time we visit a new node in the upwards direction), we check if we can reach all other sequence read nodes. If not, we go to the next node in the tree. If we can ... 

In [40]:
def find_lca(tree: nx.DiGraph, nodes: list[int]) -> int:
    # Create a set to store the ancestors of each node
    ancestors = {node: set(nx.ancestors(tree, node)) for node in nodes}

    # Find the intersection of all ancestors
    common_ancestors = ancestors[nodes[0]]  # intersection of all ancestors
    for node in nodes:
        common_ancestors &= ancestors[node]

    if common_ancestors:  
        # If there is at least one common ancestor, return any one of them
        return common_ancestors.pop()
    else:
        # If there is no common ancestor, return None or handle it as needed
        return None
for sequence_ID in mappings.keys():
    find_lca(total_G, mappings[sequence_ID])
    ancestors = nx.ancestors(total_G, mappings[sequence_ID][0])
    print(ancestors)
    break
    

{6656, 1, 33154, 85512, 1206794, 43787, 2645396, 7197, 7198, 33317, 6960, 6072, 33208, 197562, 197563, 474171, 33340, 33213, 88770, 2759, 7496, 41831, 7147, 7148, 131567, 33392, 50557}


## Task 2: What is the highest taxonomic rank (that is, toward *kingdom*) for these LCA mappings?

## Task 3: What is the lowest taxonomic rank (taht, toward the *species*) for these LCA mappings?

## Task 4: write a Python script to find theoptimal (in terms of the $F$-measure) taxonomic assignment for each sequence read.

## Task 5: What is the highest taxonomic rank (that is, toward *kingdom*) for these taxonomic assignments?

## Task 6: What the lowest taxonomic rank (that is, toward *species*) for these taxonomic assignments?