In [1]:
import os
from glob import glob
from typing import List, Literal
import shutil
import json
import numpy as np
from rich.progress import track
import cv2
from sklearn.model_selection import train_test_split

In [2]:
root_path = "../data/raw_data/标注信息"

In [3]:
image_paths = glob(os.path.join(root_path, "*.tif"))
ann_paths = [filename.replace("tif", "json") for filename in image_paths]
print("数据集总大小:", len(image_paths), len(ann_paths))

数据集总大小: 1151 1151


In [4]:
image_train, image_test, ann_train, ann_test = train_test_split(
    image_paths, ann_paths, train_size=0.8, random_state=42
)
print(len(image_train), len(image_test), len(ann_train), len(ann_test))

920 231 920 231


In [5]:
def get_mask_by_json(filename: str) -> np.ndarray:
    """
    将json文件转换为mask
    """
    label_map = {
        0: 0,
        1: 1,
        2: 2,
        3: 2,
        4: 2,
        5: 3,
    }
    json_file = json.load(open(filename))
    img_height = json_file["imageHeight"]
    img_width = json_file["imageWidth"]
    mask = np.zeros((img_height, img_width), dtype="int8")
    for shape in json_file["shapes"]:
        label = int(shape["label"])
        label = label_map[label]
        points = np.array(shape["points"]).astype(np.int32)
        cv2.fillPoly(mask, [points], label)
    return mask

In [6]:
def json_to_image(json_path, image_path):
    """
    将json文件中的标注信息转换为图片,并保存至指定路径
    """
    mask = get_mask_by_json(json_path)
    cv2.imwrite(image_path, mask)

In [7]:
def create_dataset(
    image_paths: List[str],
    ann_paths: List[str],
    phase: Literal["train", "val"],
    output_dir: str,
):
    """
    划分数据集，将标注信息转换为图片，并保存至指定路径
    """
    base_path = os.path.join(output_dir, phase)
    for image_path, ann_path in track(zip(image_paths, ann_paths),description=f"{phase} dataset"):
        ann_save_path = os.path.join(
            base_path, "ann",os.path.basename(ann_path).replace(".json", ".tif")
        )

        # 将image复制到指定路径
        new_image_path = os.path.join(base_path, "img",os.path.basename(image_path))
        shutil.copy(image_path, new_image_path)

        # 将ann保存到指定路径
        json_to_image(ann_path, ann_save_path)

In [8]:
root = "../data/grass"
create_dataset(image_train,ann_train,"train",root)
create_dataset(image_test,ann_test,"val",root)

Output()

Output()