In [2]:
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
from PIL import Image
import requests
#url = "https://huggingface.co/datasets/shi-labs/oneformer_demo/blob/main/ade20k.jpeg"
#image = Image.open(requests.get(url, stream=True).raw)

# Loading a single model for all three tasks
processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_large")
model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_large")

# Semantic Segmentation
semantic_inputs = processor(images=image, task_inputs=["semantic"], return_tensors="pt")
semantic_outputs = model(**semantic_inputs)
# pass through image_processor for postprocessing
predicted_semantic_map = processor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]

# Instance Segmentation
instance_inputs = processor(images=image, task_inputs=["instance"], return_tensors="pt")
instance_outputs = model(**instance_inputs)
# pass through image_processor for postprocessing
predicted_instance_map = processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])[0]["segmentation"]

# Panoptic Segmentation
panoptic_inputs = processor(images=image, task_inputs=["panoptic"], return_tensors="pt")
panoptic_outputs = model(**panoptic_inputs)
# pass through image_processor for postprocessing
predicted_semantic_map = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]["segmentation"]


OSError: We couldn't connect to 'https://huggingface.co' to load this file, couldn't find it in the cached files and it looks like shi-labs/oneformer_ade20k_swin_large is not the path to a directory containing a file named preprocessor_config.json.
Checkout your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'.

In [None]:
# 以下是完整补全后的代码，包括：
# 1. 封装后的 `semantic_infer_one_image()` 函数
# 2. 支持从 WebDataset 中读取图像，生成语义 mask，并存回新 tar 的主程序流程
# 3. 多卡分布式运行支持

import os
import io
import tarfile
import torch
import numpy as np
from PIL import Image
import argparse
from datasets import load_dataset
from torchvision.transforms import Resize
from torchvision.transforms.functional import to_pil_image
import torch.nn.functional as F
from transformers import (
    CLIPProcessor, CLIPModel,
    AutoProcessor, CLIPSegForImageSegmentation,
    OneFormerProcessor, OneFormerForUniversalSegmentation,
    BlipProcessor, BlipForConditionalGeneration
)
from clip import clip_classification
from clipseg import clipseg_segmentation
from oneformer import oneformer_coco_segmentation, oneformer_ade20k_segmentation
from blip import open_vocabulary_classification_blip
from configs.ade20k_id2label import CONFIG as CONFIG_ADE20K_ID2LABEL
from configs.coco_id2label import CONFIG as CONFIG_COCO_ID2LABEL
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
import torch.multiprocessing as mp
import torch.distributed as dist
from torchvision.transforms import InterpolationMode

# 参数配置
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

input_dir = '/mnt/33t/cy/blip3o_dataset'
output_dir = '/mnt/33t/cy/mask_dataset'
base_dir = '/mnt/33t/cy/mllm_models/semantic_sam'

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--world_size', type=int, default=1)
    parser.add_argument('--save_img', action='store_true')
    parser.add_argument('--ckpt_path', type=str, required=True)
    return parser.parse_args()

# ====================== 语义推理函数 ======================
def semantic_infer_one_image(img: Image.Image, processors, models, rank) -> np.ndarray:
    """对一张图像执行语义标注，返回 shape=(256,) 的单通道 numpy array，值为 0~num_class"""
    from mmcv import imcrop
    import pycocotools.mask as maskUtils

    img = np.array(img.convert("RGB"))
    anns = {'annotations': processors['sam_generator'].generate(img)}
    class_ids_from_oneformer_coco = oneformer_coco_segmentation(Image.fromarray(img),
                                                                 processors['oneformer_coco'], models['oneformer_coco'], rank)
    class_ids_from_oneformer_ade20k = oneformer_ade20k_segmentation(Image.fromarray(img),
                                                                     processors['oneformer_ade20k'], models['oneformer_ade20k'], rank)
    semantic_mask = np.zeros(img.shape[:2], dtype=np.uint8)
    for ann in anns['annotations']:
        valid_mask = torch.tensor(maskUtils.decode(ann['segmentation'])).bool()
        coco_ids = class_ids_from_oneformer_coco[valid_mask]
        ade_ids = class_ids_from_oneformer_ade20k[valid_mask]
        coco_labels = [CONFIG_COCO_ID2LABEL['refined_id2label'].get(str(i.item()), '') for i in torch.bincount(coco_ids).topk(1).indices]
        ade_labels = [CONFIG_ADE20K_ID2LABEL['id2label'].get(str(i.item()), '') for i in torch.bincount(ade_ids).topk(1).indices]
        labels = list(set(coco_labels + ade_labels))

        x0, y0, w, h = ann['bbox']
        patch = imcrop(img, np.array([x0, y0, x0+w, y0+h]), scale=1.5)
        open_vocab_labels = open_vocabulary_classification_blip(patch, processors['blip'], models['blip'], rank)
        candidate_labels = list(set(labels + open_vocab_labels))
        top_labels = clip_classification(patch, candidate_labels, min(3, len(candidate_labels)),
                                         processors['clip'], models['clip'], rank)
        seg = clipseg_segmentation(patch, top_labels, processors['clipseg'], models['clipseg'], rank).argmax(0)

        ann_mask = torch.tensor(maskUtils.decode(ann['segmentation']))
        if ann_mask.shape != seg.shape:
            ann_mask = F.interpolate(ann_mask.unsqueeze(0).unsqueeze(0).float(), size=seg.shape, mode='nearest').squeeze(0).squeeze(0).bool()
        seg = seg.cpu().numpy()
        class_name = top_labels[torch.bincount(torch.tensor(seg[ann_mask.numpy()].flatten())).topk(1).indices.item()]
        class_id = processors['label2index'].setdefault(class_name.lower().strip(" a the"), len(processors['label2index']))
        semantic_mask[ann_mask.numpy()] = class_id

    h, w = semantic_mask.shape
    scale = (256 / (h * w)) ** 0.5
    new_h, new_w = max(1, int(h * scale)), max(1, int(w * scale))
    mask_img = Image.fromarray(semantic_mask).resize((new_w, new_h), resample=Image.NEAREST)
    mask_arr = np.array(mask_img).astype(np.uint8).flatten()
    return mask_arr  # shape = (256,) numpy array

