## 1. 공통 모듈 및 평가 함수 정의

In [None]:
import torch
from torch.utils.data import DataLoader
import os
from datetime import datetime

from isegm.model.metrics import PerClassIoU, MultiClassIoU
from isegm.model.is_trimap_plaintvit_model import TrimapPlainVitModel
from tqdm import tqdm

def evaluate_model(model, dataset, dataset_name, device, log_file):
    loader = DataLoader(dataset, batch_size=1, shuffle=False)
    
    per_class_iou = PerClassIoU()
    mean_iou = MultiClassIoU()

    model.eval()
    with torch.no_grad():
        for sample in tqdm(loader, desc=f"Evaluating {dataset_name}"):
            image = sample['image'].to(device)
            trimap_label = sample['trimap'].to(device)

            output = model(image)
            per_class_iou.update(output, trimap_label)
            mean_iou.update(output, trimap_label)

    per_cls = per_class_iou.get_epoch_value()
    mean = mean_iou.get_epoch_value()

    print(f"\n=== {dataset_name} ===")
    print("Dataset Size:", len(dataset))
    print("Per-Class IoU:", per_cls)
    print("Mean IoU:", mean)

    with open(log_file, 'a') as f:
        f.write(f"Dataset: {dataset_name}\n")
        f.write(f"Samples: {len(dataset)}\n")
        f.write(f"BG IoU\tUnknown IoU\tFG IoU\tMean IoU\n")
        f.write(f"{per_cls['bg']:.4f}\t{per_cls['unknown']:.4f}\t{per_cls['fg']:.4f}\t{mean:.4f}\n")
        f.write(f"\n")

    per_class_iou.reset_epoch_stats()
    mean_iou.reset_epoch_stats()

ModuleNotFoundError: No module named 'isegm'

## 모델 정의

In [None]:
def build_model(cfg):
    backbone_params = dict(
        img_size=(cfg.INPUT_SIZE, cfg.INPUT_SIZE),
        patch_size=(14,14),
        in_chans=3,
        embed_dim=1280,
        depth=32,
        num_heads=16,
        mlp_ratio=4, 
        qkv_bias=True,
    )

    neck_params = dict(
        in_dim = 1280,
        out_dims = [240, 480, 960, 1920],
    )

    head_params = dict(
        in_channels=[240, 480, 960, 1920],
        in_index=[0, 1, 2, 3],
        dropout_ratio=0.1,
        num_classes=3,
        loss_decode=torch.nn.CrossEntropyLoss(),
        align_corners=False,
        upsample=cfg.upsample,
        channels={'x1': 256, 'x2': 128, 'x4': 64}[cfg.upsample],
    )

    model = TrimapPlainVitModel(
        use_disks=True,
        norm_radius=5,
        with_prev_mask=True,
        backbone_params=backbone_params,
        neck_params=neck_params,
        head_params=head_params,
        random_split=cfg.random_split,
    )

    model.backbone.init_weights_from_pretrained(cfg.IMAGENET_PRETRAINED_MODELS.MAE_HUGE)
    model.to(cfg.device)
    return model

## Dataset 생성

In [None]:
from isegm.data.datasets.composition import COMPOSITIONTrimapDataset
from isegm.data.datasets.p3m10k import P3M10KTrimapDataset
from isegm.data.datasets.aim500 import AIM500TrimapDataset
from isegm.data.datasets.am200 import AM200TrimapDataset

evaldataset1 = COMPOSITIONTrimapDataset('./dataset/Seg2TrimapDataset/Composition-1k-testset')
evaldataset2 = P3M10KTrimapDataset('./dataset/2.P3M-10k/P3M-10k/validation/P3M-500-NP')
evaldataset3 = AIM500TrimapDataset('./datasets/Seg2TrimapDataset/AIM-500') 
evaldataset4 = AM200TrimapDataset('./datasets/Seg2TrimapDataset/AM-200')

In [None]:
## Evaluation

In [None]:
device = "cpu"
evaluate_model(model, evaldataset1, 'Composition-1K', device)
evaluate_model(model, evaldataset2, 'P3M-500-P', device)
evaluate_model(model, evaldataset3, 'AIM-500', device)
evaluate_model(model, evaldataset4, 'AM-200', device)
