# Training of SAM2 with conversion of Dataset
[Source for Training part](https://github.com/sagieppel/fine-tune-train_segment_anything_2_in_60_lines_of_code)


In [41]:
import numpy as np
import torch
import cv2
import os
import json
from PIL import Image, ImageDraw
import shutil
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

Set the paths of the labeled data and set the formatted dataset

In [42]:
labeled_dataset_path = r"C:\Users\K3000\Desktop\conversion test\old format"
formatted_dataset_path = r"C:\Users\K3000\Desktop\conversion test\new format"

## Conversion
This step converts the labeled dataset into the correct data format, if already formatted dataset exists, this step can be skipped.

In [43]:
def generate_masks(dataset_path, output_path):
    # Create new directory structure
    images_output_path = os.path.join(output_path, "Images")
    masks_output_path = os.path.join(output_path, "Masks")
    os.makedirs(images_output_path, exist_ok=True)
    os.makedirs(masks_output_path, exist_ok=True)

    # Traverse the directory structure
    for video_folder in os.listdir(dataset_path):
        video_path = os.path.join(dataset_path, video_folder)
        if not os.path.isdir(video_path):
            continue

        # Locate the JSON file
        json_file = [f for f in os.listdir(video_path) if f.endswith('.json')]
        if len(json_file) != 1:
            print(f"JSON file not found or multiple JSONs in {video_path}")
            continue
        json_path = os.path.join(video_path, json_file[0])

        # Parse the JSON
        with open(json_path, 'r') as f:
            annotations = json.load(f)

        # Locate the source images directory
        images_path = os.path.join(video_path, "source images")
        if not os.path.exists(images_path):
            print(f"Source images directory not found in {video_path}")
            continue

        # Process images and masks
        for image_file in os.listdir(images_path):
            if not image_file.endswith('.jpg'):
                continue

            # Image number based on filename
            image_number = str(int(os.path.splitext(image_file)[0]))

            # Get corresponding annotations
            if image_number not in annotations:
                print(f"No annotations for {image_file}")
                continue

            # Read the image dimensions
            image_path = os.path.join(images_path, image_file)
            with Image.open(image_path) as img:
                width, height = img.size

            # Copy the image to the new directory
            image_output_path = os.path.join(images_output_path, f"{video_folder}_{image_file}")
            shutil.copy(image_path, image_output_path)

            # Create a blank mask
            mask = Image.new("L", (width, height), 0)

            # Draw polygons onto the mask
            draw = ImageDraw.Draw(mask)
            observations = annotations[image_number]["Observations"]

            for category_idx, observation in enumerate(observations.values(), start=1):
                for mask_polygon in observation.get("Mask Polygon", []):
                    # Extract points and draw each polygon
                    points = [(int(point[0]), int(point[1])) for point in mask_polygon]
                    if len(points) > 2:  # Ensure it's a valid polygon
                        draw.polygon(points, fill=category_idx)

            # Save the mask
            mask_output_path = os.path.join(masks_output_path, f"{video_folder}_{int(image_number):05d}.png")
            mask.save(mask_output_path)
            print(f"Saved image: {image_output_path}, mask: {mask_output_path}")



generate_masks(labeled_dataset_path, formatted_dataset_path)

Saved image: C:\Users\K3000\Desktop\conversion test\new format\Images\01850100.MPG_02550.jpg, mask: C:\Users\K3000\Desktop\conversion test\new format\Masks\01850100.MPG_02550.png
Saved image: C:\Users\K3000\Desktop\conversion test\new format\Images\01850100.MPG_02575.jpg, mask: C:\Users\K3000\Desktop\conversion test\new format\Masks\01850100.MPG_02575.png
Saved image: C:\Users\K3000\Desktop\conversion test\new format\Images\01850100.MPG_02600.jpg, mask: C:\Users\K3000\Desktop\conversion test\new format\Masks\01850100.MPG_02600.png
Saved image: C:\Users\K3000\Desktop\conversion test\new format\Images\01850100.MPG_02625.jpg, mask: C:\Users\K3000\Desktop\conversion test\new format\Masks\01850100.MPG_02625.png
Saved image: C:\Users\K3000\Desktop\conversion test\new format\Images\01850100.MPG_02650.jpg, mask: C:\Users\K3000\Desktop\conversion test\new format\Masks\01850100.MPG_02650.png
Saved image: C:\Users\K3000\Desktop\conversion test\new format\Images\01850100.MPG_02675.jpg, mask: C:\Us

JSONDecodeError: Expecting ',' delimiter: line 8117 column 32 (char 238056)

## Loading Data

In [44]:
data=[] # list of files in dataset
for ff, name in enumerate(os.listdir(os.path.join(formatted_dataset_path, "Images"))):  # go over all folder annotation
    data.append({"image":os.path.join(formatted_dataset_path, "Images", name) , "annotation": os.path.join(formatted_dataset_path, "Masks", name[:-4] + ".png")})

Show if data is loaded correctly

In [45]:
data[0]

{'image': 'C:\\Users\\K3000\\Desktop\\conversion test\\new format\\Images\\01850100.MPG_02550.jpg',
 'annotation': 'C:\\Users\\K3000\\Desktop\\conversion test\\new format\\Masks\\01850100.MPG_02550.png'}

## Training

In [46]:
def read_batch(data): # read random image and its annotation from  the dataset

   #  select image

        ent  = data[np.random.randint(len(data))] # choose random entry
        Img = cv2.imread(ent["image"])[...,::-1]  # read image
        ann_map = cv2.imread(ent["annotation"]) # read annotation
   # resize image

        r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]]) # scalling factor
        Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))
        ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)),interpolation=cv2.INTER_NEAREST)

   # merge vessels and materials annotations

        mat_map = ann_map[:,:,0] # material annotation map
        ves_map = ann_map[:,:,2] # vessel  annotaion map
        mat_map[mat_map==0] = ves_map[mat_map==0]*(mat_map.max()+1) # merged map

   # Get binary masks and points

        inds = np.unique(mat_map)[1:] # load all indices
        points= []
        masks = [] 
        for ind in inds:
            mask=(mat_map == ind).astype(np.uint8) # make binary mask
            masks.append(mask)
            coords = np.argwhere(mask > 0) # get all coordinates in mask
            yx = np.array(coords[np.random.randint(len(coords))]) # choose random point/coordinate
            points.append([[yx[1], yx[0]]])
        return Img,np.array(masks),np.array(points), np.ones([len(masks),1])

