In [1]:
from transformers import SegformerFeatureExtractor
feature_extractor = SegformerFeatureExtractor(reduce_labels=True)

from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
import torch

from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
import cv2
from imantics import Polygons, Mask
import json
model_label_dict = {"0": "obstacle", "1": "vegetation", "2": "void", "3": "truck", "4": "building", "5": "sidewalk", 
              "6": "crosswalk", "7": "road", "8": "pedestrian", "9": "stone", "10": "bicycle", "11": "vehicle"}

ori_label_name_dict = {18: 'obstacle', 20: 'vegetation', 22: 'void', 7: 'truck', 19: 'building', 13: 'sidewalk', 
 14: 'crosswalk', 12: 'road', 5: 'pedestrian', 17: 'stone', 3: 'bicycle', 1: 'vehicle'}

ori_label_name_dict_convert = {'obstacle': 18, 'vegetation': 20, 'void': 22, 'truck': 7, 'building': 19, 'sidewalk': 13, 
 'crosswalk': 14, 'road': 12, 'pedestrian': 5, 'stone': 17, 'bicycle': 3, 'vehicle': 1}

def ade_palette():
    return [[204, 5, 255], [4, 250, 7], [255, 173, 0], [255, 0, 20],
            [255, 184, 184], [0, 10, 255], [255, 5, 153], [180, 120, 120],
            [150, 5, 61], [220, 220, 220], [255, 245, 0], [0, 102, 200]]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = "../custom_checkpoint_2/model_345.pth"
model = torch.load(model_path)
model.to(device)
model.eval()     

save_path = './test_dataset_result_json'
save_dir = './test_dataset_result'
dataset_path = '../test_dataset'
fileEx = r'.jpg'
file_list = [file for file in os.listdir(dataset_path) if file.endswith(fileEx)]

## left,top,w,h
def getBboxFromPolygon(polygon):

    res = np.array(
        (
            (
                min(pts[0].min() for pts in polygon),
                min(pts[1].min() for pts in polygon),
            ),
            (
                max(pts[0].max() for pts in polygon),
                max(pts[1].max() for pts in polygon),
            ),
        )
    ) 
    return (res[0][0].item(),res[0][1].item(),res[1][0].item() - res[0][0].item(),res[1][1].item() - res[0][1].item())


for img_file in file_list : 


    img_full_path = f'{dataset_path}/{img_file}'
    
    img_save_path = f'{save_dir}/{img_file}' 

    ori_image = Image.open(img_full_path)


    encoding = feature_extractor(ori_image, return_tensors="pt")
    pixel_values = encoding.pixel_values.to(device)

    outputs = model(pixel_values)
    logits = outputs.logits.cpu()


    # First, rescale logits to original image size
    upsampled_logits = nn.functional.interpolate(logits,
                    size=ori_image.size[::-1], # (height, width)
                    mode='bilinear',
                    align_corners=False)

    # Second, apply argmax on the class dimension
    seg = upsampled_logits.argmax(dim=1)[0]


    pred = seg.numpy()
    uniques, counts = np.unique(pred, return_counts=True)
    print('uniques',uniques)
    
    polygon_total_dict = {}

    for category in uniques : 
        pred_copy = pred.copy()
        category = int(category) 

        label_name = model_label_dict[str(category)]
        label_category = ori_label_name_dict_convert[label_name]

        if label_category not in polygon_total_dict :
            polygon_total_dict[label_category] = []
        if category == 0 :
            pred_copy[pred_copy == 0] = 160
            pred_copy[pred_copy != 160] = 0
            pred_copy[pred_copy == 160] = 255

            plain_polygons = Mask(pred_copy).polygons()


            for polygon in plain_polygons.segmentation:

                if len(polygon) > 3:
                    polygon_total_dict[label_category].append(polygon)
        else :
            
            pred_copy[pred_copy != category] = 0
            pred_copy[pred_copy == category] = 255

            plain_polygons = Mask(pred_copy).polygons()

            for polygon in plain_polygons.segmentation:
                if len(polygon) > 3:
                    polygon_total_dict[label_category].append(polygon)
    

#     image = cv2.imread(img_full_path)
#     image_ori = cv2.imread(img_full_path)
    json_data = {"image_name" : img_file}
    seg_list = []

    for p_key in polygon_total_dict.keys():

        real_cate_number = p_key

        polygon_list = polygon_total_dict[p_key]

        for polygon in polygon_list : 
            points = np.array(polygon)
            points = points.reshape(int(len(polygon)/2),2)
#             color_np = np.asarray(np.random.choice(range(256), size=3), dtype=np.uint8)
#             color_fill = (color_np[0].item(),color_np[1].item(),color_np[2].item())

