#  加载DINO基础检测模型特征图

In [1]:
import os
import numpy as np
import torch
import cv2
from hq_det.models.dino import hq_dino

In [2]:
model_path = "/root/autodl-tmp/seat_model/ckpt.pth.dino.hunhe"
imgs_dir = "/root/autodl-tmp/seat_dataset/chengdu_resplit-21/valid"
imgs = [os.path.join(imgs_dir, img) for img in os.listdir(imgs_dir)]

In [3]:
model = hq_dino.HQDINO(model=model_path)

627 627


In [None]:
def create_batch_data(img_path, model, max_size=1536):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    img_scale = 1.0
    max_hw = max(img.shape[0], img.shape[1])
    if max_hw > max_size:
        rate = max_size / max_hw
        img = cv2.resize(img, (int(img.shape[1] * rate), int(img.shape[0] * rate)))
        img_scale = rate
    batch_data = model.imgs_to_batch([img])
    return batch_data, img_scale

def get_dino_object_detect_feature_map(img_path, model, max_size=1536):
    batch_data, img_scale = create_batch_data(img_path, model, max_size)
    batch_data.update(model.model.data_preprocessor(batch_data, model.model.training))
    feats = model.model.extract_feat(batch_data['inputs'])
    return feats



In [5]:
img_path = imgs[0]

In [6]:
img_feats = get_dino_object_detect_feature_map(img_path, model)

In [7]:
for feat in img_feats:
    print(feat.shape)

torch.Size([1, 256, 192, 192])
torch.Size([1, 256, 96, 96])
torch.Size([1, 256, 48, 48])
torch.Size([1, 256, 24, 24])


In [11]:
import numpy as np

def save_feats_to_npz(img_feats, filename="img_feats.npz"):
    """
    将特征图列表保存为float16的npz文件
    Args:
        img_feats: 特征图列表，可以是Tensor或numpy数组。
        filename: 保存的文件名，默认为"img_feats.npz"。
    """
    feat_arrays = [
        (feat.detach().cpu().numpy().astype(np.float16) if hasattr(feat, 'detach') else
         (feat.cpu().numpy().astype(np.float16) if hasattr(feat, 'cpu') else np.array(feat, dtype=np.float16)))
        for feat in img_feats
    ]
    np.savez(filename, *feat_arrays)
    print(f"特征图已保存为 {filename}（float16）")

# 调用示例
save_feats_to_npz(img_feats, "img_feats.npz")


特征图已保存为 img_feats.npz（float16）
