In [None]:
import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from collections import defaultdict

from custom_dataset import IND2CLASS
from SAM2UNet import SAM2UNet

# --- [RLE Ïù∏ÏΩîÎî© Ìï®Ïàò] ---
def encode_mask_to_rle(mask):
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

# --- [Îç∞Ïù¥ÌÑ∞ÏÖã ÌÅ¥ÎûòÏä§] ---
class XRayInferenceDataset(Dataset):
    def __init__(self, image_root, transforms=None):
        self.image_root = image_root
        
        # 1. ÌååÏùº Î¶¨Ïä§Ìä∏ Ï∂îÏ∂ú
        pngs = {
            os.path.relpath(os.path.join(root, fname), start=image_root)
            for root, _dirs, files in os.walk(image_root)
            for fname in files
            if os.path.splitext(fname)[1].lower() == ".png"
        }
        self.filenames = np.array(sorted(list(pngs)))
        self.transforms = transforms
        
        # 2. Hand Side (Left/Right) ÌåêÎ≥Ñ Î°úÏßÅ
        # Í∞ôÏùÄ Ìè¥Îçî ÎÇ¥ ÌååÏùºÎ™Ö Ï†ïÎ†¨ Ïãú: Ï≤´ Î≤àÏß∏=Right, Îëê Î≤àÏß∏=Left Í∞ÄÏ†ï
        self.hand_side_map = {}
        files_by_folder = defaultdict(list)
        for fname in self.filenames:
            folder = os.path.dirname(fname)
            files_by_folder[folder].append(fname)
            
        for folder, files in files_by_folder.items():
            files.sort()
            if len(files) > 0: self.hand_side_map[files[0]] = 'Right'
            if len(files) > 1: self.hand_side_map[files[1]] = 'Left'
    
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, item):
        image_name = self.filenames[item]
        image_path = os.path.join(self.image_root, image_name)
        
        # 1. Ïù¥ÎØ∏ÏßÄ Î°úÎìú (Gray)
        image_gray = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
        
        # 2. [Ï†ÑÏ≤òÎ¶¨] Right -> LeftÎ°ú Ï¢åÏö∞ Î∞òÏ†Ñ (Î™®Îç∏ÏùÄ LeftÎßå ÌïôÏäµÌñàÏúºÎØÄÎ°ú)
        hand_side = self.hand_side_map.get(image_name, 'Unknown')
        if hand_side == 'Right':
            image_gray = cv2.flip(image_gray, 1)
        
        # 3. 3Ï±ÑÎÑê Î≥ÄÌôò Î∞è Ï†ïÍ∑úÌôî
        image = np.stack([image_gray]*3, axis=-1)
        image = image / 255.0
        image = image.astype(np.float32)
        
        if self.transforms:
            image = self.transforms(image=image)["image"]

        # 4. To Tensor (H, W, C) -> (C, H, W)
        if isinstance(image, np.ndarray):
            image = image.transpose(2, 0, 1)
            image = torch.from_numpy(image).float()
            
        return image, image_name

In [None]:
IMAGE_ROOT = "../data/test/DCM"
CHECKPOINT_PATH = "../sam2_unet_result_checkpoints/experiment20.pth"
HIERA_PATH = "../checkpoints/sam2_hiera_large.pt"
SAVE_PATH = "result_full_size.csv"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 1

In [None]:
def test_full_size(model, data_loader, thr=0.5):
    model.eval()
    rles = []
    filename_and_class = []
    dataset = data_loader.dataset
    
    print("üöÄ [Full Size] Inference Start...")
    
    with torch.no_grad():
        for images, image_names in tqdm(data_loader):
            images = images.to(DEVICE)
            image_name = image_names[0]
            
            # 1. Ï∂îÎ°†
            preds = model(images)
            if isinstance(preds, tuple): preds = preds[0]
            
            # 2. ÌÅ¨Í∏∞ Î≥¥Ï†ï (ÌòπÏãú 2048Ïù¥ ÏïÑÎãê Í≤ΩÏö∞)
            if preds.shape[-1] != 2048:
                preds = F.interpolate(preds, size=(2048, 2048), mode="bilinear", align_corners=False)
            
            # 3. [Re-flip] Right ÏÜêÏù∏ Í≤ΩÏö∞, Îã§Ïãú ÏõêÎûòÎåÄÎ°ú Î∞òÏ†Ñ
            hand_side = dataset.hand_side_map.get(image_name, 'Unknown')
            if hand_side == 'Right':
                preds = torch.flip(preds, dims=[-1])
            
            # 4. Thresholding
            preds = torch.sigmoid(preds)
            preds = (preds > thr).detach().cpu().numpy()
            
            # 5. RLE Encoding
            for output, img_name in zip(preds, image_names):
                for c, segm in enumerate(output):
                    rle = encode_mask_to_rle(segm)
                    rles.append(rle)
                    filename_and_class.append(f"{IND2CLASS[c]}_{img_name}")
                    
    return rles, filename_and_class