#             image = cv2.fillPoly(image, pts=[points], color=color_fill)
            xywh = getBboxFromPolygon(points)
#             if xywh[2] < 10 and xywh[3] < 10 :
#                 real_cate_number = 22

            box = {"left":xywh[0],"top":xywh[1],"width":xywh[2],"height":xywh[3]}
            ppoints = []
            
            if len(points) < 3 :
                continue
            
#             if len(points) < 9 : 
#                 skip_point = False
            
#             else : 
#                 skip_point = True

            for idx, point in enumerate(points):
                ppoints.append({"x":point[0].item(),"y":point[1].item()})
                
#                 if skip_point == False : 
#                     ppoints.append({"x":point[0].item(),"y":point[1].item()})
#                 else : 
#                     if idx ==0 or idx%3 == 0 : 
#                         ppoints.append({"x":point[0].item(),"y":point[1].item()}) 
#                     else : 
#                         pass

            seg_list.append({"box":box,"points":ppoints,"label_number":real_cate_number,"label_name":ori_label_name_dict[real_cate_number]})

    filename = img_file[:-4]         
    save_json_file_path = save_path + '/'  + filename + '.json'
    json_data["seg"] = seg_list
    with open(save_json_file_path, 'w', encoding='utf-8') as file:
        json.dump(json_data, file, indent="\t")
    print(f'SAVE COMPLETE {save_json_file_path}')
    print(f'____________________________________________________________')


# 폴리곤 결과 프린트
#     plt.figure(figsize=(15, 10))
#     plt.imshow(image)
#     plt.show()
    
    
##################################################################################    
# 모델 결과 프린트

#     color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) # height, width, 3
#     palette = np.array(ade_palette())
#     for label, color in enumerate(palette): 
#         color_seg[seg == label, :] = color
#         img = np.array(ori_image) * 0.5 + color_seg * 0.5

    
#     img = img.astype(np.uint8)
#     plt.figure(figsize=(15, 10))
#     plt.imshow(img)
#     plt.show()
#     plt.imshow(color_seg)
#     plt.show()







uniques [0 1 2 3 4 5 9]
SAVE COMPLETE ./test_dataset_result_json/SC_PAR_20220831_131925_cam02.json
____________________________________________________________
uniques [0 1 2 4 5 6 7 8 9]
SAVE COMPLETE ./test_dataset_result_json/SC_PAR_20220831_131740_cam04.json
____________________________________________________________
uniques [0 1 2 4 5 7 9]
SAVE COMPLETE ./test_dataset_result_json/SC_PAR_20220831_131910_cam02.json
____________________________________________________________
uniques [0 1 2 4 5 6 7 8 9]
SAVE COMPLETE ./test_dataset_result_json/SC_PAR_20220831_131745_cam04.json
____________________________________________________________
uniques [0 1 2 4 5 6 7]
SAVE COMPLETE ./test_dataset_result_json/SC_PAR_20220831_131840_cam01.json
____________________________________________________________
uniques [0 1 2 4 5 6 7 8 9]
SAVE COMPLETE ./test_dataset_result_json/SC_PAR_20220831_131825_cam01.json
____________________________________________________________
uniques [0 1 2 3 4 5 7 8]
SA

uniques [0 1 2 4 5 6 7]
SAVE COMPLETE ./test_dataset_result_json/SC_PAR_20220831_131755_cam01.json
____________________________________________________________
uniques [0 1 2 3 4 5 6 7 8 9]
SAVE COMPLETE ./test_dataset_result_json/SC_PAR_20220831_131745_cam03.json
____________________________________________________________
uniques [0 1 2 4 5 6 7 8 9]
SAVE COMPLETE ./test_dataset_result_json/SC_PAR_20220831_131915_cam04.json
____________________________________________________________
uniques [0 1 2 4 5 7 8]
SAVE COMPLETE ./test_dataset_result_json/SC_PAR_20220831_131920_cam04.json
____________________________________________________________
uniques [0 1 2 4 5 6 7 8 9]
SAVE COMPLETE ./test_dataset_result_json/SC_PAR_20220831_131830_cam04.json
____________________________________________________________
uniques [0 1 2 4 5 7 8 9]
SAVE COMPLETE ./test_dataset_result_json/SC_PAR_20220831_131840_cam04.json
____________________________________________________________
uniques [0 1 2 3 4 5 6 7

In [2]:
a = 'SC_PAR_20220831_131730_cam03.jpg'
b = a[:-4]
b

'SC_PAR_20220831_131730_cam03'

In [3]:
for i in range(10):
    if i%3 == 0 :
        print(i)
        print('d')

0
d
3
d
6
d
9
d
