This code is written with the help of the github reference provided in the assignment as it waas allowed in the comment section of GC.

In [5]:
import os
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
from segment_anything import SamPredictor, SamAutomaticMaskGenerator
from segment_anything import sam_model_registry

from gradio_demo.Matcher import Matcher
from matcher.common import utils

from dinov2.models import vision_transformer as vits
import dinov2.utils.utils as dinov2_utils

In [6]:
img_dir = "Images"

def mask_creator(img_size):
    h, w = img_size
    mask_h = int(h *0.3)
    mask_w = int(w * 0.3)
    
    top = (h - mask_h) // 2
    left = (w - mask_w) // 2
    
    mask = np.zeros((h, w), dtype=np.uint8)
    
    mask[top:top+mask_h, left:left+mask_w] = 1
    
    return mask

In [None]:
def load_images(ref_path, target_path):
    transform = transforms.Compose([
        transforms.Resize(size=(518, 518)),
        transforms.ToTensor()
    ])
    
    ref_img = Image.open(ref_path).convert('RGB')
    target_img = Image.open(target_path).convert('RGB')
    
    ref_imgSize = (ref_img.size[1], ref_img.size[0])  
    target_imgSize = (target_img.size[1], target_img.size[0]) 
    
    ref_np = np.array(ref_img)
    
    mask = mask_creator(ref_np.shape[:2])
    
    ref_tensor = transform(ref_img)
    target_tensor = transform(target_img)
    
    ref_mask_t = torch.from_numpy(mask)[None, None, ...].float()
    ref_mask_t = F.interpolate(ref_mask_t, ref_tensor.size()[-2:], mode='nearest') > 0
    
    return {
        "support_img": ref_tensor[None, ...],
        "support_mask": ref_mask_t,
        "query_imgs": target_tensor[None, ...],
        "support_img_ori_size": ref_imgSize,
        "query_imgs_ori_size": [target_imgSize],
    }, ref_np, mask

In [None]:
def model_init(device="cuda:0" if torch.cuda.is_available() else "cpu"):
    sam_checkpoint = "models/sam_vit_h_4b8939.pth"
    model_type = "default"
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)

    dinov2_kwargs = dict(
        img_size=518,
        patch_size=14,
        init_values=1e-5,
        ffn_layer='mlp',
        block_chunks=0,
        qkv_bias=True,
        proj_bias=True,
        ffn_bias=True,
    )
    dinov2 = vits.__dict__["vit_large"](**dinov2_kwargs)
    dinov2_utils.load_pretrained_weights(dinov2, "models/dinov2_vitl14_pretrain.pth", "teacher")
    dinov2.eval()
    dinov2.to(device=device)
    
    return sam, dinov2

In [None]:
def mask_process(data, predictor):
    sup_mask = data['support_mask'].squeeze().cpu().numpy()
    input_points = np.argwhere(sup_mask)  

    center_y, center_x = np.mean(input_points, axis=0)
    input_points = np.array([[center_x, center_y]])
    
    input_label = np.array([1] * len(input_points))
    
    support_img_np = data['support_img'].mul(255).byte()
    support_img_np = support_img_np.squeeze().permute(1,2,0).cpu().numpy()
    
    predictor.reset_image()
    predictor.set_image(support_img_np)
    
    masks, scores, _ = predictor.predict(
        point_coords=input_points,
        point_labels=input_label,
        multimask_output=True
    )
    predictor.reset_image()

    best_mask_idx = np.argmax(scores)
    data['support_mask'] = torch.tensor(masks[best_mask_idx:best_mask_idx+1])[None, ...]
    
    return data

