# 推理与OneFormer：通用图像分割
原论文：https://arxiv.org/abs/2211.06220
OneFormer在Mask2Former框架中集成了一个文本模块，以在各自的子任务（实例、语义或panoptic）上约束模型。这样可以得到更准确的结果，但代价是增加了延迟。

## 设置环境
Mindspore 2.5.0

Mindnlp   0.4.0

python    3.9.0

## 图像加载

接下来，我们加载一个我们想要执行推理的图像。这里我们加载熟悉的猫图像，这是COCO数据集的一部分。

In [None]:
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image

## 为模型准备图像

我们可以使用处理器准备图像。OneFormer利用了一个处理器，它内部由一个图像处理器（用于图像模态）和一个标记器（用于文本模态）组成。OneFormer实际上是一个多模态模型，因为它结合了图像和文本来解决图像分割。

In [None]:
from mindnlp.transformers import AutoProcessor

# the Auto API loads a OneFormerProcessor for us, based on the checkpoint
processor = AutoProcessor.from_pretrained("shi-labs/oneformer_coco_swin_large")

In [None]:
# prepare image for the model
panoptic_inputs = processor(images=image, task_inputs=["panoptic"], return_tensors="ms")
for k,v in panoptic_inputs.items():
  print(k,v.shape)

可以看到，这个模型有一个额外的“task_inputs”，这是MaskFormer和Mask2Former所没有的。这些文本输入允许模型区分实例/语义/全景分割。



我们可以将任务输入解码回文本：

In [None]:
processor.tokenizer.batch_decode(panoptic_inputs.task_inputs)

## 加载模型



接下来，让我们从mindnlp/transformers加载一个模型。在这里，我们用一个swing -large的主干加载OneFormer模型，该主干是在COCO数据集上训练的。

In [8]:
from mindnlp.transformers import AutoModelForUniversalSegmentation

model = AutoModelForUniversalSegmentation.from_pretrained("shi-labs/oneformer_coco_swin_large")

## 前向传播



mindnlp中的前向传播是这样完成的：

In [None]:
from mindnlp.core import ops, no_grad

# forward pass
with no_grad():
  outputs = model(**panoptic_inputs)

# 可视化



接下来，我们可以对原始输出进行后处理，并将预测可视化。

In [None]:
panoptic_segmentation = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
print(panoptic_segmentation.keys())

In [None]:
from collections import defaultdict
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.patches as mpatches
import numpy as np
from mindspore import Tensor

def draw_panoptic_segmentation(segmentation, segments_info):

    if isinstance(segmentation, Tensor):
        segmentation_np = segmentation.asnumpy()
    else:
        segmentation_np = np.array(segmentation)
    
    if not np.issubdtype(segmentation_np.dtype, np.integer):
        segmentation_np = segmentation_np.astype(np.int32)
    
    # Get the maximum segment ID using numpy
    max_segment = np.max(segmentation_np)
    viridis = cm.get_cmap('viridis', max_segment + 1)  
    
    fig, ax = plt.subplots()
    ax.imshow(segmentation_np)
    
    instances_counter = defaultdict(int)
    handles = []
    
    for segment in segments_info:
        segment_id = segment['id']
        segment_label_id = segment['label_id']
        segment_label = model.config.id2label[segment_label_id]  
        label = f"{segment_label}-{instances_counter[segment_label_id]}"
        instances_counter[segment_label_id] += 1
        color = viridis(segment_id)
        handles.append(mpatches.Patch(color=color, label=label))
    
    ax.legend(handles=handles)
    plt.savefig('cats_panoptic.png')
draw_panoptic_segmentation(**panoptic_segmentation)


可以看出，该模型能够正确区分两只不同的猫以及两个不同的遥控器。

## 推理：语义分割
我们还可以使用相同的模型对猫咪图像进行语义分割！我们只需要更改任务输入（即模型的文本输入），将其改为“此任务为语义”。

In [None]:
# prepare image for the model
semantic_inputs = processor(images=image, task_inputs=["semantic"], return_tensors="ms")
for k,v in semantic_inputs.items():
  print(k,v.shape)

In [13]:
# forward pass
with no_grad():
  outputs = model(**semantic_inputs)

让我们对结果进行后处理并可视化：

In [None]:
semantic_segmentation = processor.post_process_semantic_segmentation(outputs)[0]
semantic_segmentation.shape

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from matplotlib import cm


def draw_semantic_segmentation(segmentation):

    if not isinstance(segmentation, np.ndarray):
        segmentation = np.array(segmentation)
    
    segmentation = segmentation.astype(np.int32)
    
    max_label = np.max(segmentation)  
    viridis = cm.get_cmap('viridis', max_label)
    
    labels_ids = np.unique(segmentation).tolist()
    
    fig, ax = plt.subplots()
    ax.imshow(segmentation, cmap=viridis) 
    handles = []
    
    for label_id in labels_ids:
        label = model.config.id2label[label_id]
        color = viridis(label_id / max_label)  
        handles.append(mpatches.Patch(color=color, label=label))
    
    ax.legend(handles=handles)

draw_semantic_segmentation(semantic_segmentation)

可以看到，在语义分割中，不会区分单个实例（可数的事物，如猫咪或遥控器）。相反，只会为“猫咪”类别等生成一个单一的掩码。

## 推理：实例分割

同样，我们可以使用相同的模型进行实例分割，我们只需要更改文本输入即可。

In [None]:
# prepare image for the model
instance_inputs = processor(images=image, task_inputs=["instance"], return_tensors="ms")
for k,v in instance_inputs.items():
  print(k,v.shape)

In [17]:
# forward pass
with no_grad():
  outputs = model(**instance_inputs)

让我们对结果进行后处理并可视化：

In [None]:
instance_segmentation = processor.post_process_instance_segmentation(outputs)[0]
instance_segmentation.keys()

In [None]:
from collections import defaultdict
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.patches as mpatches
import numpy as np  # 确保导入 numpy

def draw_instance_segmentation(segmentation, segments_info):
    # 转换数据类型（如果是张量或 object 类型）
    if hasattr(segmentation, 'asnumpy'):  # 处理 MindSpore 张量
        segmentation = segmentation.asnumpy()
    segmentation = np.array(segmentation, dtype=np.int32)  # 强制转换为 int32
    
    # 获取颜色映射
    max_segment_id = np.max(segmentation)  # 使用 NumPy 的 max
    viridis = cm.get_cmap('viridis', max_segment_id)
    
    fig, ax = plt.subplots()
    ax.imshow(segmentation)  # 现在 segmentation 是数值类型
    
    instances_counter = defaultdict(int)
    handles = []
    for segment in segments_info:
        segment_id = segment['id']
        segment_label_id = segment['label_id']
        segment_label = model.config.id2label[segment_label_id]
        label = f"{segment_label}-{instances_counter[segment_label_id]}"
        instances_counter[segment_label_id] += 1
        color = viridis(segment_id)
        handles.append(mpatches.Patch(color=color, label=label))
    
    ax.legend(handles=handles)
    plt.savefig('cats_panoptic.png')

# 调用函数（确保 instance_segmentation 包含正确的键）
draw_instance_segmentation(**instance_segmentation)