# Inference of a single pathology patch using LiteFM

In [1]:
from models import get_model
import torch

In [2]:
# LiteFM in the LitePath can be used as a patch-level feature extractor.
device = torch.device('cuda')
x = torch.randn(1, 3, 224, 224).to(device)
model = get_model('LiteFM', device)
feat = model(x)
print(feat.shape)



torch.Size([1, 1024])


# Inference of LitePath pipeline

In [None]:
from litepath_deploy import ModelDeployment
from models import DAttention, AdaPatchSelector, get_custom_transformer
import torch
import os
import pandas as pd
from preprocessing.create_patches_fp import mp_seg_and_patch, seg_and_patch
from datasets import PatchDataset

from preprocessing.extract_images_and_pack2h5 import read_images, get_wsi_path, read_images_parallel
from torch.utils.data import DataLoader, SubsetRandomSampler, Subset
from multiprocessing.pool import Pool
import json

torch.cuda.set_device(1)

## Data Preprocessing

### Get Coords

In [4]:
task = 'Lung_NSCLC_Subtyping'
wsi_format = 'svs'

wsi_dir = f"examples/{task}/slides/"
save_dir = f"examples/{task}/patches"
preset = 'tcga.csv'
use_mp = False
patch_save_dir = os.path.join(save_dir, "coords_h5")
mask_save_dir = os.path.join(save_dir, "masks")
stitch_save_dir = os.path.join(save_dir, "stitches")

seg = True
patch = True
stitch = False
patch_level = 0
patch_size = 256
step_size = 256
auto_skip = True

process_list = None

print("source: ", wsi_dir)
print("patch_save_dir: ", patch_save_dir)
print("mask_save_dir: ", mask_save_dir)
print("stitch_save_dir: ", stitch_save_dir)

directories = {"source": wsi_dir, "save_dir": save_dir, "patch_save_dir": patch_save_dir, "mask_save_dir": mask_save_dir, "stitch_save_dir": stitch_save_dir}

for key, val in directories.items():
    print("{} : {}".format(key, val))
    if key not in ["source"]:
        os.makedirs(val, exist_ok=True)

seg_params = {"seg_level": -1, "sthresh": 8, "mthresh": 7, "close": 4, "use_otsu": False, "keep_ids": "none", "exclude_ids": "none"}
filter_params = {"a_t": 100, "a_h": 16, "max_n_holes": 32}
vis_params = {"vis_level": -1, "line_thickness": 120}
patch_params = {"use_padding": True, "contour_fn": "four_pt"}

if preset:
    preset_df = pd.read_csv(os.path.join("presets", preset))
    for key in seg_params.keys():
        seg_params[key] = preset_df.loc[0, key]

    for key in filter_params.keys():
        filter_params[key] = preset_df.loc[0, key]

    for key in vis_params.keys():
        vis_params[key] = preset_df.loc[0, key]

    for key in patch_params.keys():
        patch_params[key] = preset_df.loc[0, key]

parameters = {"seg_params": seg_params, "filter_params": filter_params, "patch_params": patch_params, "vis_params": vis_params}

print(parameters)
print(directories)
if use_mp:
    fn = mp_seg_and_patch
else:
    fn = seg_and_patch
seg_times, patch_times = fn(**directories, **parameters, patch_size=patch_size, step_size=step_size, seg=seg, use_default_params=False, save_mask=True, 
                            stitch=stitch, patch_level=patch_level, patch=patch, process_list=process_list, 
                            auto_skip=auto_skip, wsi_format=wsi_format)