In [None]:
def create_matcher(dinov2, generator, device):
    score_filter_cfg = {
        "emd": 0.0,
        "purity": 0.02,
        "coverage": 0.0,
        "score_filter": True,
        "score": 0.33,
        "score_norm": 0.1,
        "topk_scores_threshold": 0.0
    }
    
    matcher = Matcher(
        encoder=dinov2,
        generator=generator,
        num_centers=8,
        use_box=False,
        use_points_or_centers=True,
        sample_range=(1, 6),
        max_sample_iterations=64,
        alpha=1.0,
        beta=0.0,
        exp=0.0,
        score_filter_cfg=score_filter_cfg,
        num_merging_mask=9,
        device=device
    )
    
    return matcher
    
def single_shot(ref_path, target_path, output_dir="results"):
    os.makedirs(output_dir, exist_ok=True)
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    sam, dinov2 = model_init(device)
    
    predictor = SamPredictor(sam)
    generator = SamAutomaticMaskGenerator(
        sam,
        points_per_side=64,
        points_per_batch=64,
        pred_iou_thresh=0.88,
        stability_score_thresh=0.95,
        stability_score_offset=1.0,
        sel_stability_score_thresh=0.90,
        sel_pred_iou_thresh=0.85,
        box_nms_thresh=0.65,
        sel_output_layer=3,
        output_layer=0,
        dense_pred=True,
        multimask_output=False,
        sel_multimask_output=True,
    )
    
    matcher = create_matcher(dinov2, generator, device)

    data, ref_np, initial_mask = load_images(ref_path, target_path)
    data = mask_process(data, predictor)
    
    with torch.no_grad():
        utils.fix_randseed(0)
        
        support_imgs = data["support_img"].to(device)[None, ...] 
        support_masks = data["support_mask"].to(device) 
        matcher.set_reference(support_imgs, support_masks)
        
        query_img = data["query_imgs"].to(device) 
        query_img_ori_size = data["query_imgs_ori_size"][0]
        matcher.set_target(query_img, query_img_ori_size)
        
        pred_mask, pred_mask_list = matcher.predict()
        matcher.clear()
        
        support_img_ori_size = data['support_img_ori_size']
        ref_mask = data['support_mask'].to(device).float()
        ref_mask = F.interpolate(ref_mask, support_img_ori_size, mode="bilinear", align_corners=False) > 0
        ref_mask = ref_mask.squeeze(0).cpu().numpy()
    
    target_img = Image.open(target_path).convert('RGB')
    target_np = np.array(target_img)
    
    if isinstance(pred_mask, torch.Tensor):
        pred_mask = pred_mask.cpu().numpy()
    pred_mask = np.squeeze(pred_mask)
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    axes[0, 0].imshow(ref_np)
    axes[0, 0].set_title("Reference Image")
    axes[0, 0].axis('off')
    
    axes[0, 1].imshow(ref_np)
    red_initial = np.zeros((*initial_mask.shape, 4))  
    red_initial[initial_mask > 0] = [1, 0, 0, 0.5]  
    axes[0, 1].imshow(red_initial)
    axes[0, 1].set_title("Initial Mask")
    axes[0, 1].axis('off')
    

    axes[1, 0].imshow(target_np)
    axes[1, 0].set_title("Target Image")
    axes[1, 0].axis('off')
    

    axes[1, 1].imshow(target_np)
    red_mask = np.zeros((*pred_mask.shape, 4))  
    red_mask[pred_mask > 0] = [1, 0, 0, 0.7]  
    axes[1, 1].imshow(red_mask)
    axes[1, 1].set_title("Predicted Mask")
    axes[1, 1].axis('off')
    
    plt.tight_layout()
    
    base_name = os.path.basename(ref_path).split('.')[0]
    output_path = os.path.join(output_dir, f"{base_name}_result.png")
    plt.savefig(output_path, bbox_inches='tight')
    plt.close(fig)
    
    return pred_mask

