# Create Graph Data

This notebook covers fundamentals about graphs and converting data we have in some format (in this case detected cells in a `geojson` file format) to the desired format of Pytorch Geometric. For any application of Graph Neural Networks (GNNs), this is typcally one for the steps which requries much effort and where a big portion of modelling the problem comes in.

## Example dataset

While we use a very specific example in this workshop, you should hopefully be able to generalize what is learned in this notebook to any problem where you are dealing with graph-structured data. We focus on building a GNN for the problems of node and graph classification. We will use the same dataset based on cell-graphs derived from the [Prostate cANcer graDe Assessment (PANDA) Challenge](https://www.kaggle.com/competitions/prostate-cancer-grade-assessment/data). This is a dataset of about 10000 Whole Slide Images (WSI) of microscopy scans of prostate biopsy samples, accompanied with a cancer severity grading. Our main goal will be to classify each WSI according to its severity.

### Pre-processing
We have prepared the graph dataset by preprocessing the WSI images using the [CellViT](https://github.com/TIO-IKIM/CellViT) Vision Transformer. This neural network has performed cell nucleus detection on the WSI as well as cell type classification into connective, epithelial, inflammatory and neoplastic cells.

See the image below for an example of how this looks in QuPath

![Example of hispathology image with annotations from CellViT](images/QuPath_cell_example.png)

Our main goal of this workshop is to build a graph neural network which classifies these points into the given cancer grading for the dataset. While its unlikely that we'll achieve similar performance as the winners of the challange, it serves to illustrate how we can represent some very high dimensional data (the WSI in this case) in a more compact graph representation which might still be able to capture much of the important information for solving the problem.

## Representing graphs in Pytorch Geometric

Pytorch Geometric is a general purpose framework for developing neural networks which operates on different geometrical objects, primarily graphs but also point clouds. Here we will focus on how to take a datastructure in some representation, such as points in Euclidean space, and convert it into the kind of datastructures PyG expects. 

### The Data object

In Pytorch Geometric, all objects (graphs, point clouds) are represented by a `Data` object. This object encapsulates the nodes and edges which are part of the graph. The signature of the Data constructur is as follows:

```python
class Data(x: Optional[Tensor] = None, edge_index: Optional[Tensor] = None, edge_attr: Optional[Tensor] = None, y: Optional[Tensor] = None, pos: Optional[Tensor] = None, **kwargs)
```

The parameters for the constructor are as follows:

 - **x** (torch.Tensor, optional) – Node feature matrix with shape [num_nodes, num_node_features]. (default: None)
 - **edge_index** (LongTensor, optional) – Graph connectivity in COO format with shape [2, num_edges]. (default: None)
 - **edge_attr** (torch.Tensor, optional) – Edge feature matrix with shape [num_edges, num_edge_features]. (default: None)
 - **y** (torch.Tensor, optional) – Graph-level or node-level ground-truth labels with arbitrary shape. (default: None)
 - **pos** (torch.Tensor, optional) – Node position matrix with shape [num_nodes, num_dimensions]. (default: None)

We need to construct at least the **x** tensor which contains the node features (cell features in the case of our cell graph). The central parameter for analysing our cell graph is the **edge_index**, which contains the edge list as we've specified before. This is the first hurdle in any application of GNNs - take the graph you have in some format and convert it to the representation suitable for (in this case) Pytorch Geometric.




## Detected cells from CellViT

In this workshop we will be working with the spatial information we've gotten from the CellViT framework. This has been used to get cell nucleus segmentation and positions with the results in `geojson` files. GeoJSON is a simple JSON-based format for encoding geographical features. The files we have from CellViT has the following structure (when represented with python datatypes):

```python
[{'type': 'Feature',
  'id': '3cddcb72-475c-466a-aeac-bd1c3b166b71',
  'geometry': {'type': 'MultiPoint', 'coordinates': [...]},
  'properties': {'objectType': 'annotation',
   'classification': {'name': 'Neoplastic', 'color': [255, 0, 0]}}},
 {'type': 'Feature',
  'id': '12e50fb8-33e2-48da-939c-31f11694739d',
  'geometry': {'type': 'MultiPoint', 'coordinates': [...]},
  'properties': {'objectType': 'annotation',
   'classification': {'name': 'Inflammatory', 'color': [34, 221, 77]}}},
 {'type': 'Feature',
  'id': '72d03b5f-d96e-4ad9-9571-2559c70bff0c',
  'geometry': {'type': 'MultiPoint', 'coordinates': [...]},
  'properties': {'objectType': 'annotation',
   'classification': {'name': 'Connective', 'color': [35, 92, 236]}}},
 {'type': 'Feature',
  'id': 'eedec27b-d487-4ba5-89b3-578c0a214315',
  'geometry': {'type': 'MultiPoint', 'coordinates': [...]},
  'properties': {'objectType': 'annotation',
   'classification': {'name': 'Dead', 'color': [254, 255, 0]}}},
 {'type': 'Feature',
  'id': '3b297a88-7715-40a4-825c-ae4b18a4a2cf',
  'geometry': {'type': 'MultiPoint', 'coordinates': [...]},
  'properties': {'objectType': 'annotation',
   'classification': {'name': 'Epithelial', 'color': [255, 159, 68]}}}]
```

As you can see, each cell type has been encapuslated in its own "Feature" object. Each such "Feature" object has an attribute called "geometry" and "properties" which we will use. The "geometry" attribute contains an attribute 'coordinates' which we here show as an ellipsis for compactness. In reality in contains all the coordinates of the detected cells which we will make use of to construct our cell graph.
We want to "transpose" this object, so that all points (cell locations) are in the same list, while a separate list contains the classified type of each cell.

In [None]:
import json
with open('../datasets/example/008069b542b0439ed69b194674051964/cell_detection/cell_detection.geojson') as fp:
    detected_cells_geojson = json.load(fp)

In [None]:
import numpy as np

def transpose_geojson(geojson_obj):
    node_positions = []
    node_types = []
    for feature_obj in geojson_obj:
        if feature_obj.get('type') == 'Feature':
            feature_coordinates = feature_obj['geometry']['coordinates']
            node_positions.extend(feature_coordinates)
            feature_labels = [feature_obj['properties']['classification']['name']]*len(feature_coordinates)
            node_types.extend(feature_labels)
    return np.array(node_positions), node_types

In [None]:
node_positions, node_types = transpose_geojson(detected_cells_geojson)

In [None]:
set(node_types)

{'Connective', 'Dead', 'Epithelial', 'Inflammatory', 'Neoplastic'}

In [None]:
node_positions

array([[18186.09090909,   801.1038961 ],
       [18197.59292035,   815.59292035],
       [18198.96638655,   834.02521008],
       ...,
       [20459.7340824 ,  1939.54681648],
       [10196.14912281,  2927.88596491],
       [11058.5270936 ,  2956.82758621]])

The dataset has a microns per pixel of 0.48, so if we want to constrain the radius to 20um it should be 20/0.48 pixels (since this is the coordinate system CellViT has given us the coordinates in pixel space)

In [None]:
# To effiently build the graph, we will use the cKDTree of scipy
import numpy as np
from scipy.spatial import cKDTree

RADIUS = 20/0.48
def create_edges(node_positions, radius=RADIUS):
    spatial_index = cKDTree(node_positions)
    sparse_distances = spatial_index.sparse_distance_matrix(spatial_index, radius, output_type='coo_matrix')
    sparse_distances.eliminate_zeros()  # We eliminate all zeros from the matrix, we don't want self-loops
    pair_indices = np.stack([sparse_distances.row, sparse_distances.col], axis=0)
    distances = np.copy(sparse_distances.data)
    return pair_indices, distances


In [None]:
edges, distances = create_edges(node_positions)

In [None]:
import torch
def edges_to_multiline_feature(edges: torch.Tensor, node_pos: torch.Tensor, properties = None):
    """Converts a 2d-array of row coordinates to Multi-line string Feature"""
    if properties is None:
        properties = {}
    lines = node_pos[edges.transpose()]
    coordinates = lines.tolist()
    geometry = {"type": "MultiLineString", "coordinates": coordinates}
    feature = {"type": "Feature", "geometry": geometry, "properties": properties}
    return feature

def show_edges(edges):
    import tissuumaps.jupyter as tj
    viewer = tj.loaddata(["../datasets/example/008069b542b0439ed69b194674051964.tiff"])
    pass

In [None]:
#show_edges(edges)

## Convert the dataset

We've included the preprocessed dataset here without any of the images. This is to keep the download size manageable since whole slide image datasets are often tens to hundreds of gigabytes in size.

In [None]:
import json
def convert_geojson_to_graph(geojson_string):
    geojson_object = json.loads(geojson_string)
    node_positions, node_classes = transpose_geojson(geojson_object)
    edges, distances = create_edges(node_positions)
    return {'edges': edges, 'edge_distance': distances, 'node_positions': node_positions, 'node_types': node_classes}
    
    

In [None]:
# Download the zipped dataset to ../datasets/PANDa
!wget https://github.com/eryl/aida-gnn-workshop-code/releases/download/test_subset_v1/cell_detection.geojson.zip -P ../datasets/PANDa/raw

--2023-09-02 22:12:25--  https://github.com/eryl/aida-gnn-workshop-code/releases/download/test_subset_v1/cell_detection.geojson.zip
Resolving github.com (github.com)... 

140.82.121.3
Connecting to github.com (github.com)|140.82.121.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/679686685/677633a9-7c5a-4a41-9fe4-a6ce87d73ea8?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20230902%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20230902T201226Z&X-Amz-Expires=300&X-Amz-Signature=7b4c2db07c9447564bf7ebe8fc378eb203ed62e04fdecf68409478c42531a5c6&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=679686685&response-content-disposition=attachment%3B%20filename%3Dcell_detection.geojson.zip&response-content-type=application%2Foctet-stream [following]
--2023-09-02 22:12:25--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/679686685/677633a9-7c5a-4a41-9fe4-a6ce87d73ea8?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20230902%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20230902T20

The dataset is in a zip archvie, containing a csv file with all the labels and a directory called "cell_detection.geojson" containing the GeoJSON files. We will not extract the archive, but instead use pythons `zipfile` package to work on the files in memory.

In [None]:
import zipfile
from io import StringIO, BytesIO
import multiprocessing

from tqdm.notebook import tqdm
import pandas as pd

GEOJSON_DIRECTORY = 'cell_detection.geojson'
LABEL_COLUMN = 'isup_grade'
ID_COLUMN = 'image_id'

with zipfile.ZipFile('../datasets/PANDa/raw/cell_detection.geojson.zip') as zf:
    graphs = dict()
    csv_file = None
    for name in tqdm(zf.namelist(), desc="Creating graph data"):
        if GEOJSON_DIRECTORY in name:
            *_, image_id = name.split('/')
            geojson_string = zf.read(name)
            graph_data = convert_geojson_to_graph(geojson_string)
            graph_data['image_id'] = image_id
            graphs[image_id] = graph_data
        elif '.csv' in name:
            csv_string = zf.read(name)
            csv_io = BytesIO(csv_string)
            csv_file = pd.read_csv(csv_io)
            
    if csv_file is None:
        raise RuntimeError("No csv file in archive")
    else:
        for row in csv_file.to_dict('records'):
            image_id = row[ID_COLUMN]
            if image_id in graphs:
                label = row[LABEL_COLUMN]
                graphs[image_id]['label'] = label



    
            

Creating graph data:   0%|          | 0/41 [00:00<?, ?it/s]

In [None]:
# We can take a look at an arbitrary graph from the dictionary using next(iter(dict_obj.items()))
g_id, g_dict = next(iter(graphs.items()))
g_dict

{'edges': array([[ 1717,  1717,  1730, ...,  1630, 16158, 16158],
        [ 1728,  1729,  1718, ..., 16158, 16153,  1630]], dtype=int32),
 'edge_distance': array([ 0.22300516, 24.51444931,  0.1238014 , ..., 37.56006034,
        37.99611239, 37.56006034]),
 'node_positions': array([[ 7661.1943734 ,  9505.08184143],
        [ 7609.30769231,  9557.06318681],
        [ 7576.50592885,  9566.48221344],
        ...,
        [ 6250.01941748, 22480.95631068],
        [ 6070.08095238, 22506.25714286],
        [ 6091.        , 22704.73786408]]),
 'node_types': ['Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic',
  'Neoplastic

### Formatting the target labels

Now that we've extracted the data into python dictionaries we will convert them into a Pytorch Geomtric dataset. We need to gather information about the categorical variables we have: our cell types and our labels. 

In [None]:
# We can inspect what labels we have by putting them in a set
labels = set()
for graph in graphs.values():
    labels.add(graph['label'])
labels

{0, 1, 2, 3, 4, 5}

### Encoding node types
The cell detection framework has also classified the cells into different types. We will need to encode these as integer values for the neural networks. We start by creating a mapping from the string-valued types to a integer one. Its preferable if this mapping is stable, which is the reason why we sort the string values  before creating the mapping.

In [None]:
node_types = set()
for graph in graphs.values():
    node_types.update(graph['node_types'])
node_type_map = {node_type: i for i,node_type in enumerate(sorted(node_types))}
node_type_map

{'Connective': 0,
 'Dead': 1,
 'Epithelial': 2,
 'Inflammatory': 3,
 'Neoplastic': 4}

We now go through all the graphs and convert the string-valued node types to integer tensors

In [None]:
for graph in graphs.values():
    node_types = graph['node_types']
    node_types = torch.tensor([node_type_map[node_type] for node_type in node_types])
    graph['node_types'] = node_types

We now have everything we need to create the `Data` objects for Pytorch Geometric, recall the signature of this class:

```python
class Data(x: Optional[Tensor] = None, edge_index: Optional[Tensor] = None, edge_attr: Optional[Tensor] = None, y: Optional[Tensor] = None, pos: Optional[Tensor] = None, **kwargs)
```

The parameters for the constructor are as follows:

 - **x** (torch.Tensor, optional) – Node feature matrix with shape [num_nodes, num_node_features]. (default: None)
 - **edge_index** (LongTensor, optional) – Graph connectivity in COO format with shape [2, num_edges]. (default: None)
 - **edge_attr** (torch.Tensor, optional) – Edge feature matrix with shape [num_edges, num_edge_features]. (default: None)
 - **y** (torch.Tensor, optional) – Graph-level or node-level ground-truth labels with arbitrary shape. (default: None)
 - **pos** (torch.Tensor, optional) – Node position matrix with shape [num_nodes, num_dimensions]. (default: None)

We will use the following mapping from the graph dictionaries we have created:
 - **x** = `graph['node_types']`
 - **edge_index** = `graph['edges']`
 - **edge_attr** = `graph['edge_distances']`
 - **y** = `graph['label']`
 - **pos** = `graph['node_positions']`
 


In [None]:
import torch
from torch_geometric.data import Data

def create_pyg_data(graph_dict):
    x=graph_dict['node_types']
    edge_index=graph_dict['edges']
    edge_attr=graph_dict['edge_distance']
    y=graph_dict['label']
    pos=graph_dict['node_positions']
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, pos=pos)

In [None]:
pyg_data = [create_pyg_data(g) for g in graphs.values()]
pyg_data[0]

Data(x=[17367], edge_index=[2, 111490], edge_attr=[111490], y=0, pos=[17367, 2])

## Pytorch Geometric Dataset

While pytorch support using a list of `Data` objects for input to its specialized DataLoader, we can get convenient features such as caching of the processed results by using the `Dataset` class in Pytorch Geometric. This class has an interface with some functions which we will overwrite with the code we defined above.

In [None]:
from pathlib import Path
import zipfile
from io import StringIO, BytesIO
import multiprocessing
from collections import Counter
from typing import Sequence, Literal, Optional
import copy
import json

from torch_geometric.data import InMemoryDataset, download_url, Data
from torch_geometric.data.dataset import IndexType
from tqdm.notebook import tqdm
import pandas as pd
import numpy as np
from scipy.spatial import cKDTree


GEOJSON_DIRECTORY = 'cell_detection.geojson'
LABEL_COLUMN = 'isup_grade'
ID_COLUMN = 'image_id'
MPP = 0.48
RADIUS_MICRONS = 20
RADIUS = RADIUS_MICRONS / MPP


class PANDaGraphDataset(InMemoryDataset):
    data_url = 'https://github.com/eryl/aida-gnn-workshop-code/releases/download/PANDa_workshop_data_v1/PANDa_{}.zip'
    data_split = ''
    
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices, self.node_type_map = torch.load(self.processed_paths[0])
    
    @property
    def raw_file_names(self):
        return [f'PANDa_{self.data_split}.zip']

    @property
    def processed_file_names(self):
        return [f'PANDa_{self.data_split}.pt']

    def download(self):
        # Download to `self.raw_dir`.
        url = self.data_url.format(self.data_split)
        download_url(url, self.raw_dir)

    def _transpose_geojson(self, geojson_obj):
        node_positions = []
        node_types = []
        for feature_obj in geojson_obj:
            if feature_obj.get('type') == 'Feature':
                feature_coordinates = feature_obj['geometry']['coordinates']
                node_positions.extend(feature_coordinates)
                feature_labels = [feature_obj['properties']['classification']['name']]*len(feature_coordinates)
                node_types.extend(feature_labels)
        return np.array(node_positions), node_types

    def _create_edges(self, node_positions, radius=RADIUS):
        spatial_index = cKDTree(node_positions)
        sparse_distances = spatial_index.sparse_distance_matrix(spatial_index, radius, output_type='coo_matrix')
        sparse_distances.eliminate_zeros()  # We eliminate all zeros from the matrix, we don't want self-loops
        pair_indices = np.stack([sparse_distances.row, sparse_distances.col], axis=0)
        distances = np.copy(sparse_distances.data)
        return pair_indices, distances

    def _convert_geojson_to_graph(self, geojson_string):
        geojson_object = json.loads(geojson_string)
        node_positions, node_classes = self._transpose_geojson(geojson_object)
        edges, distances = self._create_edges(node_positions)
        return {'edges': edges, 'edge_distance': distances, 'node_positions': node_positions, 'node_types': node_classes}

    def _read_geojson_archive(self):
        LABEL_COLUMN = 'isup_grade'
        ID_COLUMN = 'image_id'  
        graphs = dict()

        for filename in self.raw_file_names:
            file_path = Path(self.raw_dir) / filename
            with zipfile.ZipFile(file_path) as zf:
                graphs = dict()
                csv_file = None
                for name in tqdm(zf.namelist(), desc="Creating graph data"):
                    if GEOJSON_DIRECTORY in name:
                        *_, image_id = name.split('/')
                        geojson_string = zf.read(name)
                        graph_data = self._convert_geojson_to_graph(geojson_string)
                        graph_data['image_id'] = image_id
                        graphs[image_id] = graph_data
                    elif '.csv' in name:
                        csv_string = zf.read(name)
                        csv_io = BytesIO(csv_string)
                        csv_file = pd.read_csv(csv_io)
                        
                if csv_file is None:
                    raise RuntimeError("No csv file in archive")
                else:
                    for row in csv_file.to_dict('records'):
                        image_id = row[ID_COLUMN]
                        if image_id in graphs:
                            label = row[LABEL_COLUMN]
                            graphs[image_id]['label'] = label 
        return graphs
    
    def _convert_node_types_inplace(self, graphs):
        node_types = set()
        for graph in graphs.values():
            node_types.update(graph['node_types'])
        node_type_map = {node_type: i for i,node_type in enumerate(sorted(node_types))}
        for graph in graphs.values():
            node_types = graph['node_types']
            node_types = torch.tensor([node_type_map[node_type] for node_type in node_types])
            graph['node_types'] = node_types
        return node_type_map

    def _convert_to_pyg_data(self, graph):
        x=graph['node_types']
        edge_index=torch.tensor(graph['edges'])
        edge_attr=torch.tensor(graph['edge_distance'])
        y=graph['label']
        pos=torch.tensor(graph['node_positions'])
        return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y, pos=pos)

    def process(self):
        # Read data huge `Data` list.
        graphs = self._read_geojson_archive()
        node_type_map = self._convert_node_types_inplace(graphs)

        data_list = [self._convert_to_pyg_data(graph) for graph in tqdm(graphs.values(), desc="Converting graphs")]

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices, node_type_map), self.processed_paths[0])

    
class PANDaGraphDatasetTrain(PANDaGraphDataset):
    data_split = 'train'

class PANDaGraphDatasetDev(PANDaGraphDataset):
    data_split = 'dev'

class PANDaGraphDatasetTest(PANDaGraphDataset):
    data_split = 'test'


In [None]:
train_dataset = PANDaGraphDatasetTrain('../datasets/PANDa')
test_dataset = PANDaGraphDatasetTest('../datasets/PANDa')
dev_dataset = PANDaGraphDatasetDev('../datasets/PANDa')

In [None]:
train_dataset

PANDaGraphDatasetTrain(1596)

### DataLoaders

Pytorch Geometric has its own dataloaders. It allows for efficient batching of graphs by using the fact that all of the PyG operations rely on sparse implementations. By taking all graphs and essentially putting them into on large batch graph, the algorithms dont have to keep special track of batch examples as long as the operations are restricted to follow the graph structure.

In [None]:
from torch_geometric.loader import DataLoader
batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, drop_last=False, shuffle=True, num_workers=4)
dev_dataloader = DataLoader(dev_dataset, batch_size=batch_size, drop_last=False, shuffle=False, num_workers=4)
