In [1]:
import json
import os
import shutil
import random
import numpy as np
import math

from glob import glob
from PIL import Image, ImageDraw
from tqdm import tqdm

In [2]:
def oversampling_on_seambroken(src_path, dest_path):
    # all_seam = [cut_original_image:剪切后的原始图像, 
    # cut_mask_image:剪切后的掩模, 
    # domination_num:支配点, 
    # domination_point:支配点坐标,
    # points_cor:剪切后的原始图象中conerfracture的标记点]

    all_seam = []

    for name in tqdm(glob(os.path.join(src_path, '*.json'))):
        with open(name) as fo:
            ann = json.load(fo)

        # 筛选出'label' == 'seambroken'的json文件，并将slab和conerfracture的points存下来
        if len([shape for shape in ann['shapes'] if shape['label'] == 'seambroken']) != 1:
            continue
        if len([shape for shape in ann['shapes'] if shape['label'] == 'slab']) != 2:
            continue

        slab_hori = []
        slab_vert = []
        seambroken = []
        
        for shape in ann['shapes']:
            if shape['shape_type'] != 'polygon':
                continue

            # 处理slab标签的points
            if shape['label'] == 'slab':
                slab_x = [x for x, y in shape['points']]
                slab_y = [y for x, y in shape['points']]

                max_x = max(slab_x)
                min_x = min(slab_x)
                max_y = max(slab_y)
                min_y = min(slab_y)

                # 对slab进行横/竖分类
                ratio = (max_x - min_x) / (max_y - min_y)
                if ratio > 1:
                    slab_hori.append(shape['points'])      # [[[x1, y1], [x2, y2], ...]]
                else:
                    slab_vert.append(shape['points'])

            # 存放seambroken标签的points   
            if shape['label'] == 'seambroken':
                seam_xy = [tuple(p) for p in shape['points']]
                seambroken.append(seam_xy)
                    
        if len(slab_hori) != 1 or len(slab_vert) != 1:
            continue
         
        # seam_points_4：外接框顶点坐标：((x1 , y1), ( x2,y2 ), ( x3, y3), (x4 , y4))
        index, domination_point, seam_points_4 = seambro_external_frame(seambroken[0], slab_hori[0], slab_vert[0])    
        
        # slab_hori[0]：[[x1, y1], [x2, y2], ...]
        
    
        # 读取该json文件对应的bmp文件，得到其mask，再将mask截为seambroken外接框大小
        image = Image.open('{}.bmp'.format(name[:-5]))
        
        mask = Image.new('1', image.size)
        mask_draw = ImageDraw.Draw(mask)
        mask_draw.polygon(seambroken[0], fill=1)
        
        cut_box = (seam_points_4[0] + seam_points_4[2])

        cut_original_image = image.crop(cut_box)
        cut_mask_image = mask.crop(cut_box)
        
        # 归零
        left, top = seam_points_4[0]
        domination_point = (domination_point[0] - left, domination_point[1] - top)
        seambroken = [(x - left, y - top) for x, y in seambroken[0]]
        
        
        all_seam.append((cut_original_image, cut_mask_image, index, domination_point, seam_points_4[0], seambroken))

    
    # 找出所有两条slab且没有seambroken的json,记录该json文件对应的bmp文件/bmp文件名
    

    all_pasted = []
    
    for name in tqdm(glob(os.path.join(src_path, '*.json'))):
        with open(name) as fo:
            ann = json.load(fo)

#         筛选出'label' == 'slab'的json文件，并将slab的points存下来
        if len([shape for shape in ann['shapes'] if shape['label'] == 'cornerfracture']) != 0:
            continue
        if len([shape for shape in ann['shapes'] if shape['label'] == 'slab']) != 2:
            continue
        if len([shape for shape in ann['shapes'] if shape['label'] == 'patch']) != 0:
            continue
        if len([shape for shape in ann['shapes'] if shape['label'] == 'repair']) != 0:
            continue      
        
        points_slab_hori = []
        points_slab_vert = []

        heal_slab_hori = []
        heal_slab_vert = []
        seambroken = []
        
        for shape in ann['shapes']:
            if shape['label'] != 'slab' or shape['shape_type'] != 'polygon':
                continue
            # 对于slab标签进行处理其points
            xy_slab = [tuple(p) for p in shape['points']]                
            slab_all_x = [p[0] for p in shape['points']]
            slab_all_y = [p[1] for p in shape['points']]

            max_x = max(slab_all_x)
            min_x = min(slab_all_x)
            max_y = max(slab_all_y)
            min_y = min(slab_all_y)

            # 对slab进行横/竖分类
            ratio = (max_x - min_x) / (max_y - min_y)
            if ratio > 1:
                points_slab_hori.append(xy_slab)
            else:
                points_slab_vert.append(xy_slab)
                    
        if len(points_slab_hori) != 1 or len(points_slab_vert) != 1:
            continue

        
        # 粘贴范围
        paste_range = find_paste_range(points_slab_hori[0], points_slab_vert[0])    
        
        # 记录bmp文件
        image = Image.open('{}.bmp'.format(name[:-5]))
        all_pasted.append((image, paste_range, ann))

    # 粘贴
    copy_image(dest_path, all_seam, all_pasted)

In [3]:
def find_paste_range(slab_hori, slab_vert):

    all_hori_x = [x for x, y in slab_hori]
    all_hori_y = [y for x, y in slab_hori]
    all_vert_x = [x for x, y in slab_vert]
    all_vert_y = [y for x, y in slab_vert]  
    
    hori_min_x = math.floor(min(all_hori_x))
    hori_min_y = math.floor(min(all_hori_y))
    hori_max_x = math.ceil(max(all_hori_x))
    hori_max_y = math.ceil(max(all_hori_y))
    
    vert_min_x = math.floor(min(all_vert_x))
    vert_min_y = math.floor(min(all_vert_y))
    vert_max_x = math.ceil(max(all_vert_x))
    vert_max_y = math.ceil(max(all_vert_y))
    
    paste_range = [((hori_min_x, hori_min_y), (hori_max_x, hori_max_y)), ((vert_min_x, vert_min_y), (vert_max_x, vert_max_y))]
    
    return paste_range

