In [1]:
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torchvision
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())

PyTorch version: 2.5.1+cu124
Torchvision version: 0.20.1+cu124
CUDA is available: False


  return torch._C._cuda_getDeviceCount() > 0


In [2]:
# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )
np.random.seed(3)


using device: cpu


First, load the SAM 2 model and predictor. Change the path below to point to the SAM 2 checkpoint. Running on CUDA and using the default model are recommended for best results.

In [3]:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

sam2_checkpoint = "../checkpoints/sam2.1_hiera_large.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)

predictor = SAM2ImagePredictor(sam2_model)

In [4]:
import numpy as np
from scipy.ndimage import label, find_objects, binary_erosion, binary_dilation
from scipy.ndimage import generate_binary_structure


def get_largest_connected_component_and_bbox_mask(mask, radius=5, dilation_radius=20):
    # Define a structuring element for erosion and dilation
    struct_elem = generate_binary_structure(2, 1)
    
    # Erode and then dilate the mask to remove small noise
    mask_eroded = binary_erosion(mask, structure=struct_elem, iterations=radius)
    mask_cleaned = binary_dilation(mask_eroded, structure=struct_elem, iterations=radius)
    
    # Label connected components in the cleaned mask
    labeled_array, num_features = label(mask_cleaned)
    
    # Get sizes of connected components
    component_sizes = [(labeled_array == i).sum() for i in range(1, num_features + 1)]
    
    # Identify the largest component by its label
    largest_component_label = np.argmax(component_sizes) + 1
    
    # Create a mask with only the largest connected component
    largest_component = (labeled_array == largest_component_label).astype(np.uint8)
    
    # Find the bounding box of the largest connected component
    bbox = find_objects(labeled_array == largest_component_label)[0]
    min_row, min_col = bbox[0].start, bbox[1].start
    max_row, max_col = bbox[0].stop, bbox[1].stop
    
    # Create an empty array of the same shape as the input mask
    bbox_mask = np.zeros_like(mask, dtype=np.uint8)
    
    # Set the bounding box region to 1
    bbox_mask[min_row:max_row, min_col:max_col] = 1

    # Dilate the bounding box mask by 20 pixels
    bbox_mask_dilated = binary_dilation(bbox_mask, structure=struct_elem, iterations=dilation_radius)
    
    return largest_component, bbox_mask_dilated


def crop_image_by_bbox_per_channel(image, bbox_mask):
    # Find the coordinates of the bounding box from the bbox mask
    rows = np.any(bbox_mask, axis=1)
    cols = np.any(bbox_mask, axis=0)
    min_row, max_row = np.where(rows)[0][[0, -1]]
    min_col, max_col = np.where(cols)[0][[0, -1]]
    
    # Crop each channel of the image using the bounding box coordinates
    cropped_channels = [image[min_row:max_row+1, min_col:max_col+1, c] for c in range(image.shape[2])]
    
    # Stack cropped channels back together to form the final cropped image
    cropped_image = np.stack(cropped_channels, axis=-1)
    
    return cropped_image

def get_center_and_center_rectangle(array, size_factor=0.2):
    h, w = array.shape
    cy, cx = h // 2, w // 2
    rh, rw = int(h * size_factor), int(w * size_factor)
    
    # Calculate the corner points
    top_left = [cx - rw // 2, cy - rh // 2]
    top_right = [top_left[0], top_left[1] + rh]
    bottom_left = [top_left[0] + rw, top_left[1]]
    bottom_right = [top_left[0] + rw, top_left[1] + rh]
    
    # Return the coordinates as a list
    return np.array([[cx, cy], top_left, top_right, bottom_left, bottom_right])



## Example image

In [None]:
from PIL import Image

# Directories for input images and output cropped images
img_dir  = 'quarter_reslution_tactile_img_717/'  # Input directory containing raw tactile images
crop_dir = 'quarter_reslution_tactile_img_cropped_1point/'  # Output directory for cropped images

# Iterate through all files in the input directory in sorted order
for idx, d in enumerate(sorted(os.listdir(img_dir))):
    # Construct the full path to the input image
    img_path = img_dir + d    
    
    # Modify the file name to include '_cropped' before the extension ('id001.jpg' -> 'id001_cropped.jpg')
    dot_index = d.rfind('.')
    crop_d = d[:dot_index] + '_cropped' + d[dot_index:]
    targ_path = crop_dir + crop_d  # Full path for the cropped image output
    
    # Log the index and file name being processed
    print(idx, d)
    
    # Open the input image and ensure it's in RGB format
    image = Image.open(img_path)
    image = np.array(image.convert("RGB"))
    
    # Set the image for the predictor model
    predictor.set_image(image)
    
    # Define a central point as the input for the model
    input_point = np.array([[int(image.shape[1]/2), int(image.shape[0]/2)]])  # Central point of the image
    input_label = np.array([1])  # Label for the input point (e.g., foreground)
    
    # Perform prediction using the SAM2 model
    masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True,  # Generate multiple masks
    )
    
    # Sort the masks by their confidence scores in descending order
    sorted_ind = np.argsort(scores)[::-1]
    masks = masks[sorted_ind]  # Reorder masks by scores
    scores = scores[sorted_ind]  # Reorder scores
    logits = logits[sorted_ind]  # Reorder logits
    
    # Extract the largest connected component and the bounding box mask from the highest-scoring mask
    largest_component, bbox_mask = get_largest_connected_component_and_bbox_mask(masks[0])
    
    # Crop the image using the bounding box mask
    cropped_image = crop_image_by_bbox_per_channel(image, bbox_mask)
    
    # Ensure the cropped image is in the correct data type for saving
    cropped_image = (cropped_image * 255).astype(np.uint8) if cropped_image.dtype != np.uint8 else cropped_image
    
    # Save the cropped image to the target path
    Image.fromarray(cropped_image).save(targ_path)
    
    # Log the success of the cropping operation
    print(idx, d, "Cropped image shape:", cropped_image.shape, 'saved', targ_path)
