In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
from torch.utils.data import DataLoader, Dataset
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import glob
import numpy as np
from tqdm import tqdm           
import copy
import random 
import torch.nn.functional as F
from datetime import datetime
import pytz

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"]="1, 2, 3, 4"

# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )
    
plt.rcParams['font.family'] = 'Arial' 



In [None]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def show_mask(mask, ax, random_color=False, borders = True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        import cv2
        contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) 
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))    

def show_masks(image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True):
    for i, (mask, score) in enumerate(zip(masks, scores)):
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        show_mask(mask, plt.gca(), borders=borders)
        if point_coords is not None:
            assert input_labels is not None
            show_points(point_coords, input_labels, plt.gca())
        if box_coords is not None:
            # boxes
            show_box(box_coords, plt.gca())
        if len(scores) > 1:
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')
        plt.show
        
def extract_random_points(binary_image, n):
    dilate_binary_image = cv2.erode(binary_image, None, iterations=1)
    points = np.argwhere(dilate_binary_image > 0)
    points = points[:, ::-1] 
    chosen_indices = np.random.choice(len(points), np.random.randint(1, n + 1), replace=False)
    random_points = points[chosen_indices]
    return random_points  # (n, 2) 
    
def negative_random_point(binary_image, n):
    before_points = cv2.dilate(binary_image, None, iterations=40)
    dilate_binary_image = cv2.dilate(binary_image, None, iterations=25)
    after_points = before_points - dilate_binary_image 
    points = np.argwhere(after_points > 0)
    points = points[:, ::-1]  
    chosen_indices = np.random.choice(len(points), np.random.randint(1, n + 1), replace=False)
    random_points = points[chosen_indices]
    return random_points     
    
