<a href="https://colab.research.google.com/github/cannin/gsoc_2023_pytorch_pathway_commons/blob/main/InMemoryDataset_Class_with_brca_tcga.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch-geometric

# Importing Data and Libraries

In [2]:
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import InMemoryDataset, Data, download_url, extract_zip
import os

In [3]:
torch. __version__

'2.0.1+cu118'

In [4]:
!pip list | grep torch

torch                            2.0.1+cu118
torch-geometric                  2.3.1
torchaudio                       2.0.2+cu118
torchdata                        0.6.1
torchsummary                     1.5.1
torchtext                        0.15.2
torchvision                      0.15.2+cu118


# Creating InMemoryDataset Class for Train Set

In [5]:
class brca_tcga(InMemoryDataset):
  # Base url to download the files
    url = 'https://zenodo.org/record/8179187/files/brca_tcga.zip?download=1'

    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
      super().__init__(root, transform, pre_transform, pre_filter)
      self.data, self.slices = torch.load(self.processed_paths[0])


    @property
    def raw_file_names(self):
        # List of the raw files
        return ['graph_idx.csv', 'graph_labels.csv', 'edge_index.pt']

    @property
    def processed_file_names(self):
        return 'breast_data.pt'

    def download(self):
        # Download the file specified in self.url and store
        # it in self.raw_dir
        path = download_url(self.url, self.raw_dir)
        extract_zip(path, self.raw_dir)
        # The zip file is removed
        os.unlink(path)

    def process(self):
        # Load X_train from CSV file with the index
        graph_features = pd.read_csv(os.path.join(self.raw_dir,'brca_tcga', 'graph_idx.csv'), index_col=0)

        # Load y_train from CSV file
        graph_labels = np.loadtxt(os.path.join(self.raw_dir,'brca_tcga', 'graph_labels.csv'), delimiter=',')

        # Load the edge_index from the file
        file_path = os.path.join(self.raw_dir,'brca_tcga', 'edge_index.pt')
        edge_index = torch.load(file_path)

        # Convert X_train to NumPy array
        graph_features = graph_features.values

        # Get the number of patients in the training set
        num_patients = graph_features.shape[0]

        # Create patient-specific graphs for the training set
        graphs = []
        for i in range(num_patients):
            node_features = graph_features[i]  # Node features for the i-th patient
            target = graph_labels[i]  # Target label for the i-th patient
            graph = (node_features, edge_index, target)
            graphs.append(graph)

        # Convert graphs_train to a list of Data objects
        data = [Data(x=torch.tensor(graph[0].reshape(len(graphs[0][0]), 1)),
                    edge_index=graph[1], y=torch.tensor(graph[2])) for graph in graphs]


        data, slices = self.collate(data)
        # Save the processed data
        torch.save((data, slices), self.processed_paths[0])

In [6]:
df = brca_tcga(root='')

Downloading https://zenodo.org/record/8179187/files/brca_tcga.zip?download=1
Extracting ./raw/brca_tcga.zip
Processing...
Done!


In [7]:
#Access the attributes of a specific data object in the training set
sample = df[0]  # Get the first data object
print(sample)  # Print the data object

# Access the node features, edge indices, and target label
node_features = sample.x
edge_index = sample.edge_index
target = sample.y

print(node_features)  # Print the node features
print(edge_index)  # Print the edge indices
print(target)  # Print the target label

Data(x=[9288, 1], edge_index=[2, 271771], y=[1])
tensor([[   0.0000],
        [5798.3700],
        [   8.6165],
        ...,
        [ 415.8240],
        [ 931.9570],
        [1180.4600]], dtype=torch.float64)
tensor([[   0,    0,    0,  ..., 9287, 9287, 9287],
        [ 451,  452,  453,  ..., 3323, 3340, 3341]])
tensor([133.0506], dtype=torch.float64)
