In [1]:
import pandas as pd

def get_internal_steps(df):
    """Filter out first and last steps per (sample_id, direction, path_id)."""
    return (
        df.sort_values('step_id')
          .groupby(['sample_id', 'direction', 'path_id'])
          .apply(lambda g: g.iloc[1:-1] if len(g) > 2 else pd.DataFrame())
          .reset_index(drop=True)
    )

def extract_nodes(df, node_type_col, node_id_col):
    """Extract and rename node type and id columns."""
    return df[['sample_id', 'path_id', node_type_col, node_id_col]].rename(
        columns={node_type_col: 'node_type', node_id_col: 'node_id'}
    )

def count_internal_nodes_across_samples(df):
    internal_df = get_internal_steps(df)

    # Extract h and t nodes
    h_nodes = internal_df[['sample_id', 'path_id', 'h_type', 'h_name']]
    t_nodes = internal_df[['sample_id', 'path_id', 't_type', 't_name']]
    h_nodes.columns = t_nodes.columns = ['sample_id', 'path_id', 'node_type', 'node_name']

    node_df = pd.concat([h_nodes, t_nodes], ignore_index=True)

    # Deduplicate
    node_df = node_df.drop_duplicates(subset=['sample_id', 'path_id', 'node_type', 'node_name'])

    # Count per sample
    counts_per_sample = (
        node_df.groupby(['sample_id', 'node_type', 'node_name'])
               .size()
               .reset_index(name='path_count')
    )

    # Aggregate counts
    counts_overall = (
        counts_per_sample.groupby(['node_type', 'node_name'])['path_count']
                         .sum()
                         .reset_index(name='num_paths_with_node')
                         .sort_values(by='num_paths_with_node', ascending=False)
                         .reset_index(drop=True)
    )

    return counts_per_sample, counts_overall

def count_unique_internal_node_types_per_path(df):
    internal_df = get_internal_steps(df)

    # Extract unique h and t nodes
    h_nodes = extract_nodes(internal_df, 'h_type', 'h_id')
    t_nodes = extract_nodes(internal_df, 't_type', 't_id')
    all_nodes = pd.concat([h_nodes, t_nodes], ignore_index=True)

    # Drop duplicates within paths
    unique_nodes_per_path = all_nodes.drop_duplicates(subset=['sample_id', 'path_id', 'node_id'])

    # Count node types
    type_counts = (
        unique_nodes_per_path.groupby('node_type')
                             .size()
                             .reset_index(name='num_nodes_of_type')
                             .sort_values(by='num_nodes_of_type', ascending=False)
                             .reset_index(drop=True)
    )

    return unique_nodes_per_path, type_counts



In [2]:
import os
import pandas as pd

def process_single_file(df):
    counts_per_sample, counts_overall = count_unique_internal_node_types_per_path(df)
    df1, df2 = count_internal_nodes_across_samples(df)

    # Merge and sort result
    res = (
        df2.merge(counts_overall, on='node_type', how='left')
        .sort_values(by=['node_type', 'num_paths_with_node'], ascending=[True, False])
    )

    return res

def get_internal_node_counts_across_seeds(base_path):
    results = pd.DataFrame()

    for i, fname in enumerate(sorted(os.listdir(base_path))):
        if not fname.endswith(".csv"):
            continue

        seed = i
        file_path = os.path.join(base_path, fname)
        

        try:
            df = pd.read_csv(file_path)
            res = process_single_file(df)
            res['seed'] = seed
            results = pd.concat([results, res], ignore_index=True)

        except Exception as e:
            print(f"Failed to process {fname}: {e}")



    return results


In [3]:
def get_for_all_seeds(base_path):
    res = get_internal_node_counts_across_seeds(base_path)
    result = res.groupby(['node_name', 'node_type'])['num_paths_with_node'].sum().reset_index().sort_values(by=['node_type', 'num_paths_with_node'], ascending=[True, False])
    counts = result.groupby('node_type')['num_paths_with_node'].sum().reset_index().rename(columns={'num_paths_with_node': 'num_nodes_of_type'})
    results = result.merge(counts, on='node_type', how='left') 
    return results

In [4]:
def get_for_all_seeds_and_save(base_path):
    get_internal_node_counts_across_seeds(base_path).to_csv(f"rep_nodes/{base_path}_seed.csv", index=False)
    get_for_all_seeds(base_path).to_csv(f"rep_nodes/{base_path}_all_seeds.csv", index=False)
    


In [5]:
get_for_all_seeds_and_save("SL_unthreshold")


In [42]:
get_for_all_seeds_and_save("SL")
get_for_all_seeds_and_save("mental_health")
get_for_all_seeds_and_save("cell_proliferation")
get_for_all_seeds_and_save("anemia")
get_for_all_seeds_and_save("adrenal_gland")
get_for_all_seeds_and_save("cardiovascular")


In [44]:
res = get_for_all_seeds("SL")
res.loc[res.node_type==2]

Unnamed: 0,node_name,node_type,num_paths_with_node,num_nodes_of_type
286,KRAS,2.0,1559,10163
287,NRAS,2.0,1063,10163
288,HRAS,2.0,947,10163
289,REX1BD,2.0,602,10163
290,KLHL38,2.0,410,10163
...,...,...,...,...
662,WDR3,2.0,1,10163
663,WDR37,2.0,1,10163
664,WFDC13,2.0,1,10163
665,XK,2.0,1,10163
