This script loads an existing QuPath project,
reads all images with their annotations and/or detections
and classifies them using a Deep Learning model of your choice.
It outputs a new QuPath project with the same images and
classified annotations.
If an image does not contain annotations, the image will be skipped
and will not be added to the new QuPath project.

This code is based on code from Andrew J.
(https://github.com/choosehappy/QuPathGeoJSONImportExport)
but makes use of the paquo library (https://paquo.readthedocs.io)

Written by Sabina K. and Andrew J.

In [1]:
from __future__ import annotations
import os
import shutil
import tiffslide as openslide
from tqdm.notebook import tqdm
from math import ceil
import matplotlib.pyplot as plt
from shapely.geometry import shape
from shapely.strtree import STRtree
from shapely.geometry import Polygon
import torch
from torch import nn
from torchsummary import summary
import numpy as np
import cv2
import sys
from pathlib import Path
from paquo.projects import QuPathProject
from paquo.images import QuPathImageType
from paquo.classes import QuPathPathClass
from paquo.colors import QuPathColor

In [94]:
PROJECT_NAME = "project1"

In [95]:
PROJECT_FILE = f"{PROJECT_NAME}/project.qpproj"
PROJECT_PATH = f"{os.getcwd()}/{PROJECT_FILE}"
NEW_PROJECT_NAME = f"{PROJECT_NAME}_out"
NEW_PROJECT_FILE = f"{PROJECT_NAME}_out/project.qpproj"
NEW_PROJECT_PATH = f"{os.getcwd()}/{NEW_PROJECT_FILE}"

DEVICE = torch.device('cuda')
# DEVICE = torch.device('cpu')
OPENSLIDELEVEL = 0  # Level from openslide to read
TILESIZE = 10000  # Size of the tile to load from openslide
PATCHSIZE = 32  # Patch size needed by our DL model
MINHITS = 100  # The minimum number of objects needed to be present within a tile for the tile to be computed on
BATCHSIZE = 1024  # How many patches we want to send to the GPU at a single time
NUM_OF_CLASSES = 2  # Number of output classes our model is providing
CLASSNAMES = ["Other", "Lymphocyte"]  # The names of those classes which will appear in QuPath later on
CLASSCOLORS = [-377282, -9408287]  # Their associated color, see selection of different color values at the bottom of the file
MASK_PATCHES = False  # If we would like to black out the space around the object of interest, this is determined by how the model was trained
# MODEL_FNAME="lymph_model.pth"  # DL model to use


# ---- Load your model here
# model = LoadYourModelHere().to(DEVICE)
# checkpoint = torch.load(MODEL_FNAME, map_location=lambda storage, loc: storage)  # load checkpoint to CPU and then put to device https://discuss.pytorch.org/t/saving-and-loading-torch-models-on-2-machines-with-different-number-of-gpu-devices/6666
# model.load_state_dict(checkpoint["model_dict"])
# model.eval()
# summary(model, (3, 32, 32))

In [96]:
def read_qupath_annotations(image):
    """Read annotations from QuPath image, return list"""
    annotations = image.hierarchy.annotations  # annotations are accessible via the hierarchy
    print(f"Image {image.image_name} has {len(annotations)} annotations.")
    ann = [annotation.roi for annotation in annotations] if annotations else list()  # Return empty list if annotations are missing
    return ann


def read_qupath_detections(image):
    """Read detections from QuPath image, return list"""
    detections = image.hierarchy.detections  # detections are stored in a set like proxy object
    print(f"Image {image.image_name} has {len(detections)} detections.")  # detections don't have a path_class.name
    det = [detection.roi for detection in detections] if detections else list()
    return det


def add_qupath_classes(classnames: list, colors: list, qp):
    """Add custom classes and corresponding colors to QuPath project"""
    new_classes = []
    for class_name, class_color in zip(classnames, colors):
        new_classes.append(
            QuPathPathClass(name=class_name,
                            color=QuPathColor.from_java_rgba(class_color))
        )
    qp.path_classes = new_classes  # Setting QuPathProject.path_class always replaces all classes
    print("Adding project classes to new QuPath project:")
    for path_class in qp.path_classes:
        print(f"'{path_class.name}'")


def find_tile(tilesize: int, searchtree: STRtree, scalefactor: int, y, x):
    """Create a search polygon and find matches in the searchtree (STRtree)"""
    searchtile = Polygon([[x, y],
                          [x + tilesize * scalefactor, y],
                          [x + tilesize * scalefactor, y + tilesize * scalefactor],
                          [x, y + tilesize * scalefactor]])
    hits = searchtree.query(searchtile)
    return hits, searchtile


def get_tile(openslidelevel: int, tilesize: int, osh, paddingsize: int, y, x):
    """Load an image tile and put the RGB values in a np.array"""
    # Using tiffslide we can directly load as_array.
    tile = osh.read_region((x - paddingsize, y - paddingsize), openslidelevel,
                           (tilesize + 2 * paddingsize, tilesize + 2 * paddingsize),
                           as_array=True)[:, :, 0:3]  # Trim alpha
    return tile


def construct_mask(allshapes: list, scalefactor: int, paddingsize: int, int_coords, y, x, hits: list, tile):
    mask = np.zeros((tile.shape[0:2]), dtype=tile.dtype)
    exteriors = [int_coords(allshapes[hit.id].boundary.coords) for hit in hits]
    exteriors_shifted = [(ext - np.asarray([(x - paddingsize), (y - paddingsize)]))
                         // scalefactor for ext in exteriors]
    cv2.fillPoly(mask, exteriors_shifted, 1)
    return mask


def get_maskpatch(patchsize, mask, c, r, patch):
    maskpatch = mask[r - patchsize // 2:r + patchsize // 2,
                c - patchsize // 2:c + patchsize // 2]
    patch = np.multiply(patch, maskpatch[:, :, None])
    return patch


def divide_batch(arr, size: int):
    for i in range(0, arr.shape[0], size):
        yield arr[i:i + size, ::]


def process_batch(arr_out, hits: list):
    classids = []
    for batch_arr in tqdm(divide_batch(arr_out, BATCHSIZE), leave=False):
        # batch_arr_gpu = torch.from_numpy(batch_arr.transpose(
        #     0, 3, 1, 2)).type('torch.FloatTensor').to(DEVICE) / 255
        # Get results
        # classids.append(torch.argmax(model.img2class(batch_arr_gpu), dim=1).detach().cpu().numpy())
        classids.append(np.random.choice([0, 1], batch_arr.shape[0]))

    classids = np.hstack(classids)

    for hit, classid in zip(hits, classids):
        hit.class_id = classid


def add_annotations(qpout, entry, ann: list, allshapes: list):
    for classified_shape in allshapes:
        annotation = entry.hierarchy.add_annotation(roi=classified_shape,
                                                    path_class=qpout.path_classes[classified_shape.class_id]
                                                    if hasattr(classified_shape, "class_id")
                                                    else None)
        annotation.name = str(classified_shape.geom_type)  # We can also add a name to the annotations

    if ann:  # Add the annotations to the new project as they were
        for annotation_shape in ann:
            entry.hierarchy.add_annotation(roi=annotation_shape)


In [97]:
def project_cleanup():
    qpout = QuPathProject(NEW_PROJECT_PATH)
    for image in qpout.images:
        if not image.hierarchy.annotations:
            print(f"Removing new QuPath project '{qpout.name}'. No need to keep project without annotations.")
            shutil.rmtree(NEW_PROJECT_NAME)
        else:
            print(f"Done. Please look at QuPath project '{qpout.name}' in QuPath.")

In [98]:
def main():
    with QuPathProject(NEW_PROJECT_PATH, mode='a') as qpout:
        print(f"Created new QuPath project: '{qpout.name}'.")
        add_qupath_classes(CLASSNAMES, CLASSCOLORS, qpout)

        qp = QuPathProject(PROJECT_PATH, mode='r')
        print(f"Opened project ‘{qp.name}’ ")
        print(f"Project has {len(qp.images)} image(s).")
        try:
            image = qp.images[0]

            ann = read_qupath_annotations(image)  # We keep the annotations, but we don't classify them
            det = read_qupath_detections(image)

            if det:
                allshapes = det  # We only want to classify the detections

                searchtree = STRtree(allshapes)
                wsi_fname = image.uri.split(":")[1]
                entry = qpout.add_image(wsi_fname, image_type=QuPathImageType.BRIGHTFIELD_H_E,
                                        allow_duplicates=True)
                osh = openslide.OpenSlide(wsi_fname)
                nrow, ncol = osh.level_dimensions[0]
                nrow = ceil(nrow / TILESIZE)
                ncol = ceil(ncol / TILESIZE)
                scalefactor = int(osh.level_downsamples[OPENSLIDELEVEL])
                paddingsize = PATCHSIZE // 2 * scalefactor

                int_coords = lambda x: np.array(x).round().astype(np.int32)

                # Now lets start finding interesting tiles to operate on
                for y in tqdm(range(0, osh.level_dimensions[0][1], round(TILESIZE * scalefactor)), desc="outer",
                              leave=False):
                    for x in tqdm(range(0, osh.level_dimensions[0][0], round(TILESIZE * scalefactor)),
                                  desc=f"inner {y}", leave=False):
                        hits, searchtile = find_tile(TILESIZE, searchtree, scalefactor, y, x)
                        hits = [hit for hit in hits if hit.centroid.intersects(searchtile)]  # filter by centroid

                        if len(hits) < MINHITS:
                            continue

                        tile = get_tile(OPENSLIDELEVEL, TILESIZE, osh, paddingsize, y, x)

                        if MASK_PATCHES:
                            mask = construct_mask(allshapes, scalefactor, paddingsize,
                                                  int_coords, y, x, hits, tile)

                        arr_out = np.zeros((len(hits), PATCHSIZE, PATCHSIZE, 3))
                        # Get patches from hits within this tile and stick them (and their ids) into matrices
                        for hit, arr in zip(hits, arr_out):
                            px, py = hit.centroid.coords[:][0]  # Faster than hit.x and hit.y, likely because of call stack overhead
                            c = int((px - x + paddingsize) // scalefactor)
                            r = int((py - y + paddingsize) // scalefactor)
                            patch = tile[r - PATCHSIZE // 2:r + PATCHSIZE // 2,
                                         c - PATCHSIZE // 2:c + PATCHSIZE // 2, :]

                            if MASK_PATCHES:
                                patch = get_maskpatch(PATCHSIZE, mask, c, r, patch)

                            arr[:] = patch

                        # Process batch
                        process_batch(arr_out, hits)

                add_annotations(qpout, entry, ann, allshapes)

            else:
                print("No detections in this image.")

            print(f"Done. Please look at QuPath project '{qpout.name}' in QuPath.")
            # project_cleanup()

        except IndexError:
            print("No images to process.")

In [None]:
if __name__ == '__main__':
    main()

In [100]:
#         "name": "Positive",  # add colors
#         "colorRGB": -377282
#         "name": "Other",
#         "colorRGB": -14336
#         "name": "Stroma",
#         "colorRGB": -6895466
#         "name": "Necrosis",
#         "colorRGB": -13487566
#         "name": "Tumor",
#         "colorRGB": -3670016
#         "name": "Immune cells",
#         "colorRGB": -6268256
#         "name": "Negative",
#         "colorRGB": -9408287

# This code to perform entire polygon with complex objects
# exteriors = [int_coords(geo.coords) for hit in hits for geo in hit.boundary.geoms]  # Need this modification for complex structures
# This code to perform by center with complex objects
# exteriors = [int_coords(geo.coords) for hit in hits for geo in allshapes[hit.id].boundary.geoms]