# ====================== 主函数 ======================
def main(rank, args):
    dist.init_process_group("nccl", rank=rank, world_size=args.world_size)
    device = torch.device(f"cuda:{rank}")

    # 模型与处理器初始化
    processors = {
        'clip': CLIPProcessor.from_pretrained(f"{base_dir}/clip-vit-large-patch14"),
        'clipseg': AutoProcessor.from_pretrained(f"{base_dir}/clipseg-rd64-refined"),
        'oneformer_ade20k': OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_large"),
        'oneformer_coco': OneFormerProcessor.from_pretrained("shi-labs/oneformer_coco_swin_large"),
        'blip': BlipProcessor.from_pretrained(f"{base_dir}/blip-image-captioning-large"),
        'label2index': {}
    }
    processors['clipseg'].image_processor.do_resize = False

    models = {
        'clip': CLIPModel.from_pretrained(f"{base_dir}/clip-vit-large-patch14").to(device),
        'clipseg': CLIPSegForImageSegmentation.from_pretrained(f"{base_dir}/clipseg-rd64-refined").to(device),
        'oneformer_ade20k': OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_large").to(device),
        'oneformer_coco': OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_coco_swin_large").to(device),
        'blip': BlipForConditionalGeneration.from_pretrained(f"{base_dir}/blip-image-captioning-large").to(device),
    }

    sam = sam_model_registry["vit_h"](checkpoint=args.ckpt_path).to(device)
    processors['sam_generator'] = SamAutomaticMaskGenerator(model=sam, points_per_side=32,
        pred_iou_thresh=0.86, stability_score_thresh=0.92, crop_n_layers=0,
        crop_n_points_downscale_factor=2, min_mask_region_area=100, output_mode='coco_rle')

    # 读取当前进程对应的数据
    tar_files = sorted([os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.tar')])
    local_files = tar_files[rank::args.world_size]
    for tar_path in local_files:
        dataset = load_dataset("webdataset", data_files=tar_path, split="train")
        tar_name = os.path.splitext(os.path.basename(tar_path))[0]
        output_tar_path = os.path.join(output_dir, f"{tar_name}.tar")
        with tarfile.open(output_tar_path, "w") as tar_out:
            for item in dataset:
                key = item["__key__"]
                img: Image.Image = item["image"]
                mask_arr = semantic_infer_one_image(img, processors, models, rank)
                img_buffer = io.BytesIO()
                item["image"].save(img_buffer, format="PNG")
                mask_buffer = io.BytesIO(mask_arr.astype(np.uint8).tobytes())

                for name, buf in [("image.png", img_buffer), ("mask.pgm", mask_buffer)]:
                    info = tarfile.TarInfo(f"{key}/{name}")
                    info.size = buf.getbuffer().nbytes
                    buf.seek(0)
                    tar_out.addfile(info, buf)
        if rank == 0:
            print(f"[rank {rank}] Saved: {output_tar_path}")

if __name__ == "__main__":
    args = parse_args()
    if args.world_size > 1:
        mp.spawn(main, args=(args,), nprocs=args.world_size, join=True)
    else:
        main(0, args)


In [5]:
2**16 - 1

65535

In [3]:
from PIL import Image
import numpy as np


img_arr = np.random.randint(0, 2^16-1, (64, 64))
img = Image.fromarray(img_arr, mode='I;16')
img.save('test.png')
img.save('test.tiff')
np.save('test', img_arr)

img_recover_npy = np.load("test.npy")
img_recover_tiff = Image.open('test.tiff')
img_recover_png = Image.open('test.png')

In [22]:

from PIL import Image
import numpy as np
import time

# Create a synthetic 16-bit image
img_arr = np.random.randint(0, 2**16 - 1, (16, 25), dtype=np.uint16)
img = Image.fromarray(img_arr, mode='I;16')

# Save in different formats
img.save('/mnt/data/test.png')
img.save('/mnt/data/test.tiff')
#img.save('/mnt/data/test.mask')
np.save('/mnt/data/test.npy', img_arr)

# Load and measure time & error
results = {}

# NPY
start = time.time()
img_recover_npy = np.load('/mnt/data/test.npy')
results['npy_time'] = time.time() - start
results['npy_error'] = np.abs(img_arr.astype(np.int16) - img_recover_npy.astype(np.int16)).max()

# TIFF
start = time.time()
img_recover_tiff = Image.open('/mnt/data/test.tiff')
img_recover_tiff = np.array(img_recover_tiff)
results['tiff_time'] = time.time() - start
results['tiff_error'] = np.abs(img_arr.astype(np.int16) - img_recover_tiff.astype(np.int16)).max()

# PNG
start = time.time()
img_recover_png = Image.open('/mnt/data/test.png')
img_recover_png = np.array(img_recover_png)
results['png_time'] = time.time() - start
results['png_error'] = np.abs(img_arr.astype(np.int16) - img_recover_png.astype(np.int16)).max()

results



{'npy_time': 0.00036334991455078125,
 'npy_error': 0,
 'tiff_time': 0.0006701946258544922,
 'tiff_error': 0,
 'png_time': 0.0002529621124267578,
 'png_error': 0}

In [15]:
img_arr == img_recover_tiff

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  Tru