In [23]:
from functools import lru_cache
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data
import networkx as nx
from sklearn.model_selection import train_test_split


In [28]:
from functools import lru_cache
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data
import networkx as nx


@lru_cache(maxsize=None)
def get_data(n_fold=0, fp_radius=2):
    # Load driver genes and hierarchies
    hierarchies = pd.read_csv("/work/haarscheid/cancer_baseline2/cancer_baseline/Graphs/data/gene_to_pathway_final_with_hierarchy.csv")
    driver_genes = pd.read_csv("/work/haarscheid/cancer_baseline2/cancer_baseline/Graphs/data/driver_genes_2.csv")
    rnaseq = pd.read_csv("/work/haarscheid/cancer_baseline2/cancer_baseline/Graphs/data/rnaseq_normcount.csv", index_col=0)

    # Filter RNA-seq data for driver genes
    driver_columns = rnaseq.columns.isin(hierarchies["HGNC"])
    filtered_rna = rnaseq.loc[:, driver_columns]
    tensor_exp = torch.Tensor(filtered_rna.to_numpy())

    # Create a dictionary mapping cell lines to their expression tensors
    cell_dict = {cell: tensor_exp[i] for i, cell in enumerate(filtered_rna.index.to_numpy())}

    # Get the set of valid nodes (columns in filtered RNA)
    valid_nodes = set(filtered_rna.columns)

    # Load the Gene Regulatory Network (GRN)
    GRN = nx.read_edgelist(
        "/work/haarscheid/cancer_baseline2/cancer_baseline/Graphs/data/filtered_gene_network.edgelist"
    )

    # Filter edges based on valid nodes
    filtered_edges = [(u, v) for u, v in GRN.edges if u in valid_nodes and v in valid_nodes]
    edge_index = torch.tensor(filtered_edges, dtype=torch.long).T  # Convert to PyTorch edge_index format

    # Create PyTorch Geometric Data objects for each cell line
    graph_data_list = []
    for cell, x in cell_dict.items():
        graph_data = Data(x=x.unsqueeze(0), edge_index=edge_index)  # Create Data object
        graph_data.y = None  # Placeholder for labels (add when necessary)
        graph_data.cell_line = cell  # Store cell line identifier
        graph_data_list.append(graph_data)

    # Load drug response data and filter missing cell lines
    data = pd.read_csv("/work/haarscheid/cancer_baseline2/cancer_baseline/Graphs/data/GDSC1.csv", index_col=0)
    data = data.query("SANGER_MODEL_ID in @cell_dict.keys()")

    # Split cell lines into folds for training/validation/testing
    unique_cell_lines = data["SANGER_MODEL_ID"].unique()
    np.random.seed(420)  # Ensure reproducibility
    np.random.shuffle(unique_cell_lines)
    folds = np.array_split(unique_cell_lines, 10)
    test_lines = folds[n_fold]
    train_idxs = list(range(10))
    train_idxs.remove(n_fold)
    validation_idx = np.random.choice(train_idxs)
    train_idxs.remove(validation_idx)
    train_lines = np.concatenate([folds[idx] for idx in train_idxs])
    validation_lines = folds[validation_idx]
    test_lines = folds[n_fold]

    # Split the drug response data
    train_data = data.query("SANGER_MODEL_ID in @train_lines")
    validation_data = data.query("SANGER_MODEL_ID in @validation_lines")
    test_data = data.query("SANGER_MODEL_ID in @test_lines")

    # Create datasets for training, validation, and testing
    train_graphs = [graph for graph in graph_data_list if graph.cell_line in train_lines]
    val_graphs = [graph for graph in graph_data_list if graph.cell_line in validation_lines]
    test_graphs = [graph for graph in graph_data_list if graph.cell_line in test_lines]

    return train_graphs, val_graphs, test_graphs


In [29]:

train_graphs, val_graphs, test_graphs = get_data(n_fold=0)



  edge_index = torch.tensor(filtered_edges, dtype=torch.long).T  # Convert to PyTorch edge_index format
