In [1]:
from functools import lru_cache
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data
import networkx as nx
from sklearn.model_selection import train_test_split
import sys
sys.path.append(r"/work/haarscheid/cancer_baseline2/cancer_baseline/Graphs/scripts")
from model import ModularPathwayConv, ModularGNN
#pd.set_option('display.max_rows', None)  # Set a higher number if needed
#pd.set_option('display.max_columns', None)
torch.set_printoptions(threshold=torch.inf)

In [2]:
@lru_cache(maxsize=None)
def get_data(n_fold=0, fp_radius=2):
    import math

    # Load driver genes and hierarchies
    hierarchies = pd.read_csv("/work/haarscheid/cancer_baseline2/cancer_baseline/Graphs/data/gene_to_pathway_final_with_hierarchy.csv")
    driver_genes = pd.read_csv("/work/haarscheid/cancer_baseline2/cancer_baseline/Graphs/data/driver_genes_2.csv")
    rnaseq = pd.read_csv("/work/haarscheid/cancer_baseline2/cancer_baseline/Graphs/data/rnaseq_normcount.csv", index_col=0)

    # Load the gene network
    gene_network = nx.read_edgelist(
        r"/work/haarscheid/cancer_baseline2/cancer_baseline/Graphs/data/filtered_gene_network.edgelist",
        nodetype=str
    )

    # Create a dictionary mapping from Ensembl_ID to HGNC
    ensembl_to_hgnc = dict(zip(hierarchies['Ensembl_ID'], hierarchies['HGNC']))

    # Relabel nodes in the graph
    mapped_gene_network = nx.relabel_nodes(gene_network, ensembl_to_hgnc)

    # Convert the graph edges to a DataFrame
    edges_df = pd.DataFrame(
        list(mapped_gene_network.edges(data="weight")),
        columns=["source", "target", "weight"]
    )
    
    # Ensure the weight column is numeric
    edges_df["weight"] = edges_df["weight"].fillna(1.0).astype(float)

    # Filter RNA-seq data for driver genes
    driver_columns = rnaseq.columns.isin(hierarchies["HGNC"])
    filtered_rna = rnaseq.loc[:, driver_columns]
    tensor_exp = torch.Tensor(filtered_rna.to_numpy())

    # Create a dictionary mapping cell lines to their expression tensors
    cell_dict = {cell: tensor_exp[i] for i, cell in enumerate(filtered_rna.index.to_numpy())}

    # Get the set of valid nodes (columns in filtered RNA)
    valid_nodes = set(filtered_rna.columns)

    # Filter edges for valid nodes
    filtered_edges = edges_df[
        (edges_df["source"].isin(valid_nodes)) & (edges_df["target"].isin(valid_nodes))
    ]

    # Map nodes to numeric indices
    node_to_idx = {node: idx for idx, node in enumerate(valid_nodes)}
    filtered_edges["source_idx"] = filtered_edges["source"].map(node_to_idx)
    filtered_edges["target_idx"] = filtered_edges["target"].map(node_to_idx)

    # Create PyTorch edge tensors
    edge_index = torch.tensor(
        filtered_edges[["source_idx", "target_idx"]].values,
        dtype=torch.long
    ).T
    edge_attr = torch.tensor(
        filtered_edges["weight"].values,
        dtype=torch.float32
    )


    # Filter edges for valid nodes
    filtered_hierarchy = hierarchies[
        (hierarchies["HGNC"].isin(valid_nodes))
    ]
    # Step 1: Setup the initial pathway_dict
    pathway_dict = {
        gene: pathway.split(':', 1)[1].split('[', 1)[0].strip() if isinstance(pathway, str) and ':' in pathway else None
        for gene, pathway in dict(zip(filtered_hierarchy['HGNC'], filtered_hierarchy['Level_1'])).items()
    }

    # Step 2: Process the pathway_dict (group genes by pathway)
    grouped_pathway_dict = {}
    for gene, pathway in pathway_dict.items():
        if pathway:  # Ignore genes without valid pathways
            grouped_pathway_dict.setdefault(pathway, []).append(gene)

    # Step 3: Map pathways to numeric indices
    pathway_groups = {
        pathway: [node_to_idx[gene] for gene in genes if gene in node_to_idx]
        for pathway, genes in grouped_pathway_dict.items()
    }

    # Step 4: Convert pathway_groups to PyTorch Tensors
    pathway_tensors = {
        pathway: torch.tensor(nodes, dtype=torch.long) for pathway, nodes in pathway_groups.items()
    }

    # Create PyTorch Geometric Data objects for each cell line
    # Create PyTorch Geometric Data objects for each cell line
    graph_data_list = []
    for cell, x in cell_dict.items():
        # Transform x to have shape [num_nodes, num_features]
        if x.ndim == 2 and x.shape[0] == 1:  # [1, num_nodes]
            x = x.T
        elif x.ndim == 1:  # [num_nodes]
            x = x.unsqueeze(1)

        graph_data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
        graph_data.y = None
        graph_data.cell_line = cell
        graph_data_list.append(graph_data)



    # Load drug response data and filter missing cell lines
    data = pd.read_csv("/work/haarscheid/cancer_baseline2/cancer_baseline/Graphs/data/GDSC1.csv", index_col=0)
    data = data.query("SANGER_MODEL_ID in @cell_dict.keys()")

    # Split cell lines into folds for training/validation/testing
    unique_cell_lines = data["SANGER_MODEL_ID"].unique()
    np.random.seed(420)  # Ensure reproducibility
    np.random.shuffle(unique_cell_lines)
    folds = np.array_split(unique_cell_lines, 10)
    test_lines = folds[n_fold]
    train_idxs = list(range(10))
    train_idxs.remove(n_fold)
    validation_idx = np.random.choice(train_idxs)
    train_idxs.remove(validation_idx)
    train_lines = np.concatenate([folds[idx] for idx in train_idxs])
    validation_lines = folds[validation_idx]
    test_lines = folds[n_fold]

    # Split the drug response data
    train_data = data.query("SANGER_MODEL_ID in @train_lines")
    validation_data = data.query("SANGER_MODEL_ID in @validation_lines")
    test_data = data.query("SANGER_MODEL_ID in @test_lines")

    # Create datasets for training, validation, and testing
    train_graphs = [graph for graph in graph_data_list if graph.cell_line in train_lines]
    val_graphs = [graph for graph in graph_data_list if graph.cell_line in validation_lines]
    test_graphs = [graph for graph in graph_data_list if graph.cell_line in test_lines]

    # Check if filtered edges still reference valid nodes
    invalid_source_nodes = filtered_edges[~filtered_edges["source"].isin(valid_nodes)]
    invalid_target_nodes = filtered_edges[~filtered_edges["target"].isin(valid_nodes)]
    
    if not invalid_source_nodes.empty or not invalid_target_nodes.empty:
        print(f"Invalid source nodes: {invalid_source_nodes}")
        print(f"Invalid target nodes: {invalid_target_nodes}")
        raise ValueError("Edges reference nodes not in valid_nodes.")

    return train_graphs, val_graphs, test_graphs, pathway_tensors


