# Introduction: PseMix Walkthrough

PseMix (pseudo-bag mixup) contains two key steps:
- **generating pseudo-bags**（`Step 1`; [its notebook](https://github.com/liupei101/PseMix/blob/main/notebooks/psemix_walkthrough_step1_pseudo_bag_generation.ipynb)),
- **mixing pseudo-bags**（`Step 2`; [its notebook](https://github.com/liupei101/PseMix/blob/main/notebooks/psemix_walkthrough_step2_pseudo_bag_mixup.ipynb)).

This notebook aims to help you get started with ***Step 2: Pseudo-bag Mixup***. 

First of all, load required packages

In [1]:
import os
import os.path as osp
import numpy as np
from sklearn.manifold import TSNE
import torch
import torch.nn.functional as F

import matplotlib as mpl
from matplotlib import pyplot as plt
import seaborn as sns

os.chdir("..")
os.getcwd()

'/home/liup/repo/PseMix'

# Step 2: Pseudo-bag Mixup

This step mainly shows how to obtained **mixed bags** (or **mixup-augmented bags**).

Two WSI bags will be taken as the example for illustation: 
- Bag A: `./wsi_feats/feat_wsi_A_TCGA_3C_AALI.pt`.
- Bag B: `./wsi_feats/feat_wsi_B_TCGA_BH_A2L8.pt`.

We first obtain the pseudo-bags of bag A and bag B, as explained in [the tutorial of the first step of PseMix](https://github.com/liupei101/PseMix/blob/main/notebooks/psemix_walkthrough_step1_pseudo_bag_generation.ipynb).

In [2]:
from utils.io import read_patch_data
from utils.core import PseudoBag
from utils.func import seed_everything
seed_everything(42)

NUM_CLUSTER = 8 # the number of clusters
NUM_PSEB = 30 # the number of pseudo-bags
NUM_FT = 8 # # fine-tuning times

PB = PseudoBag(NUM_PSEB, NUM_CLUSTER, proto_method='mean', pheno_cut_method='quantile', iter_fine_tuning=NUM_FT)

# Bag A:
# load WSI features
bag_feats_A = read_patch_data("./wsi_feats/feat_wsi_A_TCGA_3C_AALI.pt", dtype='torch').to(torch.float)
label_A = torch.LongTensor([1])
# label_pseudo_bag: the indicator of pseudo-bags
label_pseudo_bag_A = PB.divide(bag_feats_A, ret_pseudo_bag=False)
print(f"[info] Bag A: it has {bag_feats_A.shape[0]} instances.")
print(f"[info] Bag A: its first pseudo-bag has {(label_pseudo_bag_A == 0).sum()} instances.")
print(f"[info] Bag A: its second pseudo-bag has {(label_pseudo_bag_A == NUM_PSEB - 1).sum()} instances.")

# Bag B:
# load WSI features
bag_feats_B = read_patch_data("./wsi_feats/feat_wsi_B_TCGA_BH_A2L8.pt", dtype='torch').to(torch.float)
label_B = torch.LongTensor([0])
# label_pseudo_bag: the indicator of pseudo-bags
label_pseudo_bag_B = PB.divide(bag_feats_B, ret_pseudo_bag=False)
print("\n")
print(f"[info] Bag B: it has {bag_feats_B.shape[0]} instances.")
print(f"[info] Bag B: its first pseudo-bag has {(label_pseudo_bag_B == 0).sum()} instances.")
print(f"[info] Bag B: its second pseudo-bag has {(label_pseudo_bag_B == NUM_PSEB - 1).sum()} instances.")

[setup] seed: 42
ProtoDiv-based pseudo-bag dividing: n = 30, l = 8.
[info] Bag A: it has 5584 instances.
[info] Bag A: its first pseudo-bag has 185 instances.
[info] Bag A: its second pseudo-bag has 187 instances.


[info] Bag B: it has 6544 instances.
[info] Bag B: its first pseudo-bag has 219 instances.
[info] Bag B: its second pseudo-bag has 219 instances.


### (1) Generating a random Mixup coefficient from Beta distribution

In [3]:
PARAM_ALPHA = 1.0 # the parameter of Beta distribution
NUM_ITER = 1
for i in range(NUM_ITER):
    lam = np.random.beta(PARAM_ALPHA, PARAM_ALPHA)
print(f"[info] current Mixup coefficient is {lam}")

[info] current Mixup coefficient is 0.5000386523859661


Here we map the Mixup coefficient to an integer that represents the number of pseudo-bags needed to sample from bag A or bag B.

Further discussions:
- This integer is uniformly distributed, one of `(0, 1,..., NUM_PSEB)`. 
- In fact, the probability of Pseudo-bag Mixup is `(n - 2) / n`, as there is no Mixup when the interger is `0` or `NUM_PSEB`.

In [4]:
lam_temp = lam if lam != 1.0 else lam - 1e-5
lam_int  = int(lam_temp * (NUM_PSEB + 1))
print(f"[info] current Mixup coefficient (integer) is {lam_int}")

[info] current Mixup coefficient (integer) is 15


### (2) Mixing pseudo-bags according to the generated Mixup coefficient

at first, we fetch the pseudo-bags from bag A and bag B:

In [5]:
def fetch_pseudo_bags(X, ind_X, n:int, n_parts:int):
    """
    X: bag features, usually with a shape of [N, d]
    ind_X: pseudo-bag indicator, usually with a shape of [N, ]
    n: pseudo-bag number, int
    n_parts: the pseudo-bag number to fetch, int
    """
    if len(X.shape) > 2:
        X = X.squeeze(0)
    assert n_parts <= n, 'the pseudo-bag number to fetch is invalid.'
    if n_parts == 0:
        return None

    ind_fetched = torch.randperm(n)[:n_parts]
    X_fetched = torch.cat([X[ind_X == ind] for ind in ind_fetched], dim=0)

    return X_fetched

In [6]:
bag_A = fetch_pseudo_bags(bag_feats_A, label_pseudo_bag_A, NUM_PSEB, lam_int)
print(f"[info] After fetching {lam_int} pseudo-bags from bag A, there are {bag_A.shape[0]} instances left.")
bag_B = fetch_pseudo_bags(bag_feats_B, label_pseudo_bag_B, NUM_PSEB, NUM_PSEB - lam_int)
print(f"[info] After fetching {NUM_PSEB - lam_int} pseudo-bags from bag B, there are {bag_B.shape[0]} instances left.")

[info] After fetching 15 pseudo-bags from bag A, there are 2796 instances left.
[info] After fetching 15 pseudo-bags from bag B, there are 3273 instances left.


Next, we directly mix the two masked bags by concatenating.

In fact, in our PseMix implementation we introduce a special `Random Mixup` mechanism, which is different from vanilla Mixup. A further discussions for this is as follows:
- It could enhance the diversity of training samples.
- It could make the model efficiently learn from vicinity samples (mixed bags), as stated in our PseMix paper.

In [7]:
PROB_MIXUP = 0.98 # the probability to mix the pseudo-bags from two bags 

if np.random.rand() <= PROB_MIXUP: # our Random Mixup mechanism
    new_bag = torch.cat([bag_A, bag_B], dim=0) # instance-axis concat
    mixup_ratio = lam_int / NUM_PSEB
else:
    new_bag = bag_A
    mixup_ratio = 1.0

print(f"[info] New bag: it has {new_bag.shape[0]} instances.")
print(f"[info] New bag: its Mixup ratio is {mixup_ratio}.")

[info] New bag: it has 6069 instances.
[info] New bag: its Mixup ratio is 0.5.


### (3) training MIL models using new mixed bags (pseudo-bag-augmented bags)

At this point, the new mixed bag and its label can be expressed as follows:
```python
Mixup_sample = new_bag # obtained above
Mixup_label = mixup_ratio * label_A + (1 - mixup_ratio) * label_B
```

This new sample (`Mixup_sample`) and its label (`Mixup_label`) can be utilized to supervise the model training. 

In implementation, actually, the `Mixup_label` is not directly used for training; instead, a weighted loss is often leveraged as follows:

Pseudo-code:
```python
# forward inference
pred = MIL_network(new_bag)

# predictive loss weighted by the `mixup_ratio`
clf_loss = mixup_ratio * BCE_loss(pred, label_A) + (1 - mixup_ratio) * BCE_loss(pred, label_B)

# backward gradients and update networks
clf_loss.backward()
optimizer.step()
```