In [5]:
#split and merge image
from PIL import Image

def split_img(img_path):
    image = Image.open(img_path).convert("RGB")

    image = image.resize((1024,1024))
    sub_imgs =[]
    for row in range(3):
        for col in range(3):
            left = col *256
            upper = row *256
            right = left +512
            lower = upper+512

            sub_image =image.crop((left,upper,right,lower))
            sub_imgs.append(sub_image)
    
    return sub_imgs

def merge_image(images):
    merge_image =Image.new("L",[1024,1024])
    new_image_0_0 = images[0].crop((0,0,384,384))
    new_image_0_1 = images[1].crop((128,0,384,384))
    new_image_0_2 = images[2].crop((128,0,512,384))
    new_image_1_0 = images[3].crop((0,128,384,384))
    new_image_1_1 = images[4].crop((128,128,384,384))
    new_image_1_2 = images[5].crop((128,128,512,384))
    new_image_2_0 = images[6].crop((0,128,384,512))
    new_image_2_1 = images[7].crop((128,128,384,512))
    new_image_2_2 = images[8].crop((128,128,512,512))

    merge_image.paste(new_image_0_0,(0,0))
    merge_image.paste(new_image_0_1,(384,0))
    merge_image.paste(new_image_0_2,(640,0))
    merge_image.paste(new_image_1_0,(0,384))
    merge_image.paste(new_image_1_1,(384,384))
    merge_image.paste(new_image_1_2,(640,384))
    merge_image.paste(new_image_2_0,(0,640))
    merge_image.paste(new_image_2_1,(384,640))
    merge_image.paste(new_image_2_2,(640,640))

    return merge_image


In [3]:
# post-process operations 
import cv2 
import numpy as np
import copy
def remove_edge_cells(mask_image):
    w,h = mask_image.shape
    pruned_mask = copy.deepcopy(mask_image)
    remove_list = []
    edges = mask_image[0,:],mask_image[w-1,:],mask_image[:,0],mask_image[:,h-1]
    for edge in edges:
        edge_masks = np.unique(edge)
        for edge_mask in edge_masks:
            remove_list.append(edge_mask)
            pruned_mask[np.where(mask_image==edge_mask)] = 0

    return pruned_mask

def remove_small_cells(mask_image,area_threshold=10):
    w,h = mask_image.shape
    pruned_mask = copy.deepcopy(mask_image)
    for mask_index in np.unique(mask_image):
        # if mask_index == mask_image[330,640]:
        #     a = 1
        area = np.sum(mask_image == mask_index)
        if area < area_threshold:
            pruned_mask[np.where(mask_image == mask_index)] = 0


    return pruned_mask

