# Predict CDH1 status for breast invasive carcinoma patients

In this notebook we study the prediction of CDH1 status for breast invasive carcinoma patients (BRCA) from H&E slides. We will use vision transformers with 224x224 pixels patches (encoded with [UNI feature extractor](https://huggingface.co/MahmoodLab/UNI)) as tokens.   


**Note:** In the following, we will use and update the code from [STAMP](https://github.com/KatherLab/STAMP) to train and test our models.

In [1]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path
import pandas as pd
import torch

## 0. Prepare files (run only one time)

The following code creates several data files from the raw data so that functions from **STAMP** can be easily adapted.

**Note:** This should only be run one time for the initialisation of the project.

In [None]:
df_clinical = pd.read_csv("data/sample_matrix.txt", sep="\t")
df_clinical = df_clinical.rename(columns={"studyID:sampleId": "PATIENT"})
df_clinical["PATIENT"] = df_clinical.PATIENT.map(lambda x: x.split(':')[1])
df_clinical.head()

Unnamed: 0,PATIENT,Altered,CDH1
0,TCGA-AN-A0FJ-01,0,0
1,TCGA-AN-A0FF-01,0,0
2,TCGA-AN-A0FD-01,0,0
3,TCGA-AN-A0AT-01,0,0
4,TCGA-AN-A0AS-01,0,0


In [None]:
df_clinical.to_csv("data/clinical.csv")

In [None]:
feature_dir = Path("drive/MyDrive/TCGA_BRCA/features/")
h5s = set(feature_dir.glob("*.h5"))

h5_df = pd.DataFrame(h5s, columns=["slide_path"])
h5_df["FILENAME"] = h5_df.slide_path.map(lambda p: p.stem)
h5_df["PATIENT"] = h5_df.FILENAME.map(lambda x: ('-'.join(x.split(".")[0].split('-')[:4]))[:-1])

h5_df = h5_df.drop(columns="slide_path")

h5_df.head()

Unnamed: 0,FILENAME,PATIENT
0,TCGA-A7-A13G-01Z-00-DX2.72EF429E-75A7-4D1B-AFF...,TCGA-A7-A13G-01
1,TCGA-AC-A3QQ-01Z-00-DX1.86463263-AB12-49FB-896...,TCGA-AC-A3QQ-01
2,TCGA-EW-A1PE-01Z-00-DX1.8EF56824-0B37-4AD1-AF3...,TCGA-EW-A1PE-01
3,TCGA-B6-A0IQ-01Z-00-DX1.662EA039-825E-41FF-91D...,TCGA-B6-A0IQ-01
4,TCGA-E9-A247-01Z-00-DX1.3B2DF1CB-054A-44C4-9AD...,TCGA-E9-A247-01


In [None]:
h5_df.to_csv("data/slide_table.csv")

## 1. Initial train-test split for both tasks

We first split our data in training, validation and test sets (with a custom function in cdh1pred package) so that we will use the same sets for all our experiments.

In [2]:
from cdh1pred.utils import init_train_val_test

In [3]:
path_clin = Path("data/clinical.csv")
path_slide = Path("data/slide_table.csv")
path_feature = Path("data/features/")

train_df, test_df, val_indexes, enc = init_train_val_test(path_clin,
                                                          path_slide,
                                                          path_feature,
                                                          target_label = "CDH1")

## 2. Warm-up task

In this first experiment we use a two-layer transformer (using **STAMP**) to predict CDH1 status from WSIs. We use all the default settings.

In [4]:
from stamp.modeling.marugoto.transformer.base import train, deploy

In [5]:
learn = train(
    bags=train_df.slide_path.values,
    targets=(enc, train_df["CDH1"].values),
    valid_idxs=train_df.PATIENT.isin(val_indexes).values,
    cores=max(1, os.cpu_count() // 4),
    n_epoch=1
    )

Model: TransMIL(
  (fc): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): GELU(approximate='none')
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0-1): 2 x ModuleList(
        (0): Attention(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (mhsa): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
        )
        (1): FeedForward(
          (mlp): Sequential(
            (0): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=512, out_features=512, bias=True)
            (2): GELU(approximate='none')
            (3): Dropout(p=0.0, inplace=False)
            (4): Linear(in_features=512, out_features=512, bias=True)
            (5): Dropout(p=0.0, inplace=False)
          )
        )
      )
    )
    (norm): LayerNorm((512,), eps=1e-0

epoch,train_loss,valid_loss,roc_auc_score,time
0,0.164452,0.143096,0.771707,09:57


Better model found at epoch 0 with valid_loss value: 0.14309552311897278.


  state = torch.load(file, map_location=device, **torch_load_kwargs)


In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == "cuda":
  torch.set_float32_matmul_precision("high")

patient_preds_df = deploy(test_df=test_df,
                          learn=learn,
                          target_label="CDH1",
                          cat_labels=False,
                          cont_labels=False,
                          device=device)

patient_preds_df.head()

Unnamed: 0,PATIENT,CDH1,pred,CDH1_0,CDH1_1,loss
56,TCGA-AN-A0AL-01,0,0,0.756654,0.243346,0.469073
20,TCGA-A2-A3XV-01,0,0,0.733053,0.266947,0.487009
127,TCGA-D8-A1JT-01,0,0,0.729519,0.270482,0.48974
142,TCGA-E2-A14Y-01,0,0,0.724617,0.275383,0.493547
8,TCGA-A2-A0CS-01,0,0,0.723395,0.276605,0.4945


## 3. Main task

We now predict the CDH1 status, considering the relative distance between patches within each WSI. To do so, We use an updated version of the TransMIL model from **STAMP** (implemented in the package **cdh1pred**) with a self-attention mechanism that takes into account the relative distance between tokens (i.e encoded patches).

Our strategy to incorporate the relative distance between tokens within the attention computation is based on the solution presented in [Shaw et al. Self-Attention with Relative Position Representations, 2018](https://arxiv.org/pdf/1803.02155). It modifies a [PyTorch implementation of such strategy](https://github.com/AliHaiderAhmad001/Self-Attention-with-Relative-Position-Representations) and adapts it to 2D relative distances.   

#### 1. Compute pairwise distances

Using the coordinates of each patch we compute the euclidean distance $d_{ij} \in \mathbb{R}$ for each pair of patches $i$ and $j$. This produces a distance matrix $D \in \mathbb{R}^{n \times n}$ ($n$ number of tokens).

#### 2. Rank neighbors and clip

For a given patch $i$ we rank every other patch $j$ from the nearest-neighbor of $i$ to the furthest. It produces a matrix $N \in \mathbb{R}^{n \times n}$ such that $n_{ij} \in \{0, 1, ..., n-1\}$ corresponds to the neighbor rank of patch $j$ with respect to patch $i$ ($n_{ij} = 0 \text{ if } i=j$).

We then clip this rank to a pre-specified number $k$ (hyperparameter of the model), using a similar strategy as the one described in [Shaw et al (2018)](https://arxiv.org/pdf/1803.02155).

**Note 1:** In our implementation we arbitrarily set $k=10$ to focus on the 10th-nearest neighbor of each patch.

**Note 2:** To deal with the additional class token $c$ we set $n_{ic} = n_{jc} = 0$ for all patches $i$ and $j$, considering the same "null distance" between each patch and the class token.  

#### 3. Associate each edge to learnable labels

For each edge type (i.e., neighbor rank $\in \{0, ..., k\})$ we associate a vector $w_l \in \mathbb{R}^{d} \, (l \in {0, ..., k})$ (with learnable coefficients) and build a tensor $\mathbf{A} \in \mathbb{R}^{n \times n \times d}$ such that $a_{ij} \in \mathbb{R}^{d}$ represents the edge (i.e., neighbor relationship) between pathches $i$ and $j$.   

This tensor is then used as a bias in the self-attention mechanism, using the formulas **(3)** and **(4)** described in [Shaw et al (2018)](https://arxiv.org/pdf/1803.02155).

#### **Pros:**

* Such strategy can be easily generalizable to other distances, or any graph structure over the patches. For instance we can think of additional features (e.g., gene expression from spatial transcriptomics) for each patch, non-euclidean distances, incorporate patches annotations from pathologists...

* By using learnable parameters to represent the different neighbor ranks, this strategy offers flexibility to the model to learn how to best incorporate the relative distances in the self-attention mechanism.

#### **Cons:**

* Although the clipping step allows to generalize to higher number of tiles and to focus on the relative position only for close neighbors it adds another hyperparameter to the model that might need cumbersome tuning to best leverage the relative distances (that also might depend on the predictive task).

* Using the neighbor rank rather than the euclidean distance may lead to cases where two pairs of patches with the same rank (e.g., $n_{ij} = n_{kl} = 1$) are considered the same while being associated with very different euclidean distances (e.g., $d_{ij} >> d_{kl}$). This may occur, in particular, when random sampling is used during training and a small subset of patches is considered.

* This strategy adds $(k+1) \times d$ parameters to the model. While this adds flexibility it might complexify the training, especially for data sets with small/moderate sample sizes.

In [4]:
from cdh1pred.base_dist import train_dist, deploy_dist

In [5]:
learn = train_dist(
    bags=train_df.slide_path.values,
    targets=(enc, train_df["CDH1"].values),
    valid_idxs=train_df.PATIENT.isin(val_indexes).values,
    cores=max(1, os.cpu_count() // 4),
    n_epoch=1
    )

Model: TransMILDist(
  (fc): Sequential(
    (0): Linear(in_features=1024, out_features=4, bias=True)
    (1): GELU(approximate='none')
  )
  (dropout): Dropout(p=0.0, inplace=False)
  (transformer): TransformerDist(
    (layers): ModuleList(
      (0-1): 2 x ModuleList(
        (0): AttentionDist(
          (norm): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
          (mhsa): RelationAwareMultiHeadAttention(
            (relative_position_k): RelativePosition()
            (relative_position_v): RelativePosition()
            (attention_heads): ModuleList(
              (0-1): 2 x RelationAwareAttentionHead(
                (query_weights): Linear(in_features=4, out_features=2, bias=True)
                (key_weights): Linear(in_features=4, out_features=2, bias=True)
                (value_weights): Linear(in_features=4, out_features=2, bias=True)
              )
            )
            (fc): Linear(in_features=4, out_features=4, bias=True)
          )
        )
        (1):

epoch,train_loss,valid_loss,roc_auc_score,time
0,0.160414,0.159942,0.443171,1:53:09


Better model found at epoch 0 with valid_loss value: 0.15994204580783844.


  state = torch.load(file, map_location=device, **torch_load_kwargs)


In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == "cuda":
  torch.set_float32_matmul_precision("high")

patient_preds_dist_df = deploy_dist(test_df=test_df,
                                  learn=learn,
                                  target_label="CDH1",
                                  device=device)

patient_preds_dist_df.head()

Unnamed: 0,PATIENT,CDH1,pred,CDH1_0,CDH1_1,loss
68,TCGA-AO-A12E-01,1,1,0.411545,0.588455,0.608599
1,TCGA-4H-AAAK-01,1,1,0.414843,0.585157,0.611611
139,TCGA-D8-A27T-01,1,1,0.415857,0.584143,0.61254
45,TCGA-A8-A0AB-01,1,1,0.418232,0.581768,0.614719
30,TCGA-A7-A4SC-01,1,1,0.418599,0.581401,0.615056
