In [8]:
import torch
import torchvision
import torchvision.transforms as T
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional as F
from torchvision import transforms

from tqdm import tqdm
import os, time
import json

import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt

# 转换数据集

## 数据集文件夹结构：

- DatasetRoot:
    - bmp: bmp 原始图片
    - json: json 原始标注文件
    - images: 转换成 jpg 格式的图片
    - segm: 分割后的图片

## 转换步骤

1. 读取bmp文件夹下的图片，转换成jpg格式，保存到images文件夹下。
2. 读取json文件夹下的标注文件，转换成分割后的图片，保存到segm文件夹下。



In [4]:
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    """ Display a list of images.
    
    Args:
        imgs (list): List of images
        num_rows (int): Number of rows
        num_cols (int): Number of columns
        titles (list, optional): List of titles. Defaults to None.
        scale (float, optional): Scale. Defaults to 1.5.
        
    Returns:
        list: List of axes
    """
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        ax.imshow(img.permute(1, 2, 0))
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

In [5]:
COLORS = [(0, 255, 0), (255, 0, 0)]
CLASSES = ['pool', 'lack_of_fusion']

In [7]:
def label2mask(json_file, root='mydata', show=False):
    """According to the json file, create a mask for the image.

    Save mask file with same name in {root}/seg/ folder. 

    Save into .jpg format.  

    Args:
        json_file (str): Json file path
        show (bool, optional): Show the mask. Defaults to False.
    """
    os.makedirs(os.path.join(root, 'segm'), exist_ok=True)
    try:
        with open(json_file, 'r') as f:
            data = json.load(f)
    except:
        with open(json_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
    img = np.zeros(
        (data['imageHeight'], data['imageWidth'], 3), dtype=np.uint8)
    save_path = os.path.join(root, 'segm', os.path.basename(
        json_file).split('.')[0] + '.png')
    labels = data['shapes']
    for label in labels:
        points = np.array(label['points'], dtype=np.int32)
        cv.fillPoly(img, [points], COLORS[CLASSES.index(label['label'])])
    cv.imwrite(save_path, img)
    if show:
        plt.imshow(img)
        plt.show()


def convert_img(root='mydata'):
    """Convert images into .jpg format.

    Args:
        root (str, optional): _description_. Defaults to 'mydata'.
    """
    os.makedirs(os.path.join(root, 'Images'), exist_ok=True)
    for img_file in tqdm(os.listdir(os.path.join(root, 'bmp'))):
        img = cv.imread(os.path.join(root, 'bmp', img_file))
        cv.imwrite(os.path.join(
            root, 'Images', img_file.split('.')[0] + '.jpg'), img)


# for json_file in tqdm(os.listdir('mydata/jsons')):
#     label2mask(os.path.join('mydata/jsons', json_file))
# convert_img()

100%|██████████| 1177/1177 [00:19<00:00, 60.32it/s]
100%|██████████| 1177/1177 [00:27<00:00, 42.60it/s]
