Breast cancer stage prediction from pathological whole slide images with hierarchical image pyramid transformers.
Project developed under the "High Risk Breast Cancer Prediction Contest Phase 2" 
by Nightingale, Association for Health Learning & Inference (AHLI)
and Providence St. Joseph Health

Parts of code were took over and adapted from HIPT library.

https://github.com/mahmoodlab/HIPT/blob/master/HIPT_4K/hipt_4k.py

https://github.com/mahmoodlab/HIPT/blob/master/HIPT_4K/hipt_model_utils.py

Copyright (C) 2023 Zsolt Bedohazi, Andras Biricz, Istvan Csabai

In [None]:
import numpy as np
#from geojson import GeoJSON
import json
import os
import glob
#import shapely
#from rtree import index
#from shapely.ops import cascaded_union, unary_union
from collections import Counter
import matplotlib.pyplot as plt
import h5py
from tqdm import tqdm
from PIL import Image
import pandas as pd
import torch

import sys
sys.path.append('../HIPT_semicol/HIPT_4K/')
import vision_transformer4k as vits4k

### Locate data

In [None]:
source = '/home/ngsci/resnet50_embeddings_4096region_256times1024_level0_holdout/'
#source = '/home/ngsci/resnet50_embeddings_4096region_256times1024_level1_holdout/'

In [None]:
slide_fp = os.path.join(source, f'*.npy')
files = np.array( sorted( glob.glob(slide_fp) ) )
files.shape, files[:3]

In [None]:
def load_h5_file(filename):
    with h5py.File(filename, "r") as f:
        coords = f['coords'][()]
        imgs = f['features_4k'][()]
        return coords, imgs

#### HIPT model

In [None]:
def get_vit4k(pretrained_weights, arch='vit4k_xs', device=torch.device('cuda:0'), input_embed_dim=1024):
    """
    Builds ViT-4K Model.
    
    Args:
    - pretrained_weights (str): Path to ViT-4K Model Checkpoint.
    - arch (str): Which model architecture.
    - device (torch): Torch device to save model.
    
    Returns:
    - model256 (torch.nn): Initialized model.
    """
    
    checkpoint_key = 'teacher'
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
    model4k = vits4k.__dict__[arch](num_classes=0)
    for p in model4k.parameters():
        p.requires_grad = False
    model4k.eval()
    model4k.to(device)

    print('HERE', pretrained_weights, os.path.isfile(pretrained_weights))
    
    if os.path.isfile(pretrained_weights):
        print('ISFILE')
        state_dict = torch.load(pretrained_weights, map_location="cpu")
        if checkpoint_key is not None and checkpoint_key in state_dict:
            print(f"Take key {checkpoint_key} in provided checkpoint dict")
            state_dict = state_dict[checkpoint_key]
        # remove `module.` prefix
        state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
        # remove `backbone.` prefix induced by multicrop wrapper
        state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
        msg = model4k.load_state_dict(state_dict, strict=False)
        print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
        
    return model4k