def show_masks_together(image, masks, masks_overlay, scores, point_coords=None, box_coords=None, input_labels=None, borders=True, save_path=None):
   
    fig, axs = plt.subplots(2, 2, figsize=(10, 10))  
    axs[0, 0].imshow(image)
    axs[0, 0].axis('off')
    axs[0, 0].set_title("Original Image", fontsize=16)

    axs[0, 1].imshow(masks_overlay)
    axs[0, 1].axis('off')
    axs[0, 1].set_title("Masks Overlay", fontsize=16)

    axs[1, 0].imshow(image)
    if point_coords is not None:
        assert input_labels is not None
        show_points(point_coords, input_labels, axs[1, 0])
    axs[1, 0].axis('off')
    axs[1, 0].set_title("Initial Points Visualization", fontsize=16)

    axs[1, 1].imshow(image)
    for i, mask in enumerate(masks):
        show_mask(mask, axs[1, 1], borders=borders)
        if point_coords is not None:
            show_points(point_coords, input_labels, axs[1, 1])
        if box_coords is not None:
            show_box(box_coords, axs[1, 1])
    axs[1, 1].set_title(f"Mask {i+1}, Score: {scores:.3f}", fontsize=16)
    axs[1, 1].axis('off')

    if save_path is not None:
        plt.savefig(save_path, dpi=200, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

def calculate_f1_score(true_mask, pred_mask):

    true_mask = true_mask.flatten()
    pred_mask = pred_mask.flatten()

    tp = np.sum((true_mask == 1) & (pred_mask == 1))
    fp = np.sum((true_mask == 0) & (pred_mask == 1))
    fn = np.sum((true_mask == 1) & (pred_mask == 0))
    
    precision = tp / (tp + fp) if tp + fp > 0 else 0
    recall = tp / (tp + fn) if tp + fn > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0

    return f1


class CustomDataset(Dataset):
    def __init__(self, image_paths, masks_paths, masks_overlay_paths, transform=None):
        self.image_paths = image_paths
        self.masks_paths = masks_paths
        self.masks_overlay_paths = masks_overlay_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = np.array(Image.open(self.image_paths[idx]))
        mask = np.array(Image.open(self.masks_paths[idx]))
        mask_overlay = np.array(Image.open(self.masks_overlay_paths[idx]))  
        file_name = os.path.basename(self.masks_overlay_paths[idx])
     
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            #(c, w ,h) -> (w, h, c)
            image = augmented['image'].permute(1, 2, 0)
            mask = augmented['mask']
        
        return image, mask, mask_overlay, file_name  

In [None]:
set_seed(42)
data_dir = 'path'
image_dir = os.path.join(data_dir, 'Input', 'png')                       
mask_dir = os.path.join(data_dir, 'Mask', 'extracted')
mask_overlay_dir = os.path.join(data_dir, 'Mask', 'raw')
image_path = np.sort(glob.glob(os.path.join(image_dir,'*.png')))
mask_path = np.sort(glob.glob(os.path.join(data_dir, 'Mask', 'extracted','*.png')))
mask_overlay_path = np.sort(glob.glob(os.path.join(data_dir, 'Mask', 'raw','*.png')))
# print(len(image_path), len(mask_path))
len(image_path)

In [None]:
transform = A.Compose([
    A.RandomGamma(gamma_limit=(80, 120), p=0.2),
    A.RandomBrightnessContrast(brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=1),
    ToTensorV2()
])

_batch_size = 8
train_image_path = image_path[:int(len(image_path) * 0.8)]
train_mask_path = mask_path[:int(len(mask_path) * 0.8)]
dataset = CustomDataset(train_image_path, train_mask_path, mask_overlay_path, transform)
dataloader = DataLoader(dataset, batch_size=_batch_size, shuffle=True)

#model
sam2_checkpoint = "../checkpoints/sam2.1_hiera_small.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
predictor = SAM2ImagePredictor(sam2_model)

In [None]:
torch.cuda.empty_cache()
#Mixed Precision Training
scaler = torch.cuda.amp.GradScaler()

optimizer = torch.optim.AdamW(params=predictor.model.parameters(), lr=0.00001, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.2)

#gradient accumulation
accumulation_steps = 4

#predictor train mode 
predictor.model.sam_mask_decoder.train(True)
predictor.model.sam_prompt_encoder.train(True)
predictor.model.image_encoder.train(False)

num_epochs = 2
epsilon = 1e-6

for epoch in range(num_epochs):
    running_loss = 0.0
    running_dice = 0.0
        
    for images, masks, _, _ in tqdm(dataloader):
        batch_size = images.shape[0]
        batch_loss = 0.0
        batch_dice = 0.0
        
        for i in range(batch_size):
            with torch.cuda.amp.autocast():

                image = images[i].cpu().numpy()
                mask = masks[i].cpu().numpy()
                mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
                _, mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
    
                # 랜덤 포인트 생성
                positive_point = extract_random_points(mask, 5)
                negative_point = negative_random_point(mask, 5)
                negative_input_label = np.zeros((len(negative_point),), dtype=np.uint8)
                input_label = np.ones((len(positive_point),), dtype=np.uint8) 
                _input_point = np.concatenate((positive_point, negative_point), axis=0)
                _input_label = np.concatenate((input_label, negative_input_label), axis=0)

                ## 
                # Parts to be modified -> set_image_batch
                ##
                predictor.set_image(image) 
                
                mask_input, point_ccoordinate, labels, _ = predictor._prep_prompts(
                    _input_point, _input_label, box=None, mask_logits=None, normalize_coords=True
                )
            
                # 프롬프트 인코더를 사용해 희소(sparse) 및 밀집(dense) 임베딩을 생성
                sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
                    points=(point_ccoordinate, labels), boxes=None, masks=None,
                )
            
                # predictor의 고해상도(high resolution) 특징 맵을 추출하고, 이를 리스트로 저장
                # 각 특성 맵에서 마지막 레벨을 선택하여 고해상도 특징을 표현. `unsqueeze(0)`은 배치 차원을 추가
                high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
            
                # SAM 마스크 디코더를 호출하여 마스크 예측(low_res_masks)과 점수(prd_scores)를 생성
                low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
                    image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),  # 마지막 이미지 임베딩에서 배치 차원을 추가하여 전달
                    image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),            # 밀집 위치 인코딩을 가져와서 전달
                    sparse_prompt_embeddings=sparse_embeddings,                             
                    dense_prompt_embeddings=dense_embeddings,                               
                    multimask_output=False,                                                 # 단일 마스크 예측 모드로 설정                
                    repeat_image = 1,
                    high_res_features=high_res_features,                                    # 고해상도 특징 맵을 디코더에 전달
                )
    
                # Dice 계산
                gt_mask = torch.tensor((mask / 255).astype(np.float32), device='cuda', requires_grad=False)
                prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])
                prd_mask = prd_masks[:, 0]
                prd_mask = prd_mask.squeeze(0)
                prd_mask = torch.sigmoid(prd_mask)
                
                inter = torch.sum(prd_mask * gt_mask)
                gt_sum = torch.sum(gt_mask)
                prd_sum = torch.sum(prd_mask)
                dice_coef = (2 * inter) / (gt_sum + prd_sum + epsilon)
                dice_loss = 1 - dice_coef
    
                #ce loss
                seg_loss = - (gt_mask * torch.log(prd_mask + epsilon) + (1 - gt_mask) * torch.log(1 - prd_mask + epsilon)).mean()
    
                #If the last batch is not divisible by accumulation_steps
                if i >= batch_size - batch_size % accumulation_steps:
                    accumulation_steps = batch_size % accumulation_steps or accumulation_steps
                    
                # f1 = calculate_f1_score_tensor(prd_mask, gt_mask)
                # loss = dice_loss + seg_loss
                loss = dice_loss
                
                batch_loss += loss.item()
                batch_dice += dice_coef.item()
                
                loss = loss / accumulation_steps
                
            # Exits ``autocast`` before backward(). 
            # Backward passes under ``autocast`` are not recommended. 
            # Backward ops run in the same ``dtype`` ``autocast`` chose for corresponding forward ops.
            scaler.scale(loss).backward()
            
            # 일정한 Step마다 가중치 업데이트
            if (i + 1) % accumulation_steps == 0 or i == batch_size - 1:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()

        #Batch Loss and Dice Average
        batch_loss /= batch_size
        batch_dice /= batch_size
        running_loss += batch_loss
        running_dice += batch_dice

    print(f"Epoch {epoch+1}, Loss: {running_loss/len(dataloader):.4f}, Dice: {running_dice/len(dataloader):.4f}")
    scheduler.step()