In [None]:

# Assuming `train_graphs`, `val_graphs`, `test_graphs`, and `pathway_tensors` are obtained
train_graphs, val_graphs, test_graphs, pathway_tensors = get_data(n_fold=0)

# Save the instance and its associated pathway tensor together
instance = train_graphs[0]
pathway_tensor = pathway_tensors  # Assuming pathway_tensors corresponds to train_graphs

# Combine them into a dictionary for structured saving
save_data = {
    'graph': instance,
    'pathway_tensor': pathway_tensor
}

# Save the dictionary
torch.save(save_data, 'instance_with_pathway.pth')


In [2]:
# Load the saved data
loaded_data = torch.load('instance_with_pathway.pth')

# Access the graph and pathway tensor
instance = loaded_data['graph']
pathway_tensors = loaded_data['pathway_tensor']

print("Cell Line Identifier:", instance.cell_line)
print("Node Features Shape:", instance.x.shape[0])  # Shape of node features (num_nodes, num_features)
print("Edge Index Shape:", instance.edge_index.shape)  # Shape of edge_index (2, num_edges)
print("Edge Index:", instance.edge_index.max())  # Edge connections
#print("Pathway Groups:", len(pathway_tensors))
print(instance.edge_attr)
print(instance.edge_attr.shape)

  loaded_data = torch.load('instance_with_pathway.pth')


Cell Line Identifier: SIDM00003
Node Features Shape: 7813
Edge Index Shape: torch.Size([2, 415867])
Edge Index: tensor(7812)


IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [9]:
hidden_dim = 64             # Number of hidden layer features
output_dim = 10             # Desired number of output features
layer_modes = [True, True, True]  # Example: Global, Pathway, Pathway layers
pooling_mode = 'none'       # Example: No pooling
aggr_modes = ['mean', 'mean', 'mean']
model=ModularGNN(input_dim=instance.x.shape[1],
                 hidden_dim=2,
                 output_dim=instance.x.shape[1],
                 pathway_groups=pathway_tensors,
                 layer_modes = [True, True, True],
                 aggr_modes=aggr_modes
                )

In [10]:
output = model(instance.x, instance.edge_index, instance.edge_attr)

next layer
next layer
next layer


In [11]:
print(output.max())
display(output)

tensor(10624483., grad_fn=<MaxBackward1>)


tensor([[1.8776e+03],
        [0.0000e+00],
        [1.8182e+06],
        [3.5201e+06],
        [1.2586e+06],
        [0.0000e+00],
        [4.7059e+06],
        [1.0153e+00],
        [2.0024e+06],
        [5.0740e+05],
        [4.6075e+06],
        [4.4273e+06],
        [7.4154e+06],
        [3.7221e+06],
        [9.3863e+02],
        [1.4192e+06],
        [1.4836e+03],
        [2.3901e+06],
        [8.0135e+05],
        [5.4594e+06],
        [4.1917e+06],
        [4.2771e+03],
        [2.3902e+06],
        [1.2858e+06],
        [7.9040e+05],
        [2.1418e+05],
        [1.6740e+06],
        [4.1550e+06],
        [7.4382e+00],
        [9.0416e+05],
        [0.0000e+00],
        [3.8919e+05],
        [1.3828e+06],
        [6.7444e+06],
        [3.1067e+05],
        [1.5641e+06],
        [7.0690e+06],
        [0.0000e+00],
        [2.2696e+06],
        [2.9224e+06],
        [4.7430e+05],
        [3.1436e+06],
        [1.9384e+06],
        [4.3687e+06],
        [6.2676e+06],
        [7

## 

In [None]:
[[1.8776e+03],
        [0.0000e+00],
        [1.8182e+06],
        [3.5201e+06],
        [1.2586e+06],
        [0.0000e+00],
        [4.7059e+06],
        [1.0153e+00],
        [2.0024e+06],
        [5.0740e+05],
        [4.6075e+06],
        [4.4273e+06],
        [7.4154e+06],
        [3.7221e+06],
        [9.3863e+02],
        [1.4192e+06],
        [1.4836e+03],