# Test trained Model
You can run this notebook to see, how the training affects the model

In [121]:
# segment image region using  fine tune model
# See Train.py on how to fine tune/train the model
import numpy as np
import torch
import cv2
import os
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

In [122]:
# use bfloat16 for the entire script (memory efficient)
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

Select the Image to be segmented. Also add the mask, from this mask num_samples amount of points will be randomly chosen to be added to sam2.

In [None]:
image_path = r"C:\Users\K3000\Desktop\conversion test\new format\Images\3M0030_3M0030_3M0033_D_11122017_141731.mpg_01475.jpg" # path to image
mask_path = r"C:\Users\K3000\Desktop\conversion test\new format\Masks\3M0030_3M0030_3M0033_D_11122017_141731.mpg_01475.png" # path to mask, the mask will define the image region to segment
num_samples = 10 # number of points/segment to sample

# Switch this to false if you want to see the untrained model results
use_trained_model = True

In [124]:
def read_image(image_path, mask_path): # read and resize image and mask
        img = cv2.imread(image_path)[...,::-1]  # read image
        mask = cv2.imread(mask_path,0)
        r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
        img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)))
        mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)),interpolation=cv2.INTER_NEAREST)
        return img, mask
def get_points(mask,num_points): # Sample points inside the input mask
        points=[]
        for i in range(num_points):
            coords = np.argwhere(mask > 0)
            yx = np.array(coords[np.random.randint(len(coords))])
            points.append([[yx[1], yx[0]]])
        return np.array(points)

In [125]:
# read image and sample points
image,mask = read_image(image_path, mask_path)
input_points = get_points(mask,num_samples)


In [126]:
# Load model you need to have pretrained model already made
sam2_checkpoint = r"C:\Users\K3000\sam2\checkpoints\sam2.1_hiera_tiny.pt" # path to model weight
model_cfg = r"C:\Users\K3000\sam2\sam2\configs\sam2.1\sam2.1_hiera_t.yaml" # model config
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")

# Build net 
predictor = SAM2ImagePredictor(sam2_model)

Here you load your trained model, skip this step to see how sam performs without the training

In [127]:
if use_trained_model:
    predictor.model.load_state_dict(torch.load("model.torch"))

<All keys matched successfully>

In [128]:
# predict mask

with torch.no_grad():
    predictor.set_image(image)
    masks, scores, logits = predictor.predict(
        point_coords=input_points,
        point_labels=np.ones([input_points.shape[0],1])
    )


In [129]:
# Short predicted masks from high to low score

np_masks = np.array(masks[:,0])
np_scores = scores[:,0]
shorted_masks = np_masks[np.argsort(np_scores)][::-1]


In [130]:
np_masks = np.array(masks[:,0].numpy()) if isinstance(masks, torch.Tensor) else np.array(masks[:,0])
np_scores = scores[:,0].float().numpy() if isinstance(scores, torch.Tensor) else np.array(scores[:,0])
shorted_masks = np_masks[np.argsort(np_scores)][::-1]

In [131]:
seg_map = np.zeros_like(shorted_masks[0], dtype=np.uint8)
occupancy_mask = np.zeros_like(shorted_masks[0], dtype=bool)

for i in range(shorted_masks.shape[0]):
    mask = shorted_masks[i]
    if (mask * occupancy_mask).sum() / mask.sum() > 0.15:
        continue
    
    # Convert mask to boolean when needed
    mask_bool = mask.astype(bool)
    
    mask_bool[occupancy_mask] = False  # Set overlapping areas to False in the mask
    seg_map[mask_bool] = i + 1         # Use boolean mask to index seg_map
    occupancy_mask[mask_bool] = True   # Update occupancy_mask

In [132]:
# create colored annotation map
height, width = seg_map.shape

# Create an empty RGB image for the colored annotation
rgb_image = np.zeros((height, width, 3), dtype=np.uint8)
# Map each class number to a random  color


for id_class in range(1,seg_map.max()+1):
    rgb_image[seg_map == id_class] = [np.random.randint(255), np.random.randint(255), np.random.randint(255)]

# save and display

cv2.imwrite("annotation.png",rgb_image)
cv2.imwrite("mix.png",(rgb_image/2+image/2).astype(np.uint8))

cv2.imshow("annotation",rgb_image)
cv2.imshow("mix",(rgb_image/2+image/2).astype(np.uint8))
cv2.imshow("image",image)
cv2.waitKey()

-1