Select, which model shall be trained

In [47]:
sam2_checkpoint = r"C:\Users\K3000\sam2\checkpoints\sam2.1_hiera_tiny.pt" # path to model weight
model_cfg = r"C:\Users\K3000\sam2\sam2\configs\sam2.1\sam2.1_hiera_t.yaml" # model config
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") # load model
predictor = SAM2ImagePredictor(sam2_model) # load net

In [48]:
predictor.model.sam_mask_decoder.train(True) # enable training of mask decoder 
predictor.model.sam_prompt_encoder.train(True) # enable training of prompt encoder

PromptEncoder(
  (pe_layer): PositionEmbeddingRandom()
  (point_embeddings): ModuleList(
    (0-3): 4 x Embedding(1, 256)
  )
  (not_a_point_embed): Embedding(1, 256)
  (mask_downscaling): Sequential(
    (0): Conv2d(1, 4, kernel_size=(2, 2), stride=(2, 2))
    (1): LayerNorm2d()
    (2): GELU(approximate='none')
    (3): Conv2d(4, 16, kernel_size=(2, 2), stride=(2, 2))
    (4): LayerNorm2d()
    (5): GELU(approximate='none')
    (6): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))
  )
  (no_mask_embed): Embedding(1, 256)
)

In [49]:
optimizer=torch.optim.AdamW(params=predictor.model.parameters(),lr=1e-5,weight_decay=4e-5)

In [50]:
scaler = torch.cuda.amp.GradScaler() # set mixed precision

In [None]:
for itr in range(100000):
    with torch.cuda.amp.autocast(): # cast to mix precision
        #with torch.cuda.amp.autocast():
            image,mask,input_point, input_label = read_batch(data) # load data batch
            if mask.shape[0]==0: continue # ignore empty batches
            predictor.set_image(image) # apply SAM image encodet to the image

            # prompt encoding

            mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)
            sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(unnorm_coords, labels),boxes=None,masks=None,)

            # mask decoder

            batched_mode = unnorm_coords.shape[0] > 1 # multi object prediction
            high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
            low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=True,repeat_image=batched_mode,high_res_features=high_res_features,)
            prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])# Upscale the masks to the original image resolution

            # Segmentaion Loss caclulation

            gt_mask = torch.tensor(mask.astype(np.float32)).cuda()
            prd_mask = torch.sigmoid(prd_masks[:, 0])
            seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean()

            # Score loss calculation (intersection over union) IOU

            inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
            iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
            score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
            loss=seg_loss+score_loss*0.05  # mix losses

            # apply back propogation

            predictor.model.zero_grad() # empty gradient
            scaler.scale(loss).backward()  # Backpropogate
            scaler.step(optimizer)
            scaler.update() # Mix precision

            if itr%1000==0: torch.save(predictor.model.state_dict(), "model.torch") # save model

            # Display results

            if itr==0: 
                mean_iou=0
                
            mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())
            
            print("step", itr, "Accuracy(IOU)=",mean_iou)