In [1]:
import glob
import cv2
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
from torchvision.ops import masks_to_boxes
from torchvision.io import read_image
from prettytable import PrettyTable
from segment_anything import sam_model_registry
import os
from segment_anything.utils.transforms import ResizeLongestSide
from skimage.transform import resize

NUM_VIEWS_PER_SCENE = 6
TOD_filepath = 'C:/dataset/TOD/training_set/'
TOD_PROCESSED = 'C:/dataset/TOD/preprocessed/'

# Utils Functions
def get_bounding_boxes(mask):
    obj_ids = torch.unique(mask)
    obj_ids = obj_ids[1:]
    masks = mask == obj_ids[:, None, None]
    boxes = masks_to_boxes(masks).detach().numpy()
    if boxes.all()==None:
        boxes=np.array([])
    return boxes

def array_to_tensor(array):
    """ Converts a numpy.ndarray (N x H x W x C) to a torch.FloatTensor of shape (N x C x H x W)
        OR
        converts a nump.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W)
    """
    if array.ndim == 4: # NHWC
        tensor = torch.from_numpy(array).permute(0,3,1,2).float()
    elif array.ndim == 3: # HWC
        tensor = torch.from_numpy(array).permute(2,0,1).float()
    else: # everything else
        tensor = torch.from_numpy(array).float()

    return tensor

def show_mem(device):
    return f"gpu used {torch.cuda.max_memory_allocated(device)/ (1024 ** 3):.02} GB memory"

# Visualization & Ploting Functions
def visualize_segmentation(im, masks, nc=None):
    """ Visualize segmentations nicely. Based on code from:
        https://github.com/roytseng-tw/Detectron.pytorch/blob/master/lib/utils/vis.py

        @param im: a [H x W x 3] RGB image. numpy array of dtype np.uint8
        @param masks: a [H x W] numpy array of dtype np.uint8 with values in {0, ..., nc-1}
        @param nc: total number of colors. If None, this will be inferred by masks

        @return: a [H x W x 3] numpy array of dtype np.uint8
    """ 
    masks = masks.astype(int)
    im = im.copy()

    # Generate color mask
    if nc is None:
        NUM_COLORS = masks.max() + 1
    else:
        NUM_COLORS = nc

    cm = plt.get_cmap('gist_rainbow')
    colors = [cm(1. * i/NUM_COLORS) for i in range(NUM_COLORS)]

    # Mask
    imgMask = np.zeros(im.shape)


    # Draw color masks
    for i in np.unique(masks):
        if i == 0: # background
            continue

        # Get the color mask
        color_mask = np.array(colors[i][:3])
        w_ratio = .4
        for c in range(3):
            color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio
        e = (masks == i)

        # Add to the mask
        imgMask[e] = color_mask

    # Add the mask to the image
    imgMask = (imgMask * 255).round().astype(np.uint8)
    im = cv2.addWeighted(im, 0.5, imgMask, 0.5, 0.0)


    # Draw mask contours
    for i in np.unique(masks):
        if i == 0: # background
            continue

        # Get the color mask
        color_mask = np.array(colors[i][:3])
        w_ratio = .4
        for c in range(3):
            color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio
        e = (masks == i)

        # Find contours
        contour, hier = cv2.findContours(
            e.astype(np.uint8).copy(), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)

        # Plot the nice outline
        for c in contour:
            cv2.drawContours(im, contour, -1, (255,255,255), 2)

    return im

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

# # Ref: https://stackoverflow.com/questions/49201236/check-the-total-number-of-parameters-in-a-pytorch-model
def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

def show_bbox_and_mask(rgb_img, bbox, mask_img):
    seg_img=visualize_segmentation(rgb_img,mask_img)
    fig = plt.figure(figsize=(10, 10)) 
    
    fig.add_subplot(1, 3, 1) 
    plt.imshow(rgb_img)
    for box in bbox:
        show_box(box, plt.gca())
    plt.axis('on')
    plt.title("BBOX")

    fig.add_subplot(1, 3, 2) 
    plt.imshow(mask_img)
    plt.axis('on')
    plt.title("MASK")

    fig.add_subplot(1, 3, 3) 
    plt.imshow(seg_img)
    plt.axis('on')
    plt.title("Segmentation")
    plt.show()

scene_dirs = sorted(glob.glob(TOD_filepath + '*/'))
num_samples = len(scene_dirs)*NUM_VIEWS_PER_SCENE
print(len(scene_dirs), num_samples)

  from .autonotebook import tqdm as notebook_tqdm


40000 240000


In [2]:
# checkpoint = "sam_vit_b_01ec64.pth"
# model_type = "vit_b"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(show_mem(device))
checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam_model = sam_model_registry[model_type](checkpoint=checkpoint).to(device)
count_parameters(sam_model)
print(show_mem(device))

gpu used 0.0 GB memory
+-----------------------------------------------------------------------------+------------+
|                                   Modules                                   | Parameters |
+-----------------------------------------------------------------------------+------------+
|                           image_encoder.pos_embed                           |  5242880   |
|                    image_encoder.patch_embed.proj.weight                    |   983040   |
|                     image_encoder.patch_embed.proj.bias                     |    1280    |
|                     image_encoder.blocks.0.norm1.weight                     |    1280    |
|                      image_encoder.blocks.0.norm1.bias                      |    1280    |
|                    image_encoder.blocks.0.attn.rel_pos_h                    |    2160    |
|                    image_encoder.blocks.0.attn.rel_pos_w                    |    2160    |
|                    image_encoder.blocks.0.att

