In [1]:
import json
import os
import shutil
import random
import numpy as np
import math
import cv2 as cv
import matplotlib.pyplot as plt

from glob import glob
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm

In [2]:
def oversampling_on_cornerfracture(src_path, temp_path, dest_path):
    """    
    all_cor = [cut_original_image:剪切后的原始图像, 
    cut_mask_image:剪切后的掩模, 
    domination_num:支配点, 
    domination_point:支配点坐标,
    points_cor:剪切后的原始图象中conerfracture的标记点
    center_point:剪切后图像的中心点，用来求旋转后corner fracture标记点的坐标值]
    
    """
    all_cor = []

    for name in tqdm(glob(os.path.join(src_path, '*.json'))):
        with open(name) as fo:
            ann = json.load(fo)

        # 筛选json文件，记录slab与conerfracture的points
        if len([shape for shape in ann['shapes'] if shape['label'] == 'cornerfracture']) != 1:
            continue
        if len([shape for shape in ann['shapes'] if shape['label'] == 'slab']) != 2:
            continue

        slab_hori = []
        slab_vert = []
        cornerfrac = []
        
        for shape in ann['shapes']:
            if shape['shape_type'] != 'polygon':
                continue
            
            # slab的points
            if shape['label'] == 'slab':
                slab_x = [x for x, y in shape['points']]
                slab_y = [y for x, y in shape['points']]

                max_x = max(slab_x)
                min_x = min(slab_x)
                max_y = max(slab_y)
                min_y = min(slab_y)

                # 对slab进行横/竖分类
                ratio = (max_x - min_x) / (max_y - min_y)
                if ratio > 1:
                    slab_hori.append(zip(slab_x, slab_y))
                else:
                    slab_vert.append(zip(slab_x, slab_y))

            # cornerfracture的points   
            if shape['label'] == 'cornerfracture':
                cor_xy = [tuple(p) for p in shape['points']]
                cornerfrac.append(cor_xy)
        
        # 要求两条slab一横一竖
        if len(slab_hori) != 1 or len(slab_vert) != 1:
            continue
        
        # slab交点  (x, y)
        overlap_points = find_slab_intersaction(slab_hori[0], slab_vert[0])
    
        # center_point：外界框中心点
        # cor_points_4：cornerfracture外接框顶点坐标
        center_point, cor_points_4 = find_cor_external_frame(cornerfrac[0]) 
        
        # 判断cornerfracture是否在交点处，找到支撑点
        domination_point, domination_index = judgement_cor_location(overlap_points, cor_points_4)
        if domination_index == -1:
            continue
    
        # 读取该json文件对应的bmp文件，得到其mask，再将mask截为cor外接框大小
        image = Image.open('{}.bmp'.format(name[:-5]))
        
        mask = Image.new('1', image.size)
        mask_draw = ImageDraw.Draw(mask)
        mask_draw.polygon(cornerfrac[0], fill=1)
        mask_save_path = ann['imagePath']
        
        cut_box = (cor_points_4[0] + cor_points_4[2])
        cut_original_image = image.crop(cut_box)
        cut_mask_image = mask.crop(cut_box)

        cut_mask_image.save(os.path.join(temp_path, mask_save_path))
        cut_mask_image = cv.imread(os.path.join(temp_path, mask_save_path), cv.IMREAD_COLOR)
        os.remove(os.path.join(temp_path, mask_save_path))
        
        # mask 膨胀        
        kernel = np.ones((10,10),np.uint8)
        cut_mask_image = cv.dilate(cut_mask_image,kernel)
        cut_mask_image = Image.fromarray(cut_mask_image.astype('uint8')).convert('RGB')

        # 坐标归零
        left, top = cor_points_4[0]
        domination_point = (domination_point[0] - left, domination_point[1] - top)
        cornerfrac = [(x - left, y - top) for x, y in cornerfrac[0]]
        center_point = (center_point[0] - left, center_point[1] - top)
        
        all_cor.append((cut_original_image, 
                        cut_mask_image, 
                        domination_index, 
                        domination_point, 
                        cornerfrac, 
                        center_point, 
                        mask_save_path
                        ))

    # 筛选所有两条slab且没有cor的json
    """
    h_all_cor = [image:无cornerfracture病害的图像, 
    center_point：slab交点框的中心点,
    ann：image对应的json文件]
    
    """
    
    h_all_cor = []
    
    for name in tqdm(glob(os.path.join(src_path, '*.json'))):
        with open(name) as fo:
            ann = json.load(fo)

        # 筛选json文件，记录slab的points
        if len([shape for shape in ann['shapes'] if shape['label'] == 'cornerfracture']) != 0:
            continue
        if len([shape for shape in ann['shapes'] if shape['label'] == 'slab']) != 2:
            continue
        if len([shape for shape in ann['shapes'] if shape['label'] == 'patch']) != 0:
            continue
        if len([shape for shape in ann['shapes'] if shape['label'] == 'repair']) != 0:
            continue      
        
        points_slab_w = []
        points_slab_h = []
        healthy_slab_w_points = []
        healthy_slab_h_points = []
        
        for shape in ann['shapes']:
            if shape['label'] != 'slab' or shape['shape_type'] != 'polygon':
                continue
                
            # slab的points
            xy_slab = [tuple(p) for p in shape['points']]                
            slab_all_x = [p[0] for p in shape['points']]
            slab_all_y = [p[1] for p in shape['points']]

            max_x = max(slab_all_x)
            min_x = min(slab_all_x)
            max_y = max(slab_all_y)
            min_y = min(slab_all_y)

            # 对slab进行横/竖分类
            ratio = (max_x - min_x) / (max_y - min_y)
            if ratio > 1:
                points_slab_w.append(xy_slab)
            else:
                points_slab_h.append(xy_slab)
        
        # 要求slab一横一竖
        if len(points_slab_w) != 1 or len(points_slab_h) != 1:
            continue
        
        healthy_slab_w_points = [list(xy) for xy in points_slab_w[0]]
        healthy_slab_h_points = [list(xy) for xy in points_slab_h[0]]
        
        # slab交点框的中心点
        center_point = find_slab_intersaction(points_slab_w[0], points_slab_h[0])        
        
        # 记录bmp文件
        image = Image.open('{}.bmp'.format(name[:-5]))
        h_all_cor.append((image, center_point, ann))

    # 粘贴
    copy_image(temp_path, dest_path, all_cor, h_all_cor)

