# Json to png mask
Provided by Prof. Lee

# Import Packages

In [None]:
import json
import numpy as np

import PIL.Image
import PIL.ImageDraw
from PIL import Image

import matplotlib.pyplot as plt
import matplotlib

import os  
import cv2

# Define shape_to_mask function

In [None]:
def shape_to_mask(
    img_shape, points, shape_type=None, line_width=10, point_size=5
):
    mask = np.zeros(img_shape[:2], dtype=np.uint8)
    mask = PIL.Image.fromarray(mask)  
    draw = PIL.ImageDraw.Draw(mask) 

    xy = [tuple(point) for point in points] 
    
    if shape_type == "circle":
        assert len(xy) == 2, "Shape of shape_type=circle must have 2 points"
        (cx, cy), (px, py) = xy
        d = math.sqrt((cx - px) ** 2 + (cy - py) ** 2)
        draw.ellipse([cx - d, cy - d, cx + d, cy + d], outline=1, fill=1) 
    elif shape_type == "rectangle":
        assert len(xy) == 2, "Shape of shape_type=rectangle must have 2 points"
        draw.rectangle(xy, outline=1, fill=1)
    elif shape_type == "line":
        assert len(xy) == 2, "Shape of shape_type=line must have 2 points"
        draw.line(xy=xy, fill=1, width=line_width)
    elif shape_type == "linestrip":
        draw.line(xy=xy, fill=1, width=line_width)
    elif shape_type == "point":
        assert len(xy) == 1, "Shape of shape_type=point must have 1 points"
        cx, cy = xy[0]
        r = point_size
        draw.ellipse([cx - r, cy - r, cx + r, cy + r], outline=1, fill=1)
    else:
        assert len(xy) > 2, "Polygon must have points more than 2" 
        draw.polygon(xy=xy, outline=1, fill=1)
    mask = np.array(mask, dtype=bool)
    return mask

# Set process folder

In [None]:
folder_path = "{YOUR PATH}"

# Create Visualize Function

In [None]:
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 16))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
    plt.show()

# Processing all json file in the folder

In [None]:
try: 
    os.makedirs(folder_path.replace("Train_Annotations", "Train_Annotations_png"))
except: 
    print('Folder has already exist')


for filename in os.listdir(folder_path):
    if 'json' in filename:
        # Read in all all the data from the CSV file       
        json_path = os.path.join(folder_path, filename)       
        
        write_msk_img_name = filename.replace("json","png")
        write_folder_path = folder_path.replace("Train_Annotations","Train_Annotations_png")

        #Read Json file
        with open(json_path, "r",encoding="utf-8") as f:
            dj = json.load(f)

        #Create an empty img as mask_img
        temp_mask_img = np.zeros([dj['imageHeight'], dj['imageWidth']],dtype=np.uint8)

        #Plot each mask into mask_img
        for i in range(len(dj['shapes'])):
            mask = shape_to_mask((dj['imageHeight'],dj['imageWidth']), dj['shapes'][i]['points'], shape_type=dj['shapes'][i]['shape_type'],line_width=1, point_size=1)            
            temp_mask_img = temp_mask_img + mask.astype(int) 
        temp_mask_img = (temp_mask_img>0).astype(int)
        
        print(f"temp_mask_img.range is {temp_mask_img.max()} to {temp_mask_img.min()}")

        # Save the file
        cv2.imwrite(write_folder_path + write_msk_img_name, temp_mask_img*255)