In [7]:
# https://www.kaggle.com/code/florisvanwettum/panda-submission
# To dos: 
# 1. Tile images (256x256)
# 2. Filter images (i.e. remove ones with mostly white space)
# 3. Mask - red/green for cancer/non cancer

In [1]:
import os
import sys
# import glob
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings('ignore')

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import pytorch_lightning as pl
# import torch.nn.functional as F

# The path can also be read from a config file, etc.
OPENSLIDE_PATH = r'C:\Users\johng\Downloads\openslide-win64-20231011\bin'

import os
if hasattr(os, 'add_dll_directory'):
    # Windows
    with os.add_dll_directory(OPENSLIDE_PATH):
        import openslide
else:
    import openslide
    
from tqdm.notebook import tqdm
import zipfile
import timm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [17]:
# Create an empty submission.csv for Kaggle to recognise
with open('submission.csv', 'w') as submis:
    pass

In [18]:
train = pd.read_csv('train.csv')
test = pd.read_csv('test.csv')
sample = pd.read_csv('sample_submission.csv')

In [19]:
class Tiler():
    def __init__(self, N_tiles= 36, tile_size = 2**8, level=1):
        self.N_tiles = N_tiles
        self.tile_size = tile_size
        self.level = level
        
    # This function takes an openslide object and returns the top left coordinates of N tiles (of a given size) with the most tissue pixels. 
    # Note: slide.level_dimensions[level] = (width,height).
    # Note: padding is done to the right and bottom, this is to keep it simple while having at most 1 tile in memory at a time.
    def _get_tile_locations_from_slide(self, slide):
        tiles = []
        required_padding = False
        xlocs, ylocs = np.arange(0, slide.level_dimensions[self.level][0], self.tile_size), np.arange(0, slide.level_dimensions[self.level][1], self.tile_size) # Get the coordinates of the top left corners of the tiles.
        for x_i, xloc in enumerate(xlocs):
            for y_i, yloc in enumerate(ylocs):
                region = np.copy(slide.read_region((xloc*(4**self.level),yloc*(4**self.level)), self.level, (self.tile_size,self.tile_size))) # The position is wrt. level 0, so must convert to level 0 coordinates by multiplying by the downsampling factor.
                region_arr = np.asarray(region)[:,:,:3] # Ignore the alpha channel
                if xloc+self.tile_size > slide.level_dimensions[self.level][0] or yloc+self.tile_size > slide.level_dimensions[self.level][1]: # if the tile goes out of bounds
                    region_arr[region_arr==0] = 255
                    required_padding = True
                pixel_sum = region_arr.sum()
                tiles.append({'xloc': xloc, 'yloc': yloc, 'pixel_sum': pixel_sum, 'required_padding': required_padding}) # store top left corner location and the tile's pixel_sum
                required_padding = False
        sorted_tiles = sorted(tiles, key= lambda d: d['pixel_sum']) # Sort tiles based on their pixel_sum field
        sorted_tiles = sorted_tiles[:self.N_tiles] # Get top N tiles
        return sorted_tiles
    
    # Return the tensor of individual tiles
    def get_individual_tiles(self, slide, transform=None):
        tiles_info = self._get_tile_locations_from_slide(slide)
        tiles = torch.empty((self.N_tiles,3, self.tile_size, self.tile_size))
        for i, tile in enumerate(tiles_info):
            img = slide.read_region((tile['xloc']*(4**self.level),tile['yloc']*(4**self.level)), self.level, (self.tile_size,self.tile_size))
            img = torch.clone(transforms.PILToTensor()(img)) # The position is wrt. level 0, so must convert to level 0 coordinates by multiplying by the downsampling factor.
            img = img[:3,:,:] # Ignore the alpha channel
            if tile['required_padding']:
                img[img==0] = 255
            if transforms: # SHOULD BE FASLE FOR TEST
                img = img.float()/255.0 # Necessary for the transformations, images are expected to be a tensor with elements between [0,1]
                img = transform(img)
            tiles[i,...] = img
        return tiles

In [20]:
class TestDataset(pl.LightningModule):
    def __init__(self, df, dir_name):
        super().__init__()
        self.df = df
        self.img_dir = f'/kaggle/input/prostate-cancer-grade-assessment/{dir_name}/'
        self.tiler = Tiler(N_tiles = 36, tile_size=2**8, level=1)

        self.normalize = transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
            
    # Returns the individual tiles
    def __getitem__(self, idx):
            item_name = self.df.iloc[idx].loc['image_id'] # Get the name of the sample
            file_path = os.path.join(self.img_dir, f'{item_name}.tiff')
            slide = openslide.OpenSlide(file_path)
            tiles = self.tiler.get_individual_tiles(slide, transform=self.normalize)
            return tiles
        
    def __len__(self):
        return len(self.df)

In [21]:
def inference(model, dataloader, device):
    model.eval()
    model.to(device)
    preds = []
    sigm = nn.Sigmoid()
    threshold = 0.5
    for i, img in enumerate(dataloader):
        img = img.to(device).float()
        with torch.no_grad():
            output = model(img)
            output = sigm(output)
            output = torch.where(output>threshold, 1, 0)
            pred = torch.sum(output, axis=1)
        preds.append(pred.to('cpu').numpy()) # add the predictions of this batch to the overall list
    preds = np.concatenate(preds) # Make it a single list of predictions over all batches
    return preds

In [22]:
def submit(model, sample, dir_name='test_images'):
    if os.path.exists(f'../input/prostate-cancer-grade-assessment/{dir_name}'):
        print('run inference')
        test_dataset = TestDataset(sample, dir_name)
        test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
        preds = inference(model, test_loader, device)
        sample['isup_grade'] = preds
    return sample

In [25]:
# # check using train_images
# model = torch.jit.load('/kaggle/input/ismi-group3-panda-trained-models/convnext_pico_submission.pt')
# submission = submit(model, train.head(), dir_name='train_images')
# submission['isup_grade'] = submission['isup_grade'].astype(int)
# submission.to_csv('submission.csv', index=False)
# submission.head()