source:  examples/Lung_NSCLC_Subtyping/slides/
patch_save_dir:  examples/Lung_NSCLC_Subtyping/patches/coords_h5
mask_save_dir:  examples/Lung_NSCLC_Subtyping/patches/masks
stitch_save_dir:  examples/Lung_NSCLC_Subtyping/patches/stitches
source : examples/Lung_NSCLC_Subtyping/slides/
save_dir : examples/Lung_NSCLC_Subtyping/patches
patch_save_dir : examples/Lung_NSCLC_Subtyping/patches/coords_h5
mask_save_dir : examples/Lung_NSCLC_Subtyping/patches/masks
stitch_save_dir : examples/Lung_NSCLC_Subtyping/patches/stitches
{'seg_params': {'seg_level': np.int64(-1), 'sthresh': np.int64(8), 'mthresh': np.int64(7), 'close': np.int64(4), 'use_otsu': np.True_, 'keep_ids': 'none', 'exclude_ids': 'none'}, 'filter_params': {'a_t': np.int64(16), 'a_h': np.int64(4), 'max_n_holes': np.int64(8)}, 'patch_params': {'use_padding': np.True_, 'contour_fn': 'four_pt'}, 'vis_params': {'vis_level': np.int64(-1), 'line_thickness': np.int64(100)}}
{'source': 'examples/Lung_NSCLC_Subtyping/slides/', 'save_dir': 'e

processing examples/Lung_NSCLC_Subtyping/slides/1028189.svs
####################################################################################################
levels: ((77688, 46977), (19422, 11744), (4855, 2936), (2427, 1468))
mpp: 0.503, object_power: 20x, patch_size: 256, step_size: 256
####################################################################################################
Performing segmentation
Creating patches for:  1028189 ...
Total number of contours to process:  3
Bounding Box: 66324 18016 10852 13697
Contour Area: 62042336.0
Extracted 997 coordinates
Bounding Box: 57713 16288 9668 14657
Contour Area: 85740368.0
Extracted 1382 coordinates
Bounding Box: 16965 10144 17734 25089
Contour Area: 268604208.0
Extracted 4197 coordinates
segmentation took 0.399716854095459 seconds
patching took 0.43612051010131836 seconds
stitching took -1 seconds


progress: 0.50, 1/2
processing examples/Lung_NSCLC_Subtyping/slides/1019708.svs
############################################

### Crop images and pack to h5

In [None]:
wsi_format = 'svs'
h5_root = f"examples/{task}/patches/coords_h5"
save_root = f"examples/{task}/packed_images"
wsi_root = f"examples/{task}/slides"
os.makedirs(save_root, exist_ok=True)
cpu_cores = 6

h5_files = os.listdir(h5_root)
h5_paths = [os.path.join(h5_root, p) for p in h5_files]
wsi_paths = get_wsi_path(wsi_root, h5_files, wsi_format)
save_roots = [os.path.join(save_root, i) for i in h5_files]

# args = [(h5, sr, wsi_path) for h5, wsi_path, sr in zip(h5_paths, wsi_paths, save_roots)]
# mp = Pool(cpu_cores)
# mp.map(read_images, args)
# print('All slides have been cropped!')

for h5, wsi_path, sr in zip(h5_paths, wsi_paths, save_roots):
    read_images_parallel((h5, sr, wsi_path), num_workers=cpu_cores)

Processing:Processing:  examples/Lung_NSCLC_Subtyping/patches/coords_h5/1019708.h5examples/Lung_NSCLC_Subtyping/patches/coords_h5/1028189.h5  examples/Lung_NSCLC_Subtyping/slides/1019708.svsexamples/Lung_NSCLC_Subtyping/slides/1028189.svs

examples/Lung_NSCLC_Subtyping/slides/1028189.svs finished!
examples/Lung_NSCLC_Subtyping/slides/1019708.svs finished!
All slides have been cropped!


### Data loader preparation

In [2]:
task = 'Lung_NSCLC_Subtyping'
labels = ['LUSC', 'LUAD']
with open('examples/selection.json', 'r') as f:
    selection = json.load(f)

selection_number = selection[task]
k_u, k_a = selection_number

In [3]:
img_id = "1019708"
# img_id = "1028189"
img_root = f"examples/{task}/packed_images"
slide_path = f"{img_root}/{img_id}.h5"

transform = get_custom_transformer('litefm')
case_dataset = PatchDataset(slide_path, transform=transform, load_to_memory=False)  
# load_to_memory: preload all patches to memory to accelerate the inference. Need more memory.



In [4]:
num_patches = len(case_dataset)
uniform_indices = torch.linspace(0, num_patches-1, steps=k_u).int()
# 除去uniform_indices之外的索引
all_indices = torch.arange(num_patches)
mask = ~torch.isin(all_indices, uniform_indices)  # 创建布尔掩码，非 uniform_indices 的索引为 True
remaining_indices = all_indices[mask]  # 选取布尔掩码中为 True 的索引

uniform_loader = torch.utils.data.DataLoader(Subset(case_dataset, uniform_indices), batch_size=256, shuffle=False, num_workers=32) if k_u > 0 else None
attention_loader = torch.utils.data.DataLoader(Subset(case_dataset, remaining_indices), batch_size=256, shuffle=False, num_workers=32) if k_a > 0 else None
# Note: Suitable num_workers are important for the inference speed. We set it to 32 for the example in RTX 3090.

## Inference

In [5]:
aps_ckpt = f"examples/{task}/models/aps_model_best.pth.tar"
mil_ckpt = f"examples/{task}/models/model_best.pth.tar"
deployer = ModelDeployment('LiteFM', n_classes=2, k_a=k_a, k_u=k_u, aps_ckpt=aps_ckpt, mil_ckpt=mil_ckpt)

Loaded APS model from examples/Lung_NSCLC_Subtyping/models/aps_model_best.pth.tar
Loaded MIL model from examples/Lung_NSCLC_Subtyping/models/model_best.pth.tar
batch_size: 256, k_a: 100, k_u: 1900, Buffer threshold: 2500


In [6]:
logits, prob, pred = deployer.infer_litepath(uniform_loader, attention_loader)
print(f"Prediction: {labels[pred]}, Probability: {prob[0][pred[0]]}")

Uniform features shape: torch.Size([1900, 1024]), time cost: 2.721139907836914
Attention features shape: torch.Size([100, 1024]), time cost: 6.701376914978027
MIL model inference time cost: 0.04475545883178711
Prediction: LUSC, Probability: 0.9450033903121948


In [7]:
num_patches, k_u, k_a

(26868, 1900, 100)