In [None]:
WINDOW_SIZE = 1024
STRIDE = 512

def sliding_window_inference(model, image, window_size, stride, num_classes=29):
    model.eval()
    _, _, H, W = image.shape
    prob_map = torch.zeros((1, num_classes, H, W), device=DEVICE)
    count_map = torch.zeros((1, num_classes, H, W), device=DEVICE)
    
    # Ïä§ÌÖù ÏÑ§Ï†ï (ÎßàÏßÄÎßâ ÏûêÌà¨Î¶¨ ÏòÅÏó≠ Ìè¨Ìï®)
    h_steps = list(range(0, H - window_size + 1, stride))
    w_steps = list(range(0, W - window_size + 1, stride))
    if (H - window_size) % stride != 0: h_steps.append(H - window_size)
    if (W - window_size) % stride != 0: w_steps.append(W - window_size)

    with torch.no_grad():
        for h in h_steps:
            for w in w_steps:
                img_patch = image[:, :, h:h+window_size, w:w+window_size]
                preds = model(img_patch)
                if isinstance(preds, tuple): preds = preds[0]
                
                prob_patch = torch.sigmoid(preds)
                prob_map[:, :, h:h+window_size, w:w+window_size] += prob_patch
                count_map[:, :, h:h+window_size, w:w+window_size] += 1.0
                
    return (prob_map / count_map).squeeze(0) # (C, H, W)

def test_sliding(model, data_loader, thr=0.5):
    rles = []
    filename_and_class = []
    dataset = data_loader.dataset
    
    print("üöÄ [Sliding Window] Inference Start...")
    
    for images, image_names in tqdm(data_loader):
        images = images.to(DEVICE)
        image_name = image_names[0]
        
        # 1. Sliding Window Ï∂îÎ°†
        probs = sliding_window_inference(model, images, WINDOW_SIZE, STRIDE)
        
        # 2. [Re-flip] Right ÏÜêÏù∏ Í≤ΩÏö∞ Î≥µÍµ¨
        hand_side = dataset.hand_side_map.get(image_name, 'Unknown')
        if hand_side == 'Right':
            probs = torch.flip(probs, dims=[-1])
            
        # 3. Thresholding & RLE
        preds = (probs > thr).detach().cpu().numpy()
        
        for c, segm in enumerate(preds):
            rle = encode_mask_to_rle(segm)
            rles.append(rle)
            filename_and_class.append(f"{IND2CLASS[c]}_{image_name}")
            
    return rles, filename_and_class

In [None]:
MODEL_CONFIGS = [
    {
        "name": "sam_sliding",
        "type": "sam2unet",
        "path": "../sam2_unet_result_checkpoints/experiment16.pth",
        "method": "sliding",
        "weight": 0.5 
    },
    {
        "name": "sam_full",
        "type": "sam2unet",
        "path": "../sam2_unet_result_checkpoints/experiment20.pth",
        "method": "full",
        "weight": 0.5
    },
]

# --- [Ï∂îÎ°† Ìó¨Ìçº Ìï®Ïàò] ---
def sliding_window_inference(model, image, window_size=1024, stride=512, num_classes=29):
    # (ÏúÑÏùò Ïä¨ÎùºÏù¥Îî© ÏúàÎèÑÏö∞ Ìï®ÏàòÏôÄ ÎèôÏùº)
    model.eval()
    _, _, H, W = image.shape
    prob_map = torch.zeros((1, num_classes, H, W), device=DEVICE)
    count_map = torch.zeros((1, num_classes, H, W), device=DEVICE)
    
    h_steps = list(range(0, H - window_size + 1, stride))
    w_steps = list(range(0, W - window_size + 1, stride))
    if (H - window_size) % stride != 0: h_steps.append(H - window_size)
    if (W - window_size) % stride != 0: w_steps.append(W - window_size)

    with torch.no_grad():
        for h in h_steps:
            for w in w_steps:
                img_patch = image[:, :, h:h+window_size, w:w+window_size]
                preds = model(img_patch)
                if isinstance(preds, tuple): preds = preds[0]
                prob_map[:, :, h:h+window_size, w:w+window_size] += torch.sigmoid(preds)
                count_map[:, :, h:h+window_size, w:w+window_size] += 1.0
                
    return (prob_map / count_map).squeeze(0)

