In [None]:
from __future__ import print_function, division
import shutil, sys  
import torch
import torch.nn as nn
import Reconstruction_dataset as dt
import model_4_o as mdo
import model_ResUGAN as md
import numpy as np
import os, glob
from torch.utils.data import Dataset, DataLoader
from collections import OrderedDict
import matplotlib.pyplot as plt
from torchvision import transforms, datasets, utils, models
import pandas as pd
import math
import torchvision
from skimage import io, transform, img_as_float, exposure
import cv2
from tkinter.filedialog import askopenfilename, askopenfilenames
from skimage.transform import rescale, resize, downscale_local_mean
import imagej
#ij=imagej.init(r'C:\fiji-win64\Fiji.app')

In [None]:
def get_image_names():
    
    filenames = askopenfilenames()
    
    return filenames

In [None]:
def image_padding(image, stride, patch_size):
    
    y_size = image.shape[0]
    x_size = image.shape[1]
    (dy,my)=divmod(y_size-patch_size,stride)
    (dx,mx)=divmod(x_size-patch_size,stride)
    dif_y=stride-my
    dif_x=stride-mx
    if (my > 0 ):
        
        new_y_size = y_size + dif_y
        pad_y = dif_y
        
    else:
    
        new_y_size = y_size
        pad_y = my
    if (mx > 0 ):   
        new_x_size = x_size + dif_x
        pad_x = dif_x
    else:
        new_x_size = x_size
        pad_x = mx
    
    image_resized = resize(image, (new_y_size, new_x_size))
    padded_image= cv2.copyMakeBorder(image,0,pad_y,0,pad_x,cv2.BORDER_CONSTANT,value=0)
    
    return (padded_image,pad_y,pad_x)

In [None]:
def patch_num_calculator(image,patch_size, stride):
    
    y_size = image.shape[0]
    x_size = image.shape[1]

    rows_patch_num = int(math.floor((image.shape[0]-patch_size)/stride)+1)
    culs_patch_num = int(math.floor((image.shape[1]-patch_size)/stride)+1)

    return (rows_patch_num,culs_patch_num)

In [None]:
def image_to_minibatch(image, stride,patch_size, device,rows_patch_num,culs_patch_num, unpadd_patch, HE_name):
    
    path = os.getcwd()  
    HEpatch_path1 =os.path.join(path, "Patches", "HEpatches")
    HEpatch_path =os.path.join(HEpatch_path1, "HEpatches")

    if ( os.path.exists(HEpatch_path) != True):
            os.mkdir(HEpatch_path)
            
    output_path1 =os.path.join(path, "Patches", "HPan-Ade170Sur-01_101")
    img_name = HE_name.split('/')[-1].split('.tif')[0]
    output_path =os.path.join(output_path1, img_name)

    if ( os.path.exists(output_path) != True):
            os.mkdir(output_path) 
                                  
    config_file=os.path.join(output_path, "TileConfiguration.txt")
    config = open(config_file,"w+")
    openning_text = "# Define the number of dimensions we are working on" + "\n" + "dim = 2" + "\n\n" + "# Define the image coordinates"+ "\n"
    config.write(openning_text)
    to_be_cropped_row=[]
    to_be_cropped_col=[]
    for i in range(rows_patch_num):
        for j in range(culs_patch_num):

             HEpatch=image[i*stride : i*stride + patch_size, j*stride : j*stride + patch_size, :]
             col = i*stride - unpadd_patch
             row = j*stride - unpadd_patch
             image_name = img_name + "_SHGpatch_" + i.__str__() +"_" + j.__str__() +".tif"
             image_path=os.path.join(HEpatch_path, image_name)
             coordinates = "; ;" + " (" + row.__str__()  + "," + col.__str__()  + ")" + "\n"
             config_text = image_name + coordinates
             config.write(config_text)
             io.imsave(image_path,HEpatch)
             if (i == rows_patch_num-1):
                to_be_cropped_row.append(image_name)
             if (j == culs_patch_num-1):
                to_be_cropped_col.append(image_name)
    config.close()
    return (HEpatch_path,output_path, to_be_cropped_row, to_be_cropped_col)

In [None]:
# print gpu
torch.cuda.set_device(0)
currentDevice = torch.cuda.current_device()
print("Current GPU: " + str(currentDevice))
print(str(torch.cuda.device_count()))
print(str(torch.cuda.get_device_capability(currentDevice)))
print(torch.__version__)

USE_GPU = 1
if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = "cpu"
print(device)

In [None]:
cwd = os.getcwd()
path = os.path.join(cwd, 'Saved model', 'GAN_output', 'trained_generator.pth')

# path = os.path.join(cwd, 'Best model', 'encoderresinfo_0414.pth')
state_dict = torch.load(path, map_location=torch.device('cuda:0'))
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
# load params
# # model = md.Net()
model = md.GeneratorUNet()
# model.load_state_dict(new_state_dict)
model.load_state_dict(new_state_dict)
model.eval()
model.to(device)

