In [None]:
# Train/Fine Tune SAM 2 on LabPics 1 dataset
# This mode use several images in a single batch
# Labpics can be downloaded from: https://zenodo.org/records/3697452/files/LabPicsV1.zip?download=1

import numpy as np
import torch
from scipy.ndimage import label
import cv2
import os
import tifffile
from pathlib import Path
import shutil
from PIL import Image

from torch.onnx.symbolic_opset11 import hstack

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor


outdir = r"/Volumes/Chris_SSD/AllSegmentations/SAM2_processed/Mito/" 
data_dir=r"/Volumes/Chris_SSD/AllSegmentations/Lab/Mito/Liver/" 

def preprocess_data_for_sam(indir, outdir):
    data=[] # list of files in dataset
    for ff, name in enumerate(os.listdir(os.path.join(indir, "EM"))):
        img_inpath = os.path.join(data_dir, "EM", name)
        if not os.path.exists(img_inpath) or not os.path.exists(os.path.join(data_dir, "Masks", name)):
            continue
        annotation = tifffile.imread(os.path.join(data_dir,"Masks",name)).astype(np.uint8)
        labeled_array, num_features = label(annotation)
        for i in range(1, num_features + 1):
            separate_mask = (labeled_array == i).astype(np.uint8) * 255  # Convert to 0 and 255
            separate_mask = separate_mask[:1024, :1024] # Crop to 1024x1024
            if np.max(separate_mask) < 1:
                continue  
            output_image = Image.fromarray(separate_mask)
            output_image.save(os.path.join(outdir, "annotation", f"{name[:-4]}_{i}.tif"))
            img_path = os.path.join(outdir, "image", name)
            if not Path(img_path).exists():
                img = np.array(cv2.imread(img_inpath), dtype=np.uint8)[:1024, :1024, :]
                cv2.imwrite(img_path, img)
preprocess_data_for_sam(data_dir, outdir)

def read_single(traindir): # read random image and single mask from  the dataset (LabPics)
    annotation_fns = [os.path.join(traindir, "annotation", x) for x in os.listdir(os.path.join(traindir, "annotation"))]

    # select image and annotation
    index  = np.random.randint(len(annotation_fns)) # choose random entry
    mask = cv2.imread(annotation_fns[index]) # read annotation
    img_name = "_".join(os.path.basename(annotation_fns[index]).split("_")[0:-1]) + ".tif"
    Img = cv2.imread(os.path.join(traindir, "image", img_name))[...,::-1]  # read image
    coords = np.argwhere(mask > 0) # get all coordinates in mask
    cv2rect = cv2.boundingRect(coords) # get bounding box of mask
    bbox = [
        cv2rect[0]- np.random.randint(0, 20), 
        cv2rect[1]- np.random.randint(0, 20), 
        cv2rect[0] + cv2rect[2]+ np.random.randint(0, 20), 
        cv2rect[1] + cv2rect[3]+ np.random.randint(0, 20)
    ] # convert to bbox format with random padding
    yx = np.array(coords[np.random.randint(len(coords))]) # choose random point/coordinate
    return Img,mask,[[yx[1], yx[0]]],bbox

def read_batch(traindir,batch_size=4):
    limage = []
    lmask = []
    linput_point = []
    lbbox = []
    for i in range(batch_size):
        image,mask,input_point,bbox = read_single(traindir)
        limage.append(image)
        lmask.append(mask)
        linput_point.append(input_point)
        lbbox.append(bbox)

    return limage, np.array(lmask), np.array(linput_point),  np.ones([batch_size,1]), np.array(lbbox)


In [None]:
# Load model
sam2_checkpoint = "sam2_hiera_small.pt" # path to model weight
model_cfg = "sam2_hiera_s.yaml" #  model config
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") # load model
predictor = SAM2ImagePredictor(sam2_model)

# Set training parameters

predictor.model.sam_mask_decoder.train(True) # enable training of mask decoder
predictor.model.sam_prompt_encoder.train(True) # enable training of prompt encoder
predictor.model.image_encoder.train(True) # enable training of image encoder: For this to work you need to scan the code for "no_grad" and remove them all
optimizer=torch.optim.AdamW(params=predictor.model.parameters(),lr=1e-5,weight_decay=4e-5)
scaler = torch.cuda.amp.GradScaler() # mixed precision

# Training loop

for itr in range(100000):
    with torch.cuda.amp.autocast(): # cast to mix precision
            image,mask,input_point, input_label = read_batch(outdir,batch_size=4) # load data batch
            if mask.shape[0]==0: continue # ignore empty batches
            predictor.set_image_batch(image) # apply SAM image encoder to the image
            # predictor.get_image_embedding()
            # prompt encoding

            mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=None, mask_logits=None, normalize_coords=True)
            sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(unnorm_coords, labels),boxes=None,masks=None,)

            # mask decoder

            high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
            low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(image_embeddings=predictor._features["image_embed"],image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),sparse_prompt_embeddings=sparse_embeddings,dense_prompt_embeddings=dense_embeddings,multimask_output=True,repeat_image=False,high_res_features=high_res_features,)
            prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])# Upscale the masks to the original image resolution

            # Segmentaion Loss caclulation

            gt_mask = torch.tensor(mask.astype(np.float32)).cuda()
            prd_mask = torch.sigmoid(prd_masks[:, 0])# Turn logit map to probability map
            seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean() # cross entropy loss

            # Score loss calculation (intersection over union) IOU

            inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
            iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
            score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
            loss=seg_loss+score_loss*0.05  # mix losses

            # apply back propogation

            predictor.model.zero_grad() # empty gradient
            scaler.scale(loss).backward()  # Backpropogate
            scaler.step(optimizer)
            scaler.update() # Mix precision

            if itr%1000==0: torch.save(predictor.model.state_dict(), "model.torch") # save model

            # Display results

            if itr==0: mean_iou=0
            mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())
            print("step)",itr, "Accuracy(IOU)=",mean_iou)