# 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 711/711 [12:32<00:00,  1.06s/it]
# Epoch 1, Loss: 0.3829, Dice: 0.6171
# 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 711/711 [12:27<00:00,  1.05s/it]
# Epoch 2, Loss: 0.2937, Dice: 0.7063

In [None]:
test_image_path = image_path[int(len(image_path) * 0.8):]
test_mask_path = mask_path[int(len(mask_path) * 0.8):]
test_mask_overlay_path = mask_path[int(len(mask_path) * 0.8):]
test_dataset = CustomDataset(test_image_path, test_mask_path, test_mask_overlay_path, transform=False)
test_dataloader = DataLoader(test_dataset, batch_size=_batch_size, shuffle=False)

In [None]:
##
#Parts to be modified
##
torch.save(predictor.model.state_dict(), f'test_model_parameters.torch')
new_sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
new_predictor = SAM2ImagePredictor(new_sam2_model)
new_predictor.model.load_state_dict(torch.load('test_model_parameters.torch'))
predictor.model.eval()
new_predictor.model.eval()

In [None]:
korea_timezone = pytz.timezone('Asia/Seoul')
korea_time = datetime.now(korea_timezone).strftime("%Y-%m-%d-%H:%M:%S ")
save_dir = f'path/{korea_time} N'
if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    
torch.cuda.empty_cache()

f1_list = []

for images, masks, mask_overlays, file_names in tqdm(test_dataloader):
    batch_loss = 0.0 
    batch_dice = 0.0  
    batch_size = images.shape[0]  
    for i in range(batch_size):
        with torch.cuda.amp.autocast():
            image = images[i].cpu().numpy()
            mask = masks[i].cpu().numpy()
            mask_overlay = mask_overlays[i]
            file_name = file_names[i]
            mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
            _, mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
            
            positive_point = extract_random_points(mask, 1)
            negative_point = negative_random_point(mask, 1)
    
            negative_input_label = np.zeros((len(negative_point),), dtype=np.uint8)
            input_label = np.ones((len(positive_point),), dtype=np.uint8)
            _input_point = np.concatenate((positive_point, negative_point), axis=0)
            _input_label = np.concatenate((input_label, negative_input_label), axis=0)
            
            predictor.set_image(image)
            predict_masks, predict_scores, predict_logits = predictor.predict(
            point_coords=_input_point,
            point_labels=_input_label,
            multimask_output=False)
    
            f1 = calculate_f1_score(mask, predict_masks[0])
            show_masks_together(image, predict_masks, mask_overlay, f1, point_coords=_input_point, input_labels=_input_label, borders=True, save_path = f'{save_dir}/{file_name}')
            f1_list.append((file_name, f1))