In [1]:
# PYTHON IMPORTS
import os
import copy
from tqdm.notebook import trange, tqdm

# IMAGE IMPORTS 
from PIL import Image
import cv2

# DATA IMPORTS 
import numpy as np

# PLOTTING
import matplotlib.pyplot as plt
import matplotlib.cm as cm

# SAM
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry

# MY OWN CLASSES
from TileLocator import *

In [2]:
sam = sam_model_registry["vit_h"](checkpoint="data/SAM/sam_vit_h_4b8939.pth")
sam = sam.to("cuda")

In [3]:
mask_generator = SamAutomaticMaskGenerator(sam)

In [4]:
def resize_image(image, n_pixels):
    
    # Get the dimensions of the image
    height, width = image.shape[:2]

    # Determine the scaling factor for resizing
    if width > height:
        scaling_factor = n_pixels / width
    else:
        scaling_factor = n_pixels / height

    # Resize the image
    new_width = int(width * scaling_factor)
    new_height = int(height * scaling_factor)
    resized_image = cv2.resize(image, (new_width, new_height))

    return resized_image

def show_anns(anns, ax):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    # ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
        
    return img

In [5]:
image = 'data/TileIndices/48201CIND0_0992.tif'

In [None]:
def split_and_SAM(image_path,mask_generator, tilesize=1024):
    # Load the image
    image = Image.open(image_path)
    
    # Calculate the number of tiles needed
    width, height = image.size
    num_tiles_x = (width + tilesize-1) // tilesize
    num_tiles_y = (height + tilesize-1) // tilesize
    
    # Create an empty list to store the output tiles
    output_masks = []
    
    # Iterate over each tile
    for tile_x in tqdm(range(num_tiles_x)):
        for tile_y in range(num_tiles_y):
                        
            # Calculate the coordinates for the current tile
            x0 = tile_x * tilesize
            y0 = tile_y * tilesize
            x1 = min(x0 + tilesize, width)
            y1 = min(y0 + tilesize, height)
            
            # Crop the image to the current tile
            tile = image.crop((x0, y0, x1, y1))
            
            # Pad the tile if needed
            pad_width = tilesize - tile.width
            pad_height = tilesize - tile.height
            if pad_width > 0 or pad_height > 0:
                padding = ((0, pad_height), (0, pad_width))
                tile = np.pad(tile, padding, mode='constant')
            
            tile = np.asarray(tile)
            tile = np.dstack((tile, tile, tile))
            tile = np.where(tile, 255, 0).astype(np.uint8)
            
            print(tile.shape)
            
            
            # Run the CNN on the tile
            masks = mask_generator.generate(tile)
            
            output_masks.append(masks)
            
    return output_masks

masks = split_and_SAM(image, mask_generator)



  0%|          | 0/15 [00:00<?, ?it/s]

(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1024, 3)
(1024, 1

In [None]:
print(test.shape)

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(20, 20))
axs[0].imshow(image_res)
#axs[1].imshow(image_res)
masks_img = show_anns(masks, axs[1])
axs[1].imshow(masks_img)

In [None]:
def revert_resizing(resized_image, original_image):
    # Get the dimensions of the original image
    original_height, original_width = original_image.shape[:2]

    # Get the dimensions of the resized image
    resized_height, resized_width = resized_image.shape[:2]

    # Determine the scaling factor for reverting the resizing
    if resized_width > resized_height:
        scaling_factor = original_width / resized_width
    else:
        scaling_factor = original_height / resized_height

    # Resize the image back to the original dimensions
    new_width = int(resized_width * scaling_factor)
    new_height = int(resized_height * scaling_factor)
    reverted_image = cv2.resize(resized_image, (new_width, new_height))

    return reverted_image * 255

or_res_masks = revert_resizing(masks_img, image).astype(np.uint8)

plt.imshow(or_res_masks)

In [None]:
plt.hist(or_res_masks.flatten())

In [None]:
cv2.imwrite("masks_test.png", or_res_masks)

In [None]:
or_res_masks_grey  = cv2.cvtColor(or_res_masks, cv2.COLOR_BGR2GRAY)
plt.imshow(or_res_masks_grey)

In [None]:
or_res_masks_black = np.where(or_res_masks_grey < 250, 0, 255)

plt.imshow(or_res_masks_black)

In [None]:
cv2.imwrite("masks_black.png", or_res_masks_black)