# HIGT Usage Tutorial:

HIGT uses thumbnail, 5x and 10x three levels of whole slide microscope image (WSI) data to form the input data. In order to allow readers to better understand and practice, we provide a complete tutorial. You can download the [TCGA data](https://portal.gdc.cancer.gov/) through the link below. This tutorial is mainly built on [CLAM](https://github.com/mahmoodlab/CLAM), but we made some changes to suit specific needs. 

This tutorial is divided into three main parts: Preprocessing, GPU Training, Testing and Evaluation. The details are as follows:

## 1. Preprocessing:

This section is mainly divided into five parts: Basic Information Statistics, Patch Segmentation, Feature Extraction, Hierarchical Graph (Tree) Generation and Dataset Split. The details of each part are as follows:

### 1.1 Basic Information Statistics:

Use the `generate_pl_bm` function to analyze the basic information of Whole Slide Images (WSI). It mainly realizes the following two functions:

1. **Objective Lens Magnification**:   
Count and record the default objective lens magnification information for the target WSI file. The recorded data is saved in a file named `bm.csv`.

2. **Process List Generation**:   
Based on the target objective lens magnification and the cutting block size of the target objective lens magnification, generate the corresponding `pl_mag{target_magnification}x_patch{base_patch_size}_{target_patch_size}.csv`.

3. **Data cleaning**:  
Cleared some WSIs without default objective magnifications.

##### (1) Parameter Description:

- **`--WSI_dir`**: Saving directory of the WSI file.
- **`--csv_dir`**: Saving directory of the CSV file.
- **`--base_patch_size`**: The size of the cutting block used by default at the target objective lens magnification.
- **`--target_mag`**: Target objective lens magnification.

##### (2) Usage Example:

In the following experiments, we demonstrate with a data set with a default objective lens magnification of 20 magnifications, called `WSI_bm20`. The codes for basic statistics can be structured as follows:

```python
generate_pl_bm(
        WSI_dir="/path/to/exp/WSI_bm20", 
        save_dir="/path/to/dataset_csv/WSI_bm20/", 
        base_patch_size=512, 
        target_mag=5
)

generate_pl_bm(
        WSI_dir="/path/to/exp/WSI_bm20", 
        save_dir="/path/to/dataset_csv/WSI_bm20/", 
        base_patch_size=512, 
        target_mag=10
)
```

**Note**: Please replace the paths and values with those relevant to your specific setup.

In [1]:
def generate_pl_bm(
        WSI_dir, 
        save_dir, 
        base_patch_size, 
        target_mag
    ):
    import openslide, glob
    import pandas as pd
    '''
    WSI_path: path for WSI files
        WSI_path/1.svs,
        WSI_path/2.svs,
        WSI_path/3.svs,
        ...
    save_path: path for process list csv file and base magnification csv file
        save_path/pl_mag20x_patch512_1024.csv,
        save_path/pl_mag20x_patch512_2048.csv,
        ...
        save_path/bm_mag20x_patch512.csv
    '''
    process_list = {} 
    base_mag_csv = {
        "slide_path": [],
        "base_mag": []
    }

    for WSI in glob.glob(WSI_dir+"/*"):
        slide = openslide.open_slide(WSI)
        wsi_name = WSI.split("/")[-1]
        if slide.properties.get(openslide.PROPERTY_NAME_OBJECTIVE_POWER) == None:
            continue

        base_mag = int(slide.properties.get(openslide.PROPERTY_NAME_OBJECTIVE_POWER))
        target_min_patch_size = int(base_patch_size*(base_mag/target_mag))

        # Update for process_list
        if target_min_patch_size not in process_list:
            process_list[target_min_patch_size] = [wsi_name]
        else:
            process_list[target_min_patch_size].append(wsi_name)

        # Update for base_mag_csv
        base_mag_csv["slide_path"].append(WSI)
        base_mag_csv["base_mag"].append(base_mag)

        # save base_mag_csv.csv
        df = pd.DataFrame(base_mag_csv)
        df.to_csv(save_dir+f"bm_mag{target_mag}x_patch{base_patch_size}.csv")

        # save patch_size_i process_list.csv
        for k in process_list.keys():
            df = pd.DataFrame({
                "slide_id": process_list[k]
            })
            df.to_csv(save_dir+f"pl_mag{target_mag}x_patch{base_patch_size}_{k}.csv")

### 1.2 Patch Segmentation:

Patch segmentation is required to be performed based on the default objective magnification of the slide, in conjunction with the cutting block size under the target magnification. This process results in the acquisition of the cutting block size under the default objective magnification. 

For HIGT, the cutting results at 5x and 10x magnification are required. First, the patch cutting at 5x magnification is performed based on CLAM’s create_patches_fp, and then the cutting results at 10x magnification of the corresponding data set are generated according to the cutting results at 5x magnification and the `generate_coords_file` function below.

#### 1.2.1 Patch segmentation at 5x magnification:

##### (1) Parameter Description:

- **`--source`**: Slide file root directory.
- **`--save_dir`**: Directory to save the results of patch cutting.
- **`--patch_size`**: The size of the patch at the default objective magnification.
- **`--step_size`**: Step size for cutting patches. If no overlap is required, this should be the same as `patch_size`.
- **`--seg`**: Flag to indicate whether to generate a mask.
- **`--patch`**: Flag to indicate whether to generate a patch.
- **`--stitch`**: Flag to indicate whether to generate a stitch.
- **`--process_list`**: Process list, mainly utilizing the `slide_id` column; other columns can remain at their default settings.

##### (2) Usage Example:
The command line for patch segmentation can be structured as follows:

```bash
python create_patches_fp.py --source /path/to/exp/WSI_bm20 --save_dir /path/to/exp/segmentation/WSI_bm20/mag5x_patch512_2048 --patch_size 2048 --step_size 2048 --seg --patch --stitch --process_list /path/to/dataset_csv/WSI_bm20/pl_mag5x_patch512_2048.csv
```

**Note**: Please replace the paths and values with those relevant to your specific setup.

#### 1.2.2 Generation of patch segmentation at 10x magnification:

Generate the segmentation results at 10x magnification according to the cutting results at 5x magnification and the `generate_coords_file` function below.

##### (1) Parameter Description:
- **`--ori_mag_seg_path`**: Saving path of the generated target magnification segmentation result.
- **`--target_mag`**: Target magnification.

##### (2) Usage Example:

The code for the generation of the patch segmentation can be structured as follows:

```python
generate_coords_file(
    ori_mag_seg_path="/path/to/exp/segmentation/WSI_bm20/mag5x_patch512_2048", 
    target_mag=10
)
```

**Note**: Please replace the paths and values with those relevant to your specific setup.

In [None]:
def generate_coords_file(ori_mag_seg_path, target_mag):
    
    import os, h5py, glob
    import numpy as np
    from wsi_core.wsi_utils import save_hdf5

    cur_mag = int(ori_mag_seg_path.split("/")[-1].split("_")[0].replace("mag","").replace("x",""))
    cur_patch_size = int(ori_mag_seg_path.split("/")[-1].split("_")[-1])
    target_patch_size = int(cur_patch_size/int(target_mag/cur_mag))

    for h5 in glob.glob(ori_mag_seg_path+"/patches/*"):
        h5_content = h5py.File(h5,'r')

        coords = h5_content["coords"][:]
        save_path = "/".join(ori_mag_seg_path.split("/")[:-1])+f"/mag{target_mag}x_patch512_{target_patch_size}/patches/"

        if not os.path.exists(save_path):
            os.makedirs(save_path)

        attr = {
            'patch_size' :            target_patch_size, 
            'patch_level' :           h5_content["coords"].attrs["patch_level"],
            'downsample':             h5_content["coords"].attrs["downsample"],
            'downsampled_level_dim' : h5_content["coords"].attrs["downsampled_level_dim"],
            'level_dim':              h5_content["coords"].attrs["level_dim"],
            'name':                   h5_content["coords"].attrs["name"],
            'save_path':              save_path
        }

        h5_content.close()
        coords_ = []
        for coord in coords:
            x,y = coord
            coords_.append([x,y])
            coords_.append([x+target_patch_size,y])
            coords_.append([x,y+target_patch_size])
            coords_.append([x+target_patch_size,y+target_patch_size])
            
        coords_ = np.array(coords_).astype(coords.dtype)
        unique_coords = np.unique(coords_, axis=0)

        save_hdf5(
            save_path+h5.split("/")[-1], 
            {"coords": unique_coords}, 
            {"coords":attr}, 
            mode="w"
        )

        # break

### 1.3 Feature Extraction:

Feature extraction is performed based on the results of patch segmentation. This section describes the parameters and provides examples for performing feature extraction. The feature extraction of the thumbnail part can be implemented using the following `extract_features_thumb` function.

#### 1.3.1 Feature Extraction at 5x and 10x magnification:

Feature extraction of patch images based on existing patch cutting results.

##### (1) Parameter Description:

- **`--csv_path`**: Specifies the path to the process list. It mainly uses the `slide_id` column, and other columns can be omitted.
- **`--data_h5_dir`**: Specifies the root directory for the results of patch segmentation.
- **`--data_slide_dir`**: Specifies the root directory for slide files.
- **`--feat_dir`**: Specifies the root directory for saving extracted features.
- **`--batch_size`**: Defines the batch size for processing (e.g., `32`, `64`, etc.).
- **`--target_patch_size`**: Specifies the size of the input image.
- **`--slide_ext`**: Defines the file suffix for slide files.

##### (2) Usage Example:

The command line for feature extraction can be structured as follows:

```bash
# 5x Magnification
python feature_extraction.py --csv_path /path/to/dataset_csv/WSI_bm20/pl_mag5x_patch512_2048.csv --data_h5_dir /path/to/exp/segmentation/WSI_bm20/mag5x_patch512_2048 --data_slide_dir /path/to/exp/WSI_bm20 --feat_dir /path/to/exp/extracted_feature/WSI_bm20/mag5x_patch512_2048 --batch_size 256 --target_patch_size 256 --slide_ext .svs

# 10x Magnification
python feature_extraction.py --csv_path /path/to/dataset_csv/WSI_bm20/pl_mag10x_patch512_1024.csv --data_h5_dir /path/to/exp/segmentation/WSI_bm20/mag10x_patch512_1024 --data_slide_dir /path/to/exp/WSI_bm20 --feat_dir /path/to/exp/extracted_feature/WSI_bm20/mag10x_patch512_1024 --batch_size 256 --target_patch_size 256 --slide_ext .svs
```

**Note**: Please replace the paths and values with those relevant to your specific setup.

### 1.3.2 Feature Extraction for Thumbnail:

Based on the `extract_features_thumb` function, extract features from WSI thumbnail images.

##### (1)Parameter Description:
- **`--data_slide_dir`**: Specifies the root directory for slide files.
- **`--cuda_id`**: ID of GPUs to use.
- **`--feat_dir`**: Specifies the root directory for saving extracted features.

##### (2) Usage Example:

The code for feature extraction of thumbnail can be structured as follows:

```python
feature_extract_thumb(
    data_slide_dir="/path/to/exp/WSI_bm20",
    cuda_id=0,
    feat_dir="/path/to/exp/extracted_feature/WSI_bm20/mag10x_patch512_1024"
)
```

**Note**: Please replace the paths and values with those relevant to your specific setup.

In [1]:
def feature_extract_thumb(
    data_slide_dir,
    cuda_id,
    feat_dir
):
    import openslide, glob, torch, h5py, os
    from torchvision import transforms
    from models.resnet_custom import resnet50_baseline
    device = torch.device(f"cuda:{cuda_id}" if torch.cuda.is_available() else "cpu")
    
    model = resnet50_baseline(pretrained=True).to(device)
    model.eval()

    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    with torch.no_grad():

        for slide_path in glob.glob(data_slide_dir+"/*"):
            slide = openslide.OpenSlide(slide_path)
            thumbnail_image = slide.get_thumbnail((512,512))
            thumbnail_tensor = transform(thumbnail_image).to(device)

            thumbnail_feat = model(thumbnail_tensor.unsqueeze(0))
            out_root_path = f'{feat_dir}/thumbnail/'
            if not os.path.exists(out_root_path):
                os.makedirs(out_root_path)
            with h5py.File(out_root_path+slide_path.split("/")[-1]+".h5", 'w') as h5f:
                h5f.create_dataset('features', data=thumbnail_feat.cpu().numpy())
                h5f.create_dataset('coords', data=[[0,0]])
            print(slide_path.split("/")[-1]," done")

### 1.4 Graph (Tree) Generation:
Based on the features of thumbnail, 5x and 10x extracted above, use the `generate_graph_tree()` function to construct a heterogeneous graph (tree),

##### (1) Parameter Description:
- **`--target_mags`**: List of all target mags.
- **`--bm_path`**: Saving path of base magnification.
- **`--feature_path`**: Specifies the root directory for saving extracted features.
- **`--base_patch_size`**: The patch size at the target magnification.
- **`--save_path`**: Save path of the generated graphs.

##### (2) Usage Example:

The code for graph generation can be structured as follows:

```python
generate_graph_tree(
     target_mags = [float("inf"), 5, 10],
     bm_path = "/path/to/dataset_csv/WSI_bm20/bm.csv",
     feature_path = "/path/to/exp/extracted_feature/WSI_bm20/",
     base_patch_size = 512,
     save_path = "/path/to/exp/extracted_feature/WSI_bm20/"
)
```

In [10]:
def generate_graph_tree(
    target_mags, bm_path, feature_path, base_patch_size, save_path
):
    import os, h5py, torch
    from torch_geometric.data import Data
    import numpy as np
    import pandas as pd
    def get_edge_index(feature_i, region_patch_size, patch_patch_size):
    
        start = []
        end = []

        len_thumbnail = 1
    
        region_mag = list(feature_i.keys())[1]
        region_features = feature_i[region_mag]["features"]
        region_coords = feature_i[region_mag]["coords"]
        len_region_node = len(region_features)

        patch_mag = list(feature_i.keys())[2]
        patch_features = feature_i[patch_mag]["features"]
        patch_coords = feature_i[patch_mag]["coords"]
        len_patch_node = len(patch_features)
        
        # 1. region_level
        for i in range(len_region_node):

            # 1. thumbnail <=> region
            region_index = i+len_thumbnail
            x,y = region_coords[i]
            start.append(0)
            end.append(region_index)
            start.append(region_index)
            end.append(0)

            # 2. region => region
            nb_coord_list = [
                [x-region_patch_size, y-region_patch_size], [x, y-region_patch_size], [x+region_patch_size, y+region_patch_size],
                [x-region_patch_size, y], [x+region_patch_size, y],
                [x-region_patch_size, y+region_patch_size], [x, y+region_patch_size], [x+region_patch_size, y+region_patch_size]
            ]
            for xy_ in nb_coord_list:
                x_, y_ = xy_
                if np.any(np.all(region_coords == [x_, y_], axis=1)):
                    to_region_index = np.where((region_coords==(x_, y_)).all(axis=1))[0][0]+len_thumbnail
                    start.append(region_index)
                    end.append(to_region_index)
                    start.append(to_region_index)
                    end.append(region_index)

        # 2. patch_level
        for i in range(len_patch_node):
            patch_index = i+len_thumbnail+len_region_node
            x,y = patch_coords[i]

            # 1. patch <=> region
            region_coord = (patch_coords[i]//2048)*2048
            x_region, y_region = region_coord
            region_index = np.where((region_coords==(x_region, y_region)).all(axis=1))[0][0]+len_thumbnail
            start.append(region_index)
            end.append(patch_index)
            start.append(patch_index)
            end.append(region_index)

            # 2. patch <=> patch
            nb_coord_list = [
                [x-patch_patch_size, y-patch_patch_size], [x, y-patch_patch_size], [x+patch_patch_size, y+patch_patch_size],
                [x-patch_patch_size, y], [x+patch_patch_size, y],
                [x-patch_patch_size, y+patch_patch_size], [x, y+patch_patch_size], [x+patch_patch_size, y+patch_patch_size]
            ]
            for xy_ in nb_coord_list:
                x_,y_ = xy_
                if np.any(np.all(patch_coords == [x_, y_], axis=1)):
                    to_patch_index = np.where((patch_coords==(x_,y_)).all(axis=1))[0][0]+len_thumbnail+len_region_node
                    start.append(patch_index)
                    end.append(to_patch_index)
                    start.append(to_patch_index)
                    end.append(patch_index)
        return [start, end]

    all_data = {}
    bm = pd.read_csv(bm_path)
    
    for slide_path,base_mag in zip(bm['slide_path'],bm['base_mag']):
        wsi_name = slide_path.split("/")[-1].replace(".svs","")
        h5_name = wsi_name+".h5"
        
        feature_i = {}
        all_feature = []

        min_mag = min(target_mags)
        min_patch_size = int(base_patch_size*(base_mag/min_mag))

        for cur_mag in target_mags:
            if cur_mag!=float("inf"):
                cur_patch_size = int(base_patch_size*(base_mag/cur_mag))
                cur_feature_path = f"{feature_path}/mag{cur_mag}x_patch512_{cur_patch_size}/h5_files/{h5_name}"
            else:
                cur_feature_path = f"{feature_path}/thumbnail/h5_files/{h5_name}"
            h5_content = h5py.File(cur_feature_path,'r')
            h5_features = h5_content["/features"][:]
            h5_coords = h5_content["/coords"][:]
            if cur_mag == float("inf"):
                cur_mag = 0
            feature_i[cur_mag] = {
                "features": h5_features,
                "coords": h5_coords
            }
            if len(h5_features.shape)<2:
                h5_features = np.expand_dims(h5_features, 0)
            all_feature.append(h5_features)

        # generate the Data
        # 0. preparation
        region_mag = list(feature_i.keys())[1]
        patch_mag = list(feature_i.keys())[2]

        region_coords = feature_i[region_mag]["coords"]
        patch_coords = feature_i[patch_mag]["coords"]


        len_region_node, len_patch_node = len(region_coords), len(patch_coords)

        region_patch_size = int((base_mag/region_mag)*base_patch_size)
        patch_patch_size = int((base_mag/patch_mag)*base_patch_size)

        # 1. x
        all_feature = np.concatenate(all_feature, axis=0)

        # 2. edge_index_tree_8nb
        edge_index = get_edge_index(feature_i, region_patch_size, patch_patch_size)
        print("edge done")

        # 3. batch
        batch = [0]*len(all_feature)

        # 4. data_id
        data_id = wsi_name.split(".")[0]

        # 5. node_type
        node_type = [0]+[1]*len_region_node+[2]*len_patch_node

        # 6. node_tree
        node_tree_wo_patch = [-1]+[0]*len_region_node
        node_tree_patch = []
        for i in range(len_patch_node):
            x, y = patch_coords[i]
            x_region = int((x//region_patch_size)*region_patch_size)
            y_region = int((y//region_patch_size)*region_patch_size)
            region_index = np.where((region_coords==(x_region,y_region)).all(axis=1))[0][0]+1
            node_tree_patch.append(region_index)
        node_tree = node_tree_wo_patch+node_tree_patch

        # 7. x_y_index 
        region_patch_size = int(base_patch_size*(base_mag/target_mags[1]))
        patch_patch_size = int(base_patch_size*(base_mag/target_mags[2]))
        # print(np.array([[0,0]]).shape, region_coords.shape, patch_coords/patch_patch_size)
    
        x_y_index = np.concatenate([
            np.array([[0,0]]),
            region_coords/region_patch_size,
            patch_coords/patch_patch_size,
        ])
    
        # generate the Data 
        node_attr=torch.tensor(all_feature, dtype=torch.float)
        edge_index_tree_8nb = torch.tensor(edge_index,dtype=torch.long)
        batch = torch.tensor(batch)
        node_type = torch.tensor(node_type)
        node_tree = torch.tensor(node_tree)
        x_y_index = torch.tensor(x_y_index, dtype=torch.float)

        data = Data(
            x = node_attr,
            edge_index_tree_8nb = edge_index_tree_8nb,
            data_id = data_id,
            batch = batch,
            node_type = node_type,
            node_tree = node_tree,
            x_y_index = x_y_index
        )

        torch.save(data, f'{save_path}/pt_files/{data_id}.pt')

### 1.5 Dataset Split

Based on CLAM's `create_split_seq.py`, some changes have been made, the input of the label file has been added, and the task has been set to any number of classifications，re-write the `create_split_seq_re.py` file. The `label.csv` needs to include three columns of useful information: case_id, slide_id and label.

##### (1) Parameter Description:
- **`--label_path`**: Saving path of the classification label.
- **`--seed`**: Random seed.
- **`--task`**: Task name.
- **`--k`**: Number of splits.
- **`--label_frac`**: Fraction of labels.
- **`--val_frac`**: Fraction of labels for validation.
- **`--test_frac`**: Fraction of labels for test.

##### (2) Usage Examples:
The command line for dataset split can be structured as follows:

```bash
python create_splits_seq_re.py --label_path /path/to/dataset_csv/WSI_bm20/label.csv --seed 1 --task WSI_bm20 --label_frac 1.0 --k 10 --val_frac 0.2 --test_frac 0.2
```

## 2. GPU Training:

In `main.py` and `utils/core_utils.py`, some minor additions and changes were made to enable the training of the HIGT model, and re-write the `main_re.py` and `utils/core_utils_re.py` files.

##### (1) Parameter Description:
- **`--max_epochs:`**: Maximum number of epochs to train.
- **`--label_path:`**: The path of the classification label.
- **`--drop_out`**: Enable dropout (p=0.25).
- **`--early_stopping`**: Enable early stopping.
- **`--lr`**: Learning rate.
- **`--k`**:Number of folds.
- **`--label_frac`**: Fraction of training labels.
- **`--exp_code`**: Experiment code for saving results.
- **`--weighted_sample`**: Enable weighted sampling.
- **`--bag_loss`**: Slide-level classification loss function.
- **`--task`**: Task name.
- **`--split_dir`**: Manually specify the set of splits to use.
- **`--model_type`**: Model type used for training.
- **`--log_data`**: Log data using tensorboard.
- **`--data_root_dir`**: Data directory.
- **`--results_dir:`**: Desults directory.

##### (2) Usage Examples:

The command line for GPU training can be structured as follows:

```bash
CUDA_VISIBLE_DEVICES=0 python main_re.py --max_epochs 100 --label_path /path/to/dataset_csv/WSI_bm20/label.csv --drop_out --early_stopping --lr 2e-4 --k 10 --label_frac 1.0 --exp_code test_binary_100 --weighted_sample --bag_loss ce --task WSI_bm20 --split_dir /path/to/exp/splits/WSI_bm20 --model_type HIGT --log_data --data_root_dir /path/to/exp/extracted_feature/WSI_bm20/graph --results_dir /path/to/exp/result/WSI_bm20
```

## 3. Testing and Evaluation:


In `eval.py`, some minor additions and changes were made to enable the testing and evaluating of the HIGT model, and re-write the `eval_re.py` file.

##### (1) Parameter Description:
- **`--drop_out`**: Enable dropout (p=0.25).
- **`--k`**: Number of folds.
- **`--models_exp_code`**: Fraction of training labels.
- **`--save_exp_code`**: Experiment code for saving results.
- **`--task`**: Task name.
- **`--model_type`**: Model type used for training.
- **`--results_dir:`**: Results directory.
- **`--split_dir`**: Manually specify the set of splits to use.
- **`--data_root_dir`**: Data directory.

##### (2) Usage Examples:

The command line for testing and evaluation can be structured as follows:

```bash
CUDA_VISIBLE_DEVICES=0 python eval_re.py --drop_out --k 10 --models_exp_code test_binary_100 --save_exp_code WSI_bm20 --task WSI_bm20 --model_type HIGT --results_dir /path/to/exp/result/WSI_bm20 --split_dir /path/to/exp/splits/WSI_bm20 --data_root_dir /path/to/exp/extracted_feature/WSI_bm20/graph
```