In [3]:
def find_slab_intersaction(slab_hori, slab_vert):
    # slab的交点框坐标
    all_hori_y = [y for x, y in slab_hori]
    all_vert_x = [x for x, y in slab_vert]

    hori_max_y = max(all_hori_y)
    hori_min_y = min(all_hori_y)
    vert_max_x = max(all_vert_x)
    vert_min_x = min(all_vert_x)

    center_intersaction_point = ((vert_max_x + vert_min_x) // 2, (hori_max_y + hori_min_y) // 2)
    
    return center_intersaction_point

In [4]:
def find_cor_external_frame(list_points_cor):
    all_x = [x for x, y in list_points_cor]
    all_y = [y for x, y in list_points_cor]
    
    max_x = max(all_x)
    min_x = min(all_x)
    max_y = max(all_y)
    min_y = min(all_y)
    
    # 中心点坐标，旋转时用到
    center_point = ((max_x + min_x) // 2, (max_y + min_y) // 2)
    
    external_frame_points = ((min_x, min_y), 
                             (max_x, min_y),
                             (max_x, max_y),
                             (min_x, max_y))
    
    return (center_point, external_frame_points)  # 中心点和外接框的四角

In [5]:
def judgement_cor_location(center_point, cor_points_4):
    # 通过cor外接框与slab交点框比对，找到支配点坐标以及其对应的标号index
    center_x, center_y = center_point
    
    domination_point = None
    square_root_MIN = 100
    
    index = -1
    for idx, (x, y) in enumerate(cor_points_4):
        square_root = ((center_x - x) ** 2 + (center_y - y) ** 2) ** 0.5
        if square_root < square_root_MIN:
            square_root_MIN = square_root
            domination_point = (x, y)
            index = idx

    return [domination_point, index]

In [6]:
def copy_image(temp_path, dest_path, all_cor, h_all_cor):
    for image, center_point, ann in h_all_cor:
        pic_and_mask = random.choice(all_cor)
        pic, \
        mask, \
        domination_index, \
        domination_point, \
        points_cor, \
        cor_center_point, \
        mask_save_path = pic_and_mask
        
        if pic.width < 100 or pic.height < 100:
            continue
        
        # ROTATE
        rotate_times = random.randint(0, 3)        
        
        pic_processed = pic.rotate(rotate_times * -90, expand = True)
        mask_processed = mask.rotate(rotate_times * -90, expand = True)
        domination_index = (domination_index + rotate_times) % 4

        temp_x, temp_y = (center_point[0] - pic_processed.width, center_point[1] - pic_processed.height)
    
        if domination_index == 0:
            center_intersaction_point = (int(center_point[0]) + mask_processed.width // 2,
                                         int(center_point[1]) + mask_processed.height // 2)
        if domination_index == 1:
            center_intersaction_point = (int(temp_x) + mask_processed.width // 2,
                                         int(center_point[1]) + mask_processed.height // 2)
        if domination_index == 2:
            center_intersaction_point = (int(temp_x) + mask_processed.width // 2,
                                        int(temp_y) + mask_processed.height // 2)
        if domination_index == 3:
            center_intersaction_point = (int(center_point[0]) + mask_processed.width // 2,
                                         int(temp_y) + mask_processed.height // 2)
        
        if center_intersaction_point[0] - mask_processed.width // 2 <= 0 \
        or center_intersaction_point[1] - mask_processed.height // 2 <= 0:
            continue
        if center_intersaction_point[0] + mask_processed.width // 2 >= image.width \
        or center_intersaction_point[1] - mask_processed.height // 2 <= 0:
            continue
        if center_intersaction_point[0] + mask_processed.width // 2 >= image.width \
        or center_intersaction_point[1] + mask_processed.height // 2 >= image.height:
            continue
        if center_intersaction_point[0] - mask_processed.width // 2 <= 0 \
        or center_intersaction_point[1] + mask_processed.height // 2 >= image.height:
            continue
        
        # write and read mask_processed, use cv read
        mask_processed.save(os.path.join(temp_path, mask_save_path))
        mask_processed = cv.imread(os.path.join(temp_path, mask_save_path), cv.IMREAD_COLOR)
        os.remove(os.path.join(temp_path, mask_save_path))
        
        pic_processed.save(os.path.join(temp_path, mask_save_path))
        pic_processed = cv.imread(os.path.join(temp_path, mask_save_path), cv.IMREAD_COLOR)
        os.remove(os.path.join(temp_path, mask_save_path))
        
        image.save(os.path.join(temp_path, mask_save_path))
        image = cv.imread(os.path.join(temp_path, mask_save_path), cv.IMREAD_COLOR)
        os.remove(os.path.join(temp_path, mask_save_path))
                
        image = cv.seamlessClone(pic_processed, image, mask_processed, center_intersaction_point, cv.NORMAL_CLONE)
        image = Image.fromarray(image.astype('uint8')).convert('RGB')
        pic_processed = Image.fromarray(pic_processed.astype('uint8')).convert('RGB')
        
        
#         粘贴点标记
#         x, y = center_intersaction_point[0], center_intersaction_point[1]
#         image_draw = ImageDraw.Draw(image)
#         font = ImageFont.truetype("consola.ttf", 40, encoding="unic") #设置字体
#         image_draw.text((x, y), 'p', 'fuchsia', font)
        
#         display(image)
        
        # 旋转后坐标值
        new_points_cor = []
        angel = rotate_times * (math.pi / 2)
        
        # cornerfracture标记点坐标值变换 
        for x, y in points_cor:
            x = x - cor_center_point[0]
            y = y - cor_center_point[1]
            
            new_x = x * math.cos(angel) - y * math.sin(angel)
            new_y = x * math.sin(angel) + y * math.cos(angel)
            final_x = new_x + center_intersaction_point[0] 
            final_y = new_y + center_intersaction_point[1] 
            
            final_x = max(0, final_x)
            final_x = min(final_x, image.width)
            final_y = max(0, final_y)
            final_y = min(final_y, image.height)
            
            new_points_cor.append((final_x, final_y))
        
        # 保存图片及对应的json文件
        save_image_json(dest_path, image, ann, new_points_cor)

In [7]:
def save_image_json(dest_path, image, ann, new_points_cor):
    image.save(os.path.join(dest_path, ann['imagePath']))
#     cv.imwrite(os.path.join(dest_path, ann['imagePath']), image)
    json_info = {"label": "cornerfracture",
                          "line_color": None,
                          "fill_color": None,
                          "points": new_points_cor ,
                          "shape_type": "polygon",
                          "flags": {}}      
    ann['shapes'].append(json_info)
    
    json_file_path = os.path.join(dest_path, '{}.json'.format(ann['imagePath'][:-4]))
    json_file = open(json_file_path, mode='w')
    json.dump(ann, json_file) 

In [8]:
oversampling_on_cornerfracture("D:/data/APD202004v2/train", "D:/data/temp_path/", "D:/data/new/new16/")

100%|████████████████████████████████████████████████████████████████████████████| 2769/2769 [00:01<00:00, 2417.20it/s]
100%|████████████████████████████████████████████████████████████████████████████| 2769/2769 [00:00<00:00, 4001.45it/s]