In [None]:
def image_straight(img_dir):
    for subfolder in os.listdir(img_dir):
        subfolder_path = os.path.join(img_dir, subfolder)
        if os.path.isdir(subfolder_path):
            image_files = [f for f in os.listdir(subfolder_path) if f.endswith(('.jpg'))]
            ref_path = os.path.join(subfolder_path, image_files[0])
            target_path = os.path.join(subfolder_path, image_files[1])
            print(f" {subfolder}: {ref_path} -> {target_path}")
            os.makedirs(f"results/{subfolder}", exist_ok=True)
            single_shot(ref_path, target_path, output_dir=f"results/{subfolder}")


image_straight(img_dir)

 backpack: Images/backpack/00.jpg -> Images/backpack/05.jpg


 backpack_dog: Images/backpack_dog/00.jpg -> Images/backpack_dog/01.jpg
 barn: Images/barn/00.jpg -> Images/barn/01.jpg
 bear_plushie: Images/bear_plushie/00.jpg -> Images/bear_plushie/01.jpg
 berry_bowl: Images/berry_bowl/00.jpg -> Images/berry_bowl/01.jpg
 can: Images/can/00.jpg -> Images/can/01.jpg
 candle: Images/candle/00.jpg -> Images/candle/01.jpg
 cat: Images/cat/00.jpg -> Images/cat/01.jpg
 cat2: Images/cat2/00.jpg -> Images/cat2/01.jpg
 cat_statue: Images/cat_statue/00.jpg -> Images/cat_statue/01.jpg
 chair: Images/chair/00.jpg -> Images/chair/01.jpg
 clock: Images/clock/00.jpg -> Images/clock/01.jpg
 colorful_sneaker: Images/colorful_sneaker/00.jpg -> Images/colorful_sneaker/01.jpg
 colorful_teapot: Images/colorful_teapot/00.jpg -> Images/colorful_teapot/01.jpg
 dog: Images/dog/00.jpg -> Images/dog/03.jpg
 dog2: Images/dog2/00.jpg -> Images/dog2/01.jpg
 dog3: Images/dog3/00.jpg -> Images/dog3/01.jpg


In [8]:
def image_reverse(img_dir):
    for subfolder in os.listdir(img_dir):
        subfolder_path = os.path.join(img_dir, subfolder)
        if os.path.isdir(subfolder_path):
            image_files = [f for f in os.listdir(subfolder_path) if f.endswith(('.jpg'))]
            ref_path = os.path.join(subfolder_path, image_files[0])
            target_path = os.path.join(subfolder_path, image_files[1])
            print(f"Processing {subfolder}: {target_path} -> {ref_path}")
            os.makedirs(f"results/{subfolder}", exist_ok=True)
            single_shot(target_path, ref_path, output_dir=f"results/{subfolder}")

image_reverse(img_dir)

Processing backpack: Images/backpack/05.jpg -> Images/backpack/00.jpg


Processing backpack_dog: Images/backpack_dog/01.jpg -> Images/backpack_dog/00.jpg
Processing barn: Images/barn/01.jpg -> Images/barn/00.jpg
Processing bear_plushie: Images/bear_plushie/01.jpg -> Images/bear_plushie/00.jpg
Processing berry_bowl: Images/berry_bowl/01.jpg -> Images/berry_bowl/00.jpg
Processing can: Images/can/01.jpg -> Images/can/00.jpg
Processing candle: Images/candle/01.jpg -> Images/candle/00.jpg
Processing cat: Images/cat/01.jpg -> Images/cat/00.jpg
Processing cat2: Images/cat2/01.jpg -> Images/cat2/00.jpg
Processing cat_statue: Images/cat_statue/01.jpg -> Images/cat_statue/00.jpg
Processing chair: Images/chair/01.jpg -> Images/chair/00.jpg
Processing clock: Images/clock/01.jpg -> Images/clock/00.jpg
Processing colorful_sneaker: Images/colorful_sneaker/01.jpg -> Images/colorful_sneaker/00.jpg
Processing colorful_teapot: Images/colorful_teapot/01.jpg -> Images/colorful_teapot/00.jpg
Processing dog: Images/dog/03.jpg -> Images/dog/00.jpg
Processing dog2: Images/dog2/01.