In [None]:
def minibatch_to_network(HEpatch_path, model, minibatch_size, device, output_path, patch_size,to_be_cropped_row, to_be_cropped_col,pad_y,pad_x):
   
    image_datasets = datasets.ImageFolder(HEpatch_path, transform = transforms.ToTensor())
    dataloader = DataLoader(image_datasets, batch_size=minibatch_size,
                        shuffle=False, num_workers=0)
    k=0
    l=0
    for  inputs, labels in dataloader:
                inputs = inputs.to(device)
                network_output = model(inputs)
                
                
                for i in range(network_output.shape[0]):
                    
                        patch1 = network_output[i,:, :, :].data.cpu().numpy().transpose(1,2,0).reshape(patch_size , patch_size)
                        patch = exposure.adjust_gamma(patch1, 0.9)
                        img_file_name = image_datasets.imgs[k*minibatch_size+i][0]
                        img_name = img_file_name.split('\\')[-1]
                        image_path=os.path.join(output_path, img_name)
                        data = 255 * patch # Now scale by 255
                        img = data.astype(np.uint8)
                        ret,img = cv2.threshold(img,40,255,cv2.THRESH_TOZERO)

                        if (img_name in to_be_cropped_col):
                            img=img[:,0:patch_size-pad_x]
                        if (img_name in to_be_cropped_row):
                            img=img[0:patch_size-pad_y,:]
                            
                        cv2.imwrite(image_path,img)
                        l=l+1
                k=k+1


In [None]:
unpadd_patch=0
stride=96
patch_size= 128
minibatch_size = 128
HE_names = get_image_names()
for i in range(HE_names.__len__()):
    HE_name = ''.join(HE_names[i])
    HE_image=io.imread(HE_name)
    (padded_image,pad_y,pad_x)=image_padding(HE_image, stride, patch_size)

    (rows_patch_num,culs_patch_num) = patch_num_calculator(padded_image,patch_size, stride)
    (HEpatch_path,output_path, to_be_cropped_row, to_be_cropped_col) = image_to_minibatch(padded_image, 
                                                                                          stride,patch_size, 
                                                                                          device,rows_patch_num,
                                                                                          culs_patch_num,unpadd_patch,
                                                                                          HE_name)
    minibatch_to_network("F:\HE_SHG_project\Patches\HEpatches", model, minibatch_size, device, output_path, 
                         patch_size,to_be_cropped_row, to_be_cropped_col,pad_y,pad_x)

    shutil.rmtree(HEpatch_path)


In [None]:
unpadd_patch=0
stride=96
patch_size= 128
minibatch_size = 128
HE_names = get_image_names()
save_path= r'F:\HE_SHG_project\stictched synthesized\New_breast'
for i in range(HE_names.__len__()):
    HE_name = ''.join(HE_names[i])
    HE_image=io.imread(HE_name)
    (padded_image,pad_y,pad_x)=image_padding(HE_image, stride, patch_size)

    (rows_patch_num,culs_patch_num) = patch_num_calculator(padded_image,patch_size, stride)
    (HEpatch_path,output_path, to_be_cropped_row, to_be_cropped_col) = image_to_minibatch(padded_image, 
                                                                                          stride,patch_size, 
                                                                                          device,rows_patch_num,
                                                                                          culs_patch_num,unpadd_patch,
                                                                                          HE_name)
    minibatch_to_network("F:\HE_SHG_project\Patches\HEpatches", model, minibatch_size, device, output_path, 
                         patch_size,to_be_cropped_row, to_be_cropped_col,pad_y,pad_x)
    args = {'type': 'Positions from file', 'order': 'Defined by TileConfiguration', 'directory':output_path, 
            'ayout_file': 'TileConfiguration.txt', 'fusion_method': 'Linear Blending', 'regression_threshold': '0.30', 
            'max/avg_displacement_threshold':'2.50', 'absolute_displacement_threshold': '3.50', 
            'computation_parameters': 'Save memory (but be slower)', 'image_output': 'Write to disk', 
            'output_directory': save_path}
    plugin = "Grid/Collection stitching"
     
    ij.py.run_plugin(plugin, args)
    shutil.rmtree(HEpatch_path)
    shutil.rmtree(output_path)

In [None]:
HE_names = get_image_names()
output_path = "F:\\HE_SHG_project\\stictched synthesized\\SA_stroma\\Thresholded\\"
for i in range(HE_names.__len__()):
    HE_name = ''.join(HE_names[i])
    HE_image=io.imread(HE_name)
    ret,img = cv2.threshold(HE_image,40,200,cv2.THRESH_TOZERO)
    img_name = HE_name.split('\\')[-1]
    image_path=os.path.join(output_path, img_name)
                       
    cv2.imwrite(image_path,img)

In [None]:
    image_path=os.path.join(output_path, img_name)

image_path

In [None]:
output_path = "F:\\HE_SHG_project\\stictched synthesized\\SA_stroma\\Thresholded\\"
image_path=os.path.join(output_path, img_name)
image_path

In [None]:
print(i)