def ensemble_test(models_list, data_loader, thr=0.5):
    rles = []
    filename_and_class = []
    dataset = data_loader.dataset
    
    print("üöÄ [Soft Voting] Ensemble Start...")
    
    with torch.no_grad():
        for images, image_names in tqdm(data_loader):
            images = images.to(DEVICE)
            image_name = image_names[0]
            
            final_prob = torch.zeros((29, 2048, 2048), device=DEVICE)
            total_weight = 0.0
            
            # --- Í∞Å Î™®Îç∏Î≥Ñ Ï∂îÎ°† Î∞è Í∞ÄÏ§ë Ìï©ÏÇ∞ ---
            for config in models_list:
                model = config['model']
                weight = config['weight']
                method = config['method']
                
                if weight <= 0: continue
                
                if method == 'sliding':
                    prob = sliding_window_inference(model, images)
                else: # method == 'full'
                    preds = model(images)
                    if isinstance(preds, tuple): preds = preds[0]
                    # ÌÅ¨Í∏∞ Î≥¥Ï†ï
                    if preds.shape[-1] != 2048:
                        preds = F.interpolate(preds, size=(2048, 2048), mode="bilinear", align_corners=False)
                    prob = torch.sigmoid(preds).squeeze(0)
                
                final_prob += prob * weight
                total_weight += weight
            
            # Ï†ïÍ∑úÌôî (Weights Ìï©Ïù¥ 1Ïù¥ ÏïÑÎãê Í≤ΩÏö∞ ÎåÄÎπÑ)
            if total_weight > 0:
                final_prob /= total_weight
                
            # --- [Re-flip] Right ÏÜêÏù∏ Í≤ΩÏö∞ Î≥µÍµ¨ ---
            hand_side = dataset.hand_side_map.get(image_name, 'Unknown')
            if hand_side == 'Right':
                final_prob = torch.flip(final_prob, dims=[-1])
            
            # --- Thresholding ---
            preds = (final_prob > thr).detach().cpu().numpy()
            
            for c, segm in enumerate(preds):
                rle = encode_mask_to_rle(segm)
                rles.append(rle)
                filename_and_class.append(f"{IND2CLASS[c]}_{image_name}")
                
    return rles, filename_and_class

In [None]:
if __name__ == '__main__':
    # 1. Î™®Îç∏ Ï¥àÍ∏∞Ìôî Î∞è Î°úÎìú
    for config in MODEL_CONFIGS:
        print(f"Loading {config['name']}...")
        if config['type'] == 'sam2unet':
            model = SAM2UNet(HIERA_PATH).to(DEVICE)
        elif config['type'] == 'smp_unetpp':
            model = smp.UnetPlusPlus(encoder_name="efficientnet-b3", classes=29).to(DEVICE)
        
        # Ï≤¥ÌÅ¨Ìè¨Ïù∏Ìä∏ Î°úÎìú (state_dict Ï≤òÎ¶¨)
        if os.path.exists(config['path']):
            ckpt = torch.load(config['path'], map_location=DEVICE)
            if isinstance(ckpt, dict) and ('model' in ckpt or 'state_dict' in ckpt):
                state_dict = ckpt.get('model', ckpt.get('state_dict'))
            elif hasattr(ckpt, 'state_dict'): # Î™®Îç∏ Í∞ùÏ≤¥ ÌÜµÏß∏Î°ú Ï†ÄÏû•Îêú Í≤ΩÏö∞
                 state_dict = ckpt.state_dict()
            else:
                state_dict = ckpt
            
            model.load_state_dict(state_dict, strict=False)
            config['model'] = model # Î°úÎìúÎêú Î™®Îç∏ Í∞ùÏ≤¥ Ï†ÄÏû•
        else:
            print(f"‚ùå Path Not Found: {config['path']}")
            config['weight'] = 0 # Î°úÎìú Ïã§Ìå®Ïãú Í∞ÄÏ§ëÏπò 0 Ï≤òÎ¶¨

    # 2. Îç∞Ïù¥ÌÑ∞ÏÖã
    tf = A.Compose([ToTensorV2()])
    test_ds = XRayInferenceDataset(IMAGE_ROOT, transforms=tf)
    test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4)

    # 3. Ïã§Ìñâ
    rles, fnames = ensemble_test(MODEL_CONFIGS, test_loader)
    
    classes, filename = zip(*[x.split("_") for x in fnames])
    df = pd.DataFrame({"image_name": [os.path.basename(f) for f in filename], "class": classes, "rle": rles})
    df.to_csv(SAVE_PATH, index=False)
    print(f"‚úÖ Ensemble Result Saved to {SAVE_PATH}")