In [3]:
# # Ref: https://github.com/bowang-lab/MedSAM/blob/66cf4799a9ab9a8e08428a5087e73fc21b2b61cd/finetune_and_inference_tutorial_2D_dataset.ipynb

START_DATA = 1500
END_DATA = 2100

# tensorboard --logdir=D:\UTD\Summer\log\sam
# tensorboard --logdir=./log/sam

with torch.profiler.profile(
        schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
        on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/sam'),
        record_shapes=True,
        with_stack=True
) as prof:
    for idx in tqdm(range(START_DATA, END_DATA)):
        
        # Get scene directory
        scene_idx = idx // NUM_VIEWS_PER_SCENE                                              # [0-40,000]
        scene_dir = scene_dirs[scene_idx]       

        # Get view number
        view_num = (idx % NUM_VIEWS_PER_SCENE)+1                                            # [0-240,000]
        
        # get image and mask paths
        rgb_img_filename = scene_dir + f"rgb_{view_num:05d}.jpeg"   
        mask_filename = scene_dir + f"segmentation_{view_num:05d}.png"
        
        # Read image and segmentation mask 
        image = cv2.cvtColor(cv2.imread(rgb_img_filename), cv2.COLOR_BGR2RGB)               # Shape: (480, 640, 3)                   OG💾                                           
        mask  = np.array(Image.open(mask_filename))                                         # (480, 640)                             OG💾
        
        original_image_size = image.shape[:2]                                                                                                 

        # Genrate Bounding Boxes from segmentation image
        boxes = get_bounding_boxes(read_image(mask_filename))                               # Shape: (*, 4) - xmin, ymin, xmax, ymax OG💾    
        # boxes_torch = torch.tensor(boxes)                                                                                                     

        resize_image = resize(
            image, 
            (sam_model.image_encoder.img_size, sam_model.image_encoder.img_size),           # Shape: (1024, 1024, 3)
            anti_aliasing=False) 

        # Transform input image to the proper input format and size
        sam_transformations = ResizeLongestSide(sam_model.image_encoder.img_size)
        transform_image = sam_transformations.apply_image(resize_image)                     # Shape: (1024, 1024, 3) np
        transform_image = torch.as_tensor(transform_image)                                  # Shape: (1024, 1024, 3) torch
        transform_image = transform_image.permute(2, 0, 1).contiguous()                     # Shape: (3, 1024, 1024) torch

        # Transform boxes to proper input format
        # transform_boxes = sam_transformations.apply_boxes_torch(boxes_torch, original_image_size) # (*, 4) but scaled to (1024, 1024) IO💾 

        # Seperate Segments out of masks 
        num_objects = boxes.shape[0]                                                        # no. of objects in the image  (*)              
        
        # Add Batch to the sample input 
        torch_image = torch.tensor(np.array([transform_image]))                             # Shape: (1, 3, 1024, 1024)
        # torch_mask  = torch.tensor(np.array([gtmask])).permute(1, 0, 2, 3)                  # Shape: (*, 1, 480, 640)                 IO💾

        # Generate Image Embedding  
        with torch.no_grad():
            input_image = sam_model.preprocess(torch_image.to(device))                          # Shape: (1, 3, 1024, 1024)               
            assert input_image.shape == (1, 3, sam_model.image_encoder.img_size, sam_model.image_encoder.img_size)
            image_embedding = sam_model.image_encoder(input_image)                          # Shape: (1, 256, 64, 64)                 IO💾

        # Save All
        if not os.path.exists(TOD_PROCESSED+f"scene_{scene_idx:05d}"):
            os.makedirs(TOD_PROCESSED+f"scene_{scene_idx:05d}")

        np.save(TOD_PROCESSED+f"scene_{scene_idx:05d}/"+f"embeddings_{view_num:05d}", image_embedding.cpu().numpy())  # Shape: (1, 256, 64, 64)
        np.save(TOD_PROCESSED+f"scene_{scene_idx:05d}/"+f"image_{view_num:05d}", np.array(image))                     # Shape: (480, 640, 3) 
        np.save(TOD_PROCESSED+f"scene_{scene_idx:05d}/"+f"mask_{view_num:05d}", np.array([mask]))                     # Shape: (1, 480, 640)  
        np.save(TOD_PROCESSED+f"scene_{scene_idx:05d}/"+f"boxes_{view_num:05d}", np.array(boxes))                     # Shape: (*, 4)

        prof.step() # Need to call this at the end of each step to notify profiler of steps' boundary.

100%|██████████| 600/600 [24:17<00:00,  2.43s/it]


In [None]:
# %tensorboard --logdir="D:/UTD/Summer/Dataset/log/sam"
# DATASET_PATH = "D:/UTD/Summer/Dataset/"

# np.savez_compressed(DATASET_PATH+f"original_dataset_{START_DATA}_to_{END_DATA}.npz", imgs=original_images, masks=original_masks)
# np.savez_compressed(DATASET_PATH+f"almost_preprocessed_dataset_{START_DATA}_to_{END_DATA}.npz", imgs=original_images, img_inputs=sam_input_images, masks=original_masks)
# # np.savez_compressed(DATASET_PATH+f"preprocessed_dataset_{START_DATA}_to_{END_DATA}.npz", imgs=original_images, img_embeddings=sam_image_embeddings, masks=original_masks)