In [3]:
import torch
import numpy as np
import cv2
from model.u2net import U2NET  

# 加载预训练模型
model = U2NET(3, 1)
model.load_state_dict(torch.load('u2net.pth'))
model.eval()

def extract_main_object(image_path):
    # 读取并预处理图像
    gray_img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    img_rgb = cv2.cvtColor(gray_img, cv2.COLOR_GRAY2RGB)
    img = cv2.resize(img_rgb, (320, 320))
    img = img / 255.0
    img = torch.tensor(img.transpose(2, 0, 1), dtype=torch.float32).unsqueeze(0)

    # 推断显著图
    with torch.no_grad():
        saliency_map = model(img)[0].squeeze().numpy()
    
    # 二值化并提取中心区域
    _, mask = cv2.threshold((saliency_map * 255).astype(np.uint8), 0, 255, cv2.THRESH_OTSU)
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # 筛选最大或中心连通域
    if contours:
        largest_contour = max(contours, key=cv2.contourArea)
        x, y, w, h = cv2.boundingRect(largest_contour)
        center_x, center_y = x + w//2, y + h//2
        # 检查是否位于图像中央区域（例如中心50%范围内）
        if (0.25*320 < center_x < 0.75*320) and (0.25*320 < center_y < 0.75*320):
            final_mask = cv2.drawContours(np.zeros_like(mask), [largest_contour], -1, 255, -1)
            return cv2.resize(final_mask, (gray_img.shape[1], gray_img.shape[0]))
    return None  # 未检测到显著中心物体

ModuleNotFoundError: No module named 'model'

In [2]:
import torchvision
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights

model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT)
model.eval()

def extract_with_maskrcnn(image_path):
    img = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
    img_tensor = torch.tensor(img.transpose(2, 0, 1) / 255.0, dtype=torch.float32)
    
    with torch.no_grad():
        predictions = model([img_tensor])[0]
    
    # COCO类别ID（动物：1-25，植物：如'potted plant'为63）
    target_classes = [i for i in range(1, 26)] + [63]
    center = np.array([img.shape[1]//2, img.shape[0]//2])
    
    for score, label, mask, box in zip(predictions['scores'], predictions['labels'], 
                                      predictions['masks'], predictions['boxes']):
        if score > 0.5 and label.item() in target_classes:
            box_center = (box[:2] + box[2:]) / 2
            # 检查是否靠近中心
            if np.linalg.norm(box_center.numpy() - center) < max(img.shape)//4:
                return mask[0].numpy() > 0.5
    return None

Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /home/ubuntu/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth
100%|██████████| 170M/170M [00:15<00:00, 11.7MB/s] 