def remove_concentric_masks(mask_image):
    # Convert the mask image to grayscale
    cell_values = np.unique(mask_image)
    for i in range(1, len(cell_values)):# remove background
        mask_one = np.array(mask_image == cell_values[i],dtype=np.uint8)
        # mask_one_dilated = cv2.dilate(mask_one, np.ones((5, 5), np.uint8),100)
        # xmin, xmax, ymin, ymax = np.min(np.where(mask_one == 1)[0]), np.max(np.where(mask_one == 1)[0]),\
        #     np.min(np.where(mask_one == 1)[1]), np.max(np.where(mask_one == 1)[1]),
        contour, _ = cv2.findContours(mask_one, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if len(contour) > 0:
            largest_contour = max(contour, key=cv2.contourArea)

            mask_image = cv2.drawContours(mask_image, [largest_contour], -1, (int(cell_values[i])), thickness=cv2.FILLED)
    return mask_image


def post_process(final_mask,img_path):
    gray_image = final_mask.convert("L")
    image_array = np.array(gray_image)
    image = cv2.imread(img_path)
    resized_img = cv2.resize(image, (1024,1024))    

    _,thresholded_image =cv2.threshold(image_array,0,255,cv2.THRESH_BINARY)
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(thresholded_image, connectivity=8)
    
    output_image=np.zeros_like(gray_image)

    for label in range(1,num_labels):
        output_image[labels == label] =label
    
    pruned_mask =remove_edge_cells(output_image)
    pruned_mask_reduce = remove_small_cells(pruned_mask,area_threshold=150)
    pruned_mask_reduce = remove_concentric_masks(pruned_mask_reduce)
    
    cell_mask = np.zeros((pruned_mask.shape[0],pruned_mask.shape[1],3))
    cell_num = len(np.unique(pruned_mask_reduce)) - 1
    
    properties ={}
    start_idx = 0
    for i in range(1, cell_num+1):
        mask_one = np.array(pruned_mask_reduce == np.unique(pruned_mask_reduce)[i],dtype=np.uint8)
        try:
            #properties['cell %i'%(i+start_idx)] = analyze_cell_properties(mask_one)
            #cell_color = (np.random.randint(255),
            #                                            np.random.randint(255),
            #                                            np.random.randint(255))
            cell_color =(0,255,0)
            # cell_mask[np.where(mask_one==1)[0],np.where(mask_one==1)[1],:] = (255,0,0)
            contours, hierarchy = cv2.findContours(mask_one, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(cell_mask, contours, -1, cell_color, 3)

            text = str(i+start_idx)
            # Define the font properties
            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = 1
            font_color = cell_color  # White color (BGR format)
            thickness = 3

            # Find the size of the text
            text_size, _ = cv2.getTextSize(text, font, font_scale, thickness)

            # Calculate the position to center the text on the image
            text_x = ( np.max(np.where(mask_one==1)[0]) - np.min(np.where(mask_one==1)[0]))//2 + np.min(np.where(mask_one==1)[0])
            text_y = ( np.max(np.where(mask_one==1)[1]) - np.min(np.where(mask_one==1)[1]))//2 + np.min(np.where(mask_one==1)[1])

            # Add the text to the image
            cell_mask = cv2.putText(cell_mask, text, (text_y, text_x), font, font_scale, font_color, thickness)
        except ZeroDivisionError:
            pass
    
    cell_mask = cv2.addWeighted(np.array(cell_mask, dtype=np.uint8), 1, resized_img, 1, 0)
    pure_mask = np.where(pruned_mask_reduce>0,255,0)
    return pure_mask,cell_mask


In [8]:
#main part
from organoiddataset import *
from torchvision import utils as vutils
import torch
from torch.utils.data import DataLoader
#import torchvision.transforms as transforms


def main(img_path):
    INPUT_SIZE=512
    input_imgs = split_img(img_path)
    
    transforms_test = transforms.Compose([
        transforms.Resize((INPUT_SIZE,INPUT_SIZE)),
        transforms.ToTensor()
    ])

    test_dataset =Organoid(
        images =input_imgs,transform =transforms_test
    )
    test_dataloader =DataLoader(test_dataset, batch_size=1,shuffle=False,num_workers=1)

    best_model = torch.load('checkpoint.pth', map_location='cpu')

    output_imgs = []
    for i, (X) in enumerate(test_dataloader):
        best_model.eval()
        prediction,attention =best_model(X)
        output =torch.argmax(prediction,dim=1)
        #X = X[0,1,:,:].reshape(1,512,512)
        print(output.shape)
        #exit()
        tensor_to_image =transforms.ToPILImage()
        output_img = tensor_to_image(output.to(torch.float32))
        output_imgs.append(output_img)
    second_mask = merge_image(output_imgs)
    final_output,final_mask = post_process(second_mask,img_path)
    
    return final_output,final_mask

if __name__ =="__main__":
    target_img_path = "B03_1.png"

    target_img = Image.open(target_img_path)
    width, height = target_img.size
    #print(width,height)
    final_ouput, final_mask = main(target_img_path)
    final_ouput = cv2.resize(final_ouput,(width,height))
    final_mask = cv2.resize(final_mask,(width,height))
    cv2.imwrite("final_output.png",final_ouput)
    cv2.imwrite("final_mask.png",final_mask)



1024 1024
torch.Size([1, 512, 512])
torch.Size([1, 512, 512])
torch.Size([1, 512, 512])
torch.Size([1, 512, 512])
torch.Size([1, 512, 512])
torch.Size([1, 512, 512])
torch.Size([1, 512, 512])
torch.Size([1, 512, 512])
torch.Size([1, 512, 512])