In [4]:
def seambro_external_frame(seambroken_points, slab_hori, slab_vert):
    # slab_hori：[[x1, y1], [x2, y2], ...]
    all_seam_x = [x for x, y in seambroken_points]
    all_seam_y = [y for x, y in seambroken_points]
    
    max_seam_x = max(all_seam_x)
    min_seam_x = min(all_seam_x)
    max_seam_y = max(all_seam_y)
    min_seam_y = min(all_seam_y)
    
    external_frame_points = ((min_seam_x, min_seam_y), 
                             (max_seam_x, min_seam_y),
                             (max_seam_x, max_seam_y),
                             (min_seam_x, max_seam_y))
    
    all_slab_hx = [x for x, y in slab_hori]
    all_slab_hy = [y for x, y in slab_hori]
    all_slab_vx = [x for x, y in slab_vert]
    all_slab_vy = [y for x, y in slab_vert]
                             
    min_slab_hy = min(all_slab_hy)
    max_slab_hy = max(all_slab_hy)
    min_slab_vx = min(all_slab_vx)
    max_slab_vx = max(all_slab_vx)
    
    global index
    if (max_seam_x - min_seam_x) / (max_seam_y - min_seam_y) > 1:
        if min_seam_y < min_slab_hy:
            index = 'up'
        if max_seam_y > max_slab_hy:
            index = 'down'
    else:
        if min_seam_x < min_slab_vx:
            index = 'left'
        if max_seam_x > max_slab_vx:
            index = 'right'
            
    domination_point = (min_seam_x, min_seam_y)
        
    return (index, domination_point, external_frame_points)  # 关键点和外接框四角坐标

In [5]:
def copy_image(dest_path, all_seam, all_pasted):
    #cut_original_image, cut_mask_image, index, domination_point, seambroken
    for image, paste_range, ann in all_pasted:
        pic_and_mask = random.choice(all_seam)
        pic, \
        mask, \
        index, \
        domination_point, \
        frame_left_top, \
        seambroken = pic_and_mask
        
        image_name = ann['imagePath']
        
        hori_range, vert_range = paste_range
        hori_left_top, hori_right_bottom = hori_range
        vert_left_top, vert_right_bottom = vert_range
        
        global paste_point
        try:
            if index == 'up':
                paste_point = (random.choice(list(range(hori_left_top[0], vert_left_top[0] - pic.width)) + \
                                             list(range(vert_right_bottom[0], hori_right_bottom[0] - pic.width))),
                               (hori_left_top[1] + hori_right_bottom[1]) // 2 - pic.height)

            if index == 'down':
                paste_point = (random.choice(list(range(hori_left_top[0], max(0, vert_left_top[0] - pic.width))) + \
                                             list(range(vert_right_bottom[0], max(0, hori_right_bottom[0] - pic.width)))),
                               (hori_left_top[1] + hori_right_bottom[1]) // 2)

            if index == 'left':
                paste_point = ((vert_left_top[0] + vert_right_bottom[0]) // 2 - pic.width ,
                               random.choice(list(range(max(0, vert_left_top[1]), max(0, hori_left_top[1] - pic.height))) + \
                                             list(range(max(0, vert_right_bottom[1]), max(0, vert_right_bottom[1] - pic.height)))))

            if index == 'right':
                paste_point = ((vert_left_top[0] + vert_right_bottom[0]) // 2,
                               random.choice(list(range(vert_left_top[1], max(0, hori_left_top[1] - pic.height))) + \
                                             list(range(hori_right_bottom[1], max(0, vert_right_bottom[1] - pic.height)))))
        except IndexError:
            pass 
                

        image.paste(pic, paste_point, mask=mask)

        image = image.convert('RGB')
    
#         # 新坐标值
        new_points = []
#         angel = rotate_times * (math.pi / 2)
        
        # seambroken标记点的坐标值变换 
        for x, y in seambroken:
#             x = x - frame_left_top[0]  # 归零
#             y = y - frame_left_top[1]
            
            new_x = x + paste_point[0]
            new_y = y + paste_point[1]
            
            new_points.append((new_x, new_y))
            
#         draw = ImageDraw.Draw(image)
#         draw.polygon(new_points, fill=(128, 0, 0))

        
        # 保存图片和对应的json文件
        save_image_json(dest_path, image, ann, new_points)
    

In [6]:
def save_image_json(dest_path, image, ann, new_points):
    # 存储图片与对应的json文件
    image.save(os.path.join(dest_path, ann['imagePath']))
    
    # 存json文件 
    json_info = {"label": "seambroken",
                          "line_color": None,
                          "fill_color": None,
                          "points": new_points ,
                          "shape_type": "polygon",
                          "flags": {}}      
    ann['shapes'].append(json_info)
    
    json_file_path = os.path.join(dest_path, '{}.json'.format(ann['imagePath'][:-4]))
    json_file = open(json_file_path, mode='w')
    json.dump(ann, json_file) 

In [7]:
oversampling_on_seambroken("D:/data/APD202004v2/train", "D:/data/temp")

100%|████████████████████████████████████████████████████████████████████████████| 2769/2769 [00:01<00:00, 1878.74it/s]
100%|████████████████████████████████████████████████████████████████████████████| 2769/2769 [00:00<00:00, 3852.74it/s]
