In [1]:
from pathlib import Path
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd

In [2]:
def get_graph_data_from_topo(filepath=None):
    """
    Reads a .topo file and returns:
    - A NetworkX directed graph with gene names as node labels and 'Type' as edge weight.
    - A mapping from gene names to integer indices (useful for ML models like PyG).
    
    :param filepath: path to the topology file
    :return: G_named (NetworkX DiGraph), gene_to_idx (dict)
    """
    df = pd.read_csv(filepath, sep=r"\s+")

    # Create gene-to-index mapping for optional ML use
    genes = sorted(set(df['Source']).union(df['Target']))
    gene_to_idx = {gene: idx for idx, gene in enumerate(genes)}

    # Build NetworkX DiGraph with weights
    edges_with_weights = list(zip(df['Source'], df['Target'], df['Type']))
    G = nx.DiGraph()
    G.add_weighted_edges_from(edges_with_weights)

    return G, gene_to_idx

In [2]:
def create_sample_topo_file(filepath, num_nodes=500, num_edges=1000, num_hubs=2):
    import random
    import string
    from itertools import product
    from pathlib import Path

    if filepath is None:
        raise ValueError("filepath must be provided")
    p = Path(filepath)

    if num_nodes <= 1:
        raise ValueError("num_nodes must be > 1")
    if not (1 <= num_hubs < num_nodes):
        raise ValueError("num_hubs must be >=1 and less than num_nodes")

    # maximum directed edges without self-loops
    max_edges = num_nodes * (num_nodes - 1)
    required_hub_edges = num_hubs * (num_nodes - 1)  # each hub -> every other node (excluding self)

    if num_edges < required_hub_edges:
        raise ValueError(
            f"num_edges ({num_edges}) is smaller than the number of edges required to connect "
            f"{num_hubs} hubs to all other nodes ({required_hub_edges})."
        )
    if num_edges > max_edges:
        raise ValueError(f"num_edges ({num_edges}) exceeds maximum possible directed edges without self-loops ({max_edges}).")

    # Generate alphabetic-only gene names (no digits). Use two-letter combos (26^2 = 676 >= 500).
    letters = string.ascii_uppercase
    gen_names = (''.join(t) for t in product(letters, repeat=2))
    nodes = [next(gen_names) for _ in range(num_nodes)]

    hubs = nodes[:num_hubs]

    # Use dict to avoid duplicate directed edges: mapping (src, tgt) -> weight
    edges_map = {}

    # Connect hubs to all other nodes (exclude self-loops)
    for hub in hubs:
        for node in nodes:
            if hub == node:
                continue
            edges_map[(hub, node)] = random.choice([1, 2])

    # Add random unique edges until reaching desired count
    attempts = 0
    while len(edges_map) < num_edges:
        attempts += 1
        if attempts > (num_edges * 100):  # safety to avoid infinite loop
            raise RuntimeError("Too many attempts to generate unique random edges; adjust parameters.")
        src = random.choice(nodes)
        tgt = random.choice(nodes)
        if src == tgt:
            continue
        if (src, tgt) in edges_map:
            continue
        edges_map[(src, tgt)] = random.choice([1, 2])

    # Write to file
    p.parent.mkdir(parents=True, exist_ok=True)
    with p.open('w') as f:
        f.write("Source Target Type\n")
        for (src, tgt), weight in edges_map.items():
            f.write(f"{src} {tgt} {weight}\n")

    print(f"Created {p} with {num_nodes} nodes, {len(edges_map)} edges, {num_hubs} hubs.")
    return p

def create_equal_topo_file(filepath, num_nodes=500):
    #create a topo file where all nodes are arraged in a circle and each node points to the nodes on the left and right
    import string
    import random
    from itertools import product
    from pathlib import Path
    if filepath is None:
        raise ValueError("filepath must be provided")
    p = Path(filepath)
    
    # Generate alphabetic-only gene names (no digits). Use two-letter combos (26^2 = 676 >= 500).
    letters = string.ascii_uppercase
    # Keep gen_names as an iterator and prepend "GENE" when generating node names.
    gen_names = ("GENE" + ''.join(t) for t in product(letters, repeat=2))
    nodes = [next(gen_names) for _ in range(num_nodes)]
    # Create a circular topology
    with p.open('w') as f:
        f.write("Source Target Type\n")
        for i in range(num_nodes):
            left = (i - 1) % num_nodes
            right = (i + 1) % num_nodes
            weight = random.choice([1, 2])
            f.write(f"{nodes[i]} {nodes[left]} {weight}\n")
            f.write(f"{nodes[i]} {nodes[right]} {weight}\n")

    print(f"Created {p} with {num_nodes} nodes arranged in a circle.")
    return p

topo_filepath = "equal_500_1000.topo"
create_equal_topo_file(topo_filepath, num_nodes=500)

# # Create sample .topo file using existing topo_filepath variable
# create_sample_topo_file(topo_filepath)

Created equal_500_1000.topo with 500 nodes arranged in a circle.


PosixPath('equal_500_1000.topo')

In [3]:
G, _ = get_graph_data_from_topo(filepath=Path("dorothea_500_1000.topo"))
nodes = list(G.nodes())


NameError: name 'get_graph_data_from_topo' is not defined

In [8]:
print(f"Graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges.")

Graph has 500 nodes and 1000 edges.


In [9]:
print("number of nodes with outgoing edges:", sum(1 for n in G.nodes() if G.out_degree(n) > 0))

number of nodes with outgoing edges: 279


In [5]:
import pandas as pd
df = pd.read_csv("XXL_run_gene_metrics.csv")
genes_XXL = df["gene"].tolist()

FileNotFoundError: [Errno 2] No such file or directory: 'XXL_run_gene_metrics.csv'

In [None]:
print(len(nodes), len(genes_XXL))
print(len(set(nodes)), len(set(genes_XXL)))  # Check if the sets are equal
# print the value counts for genes_XXL
print(df["gene"].value_counts())

150 228
150 125
gene
PGR      2
FOS      2
JUN      2
STAT1    2
HIF1A    2
        ..
NR1H2    1
HNF4A    1
MAFB     1
KLF13    1
HBP1     1
Name: count, Length: 125, dtype: int64


In [3]:
# read splits.pt

import torch
splits = torch.load("splits.pt")

In [8]:
print(len(splits["train_index_forward"]))
print(len(splits["train_index_backward"]))
print(len(splits["test_index_forward"]))
print(len(splits["test_index_backward"]))
print(len(splits["val_index_forward"]))
print(len(splits["val_index_backward"]))

3895
3895
1299
1299
1298
1298
