In [None]:
from functools import lru_cache
import os
import zipfile
import requests
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data
import networkx as nx
from sklearn.model_selection import train_test_split
import sys
sys.path.append(r"scripts")
from model_GNN import ModularPathwayConv, ModularGNN
torch.set_printoptions(threshold=torch.inf)
from torch_geometric.loader import DataLoader

In [None]:
@lru_cache(maxsize=None)
def get_data(n_fold=0, fp_radius=2):
    import math

    def download_if_not_present(url, filepath):
        """
        Downloads a file from a URL if it does not already exist locally.
        """
        if not os.path.exists(filepath):
            print(f"File not found at {filepath}. Downloading...")
            response = requests.get(url, stream=True)
            os.makedirs(os.path.dirname(filepath), exist_ok=True)  # Ensure the directory exists
            with open(filepath, "wb") as file:
                for chunk in response.iter_content(chunk_size=8192):
                    file.write(chunk)
            print("Download completed.")
        else:
            print(f"File already exists at {filepath}.")

    # Download RNA-seq data if not present
    zip_url = "https://cog.sanger.ac.uk/cmp/download/rnaseq_all_20220624.zip"
    zip_filepath = "data/rnaseq.zip"
    rnaseq_filepath = "data/rnaseq_normcount.csv"
    extraction_path = "data/"
    
        # Check if the RNA-seq file already exists
    if not os.path.exists(rnaseq_filepath):
        print(f"RNA-seq file not found at {rnaseq_filepath}. Checking for ZIP file...")
        
        # Step 1: Download the ZIP file if it is not already present
        download_if_not_present(zip_url, zip_filepath)
    
        # Step 2: Extract the ZIP file
        if os.path.exists(zip_filepath):
            print("Extracting the ZIP file...")
            with zipfile.ZipFile(zip_filepath, "r") as zipf:
                zipf.extractall(extraction_path)
                print("Extraction completed.")
        else:
            raise FileNotFoundError(f"ZIP file not found at {zip_filepath}. Could not extract RNA-seq data.")
    
    # Load RNA-seq data
    if os.path.exists(rnaseq_filepath):
        rnaseq = pd.read_csv(rnaseq_filepath, index_col=0)
        print("RNA-seq CSV file loaded successfully.")
    else:
        raise FileNotFoundError(f"RNA-seq file not found at {rnaseq_filepath} after extraction.")


    # Load driver genes and hierarchies
    hierarchies = pd.read_csv("data/gene_to_pathway_final_with_hierarchy.csv")
    driver_genes = pd.read_csv("data/driver_genes_2.csv")

    # Load the gene network
    gene_network = nx.read_edgelist(
        r"data/filtered_gene_network.edgelist",
        nodetype=str
    )

    # Create a dictionary mapping from Ensembl_ID to HGNC
    ensembl_to_hgnc = dict(zip(hierarchies['Ensembl_ID'], hierarchies['HGNC']))

    # Relabel nodes in the graph
    mapped_gene_network = nx.relabel_nodes(gene_network, ensembl_to_hgnc)

    # Convert the graph edges to a DataFrame
    edges_df = pd.DataFrame(
        list(mapped_gene_network.edges(data="weight")),
        columns=["source", "target", "weight"]
    )
    
    # Ensure the weight column is numeric
    edges_df["weight"] = edges_df["weight"].fillna(1.0).astype(float)

    # Filter RNA-seq data for driver genes
    driver_columns = rnaseq.columns.isin(hierarchies["HGNC"])
    filtered_rna = rnaseq.loc[:, driver_columns]
    tensor_exp = torch.Tensor(filtered_rna.to_numpy())

    # Create a dictionary mapping cell lines to their expression tensors
    cell_dict = {cell: tensor_exp[i] for i, cell in enumerate(filtered_rna.index.to_numpy())}

    # Get the set of valid nodes (columns in filtered RNA)
    valid_nodes = set(filtered_rna.columns)

    # Filter edges for valid nodes
    filtered_edges = edges_df[
        (edges_df["source"].isin(valid_nodes)) & (edges_df["target"].isin(valid_nodes))
    ]

    # Map nodes to numeric indices
    node_to_idx = {node: idx for idx, node in enumerate(valid_nodes)}
    filtered_edges["source_idx"] = filtered_edges["source"].map(node_to_idx)
    filtered_edges["target_idx"] = filtered_edges["target"].map(node_to_idx)

    # Create PyTorch edge tensors
    edge_index = torch.tensor(
        filtered_edges[["source_idx", "target_idx"]].values,
        dtype=torch.long
    ).T
    edge_attr = torch.tensor(
        filtered_edges["weight"].values,
        dtype=torch.float32
    )

    # Filter edges for valid nodes
    filtered_hierarchy = hierarchies[
        (hierarchies["HGNC"].isin(valid_nodes))
    ]
    # Step 1: Setup the initial pathway_dict
    pathway_dict = {
        gene: pathway.split(':', 1)[1].split('[', 1)[0].strip() if isinstance(pathway, str) and ':' in pathway else None
        for gene, pathway in dict(zip(filtered_hierarchy['HGNC'], filtered_hierarchy['Level_1'])).items()
    }

    # Step 2: Process the pathway_dict (group genes by pathway)
    grouped_pathway_dict = {}
    for gene, pathway in pathway_dict.items():
        if pathway:  # Ignore genes without valid pathways
            grouped_pathway_dict.setdefault(pathway, []).append(gene)

    # Step 3: Map pathways to numeric indices
    pathway_groups = {
        pathway: [node_to_idx[gene] for gene in genes if gene in node_to_idx]
        for pathway, genes in grouped_pathway_dict.items()
    }

    # Step 4: Convert pathway_groups to PyTorch Tensors
    pathway_tensors = {
        pathway: torch.tensor(nodes, dtype=torch.long) for pathway, nodes in pathway_groups.items()
    }

    # Create PyTorch Geometric Data objects for each cell line
    graph_data_list = []
    for cell, x in cell_dict.items():
        # Transform x to have shape [num_nodes, num_features]
        if x.ndim == 2 and x.shape[0] == 1:  # [1, num_nodes]
            x = x.T
        elif x.ndim == 1:  # [num_nodes]
            x = x.unsqueeze(1)

        graph_data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
        graph_data.y = None
        graph_data.cell_line = cell
        graph_data_list.append(graph_data)

    # Load drug response data and filter missing cell lines
    data = pd.read_csv("data/GDSC1.csv", index_col=0)
    data = data.query("SANGER_MODEL_ID in @cell_dict.keys()")

    # Split cell lines into folds for training/validation/testing
    unique_cell_lines = data["SANGER_MODEL_ID"].unique()
    np.random.seed(420)  # Ensure reproducibility
    np.random.shuffle(unique_cell_lines)
    folds = np.array_split(unique_cell_lines, 10)
    test_lines = folds[n_fold]
    train_idxs = list(range(10))
    train_idxs.remove(n_fold)
    validation_idx = np.random.choice(train_idxs)
    train_idxs.remove(validation_idx)
    train_lines = np.concatenate([folds[idx] for idx in train_idxs])
    validation_lines = folds[validation_idx]
    test_lines = folds[n_fold]

    # Split the drug response data
    train_data = data.query("SANGER_MODEL_ID in @train_lines")
    validation_data = data.query("SANGER_MODEL_ID in @validation_lines")
    test_data = data.query("SANGER_MODEL_ID in @test_lines")

    # Create datasets for training, validation, and testing
    train_graphs = [graph for graph in graph_data_list if graph.cell_line in train_lines]
    val_graphs = [graph for graph in graph_data_list if graph.cell_line in validation_lines]
    test_graphs = [graph for graph in graph_data_list if graph.cell_line in test_lines]

    # Check if filtered edges still reference valid nodes
    invalid_source_nodes = filtered_edges[~filtered_edges["source"].isin(valid_nodes)]
    invalid_target_nodes = filtered_edges[~filtered_edges["target"].isin(valid_nodes)]
    
    if not invalid_source_nodes.empty or not invalid_target_nodes.empty:
        print(f"Invalid source nodes: {invalid_source_nodes}")
        print(f"Invalid target nodes: {invalid_target_nodes}")
        raise ValueError("Edges reference nodes not in valid_nodes.")

    return train_graphs, val_graphs, test_graphs, pathway_tensors

In [None]:
train_graphs, val_graphs, test_graphs, pathway_tensors=get_data(0=