In [None]:
# @model
import json, os, random
import numpy as np

import torch
import torch.nn as nn

from model import Mask2Former
from lib.config_helper import merge


ckpt_path = "./transnext_mask2former_checkpoints/transnext_mask2former_cls_T512_pix_add1_ema.pth"

with open("config.json", "r", encoding="utf-8") as f:
    config = json.load(f)
if config["Model"].get("structure") == "mask2former":
    with open("./model/mask2former/m2f_config.json", "r", encoding="utf-8") as f:
        m2f_config = json.load(f)
    config = merge(config, m2f_config)
    
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)
# torch.backends.cudnn.benchmark = True

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_name = torch.cuda.get_device_name(device) if torch.cuda.is_available() else "CPU"
print(f"Device: {device_name}")

if "mask2former_decoder" in config["Model"]:
    if config["Model"]["unet_backbone"]["encoder_type"] == "TransNeXt-Tiny":
        model_name = "TransNeXt-Mask2Former"
    elif config["Model"]["unet_backbone"]["encoder_type"] == "Swin-Small":
        model_name = "SwinUNet-Mask2Former"
    elif config["Model"]["unet_backbone"]["encoder_type"] == "Swin-SMT":
        model_name = "SwinSMT-Mask2Former"
    print(model_name)
else:
    raise NotImplementedError 

# --------------------------------------------------------------------------
# Model Configuration & Initialization
# --------------------------------------------------------------------------
# Create the model
# merge config for mask2former
with open("./model/mask2former/m2f_config.json", "r", encoding="utf-8") as f:
    m2f_config = json.load(f)
config = merge(config, m2f_config)
model = Mask2Former(config)
# model.load_pretrained_weight()
anchors_pos = None
det_head = False

# --------------------------------------------------------------------------
# Read ckpt
# --------------------------------------------------------------------------
ckpt = torch.load(ckpt_path)
if "n_averaged" in ckpt.keys(): ## key names in ema_model & model are slightly different
    ema_model = torch.optim.swa_utils.AveragedModel(model, multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(config["Validation"]["ema"]))
    model = ema_model
model.load_state_dict(ckpt, strict=True)
print(f"Ckpt loaded {ckpt_path}")

In [None]:
## folders for inference
folder_name_list = [
"Sonosite20241008_1401_474frams_abc(待茹瑄_似蓉_承原)",
'Sonosite20241008_1207_158frams_abc(待茹瑄_似蓉_承原)',
'Sonosite20241008_1301_360frams_abc(待茹瑄_似蓉_承原)',
'Sonosite20241008_1334_212frams_abc(待茹瑄_似蓉_承原)',
'Sonosite20241008_1146_1066frams_abc(待茹瑄_似蓉_承原)',
'Sonosite20240924_1205_378frams_abc(待茹瑄_似蓉_承原)',
'Sonosite20241008_1148_1070frams_abc(待茹瑄_似蓉_承原)',
'Sonosite20241008_1144_1056frams_abc(待茹瑄_似蓉_承原)',
'Sonosite20241008_0946_910frams_abc(待茹瑄_似蓉_承原)'
]

In [None]:
from torch.utils.data import DataLoader, ConcatDataset
from dataset import UnlabeledDataset, Augmentation
from evaluation_v2 import inference

valid_transform = Augmentation(crop=False, rotate=False, color_jitter=False, horizontal_flip=False, image_size=512)

for folder_name in folder_name_list:

    medium_test_dataset = UnlabeledDataset(f"C:/Dropbox/家庭資料室/{folder_name}", transform=valid_transform,
                                        time_window=3,
                                        buffer_num_sample=1,
                                        line_width=20,
                                        det_num_classes=1,
                                    )
    ## Test batch size 1
    loader = DataLoader(medium_test_dataset, batch_size=1, shuffle=False, drop_last=False, num_workers=4, persistent_workers=True, pin_memory=True)
    print(len(medium_test_dataset))

    sample = next(iter(loader))
    print(sample["images"].shape)
    print(sample["img_path"][2])

    ## similar to evaluate function but writes points to json file
    inference(model, device, loader, det_head, save_json=True)
    # break