In [None]:
class HIPT_4K(torch.nn.Module):
    """
    HIPT Model (ViT-4K) for encoding non-square images (with [256 x 256] patch tokens), with 
    [256 x 256] patch tokens encoded via ViT-256 using [16 x 16] patch tokens.
    """
    def __init__(self, 
        #model4k_path: str = '../Checkpoints/vit4k_xs_dino.pth', 
                 
        # stage 2 model trained locally without finetuning on platform
        #model4k_path: str = 'nightingale_checkpoint_ViT4096_on_resnet50_embeddings.pth',
                 
        # stage 2 model trained locally finetuned on platform
        model4k_path: str = '/home/ngsci/project/checkpoints_for_hipt_stage3_input_generator_resnet_level0/checkpoint_on_resnet_level0_nofinetune_from_local_bracs_dinoloss1.6_1.6.pth',         
        device4k = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')):

        super().__init__()
        self.model4k = get_vit4k(pretrained_weights=model4k_path).to(device4k)
        #self.model4k = get_vit4k(pretrained_weights='None').to(device4k)
        self.device4k = device4k

    def forward(self, x):
        """
        Forward pass of HIPT (given an image tensor x), outputting the [CLS] token from ViT-4K.
        1. x is center-cropped such that the W / H is divisible by the patch token size in ViT-4K (e.g. - 256 x 256).
        2. x then gets unfolded into a "batch" of [256 x 256] images.
        3. A pretrained ViT-256 model extracts the CLS token from each [256 x 256] image in the batch.
        4. These batch-of-features are then reshaped into a 2D feature grid (of width "w_256" and height "h_256".)
        5. This feature grid is then used as the input to ViT-4K, outputting [CLS]_4K.

        Args:
            - x (torch.Tensor): [1 x C x W' x H'] image tensor.

        Return:
            - features_cls4k (torch.Tensor): [1 x 192] cls token (d_4k = 192 by default).
        """
        features_resnet = torch.from_numpy(x) # B x 256 x 1024
        features_resnet = features_resnet.transpose(1,2) # B x 1024 x 256
        features_resnet = features_resnet.reshape(x.shape[0], 1024, 16, 16) # B, embed_dim, w, h
        #print( features_resnet.shape )
        features_resnet = features_resnet.to(self.device4k, non_blocking=True)  # 4. [B x 1024 x 16 x 16]
        features_cls4k = self.model4k.forward(features_resnet)                  # 5. [B x 192], where 192 == dim of ViT-4K [ClS] token.
        return features_cls4k

In [None]:
model = HIPT_4K()
model.eval()

In [None]:
import warnings

### numpy file

In [None]:
#destination = 'embeddings/vit_xs_embeddings_nofinetuned_resnet50_embeddings_4096region_256times1024_level0_holdout/'
destination = '/home/ngsci/vit_xs_embeddings_nofinetuned_resnet50_embeddings_4096region_256times1024_level0_holdout/'
os.makedirs(destination, exist_ok=True)
with warnings.catch_warnings(record=True):
    #preds_all = []
    for p in tqdm( range( 9000, 12500)):#files.shape[0] ) ):
        # skip already processed
        if not os.path.exists( destination+os.path.basename( files[p] ) ):
            emb_4k = np.load( files[p] ).astype(np.float32)
            #_, emb_4k = load_h5_file(files[p])
            #emb_4k = emb_4k.astype(np.float32)

            # skip empty files:
            if emb_4k.size == 0:
                print(f"Skipping empty file: {files[p]}")
                continue

            preds = model(emb_4k).cpu().numpy().astype(np.float16)
            np.save( destination+os.path.basename( files[p].replace('.h5','') ), preds )
        else:
            pass
            #preds_all.append(preds)

### hdf5 file

In [None]:
destination = 'vit_xs_embeddings_finetuned_on_bracs_on_top_of_resnet50_embeddings_4096region_256times1024_level1_finetuned_on_nightingale_level0_holdout/'
os.makedirs(destination, exist_ok=True)
with warnings.catch_warnings(record=True):
    #preds_all = []
    for p in tqdm( range(4500)):#files.shape[0] ) ):
        # skip already processed
        if not os.path.exists( destination+os.path.basename( files[p] ) ):
            #emb_4k = np.load( files[p] ).astype(np.float32)
            _, emb_4k = load_h5_file(files[p])
            emb_4k = emb_4k.astype(np.float32)

            # skip empty files:
            if emb_4k.size == 0:
                print(f"Skipping empty file: {files[p]}")
                continue

            preds = model(emb_4k).cpu().numpy().astype(np.float16)
            np.save( destination+os.path.basename( files[p].replace('.h5','') ), preds )
        else:
            pass
            #preds_all.append(preds)

In [None]:
preds.shape

In [None]:
_ = plt.hist( preds[10], bins=100 )

In [None]:
_ = plt.hist( preds[30], bins=100 )