In [1]:
from skimage import io, filters, transform 
import tifffile as tiff
import albumentations as A
 
import os
import cv2
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F


In [2]:
class config:
    DEVICE = "cuda"
    FOLDS = 5
    LR = 1e-3

In [3]:
class HubDataset(torch.utils.data.Dataset):
    def __init__(self,reader, coords_list,):
        self.reader = reader
        self.coords_list = coords_list

        
    def __len__(self):
        return len(self.coords_list)
    
    def __getitem__(self,item):
        
        coords = self.coords_list[item]
        image = self.reader.get_tiles(coords[0], coords[1])

        image = image.astype(np.uint8)/255    
        image = np.transpose(image, (2, 0, 1)).astype(np.float32)

        return {
            "image": torch.tensor(image, dtype=torch.float),
            "coords": coords,
        }

In [4]:
import torch
import torch.nn as nn

@torch.no_grad()    
def infer(model,valid_loader,device):
    model.eval()
    final_coords = []
    masks = []
    for data in valid_loader:
        inputs = data['image']
        
        inputs = inputs.to(device, dtype=torch.float)

        output = model(inputs,)  
        output = torch.sigmoid(output)
        
        output = output[:,0,:,:].detach().cpu().numpy()  ## B, H, W

        ## postprocess
        for idx in range(output.shape[0]):
            threshold = filters.threshold_mean(output[idx]) ##  isodata, otsu, li, mean, yen, minimum
            mask = output[idx] > threshold
            mask = mask.astype(np.int8)*255
            masks.append(mask)
            
    return masks  #, final_coords

In [5]:
import sys
sys.path.append('CoAT/')

from coat import *
from daformer import *
from helper import *

In [6]:
import torch
import torch.nn as nn
import timm

class MixUpSample(nn.Module):
    def __init__(self, scale_factor=4):
        super().__init__()
        self.mixing = nn.Parameter(torch.tensor(0.5))
        self.scale_factor = scale_factor

    def forward(self, x):
        x = self.mixing * F.interpolate(
            x, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
        ) + (1 - self.mixing) * F.interpolate(
            x, scale_factor=self.scale_factor, mode="nearest"
        )
        return x
    
class Net(nn.Module):
    
    def __init__(self,
                 encoder=coat_lite_medium,
                 decoder=daformer_conv3x3,
                 encoder_cfg={},
                 decoder_cfg={},
                 ):
        
        super(Net, self).__init__()
        decoder_dim = decoder_cfg.get('decoder_dim', 320)

        self.encoder = encoder

        self.rgb = RGB()

        encoder_dim = self.encoder.embed_dims
        # [64, 128, 320, 512]

        self.decoder = decoder(
            encoder_dim=encoder_dim,
            decoder_dim=decoder_dim,
        )
#         self.logit = nn.Sequential(
#             nn.Conv2d(decoder_dim, 1, kernel_size=1),
#             nn.Upsample(scale_factor = 4, mode='bilinear', align_corners=False),
#         )
        self.logit = nn.Conv2d(decoder_dim, 1, kernel_size=1)
        self.mixup = MixUpSample()
    def forward(self, x):

        x = self.rgb(x)

        B, C, H, W = x.shape
        encoder = self.encoder(x)

        last, decoder = self.decoder(encoder)
        logits = self.logit(last)
        
        upsampled_logits = self.mixup(logits)
        
        return upsampled_logits
    

### encoder
class coat_parallel_small_plus1 (CoaT):
    def __init__(self, **kwargs):
        super(coat_parallel_small_plus1, self).__init__(
            patch_size=4,
            embed_dims=[152, 320, 320, 320, 320],
            serial_depths=[2, 2, 2, 2, 2],
            parallel_depth=6,
            num_heads=8,
            mlp_ratios=[4, 4, 4, 4, 4],
            pretrain ='coat_small_7479cf9b.pth',
            **kwargs)


def HubmapModel():
    encoder = coat_lite_medium()
    checkpoint = 'coat_lite_medium_384x384_f9129688.pth'
    checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
    state_dict = checkpoint['model']
    encoder.load_state_dict(state_dict,strict=False)
    
    net = Net(encoder=encoder).cuda()
    
    return net

model_paths = [
              "weights/model-0.pth",
              "weights/model-1.pth",
              "weights/model-2.pth",
              "weights/model-3.pth",
              "weights/model-4.pth", 
              ]

## WSI inference

In [7]:
WSI = "/mnt/prj001/Rama_Downloaded/hamarepository_data/69026_H&E.ndpi"
batch_size = 4
import matplotlib.pyplot as plt
from wsi_inference.patch_generation import ImageReader

# def one_whole_slide(wsi_name):
image_reader = ImageReader(WSI, tile_size=1024, scale_factor=2)
mask_details = image_reader.get_mask(magnification=10)

coords_list = mask_details["list_indices"]
stitch_shape = mask_details["shape"]
steps = mask_details["step_size"]
scale = mask_details["scaling"]

# infer
valid_dataset = HubDataset(reader=image_reader,coords_list=coords_list)
valid_loader = torch.utils.data.DataLoader(valid_dataset,batch_size=batch_size,shuffle=False,pin_memory=True)

model = HubmapModel().cuda()
model.load_state_dict(torch.load(model_paths[0]))
masks_output = infer(model=model,valid_loader=valid_loader,device=config.DEVICE)

# batched_coords = [coords_list[i:i + batch_size] for i in range(0, len(coords_list), batch_size)]  


In [8]:
# def get_stitiched(image_list, coords_list, mask_shape, step_size, scaling):

#         empty_mask = np.zeros(mask_shape)

#         for batch in range(len(image_list)):
#             for indx,coords in enumerate(coords_list[batch]):
#                 patch = transform.resize(image=image_list[batch][indx],output_shape=(step_size,step_size),mode="constant")
#                 empty_mask[int(coords[0]/scaling):int(coords[0]/scaling) + step_size,int(coords[1]/scaling):int(coords[1]/scaling) + step_size] =  patch #np.ones((step_size,step_size))


#         return empty_mask

def get_stitiched(image_list, coords_list, mask_shape, step_size, scaling):

        empty_mask = np.zeros(mask_shape)
        for indx,coords in enumerate(coords_list):
            patch = transform.resize(image=image_list[indx],output_shape=(step_size,step_size),mode="constant")
            empty_mask[int(coords[0]/scaling):int(coords[0]/scaling) + step_size,int(coords[1]/scaling):int(coords[1]/scaling) + step_size] =  patch #np.ones((step_size,step_size))

        return empty_mask

In [9]:
masks_output[1]

array([[  0,   0,   0, ...,   0,   0,   0],
       [  0,   0,   0, ...,   0,   0,   0],
       [  0,   0,   0, ...,   0,   0,   0],
       ...,
       [255, 255, 255, ...,   0,   0,   0],
       [255, 255, 255, ...,   0,   0,   0],
       [255, 255, 255, ...,   0,   0,   0]], dtype=int16)

In [10]:
out_wsi = get_stitiched(image_list=masks_output,
                                     coords_list=coords_list,
                                     mask_shape=stitch_shape,
                                     step_size=steps,
                                     scaling=scale)


In [None]:
io.imshow(out_wsi)

In [1]:
io.imshow(out_wsi)

NameError: name 'io' is not defined