In [None]:
import torch
import torch.nn as nn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import time
import os

import tifffile as tif
import pandas as pd
from PIL import ImageEnhance
from torch.utils.data import DataLoader,Dataset
from PIL import Image
import cv2


import patch_gen_test

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')

In [None]:
model = torch.load('../model_name.pth')
model = model.to(device)
next(model.parameters()).is_cuda

In [None]:
class CancerTestDataset(Dataset):
    
    def __init__(self, img_path, img_list, transforms = None):
        
        super().__init__()
        self.img_path = img_path
        self.img_list = img_list
        self.transforms = transforms
        
        
    def __getitem__(self, index):
    
        filename = self.img_list[index]
        image = Image.open(os.path.join(self.img_path, filename))

        if self.transforms:
            image = self.transforms(image)
        
        return image, filename
        
    def __len__(self):
        return (len(self.img_list))

In [None]:
def get_default_device():
    """Pick GPU if available, else CPU"""
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
    
def to_device(data, device):
    """Move tensor(s) to chosen device"""
    if isinstance(data, (list,tuple)):
        return [to_device(x, device) for x in data]
    return data.to(device, non_blocking=True)

class DeviceDataLoader():
    """Wrap a dataloader to move data to a device"""
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device
        
    def __iter__(self):
        """Yield a batch of data after moving it to device"""
        for b in self.dl: 
            yield to_device(b, self.device)

    def __len__(self):
        """Number of batches"""
        return len(self.dl)    

In [None]:
def get_transform():
    return transforms.Compose([
        transforms.Resize(256),
        transforms.ToTensor(),
        #transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

In [None]:
#testing part
strip_num = 11
slide_path = 'H:/WSI from POSTECH/OPM data/colon/cancer/21S-122361A10'
slide_test = tif.imread(os.path.join(slide_path, f'{strip_num}.tiff'))
img = (slide_test/65535)*255
slide = np.asarray(img).astype(np.uint8)
#img = Image.fromarray(img)+

W, H, C = slide.shape
W, H

In [None]:
saving_path = f'../test/{strip_num}/'

for _,_,files in os.walk(saving_path):
    for file in files:
        os.remove(os.path.join(saving_path, file))
#cord_list = patch_generator_test(xo=0, yo=0, xstep=768, ystep=768, H=H, W=W, patch=1024)
#cord_list = patch_gen_test.patch_generator(saving_path, xo=0, yo=0, xstep=384, ystep=384, H=H, W=W, patch_size=512, img=img)
#patchlist = pd.DataFrame(cord_list)

prob_arr_dim = patch_gen_test.patch_generator(saving_path, xo=0, yo=0, 
                                              xstep=384, ystep=384, H=H, W=W, patch_size=512, img=slide)
prob_arr_dim

In [None]:
#img_list = []

for _, _, files in os.walk(saving_path):
    img_list = files

In [None]:
test_dataset2 = CancerTestDataset(saving_path, img_list, get_transform())

In [None]:
test_dataset2[2][1]

In [None]:
class_to_int = {'BG': 0, 'STR': 1, 'cancer': 2, 'normal': 3}
#class_to_int = {'BG': 0, 'LYM': 1, 'STR': 2, 'cancer': 3, 'normal': 4}

In [None]:
#Prediction and testing of MODEL

def predict_image2(img, model):
    # Convert to a batch of 1
    xb = to_device(img.unsqueeze(0), get_default_device())

    pred_sf = model(xb).softmax(dim=1)
    probability = format(pred_sf[0][class_to_int['cancer']], '.4f')
    #print(pred_sf)
    
    outputs = model(xb)
    pred_values, preds_indx = torch.max(outputs, 1)
    #print(probability, preds_indx.item())
    #print(pred_values, preds_indx.item(), format(pred_sf[0][preds_indx.item()], '.4f'))

    return probability, preds_indx

In [None]:
len(test_dataset2)

In [None]:
pred2 ={'x':[], 'y':[], 'idx':[], 'cancer_score':[]}
im = np.zeros((W, H))

pred_im = np.zeros((512, 512))

#prob_arr = np.zeros((prob_arr_dim[0], prob_arr_dim[1]))

for i in range(len(test_dataset2)):
    image = test_dataset2[i][0]
    filename = test_dataset2[i][1]
    prediction, pred_indx = predict_image2(image, model)
    #print(pred_indx.item())
    
    #img11 = Image.open(os.path.join(saving_path, img_list[i]))
    img11 = Image.open(os.path.join(saving_path, filename))
    
    #enhancer = ImageEnhance.Color(img11)
    #img1 = enhancer.enhance(float(prediction)*4)
    
    x = int((filename.split('_')[0]).split('x')[1])
    y = int((filename.split('_')[1]).split('y')[1])
    
    #im.paste(img1, (y, x))
    
    pred2['x'].append(x)
    pred2['y'].append(y)
    pred2['idx'].append(pred_indx.item())
    pred2['cancer_score'].append(prediction)
    
    prediction = int(float(prediction)*100)
    im[x:x+512, y:y+512] = prediction
    
    '''
    #for generating images for the manuscripts
    fig = plt.figure(figsize=(20,20))
    ax = plt.subplot(1, 4, 3)
    ax.imshow(img11, cmap='gray', alpha=1)
    my_cmap = plt.get_cmap('rainbow').copy()
    my_cmap.set_under('k', alpha=0)
    pred_im[0:512, 0:512] = prediction
    print(prediction)
    ax.imshow(pred_im,cmap=my_cmap, vmin=20, vmax=100, alpha=0.6, interpolation='bicubic')
    fig.savefig(f'{saving_path}{filename}.jpg')
    '''
    
    #prob_arr[y//384][x//384] = prediction
    #print([y//384],[x//384])
    
    #pred2.append(prediction)
#im.save(f'{strip_num}_95_74.tif')

In [None]:
gr_slide = cv2.cvtColor(slide, cv2.COLOR_BGR2GRAY)
#gr_dlide = cv2.resize(gr_slide, (gr_slide.shape[1]//10, gr_slide.shape[0]//10))

In [None]:
#overlapping images with prediction map
import matplotlib

fig = plt.figure(figsize=(18,18))
#fig2 = plt.figure(figsize=(10,10))
#fig = plt.figure()

font = {'family' : 'normal',
        'weight' : 'bold',
        'size'   : 10}

matplotlib.rc('font', **font)

plt.rc('font', size=18)
#plt.rc('xtick', labelsize=18)
#plt.rc('ytick', labelsize=18)

ax2 = plt.subplot(1, 4, 3, aspect='equal')
ax2.imshow(np.squeeze(gr_slide), alpha = 1, cmap='gray')

#my_cmap = cm.get_cmap("gist_rainbow_r").copy()
my_cmap = cm.get_cmap("rainbow").copy()
my_cmap.set_under('k', alpha=0)
hm = ax2.imshow(im, alpha=0.6, clim=[40,100], cmap=my_cmap, interpolation='gaussian')
#fig.colorbar(hm, fraction=0.15)
#plt.ylabel('cancer probability in percentage', )
fig.savefig(f'{slide_path}/{strip_num}_94_75.jpg')

In [None]:
pd.DataFrame(pred2)