In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
from torch.utils.data import DataLoader, Dataset
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
import glob
import numpy as np
from tqdm import tqdm           
import copy
import random 
from sklearn.metrics import f1_score
import pandas as pd
import seaborn as sns
import time
from datetime import datetime
import pytz

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"]="1, 2, 3, 4"

available_fonts = fm.findSystemFonts(fontpaths=None, fontext='ttf')
if any('Arial' in font for font in available_fonts):
    plt.rcParams['font.family'] = 'Arial'
else:
    plt.rcParams['font.family'] = 'sans-serif'

# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
    
device = 'cuda:1'
print(f"using device: {device}")

In [None]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def show_mask(mask, ax, random_color=False, borders = True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.astype(np.uint8)
    mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    if borders:
        contours, _ = cv2.findContours(mask,cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
        # Try to smooth contours
        contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
        mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) 
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   

#Random positive points
def positive_random_point(binary_image, n, random_num=False):
    dilate_binary_image = cv2.erode(binary_image, None, iterations=1)
    points = np.argwhere(dilate_binary_image > 0)
    points = points[:, ::-1] 
    if random_num :
        chosen_indices = np.random.choice(len(points), np.random.randint(1, n + 1), replace=False)
    else :
        chosen_indices = np.random.choice(len(points), n, replace=False)
    random_points = points[chosen_indices]
    return random_points  # (n, 2) 

#Random nagative foint
def negative_random_point(binary_image, n, random_num=False):
    before_points = cv2.dilate(binary_image, None, iterations=30)
    dilate_binary_image = cv2.dilate(binary_image, None, iterations=15)
    after_points = before_points - dilate_binary_image 
    points = np.argwhere(after_points > 0)
    points = points[:, ::-1]  
    if random_num :
        chosen_indices = np.random.choice(len(points), np.random.randint(1, n + 1), replace=False)
    else :
        chosen_indices = np.random.choice(len(points), n, replace=False)
    random_points = points[chosen_indices]
    return random_points  

def show_masks_together(image, masks, masks_overlay, scores, point_coords=None, box_coords=None, input_labels=None, borders=True, save_path=None):
   
    # 2x2 subplot 
    fig, axs = plt.subplots(2, 2, figsize=(10, 10))  

    #Original Image
    axs[0, 0].imshow(image)
    axs[0, 0].axis('off')
    axs[0, 0].set_title("Original Image", fontsize=16)

    #Mask overlay
    axs[0, 1].imshow(masks_overlay)
    axs[0, 1].axis('off')
    axs[0, 1].set_title("Masks Overlay", fontsize=16)

    #Image and points
    axs[1, 0].imshow(image)
    if point_coords is not None:
        assert input_labels is not None
        show_points(point_coords, input_labels, axs[1, 0])
    axs[1, 0].axis('off')
    axs[1, 0].set_title("Initial Points Visualization", fontsize=16)

    #Show mask
    axs[1, 1].imshow(image)
    for i, mask in enumerate(masks):
        show_mask(mask, axs[1, 1], borders=borders)
        if point_coords is not None:
            show_points(point_coords, input_labels, axs[1, 1])
        if box_coords is not None:
            show_box(box_coords, axs[1, 1])
    axs[1, 1].set_title(f"Mask {i+1}, Score: {scores:.3f}", fontsize=16)
    axs[1, 1].axis('off')

    if save_path is not None:
        plt.savefig(save_path, dpi=200, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

def calculate_dice_score(pred_mask, gt_mask):

    pred_mask_binary = (pred_mask > 0.5).astype(np.uint8)  
    gt_mask_binary = (gt_mask > 0.5).astype(np.uint8)      
    dice_score = f1_score(gt_mask_binary.flatten(), pred_mask_binary.flatten())

    return dice_score

def visualize_images(images, titles, cmap='gray', figsize=(10, 5)):
    fig, axes = plt.subplots(1, len(images), figsize=figsize)  
    for i, (image, title) in enumerate(zip(images, titles)):
        axes[i].imshow(image.detach().cpu().numpy(), cmap=cmap)  
        axes[i].axis('off')  
        axes[i].set_title(title)
    plt.tight_layout()
    plt.show()

prd_titles = [f"Pred {i+1}" for i in range(4)]  
visualize_images(prd_masks[:4], prd_titles)

gt_titles = [f"Ground Truth {i+1}" for i in range(4)]  
visualize_images(gt_mask[:4], gt_titles)

In [None]:
transform = A.Compose([
    #Brightness and Contrast Control
    A.RandomGamma(gamma_limit=(80, 120), p=0.2),
    A.RandomBrightnessContrast(brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=0.6),
    ToTensorV2()
])

class CustomDataset(Dataset):
    def __init__(self, image_paths, masks_paths, masks_overlay_paths, transform=None):
        self.image_paths = image_paths
        self.masks_paths = masks_paths
        self.masks_overlay_paths = masks_overlay_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = np.array(Image.open(self.image_paths[idx]))
        mask = np.array(Image.open(self.masks_paths[idx]))
        mask_overlay = np.array(Image.open(self.masks_overlay_paths[idx]))  
        file_name = os.path.basename(self.masks_overlay_paths[idx])
     
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            #(c, w ,h) -> (w, h, c)
            image = augmented['image'].permute(1, 2, 0)
            mask = augmented['mask']
        
        return image, mask, mask_overlay, file_name  

In [None]:
set_seed(42)
data_dir = 'path'
image_dir = os.path.join(data_dir, 'Input', 'png')                       
mask_dir = os.path.join(data_dir, 'Mask', 'extracted')
mask_overlay_dir = os.path.join(data_dir, 'Mask', 'raw')
image_path = np.sort(glob.glob(os.path.join(image_dir,'*.png')))
mask_path = np.sort(glob.glob(os.path.join(data_dir, 'Mask', 'extracted','*.png')))
mask_overlay_path = np.sort(glob.glob(os.path.join(data_dir, 'Mask', 'raw','*.png')))
# print(len(image_path), len(mask_path))
len(image_path)

In [None]:
# Train, validation, test dataset split
_batch_size = 4

train_split = 0.7
valid_split = 0.1
test_split = 0.2

train_image_path = image_path[:int(len(image_path) * train_split)]
train_mask_path = mask_path[:int(len(mask_path) * train_split)]

valid_image_path = image_path[int(len(image_path) * train_split):int(len(image_path) * (train_split + valid_split))]
valid_mask_path = mask_path[int(len(mask_path) * train_split):int(len(mask_path) * (train_split + valid_split))]

test_image_path = image_path[int(len(image_path) * (train_split + valid_split)):]
test_mask_path = mask_path[int(len(mask_path) * (train_split + valid_split)):]
test_mask_overlay_path = mask_path[int(len(mask_path) * (train_split + valid_split)):]

train_dataset = CustomDataset(train_image_path, train_mask_path, mask_overlay_path, transform)
train_dataloader = DataLoader(train_dataset, batch_size=_batch_size, shuffle=True)

valid_dataset = CustomDataset(valid_image_path, valid_mask_path, mask_overlay_path, transform=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=_batch_size, shuffle=False)

test_dataset = CustomDataset(test_image_path, test_mask_path, test_mask_overlay_path, transform=False)
test_dataloader = DataLoader(test_dataset, batch_size=_batch_size, shuffle=False)

In [None]:
#model
sam2_checkpoint = "../checkpoints/sam2.1_hiera_base_plus.pt"
model_cfg = "../sam2/configs/sam2.1/sam2.1_hiera_b+.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
predictor = SAM2ImagePredictor(sam2_model)
sam2_model.sam_mask_decoder.train(True)
sam2_model.sam_prompt_encoder.train(True)
sam2_model.image_encoder.train(False)
print('predictor.model.sam_mask_decoder.training :', predictor.model.sam_mask_decoder.training)  
print('predictor.model.sam_prompt_encoder.training :', predictor.model.sam_prompt_encoder.training)  
print('predictor.model.image_encoder.training :', predictor.model.image_encoder.training)  

In [None]:
torch.cuda.empty_cache()
scaler = torch.cuda.amp.GradScaler()
optimizer = torch.optim.AdamW(params=predictor.model.parameters(), lr=1e-5, weight_decay=1e-4)
num_epochs = 50
epsilon = 1e-6
train_loss_history = []
train_dice_history = []
val_f1_history = []

for epoch in range(num_epochs):
    epoch_loss = 0
    epoch_dice = 0
    non_data = 0
    epoch_start_time = time.time()
    
    with tqdm(train_dataloader, desc=f"🚀 Epoch {epoch+1}/{num_epochs} - Training", unit="batch", colour="green", dynamic_ncols=True) as tbar:
        for images, masks, _, _ in tbar:
            
            batch_size = images.shape[0]
            batch_image = []
            batch_mask = []
            batch_point = []
            batch_label = []
            
            # Skip processing if the current batch size is smaller than the defined batch size    
            if images.shape[0] < _batch_size : 
                non_data = images.shape[0]
                break
                
            for i in range(batch_size):
                
                image = images[i].cpu().numpy()
                mask = masks[i].cpu().numpy()
                mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
                _, mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
        
                # Generate random positive and negative points for the mask
                positive_point = positive_random_point(mask, 1)
                negative_point = negative_random_point(mask, 1)
                negative_input_label = np.zeros((len(negative_point),), dtype=np.uint8)
                input_label = np.ones((len(positive_point),), dtype=np.uint8) 
                _input_point = np.concatenate((positive_point, negative_point), axis=0)
                _input_label = np.concatenate((input_label, negative_input_label), axis=0)
        
                batch_image.append(image)
                batch_mask.append(mask)
                batch_point.append(_input_point)
                batch_label.append(_input_label)
            
            with torch.cuda.amp.autocast():
                predictor.set_image_batch(batch_image)
                
                # predictor._prep_prompts
                # 예측 과정에서 사용할 포인트와 라벨을 준비.
                # normalize_coords: 좌표를 정규화하여 모델 입력 크기에 맞게 조정.
                _, point_ccoordinate, labels, _ = predictor._prep_prompts(
                np.array(batch_point), np.array(batch_label), box=None, mask_logits=None, normalize_coords=True
                )
                
                # sparse_embeddings, dense_embeddings
                # SAM2의 Prompt Encoder를 사용하여 포인트와 라벨을 기반으로 sparse와 dense 프롬프트 임베딩을 생성.
                # sparse_embeddings: 포인트 기반으로 생성된 sparse 임베딩.
                # dense_embeddings: 전체 이미지의 dense 임베딩(포인트의 주변 영역 포함).
                sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
                points=(point_ccoordinate, labels), boxes=None, masks=None,
                )
                
                # high_res_features
                # SAM2의 예측에 필요한 고해상도 특징맵을 생성.
                # high_res_features는 predictor._features에서 "high_res_feats"라는 키를 사용하여 추출된 고해상도 피처맵을 가져옴
                high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
                
                # low_res_masks, prd_scores
                # SAM2의 Mask Decoder를 호출하여 저해상도 마스크(low_res_masks)와 점수(prd_scores)를 예측합니다.
                # image_embeddings: image encoding info.
                # image_pe: image positional encoding info.
                # sparse_prompt_embeddings: sparse embeddings.
                # dense_prompt_embeddings: dense embeddings.
                # high_res_features: 고해상도 피처맵 정보.
                low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
                image_embeddings=predictor._features["image_embed"],
                image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
                repeat_image=False,
                high_res_features=high_res_features,
                )
                    
                #Post-process predicted masks
                prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])
                
                #[batch, channel, w, h] - > [batch, channel, w, h]
                prd_masks = prd_masks[:, 0].squeeze(0)
                
                #to use sigmoid logit 
                prd_masks = torch.sigmoid(prd_masks)
                prd_masks = prd_masks.squeeze(1)

                #Ground truth masks
                gt_mask = torch.tensor((np.array(batch_mask) / 255).astype(np.float32), device=device)
                
                #To use dice loss function
                intersection = torch.sum(gt_mask * prd_masks, dim=(1, 2))  
                total = torch.sum(gt_mask, dim=(1, 2)) + torch.sum(prd_masks, dim=(1, 2))  
                dice = (2. * intersection + epsilon) / (total + epsilon)
                loss = 1 - dice 
                loss = loss.mean()
                epoch_loss += loss.item() 
                epoch_dice += dice.mean().item()

            #Backpropagation
            predictor.model.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

    #Time check
    epoch_end_time = time.time()
    epoch_duration = epoch_end_time - epoch_start_time
    
    epoch_loss = epoch_loss/(len(train_dataloader) - non_data)
    epoch_dice = epoch_dice/(len(train_dataloader) - non_data)
    train_loss_history.append(epoch_loss)
    train_dice_history.append(epoch_dice)

    print(f"Epoch {epoch+1}, Loss: {epoch_loss:.4f}, Dice: {epoch_dice:.4f}")
    print(f"Epoch {epoch+1} Processing time: {epoch_duration:.2f} seconds.")
    
    if epoch % 5 == 0 or epoch == num_epochs - 1  :
        #evaluate mode -> Train = False
        sam2_model.eval()
        
        val_dice_list = []
        
        with tqdm(valid_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation", unit="batch") as vbar:
            for images, masks, mask_overlays, file_names in vbar:
                batch_size = images.shape[0]  
                with torch.no_grad():
                    for i in range(batch_size):
                        with torch.no_grad():
                            image = images[i].cpu().numpy()
                            mask = masks[i].cpu().numpy()
                            mask_overlay = mask_overlays[i]
                            file_name = file_names[i]
                            mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
                            _, mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
                            
                            positive_point = positive_random_point(mask, 2)
                            positive_point_label = np.ones((len(positive_point),), dtype=np.uint8)

                            # Generate random negative points (use if needed)
                            # negative_point = negative_random_point(mask, 1)
                            # negative_point_label = np.zeros((len(negative_point),), dtype=np.uint8)
                            
                            # _input_point = np.concatenate((positive_point, negative_point), axis=0)
                            # _input_label = np.concatenate((positive_point_label, negative_point_label), axis=0)
                            
                            with torch.cuda.amp.autocast():
                                predictor.set_image(image)
                                predict_masks, predict_scores, predict_logits = predictor.predict(
                                point_coords=positive_point,
                                point_labels=positive_point_label,
                                multimask_output=False)
                                
                    val_dice_score = calculate_dice_score(mask, predict_masks[0])
                    val_dice_list.append(val_dice_score)
                    
        val_f1_history.append(np.mean(val_dice_list))
        
        print(f"Validation F1 Score: {np.mean(val_dice_list):.4f}, Median : {np.median(val_dice_list):.4f}")
                        
        # Set the model components to training mode before switching to evaluation mode
        sam2_model.sam_mask_decoder.train(True)
        sam2_model.sam_prompt_encoder.train(True)
        sam2_model.image_encoder.train(False)
       
#🚀 Epoch 1/50 - Training: 100%|██████████████████████████████████████████████████████████████████████▉| 1242/1243 [08:28<00:00,  2.44batch/s]
#Epoch 1, Loss: 0.2385, Dice: 0.7631
#Epoch 1 Processing time: 508.72 seconds.
#🚀 Epoch 2/50 - Training: 100%|██████████████████████████████████████████████████████████████████████▉| 1242/1243 [06:24<00:00,  3.23batch/s]
#Epoch 2, Loss: 0.1881, Dice: 0.8135
#Epoch 2 Processing time: 384.01 seconds.
#🚀 Epoch 3/50 - Training: 100%|██████████████████████████████████████████████████████████████████████▉| 1242/1243 [06:23<00:00,  3.24batch/s]
#Epoch 3, Loss: 0.1735, Dice: 0.8281
#Epoch 3 Processing time: 383.79 seconds.

In [None]:
#Save model
model_type = 'float16' if sam2_model.parameters().__next__().dtype == torch.float16 else 'float32'
torch.save({'model_state_dict': sam2_model.state_dict(), 'model_type': model_type}, 'sam2_model.torch')

In [None]:
#Loss 
plt.figure(figsize=(10, 5))
plt.plot(train_loss_history, label='Training Loss', color='blue')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.legend()
plt.grid(True)
plt.show()

# Dice 
plt.figure(figsize=(10, 5))
plt.plot(train_dice_history, label='Training Dice', color='green')
plt.plot(val_f1_history, label='Validation F1 Score', color='orange')
plt.xlabel('Epochs')
plt.ylabel('Dice / F1 Score')
plt.title('Training Dice and Validation F1 Score Over Epochs')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
#load model
checkpoint = torch.load('sam2_model.torch')
new_sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
new_sam2_model.load_state_dict(checkpoint['model_state_dict'])
new_predictor = SAM2ImagePredictor(new_sam2_model)
new_predictor.model.eval()
torch.cuda.empty_cache()

In [None]:
# To organize results in a directory named with the current timestamp
korea_timezone = pytz.timezone('Asia/Seoul')
korea_time = datetime.now(korea_timezone).strftime("%Y-%m-%d-%H:%M:%S ")
save_dir = f'results-local/positive_random/{korea_time} N'
if not os.path.exists(save_dir):
        os.mkdir(save_dir)

test_file_name = []
test_dice_scores = [] 

# Use tqdm with a dynamic update of the Dice score
with tqdm(test_dataloader, desc="🚀 Testing ", unit="batch", colour="blue") as tbar:
    for images, masks, mask_overlays, file_names in tbar:
        batch_loss = 0.0
        batch_dice = 0.0  
        batch_size = images.shape[0]  

        with torch.no_grad():
            for i in range(batch_size):
                image = images[i].cpu().numpy()
                mask = masks[i].cpu().numpy()
                mask_overlay = mask_overlays[i]
                file_name = file_names[i]
                mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
                _, mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)

                # Generate random positive points
                positive_point = positive_random_point(mask, 2)
                positive_point_label = np.ones((len(positive_point),), dtype=np.uint8)

                # Generate random negative points (use if needed)
                # negative_point = negative_random_point(mask, 1)
                # negative_point_label = np.zeros((len(negative_point),), dtype=np.uint8)
                
                # _input_point = np.concatenate((positive_point, negative_point), axis=0)
                # _input_label = np.concatenate((positive_point_label, negative_point_label), axis=0)
                
                with torch.cuda.amp.autocast():
                    new_predictor.set_image(image)
                    predict_masks, predict_scores, predict_logits = new_predictor.predict(
                        point_coords=positive_point,
                        point_labels=positive_point_label,
                        multimask_output=False)
                
                test_dice = calculate_dice_score(mask, predict_masks[0])
                show_masks_together(image, predict_masks, mask_overlay, test_dice, point_coords=positive_point, input_labels=positive_point_label, borders=True, save_path=f'{save_dir}/{file_name}')
                
                test_dice_scores.append(test_dice)
                test_file_name.append(file_name)

            # Calculate the mean Dice score for the batch
            mean_dice = np.mean(test_dice_scores)
            tbar.set_postfix(mean_dice=mean_dice)

In [None]:
df = pd.DataFrame()
df['file_name'] = [f_name for f_name in test_file_name]
df['fine_tune'] = [d_score for d_score in test_dice_scores]

plt.figure(figsize=(5,7))
sns.boxplot(df['fine_tune'],showmeans=True,
            meanprops={'marker':'o',
                       'markerfacecolor':'white', 
                       'markeredgecolor':'black',
                       'markersize':'5'})

plt.title(f"f1-score comparison (mean: {np.mean(df['fine_tune']):.2f})")