In [None]:
# 利用xml檔測試
import cv2
import json
import os
import numpy as np
import xml.etree.ElementTree as ET
from glob import glob
from segment_anything import sam_model_registry, SamPredictor
import time

def write_json(image, filename_element, all_contours, save_path):
    d = {
        "version": "4.2.9",
        "flags": {},
        "shapes": [],
        "imagePath":filename_element ,
        "imageData": None,
        "imageHeight": image.shape[0],
        "imageWidth": image.shape[1]
        }
    shape = []
    points = []
    for contours in all_contours:
        for contour in contours:
            la = {"label": "NG", "text": "", "points":None , "group_id": None, "shape_type": "polygon", "flags": {}}
            points = contour.reshape((-1, 2)).tolist()
            if len(points) > 2:
                if cv2.contourArea(contour)>8:
                    la["points"] = points
                    shape.append(la)
    d["shapes"] = shape
    ph_j = os.path.join(save_path , filename_element.replace('jpg','json'))
    with open(ph_j , "w") as f:
        json.dump(d, f)

def show_mask(masks, image, input_point, boxes ,txt, output_dir,random_color=False):
    '''Plot masks on the image'''
    for i, mask in enumerate(masks):
        if random_color:
            colors = np.random.random(3)*250
            color = np.array([int(k) for k in colors],dtype=np.uint8) 
        else:
            color = np.array([200, 0, 0],dtype=np.uint8)
        
        h, w = mask.shape[0], mask.shape[1]
        mask_image = mask.reshape(h, w, 1).astype(np.uint8) * color.reshape(1, 1, -1)
        mask =  np.uint8(masks[0])
        # ret, mask0  = cv2.threshold(mask, 0, 1, cv2.THRESH_BINARY_INV)
        # image = cv2.bitwise_and(image, image, mask = mask0)
        image = cv2.addWeighted(image, 0.7, mask_image, 0.4, 30)
        if boxes is not None:
            for box in boxes:
                box0 = (int(box[0]), int(box[1]))
                box1 = (int(box[2]), int(box[3]))
                image = cv2.rectangle(image, box0, box1,(100, 50, 0), 1)
        if input_point is not None:
            if input_point.any():
                point = (int(input_point[0]),int(input_point[1]))
                image = cv2.circle(image, point, 3, (100, 50, 0),-1)
        if txt:
            for i in range(len(boxes)):
                text = 'IOUs:{:.3f}'.format(txt[i])
                xx = int(boxes[i][0])-5
                yy =  int(boxes[i][1])-5
                cv2.putText(image, text, (xx, yy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 1)
        cv2.imwrite(output_dir, image)

def crop_img(image, center_x, center_y, crop_size = 800):
    height, width = image.shape[0],image.shape[1]
    x1 = max(0, center_x - crop_size // 2)
    y1 = max(0, center_y - crop_size // 2)
    x2 = min(width, center_x + crop_size // 2)
    y2 = min(height, center_y + crop_size // 2)
    if (center_x - crop_size // 2) < 0:
        x2 = min(width, x2+ abs(center_x - crop_size // 2))
    if (center_y - crop_size // 2) < 0:
        y2 = min(height, y2+ abs(center_y - crop_size // 2))
    if x2 == width:
        x1 = max(0, x2-crop_size // 2)
    if y2 == height:
        y1 = max(0, y2-crop_size // 2)
    cropped_image = image[y1:y2, x1:x2]
    return cropped_image

sam_checkpoint = "sam_pb4.pth"
model_type = "vit_l"
device = "cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device)
predictor = SamPredictor(sam)

xml_path = glob('D:\\yangu\\dataset\\wise\\TL7903BD_Final_ATST_new\\*\\*\\*.xml')
for path in xml_path:
    tree = ET.parse(path)
    root = tree.getroot()
    filename_element = root.find('filename')
    pic_basepath = os.path.dirname(path)
    img_path = os.path.join(pic_basepath, filename_element.text)
    if os.path.isfile(img_path):
        image = cv2.imread(img_path)
        predictor.set_image(image)
        IOUs = []
        boxes = []
        maskes = []
        all_contours = []
        for obj in root.findall('object'):
            bndbox = obj.find('bndbox')
            xmin = int(bndbox.find('xmin').text)
            xmax = int(bndbox.find('xmax').text)
            ymin = int(bndbox.find('ymin').text)
            ymax = int(bndbox.find('ymax').text)
            box = [xmin, ymin, xmax, ymax]
            point = np.array([[int((xmin+xmax)/2), int((ymin+ymax)/2)]])
            masks, a, _ = predictor.predict(
                            point_coords = point,
                            point_labels = [1],
                            box = np.array(box),
                            multimask_output = False)
            mask = np.uint8(masks[0])
            contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            if a[0]>0.8:
                boxes.append([xmin, ymin, xmax, ymax])
                maskes.append(mask)
                IOUs.append(round(a[0],3))
                all_contours.append(contours)
            
        if len(maskes) > 0:
            # gdsam_path = os.path.join(pic_basepath, 'Pic_gdsam')
            # if not os.path.isdir(gdsam_path):
            #     os.makedirs(gdsam_path)
            # out_name = os.path.join(gdsam_path, filename_element.text.replace('xml', 'png'))
            # show_mask(maskes, image, None, boxes, IOUs, out_name, random_color = False)
            # Json_gdsam_path = os.path.join(pic_basepath, 'Json_gdsam')
            Json_gdsam_path =  os.path.dirname(img_path)
            # if not os.path.isdir(Json_gdsam_path):
            #     os.makedirs(Json_gdsam_path)
            write_json(image, filename_element.text, all_contours, Json_gdsam_path)