#  Medical Whole Slide Processing: Speed Up!

The first challenge with **medical whole slide images** is probably their huge size and pre-processing required to get them prepared for model training and further analysis. Major concerns are:
* Generating Tiles
* Thresholding (discard tiles with no structural data, e.g. mostly background tiles)
* Resource management (RAM and Storage)
* Time Management (Heavy Computational Processes)

Here I share my experience with some tricky slides from [this kaggle competition](https://www.kaggle.com/competitions/mayo-clinic-strip-ai) and will discuss some solutions.

To view the code in detail you will need to open the notebook in kaggle or view it from google colab.

In [None]:
import warnings
warnings.filterwarnings('ignore')


# install fast kaggle
try: import fastkaggle
except ModuleNotFoundError:
    !pip install -Uqq fastkaggle
    
from fastkaggle import *

In [None]:
# install pyvips
!conda install -y -qq --channel conda-forge pyvips > quiet.txt

In [None]:
# set up the competition
comp = 'mayo-clinic-strip-ai'
path = setup_comp(comp, install='"fastcore>=1.4.5" "fastai>=2.7.1" "timm>=0.6.2.dev0"')
path.ls()

In [None]:
from fastai.vision.all import *
from fastcore.parallel import *
import pyvips

In [None]:
trn_path = path/'train'
trn_slides = get_image_files(trn_path)

tst_path = path/'test'
tst_slides = get_image_files(tst_path)

## What do these slides look like anyway?

Let’s get on to it and see what we are dealing with.

In [None]:
slides = tst_slides[3], tst_slides[1]
fig = plt.figure(figsize=(12,8))

ax = fig.add_subplot(1,2,1)
wst = pyvips.Image.thumbnail(slides[0], 400)
wsi = pyvips.Image.new_from_file(slides[0])
ax.set_title(f'{wsi.width} x {wsi.height} pixel')
implot = plt.imshow(wst)

ax = fig.add_subplot(1,2,2)
wst = pyvips.Image.thumbnail(slides[1], 400)
wsi = pyvips.Image.new_from_file(slides[1])
ax.set_title(f'{wsi.width} x {wsi.height} pixel')
implot = plt.imshow(wst)

Well, if nothing they are huge in size (pixel) and colour tone could be very different for each slide apparently because of the microscope, lab lighting, and some other factors which is not my concern here.

## What is a Tile about?

To make these whole slides manageable, we need to break them into smaller pieces. Well it would be worth of a discussion to ask:
* What tile size is OK?
* How many of them? Is sampling OK?
* What should each tile be labelled?

Very important questions I suppose, but not for this blog post.

Let’s have a closer look at some of those tiles from one slide.

In [None]:
slide = tst_slides[1]
wsi = pyvips.Image.new_from_file(slide)
tiles = [wsi.crop(1500, 3000, 2000, 2000),
         wsi.crop(3500, 3000, 2000, 2000),
         wsi.crop(5000, 3000, 2000, 2000)]
_,axs = subplots(1, 3)
for i, ax in enumerate(axs):
    show_image(tiles[i], ctx=ax, title=f'tile-0{i+1}')

Obviously no one wants to waste time and compute resources on background tiles like **tile-03** which do not contain any structural data.

What about **tile-02**? Is it important to keep that one?

What if you have hundreds of tiles like **tile-01** from one slide? Is it still worth it to keep **tile-02**?

I would say there is no definite answer unless through experiments!

For now let’s see how thresholding is going to help.

## Thresholding: When Things Do Not Work as Expected!

Thresholding  is basically separating foreground from background. So it is used to discard the tiles or part of the image that does not contain any useful data. [skimage](https://scikit-image.org/docs/stable/auto_examples/applications/plot_thresholding.html) is probably  a good place to dive into more details.

Anyway, there are different algorithms for thresholding and in practice it becomes to some extent clear why so many! Because they do not perform consistently depending on the image they are processing.

Thresholding is usually performed at grayscale (single band) level. One fun fact to know is that medical images do not necessarily look grey at grayscale :)

In [None]:
_,axs = subplots(1, 3)
for i, ax in enumerate(axs):
    show_image(tiles[i].colourspace('b-w'), ctx=ax, title=f'tile-0{i+1}-grey')

Let’s start with a fairly simple method known as [threshold minimum](https://scikit-image.org/docs/stable/api/skimage.filters.html#skimage.filters.threshold_minimum).

In [None]:
from skimage.filters import threshold_minimum

_,axs = subplots(1, 3)
for i, ax in enumerate(axs):
    tile = tiles[i].colourspace('b-w')
    tile = tile > threshold_minimum(tile.numpy())
    show_image( tile, ctx=ax, title=f'tile-0{i+1} ({tile.avg()/255:.0%})')

Well, the first thing to point out is that I am apparently measuring the background rather than foreground. By definition, `image > threshold` should grab the foreground. Which does not seem to be the case here.

Apart from that, the results seem promising! Don't they!?


Let’s grab another set of tiles from a different slide and see how this method works out!

These are the new tiles:

In [None]:
slide = tst_slides[3]
wsi = pyvips.Image.new_from_file(slide)
tiles = [wsi.crop(4000, 14000, 2000, 2000),
         wsi.crop(10000, 12000, 2000, 2000),
         wsi.crop(10000, 16000, 2000, 2000)]

_,axs = subplots(1, 3)
for i, ax in enumerate(axs):
    show_image(tiles[i], ctx=ax, title=f'tile-0{i+4}')

And these are the thresholding results:

In [None]:
_,axs = subplots(1, 3)
for i, ax in enumerate(axs):
    tile = tiles[i].colourspace('b-w')
    tile = tile > threshold_minimum(tile.numpy())
    show_image( tile, ctx=ax, title=f'tile-0{i+4} ({tile.avg()/255:.0%})')

I expected both **tile-05** and **tile-06** would be discarded as garbage, but the method begs to differ. Now what? 

## There Should be Better Methods to Do This!?

There are various thresholding methods implemented by `skimage` that you can give a try. I am going to try one or two more methods here. Let's try a well known method called [**otsu thresholding**](https://scikit-image.org/docs/stable/api/skimage.filters.html#skimage.filters.threshold_otsu).

Let’s see how `otsu` will perform on the latest tiles.

In [None]:
from skimage.filters import threshold_otsu

_,axs = subplots(1, 3)
for i, ax in enumerate(axs):
    tile = tiles[i].colourspace('b-w')
    tile = tile > threshold_otsu(tile.numpy())
    show_image( tile, ctx=ax, title=f'tile-0{i+4} (otsu: {tile.avg()/255:.0%})')

Does not sound good! Does it!? Gorgeous output though! :)

There is another method I like to try which is called [yen thresholding](https://scikit-image.org/docs/stable/api/skimage.filters.html#skimage.filters.threshold_yen) method.

The output looks like below:

In [None]:
from skimage.filters import threshold_yen

_,axs = subplots(1, 3)
for i, ax in enumerate(axs):
    tile = tiles[i].colourspace('b-w')
    tile = tile > threshold_yen(tile.numpy())
    show_image( tile, ctx=ax, title=f'tile-0{i+4} (yen: {tile.avg()/255:.0%})')

On one hand it has successfully recognised both background tiles, but on the other hand it tends to evaluate low the **tile-04** which is mostly of valuable data. I could set the discard margin like 20% and live with it. But is this a consistent result really? You have to keep trying!

Let's give it another try on a third slide.

In [None]:
from skimage.filters import threshold_yen

slide = trn_slides[157]
wsi = pyvips.Image.new_from_file(slide)
tiles = [wsi.crop(14000, 20000, 2000, 2000),
         wsi.crop(14000, 25000, 2000, 2000),
         wsi.crop(14000, 27000, 2000, 2000),
         wsi.crop(13000, 28000, 2000, 2000)]

_,axs = subplots(1, 4)
for i, ax in enumerate(axs):
    show_image(tiles[i], ctx=ax, title=f'tile-0{i+7}')

Funny right! But they are honestly all from a single slide! :)

Let check out the `yen` method:


In [None]:
slide = trn_slides[157]
wsi = pyvips.Image.new_from_file(slide)
tiles = [wsi.crop(14000, 20000, 2000, 2000),
         wsi.crop(14000, 25000, 2000, 2000),
         wsi.crop(14000, 27000, 2000, 2000),
         wsi.crop(13000, 28000, 2000, 2000)]

_,axs = subplots(1, 4)
for i, ax in enumerate(axs):
    tile = tiles[i].colourspace('b-w')
    tile = tile > threshold_yen(tile.numpy())
    show_image( tile, ctx=ax, title=f'tile-0{i+7} (yen: {tile.avg()/255:.0%})')

Mission accomplished! The `yen thresholding method` is broken successfully!

But it is not the end of the world yet.

## Morphology: Robust but Expensive!!

There is a totally separate chapter on the image processing as called **Morphology**. It is a collection of `non-linear` operations to manipulate an image. I suppose [skimage documentation](https://scikit-image.org/docs/stable/auto_examples/applications/plot_morphology.html) would be a good start.

I did a bit of research and apparently a very common pipeline is dilation or erosion over an image with edges marked/improved/whatever!

I will implement a `canny()` edge detector plus a `binary_dilation()` operation with a `mask` of size `11`.

In [None]:
from skimage.feature import canny
from skimage.morphology import binary_dilation, disk

def skimage_morph(tile):
    tile = tile.colourspace('b-w').numpy()
    tile = canny(tile)
    tile = binary_dilation(tile, disk(6))
    return tile

Let's see how it will perform on all above tiles!

In [None]:
slide = tst_slides[1]
wsi = pyvips.Image.new_from_file(slide)
tiles = [wsi.crop(1500, 3000, 2000, 2000),
         wsi.crop(3500, 3000, 2000, 2000),
         wsi.crop(5000, 3000, 2000, 2000)]
_,axs = subplots(1, 3)
for i, ax in enumerate(axs):
    tile = skimage_morph(tiles[i])
    show_image(tile, ctx=ax, title=f'tile-0{i+1} ({tile.mean():.0%})')

In [None]:
slide = tst_slides[3]
wsi = pyvips.Image.new_from_file(slide)
tiles = [wsi.crop(4000, 14000, 2000, 2000),
         wsi.crop(10000, 12000, 2000, 2000),
         wsi.crop(10000, 16000, 2000, 2000)]
_,axs = subplots(1, 3)
for i, ax in enumerate(axs):
    tile = skimage_morph(tiles[i])
    show_image(tile, ctx=ax, title=f'tile-0{i+4} ({tile.mean():.0%})')

In [None]:
slide = trn_slides[157]
wsi = pyvips.Image.new_from_file(slide)
tiles = [wsi.crop(14000, 20000, 2000, 2000),
         wsi.crop(14000, 25000, 2000, 2000),
         wsi.crop(14000, 27000, 2000, 2000),
         wsi.crop(13000, 28000, 2000, 2000)]
_,axs = subplots(1, 4)
for i, ax in enumerate(axs):
    tile = skimage_morph(tiles[i])
    show_image(tile, ctx=ax, title=f'tile-0{i+7} ({tile.mean():.0%})')

Voila! Can not be better! Right?

Well not so fast!

Let’s put it in a more practical test and see how it will perform in terms of both the number of tiles generated and required time!

## “skimage” Time Based Test

I am going to run the pipeline on the slide below which is quite massive and messy! Kudos to the lab technician! :)

In [None]:
slide = trn_slides[157]
wst = pyvips.Image.thumbnail(slide, 600)
wsi = pyvips.Image.new_from_file(slide)
title = f'{wsi.width} x {wsi.height} pixel'
show_image(wst, title=title, figsize=(8,10));

I am running the test on a **Kaggle** notebook instance. At the time of writing, it is powered up by a Linux machine with 4 `CPU core`s (allowing two threads per core) and  16GB of `RAM`. I also run the pipeline through a parallel process of `ThreadPoolExcuter` to utilise the resources as much as possible.

What I am doing here is basically cropping the slide in 2k x 2k pixels, downsizing to scale of 0.25 if contains like 40% blood clot, then save the tile.

**One more thing!**

During my research I found out that apparently with medical slides, the green band (of `RGB ` spectrum image) contains the most structural (/useful) data. So I will use the green layer here instead of converting it to `grayscale`

In [None]:
slide = trn_slides[176]
wsi = pyvips.Image.new_from_file(slide)
tile = wsi.crop(2000, 1000, 2000, 2000)
tiles = [tile, tile.colourspace('b-w'), tile[1]]
titles = ['Original', 'Greyscale', 'G band of RGB']
_,axs = subplots(1, 3)
for i, ax in enumerate(axs):
    show_image(tiles[i], ctx=ax, title=titles[i])

In [None]:
!rm -rf /tmp/tiles/

tpath = Path('/tmp/tiles/')
folders = ['skimage', 'pyvips']

for f in folders:
    if not (tpath/f).exists():
        (tpath/f).mkdir(exist_ok=True, parents=True)

def skimage_pct(tile, thresh=0.4):
    tile = tile[1].numpy()
    tile = canny(tile)
    tile = binary_dilation(tile, disk(6))
    return tile.mean() > thresh
    
mask = pyvips.Image.mask_ideal(11, 11, 1, optical=True, reject=True)
mask = (mask * 128 + 128).cast("uchar").copy_memory()

def pyvips_pct(tile, thresh=0.4, mask=mask):
    tile = tile[1].canny(precision='integer').dilate(mask)
    return tile.avg()/255 > thresh

def get_tiles(slide, t_size=2_000):
    slide = pyvips.Image.new_from_file(slide)
    return (
        slide.crop(x, y, min(t_size, slide.width - x), min(t_size, slide.height - y)
                   ).gravity('south-east', t_size, t_size, extend='repeat')
        for y in range(0, slide.height, t_size)
        for x in range(0, slide.width, t_size))


def save_jpeg(tiles, tname, clot_pct, folder, tcount=1000):
    for tile in tiles:
        if clot_pct(tile):
            tile = tile.resize(0.25, kernel='linear')
            fname = f'{tname}_{tcount}.jpg'
            tile.write_to_file(tpath/folder/fname)
            tcount += 1
            
def make_jpeg_skimage(slide, tname='skimage', clot_pct=skimage_pct, folder='skimage'):
    tiles = get_tiles(slide)
    save_jpeg(tiles=tiles, tname=tname, clot_pct=clot_pct, folder=folder)
    
def make_jpeg_pyvips(slide, tname='pyvips', clot_pct=pyvips_pct, folder='pyvips'):
    tiles = get_tiles(slide)
    save_jpeg(tiles=tiles, tname=tname, clot_pct=clot_pct, folder=folder)
    
def walltime(s, e):
    t = e - s
    m = int(t//60)
    s = t - (m * 60)
    if m==0: return f"{s:.0f}s"
    else: return f"{m}min {s:.0f}s"
    
elapsed=list()
t_num=list()

In [None]:
start = time.time()
parallel(make_jpeg_skimage, trn_slides[157:158], n_workers=8, progress=False, threadpool=True)
end = time.time()
elapsed.append(walltime(s=start, e=end))
f = get_image_files(tpath/'skimage')
t_num.append(len(f))

OK, let’s see how the performance is!

In [None]:
df = pd.DataFrame({
    'library': ['skimage'],
    'time elapsed': elapsed,
    'tiles generated': t_num
})
df

To my expectation, that is terrible!

Considering a collection of 1,000 slides and 10 mins average on each slide (being optimistic), we are talking about days of pre-processing. With the trial and experimental nature of deep learning, this does not sound practical at all.


The tiles generated look good though!

In [None]:
def label_func(f): return f.stem

dbl = DataBlock(
    blocks = (ImageBlock, CategoryBlock),
    get_items = get_image_files,
    get_y = label_func)


dls = dbl.dataloaders(tpath/'skimage', bs=64)

dls.show_batch(max_n=12)

## “pyvips” is Out There to Astonish  You!

Before I run the test for [pyvips](https://libvips.github.io/pyvips/vimage.html), I must make it clear that even with [skimage](https://scikit-image.org/docs/stable/), the backbone processes like loading those massive slides in memory, cropping, and grabbing tiles all have been handled  by `pyvips`.

In [None]:
start = time.time()
parallel(make_jpeg_pyvips, trn_slides[157:158], n_workers=8, progress=False, threadpool=True)
end = time.time()
elapsed.append(walltime(s=start, e=end))
f = get_image_files(tpath/'skimage')
t_num.append(len(f))

Having said that, let’s see how `pyvips` performs on same condition:

In [None]:
df = pd.DataFrame({
    'library': folders,
    'time elapsed': elapsed,
    'tiles generated': t_num
})
df

That is something! Isn’t it? That is about five times quicker. 

`pyvips` has generated the same number of tiles as well. They are for sure up to expectation:


In [None]:
dbl = DataBlock(
    blocks = (ImageBlock, CategoryBlock),
    get_items = get_image_files,
    get_y = label_func)

dls = dbl.dataloaders(tpath/'pyvips', bs=64)

dls.show_batch(max_n=12)

## Satisfied!?

Honestly, not quite there yet!

Although it has amazingly improved thanks to `pyvips`, it will still take about a day or so to run it through all the slides. Which I do not feel comfortable to call it practical.

What we can do better!?

Well, one solution could be using both `histogram` based operations (e.g. `otsu`) and morphology back to back. You may lose some useful data, but it could shorten the process to the scale of *a couple of hours*.

**I would say it is a trade-off worth trying!**
