In [12]:
#-----------------------------------------------------------------------#
#   predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能
#   整合到了一个py文件中，通过指定mode进行模式的修改。
#-----------------------------------------------------------------------#
import time
import nbimporter

import cv2
import numpy as np
from PIL import Image

from yolo import YOLO

if __name__ == "__main__":
    yolo = YOLO()
    #----------------------------------------------------------------------------------------------------------#
    #   mode用于指定测试的模式：
    #   'predict'           表示单张图片预测，如果想对预测过程进行修改，如保存图片，截取对象等，可以先看下方详细的注释
    #   'video'             表示视频检测，可调用摄像头或者视频进行检测，详情查看下方注释。
    #   'fps'               表示测试fps，使用的图片是img里面的street.jpg，详情查看下方注释。
    #   'dir_predict'       表示遍历文件夹进行检测并保存。默认遍历img文件夹，保存img_out文件夹，详情查看下方注释。
    #   'heatmap'           表示进行预测结果的热力图可视化，详情查看下方注释。
    #   'export_onnx'       表示将模型导出为onnx，需要pytorch1.7.1以上。
    #----------------------------------------------------------------------------------------------------------#
    mode = "video"
    #-------------------------------------------------------------------------#
    #   crop                指定了是否在单张图片预测后对目标进行截取
    #   count               指定了是否进行目标的计数
    #   crop、count仅在mode='predict'时有效
    #-------------------------------------------------------------------------#
    crop            = False
    count           = False
    #----------------------------------------------------------------------------------------------------------#
    #   video_path          用于指定视频的路径，当video_path=0时表示检测摄像头
    #                       想要检测视频，则设置如video_path = "xxx.mp4"即可，代表读取出根目录下的xxx.mp4文件。
    #   video_save_path     表示视频保存的路径，当video_save_path=""时表示不保存
    #                       想要保存视频，则设置如video_save_path = "yyy.mp4"即可，代表保存为根目录下的yyy.mp4文件。
    #   video_fps           用于保存的视频的fps
    #
    #   video_path、video_save_path和video_fps仅在mode='video'时有效
    #   保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
    #----------------------------------------------------------------------------------------------------------#
    video_path      = "vedio1.mp4"
    video_save_path = "video2.mp4"
    video_fps       = 25.0
    #----------------------------------------------------------------------------------------------------------#
    #   test_interval       用于指定测量fps的时候，图片检测的次数。理论上test_interval越大，fps越准确。
    #   fps_image_path      用于指定测试的fps图片
    #   
    #   test_interval和fps_image_path仅在mode='fps'有效
    #----------------------------------------------------------------------------------------------------------#
    test_interval   = 100
    fps_image_path  = "img/street.jpg"
    #-------------------------------------------------------------------------#
    #   dir_origin_path     指定了用于检测的图片的文件夹路径
    #   dir_save_path       指定了检测完图片的保存路径
    #   
    #   dir_origin_path和dir_save_path仅在mode='dir_predict'时有效
    #-------------------------------------------------------------------------#
    dir_origin_path = "img/"
    dir_save_path   = "img_out/"
    #-------------------------------------------------------------------------#
    #   heatmap_save_path   热力图的保存路径，默认保存在model_data下
    #   
    #   heatmap_save_path仅在mode='heatmap'有效
    #-------------------------------------------------------------------------#
    heatmap_save_path = "model_data/heatmap_vision.png"
    #-------------------------------------------------------------------------#
    #   simplify            使用Simplify onnx
    #   onnx_save_path      指定了onnx的保存路径
    #-------------------------------------------------------------------------#
    simplify        = True
    onnx_save_path  = "model_data/models.onnx"

    if mode == "predict":
        '''
        1、如果想要进行检测完的图片的保存，利用r_image.save("img.jpg")即可保存，直接在predict.py里进行修改即可。 
        2、如果想要获得预测框的坐标，可以进入yolo.detect_image函数，在绘图部分读取top，left，bottom，right这四个值。
        3、如果想要利用预测框截取下目标，可以进入yolo.detect_image函数，在绘图部分利用获取到的top，left，bottom，right这四个值
        在原图上利用矩阵的方式进行截取。
        4、如果想要在预测图上写额外的字，比如检测到的特定目标的数量，可以进入yolo.detect_image函数，在绘图部分对predicted_class进行判断，
        比如判断if predicted_class == 'car': 即可判断当前目标是否为车，然后记录数量即可。利用draw.text即可写字。
        '''
        while True:
            img = input('Input image filename:')
            try:
                image = Image.open(img)
            except:
                print('Open Error! Try again!')
                continue
            else:
                r_image = yolo.detect_image(image, crop = crop, count=count)
                display(r_image)


    elif mode == "video":
        capture = cv2.VideoCapture(video_path)

        # 如果需要保存视频
        if video_save_path != "":
            fourcc = cv2.VideoWriter_fourcc(*'XVID')
            size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
            out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)

        fps = 0.0

        while True:
            t1 = time.time()

            # 读取某一帧
            ref, frame = capture.read()
            if not ref:
                print("未能正确读取摄像头，请检查摄像头是否正确连接。")
                break

            # 格式转变，BGR to RGB
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

            # 转变成 Image 对象
            frame = Image.fromarray(np.uint8(frame))

            # 进行检测
            frame = np.array(yolo.detect_image(frame))

            # RGB to BGR 满足 OpenCV 显示格式
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

            fps = (fps + (1. / (time.time() - t1))) / 2
            print("fps= %.2f" % (fps))

            # 在帧上绘制 FPS
            frame = cv2.putText(frame, "fps= %.2f" % (fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

            # 显示检测到的帧
            cv2.imshow("video", frame)

            k= cv2.waitKey(1) & 0xff 
            if cv2.getWindowProperty("video", cv2.WND_PROP_VISIBLE) < 1:
                break
            if k == 27:
                break


            # 保存视频
            if video_save_path != "":
                out.write(frame)

        # 释放资源
        capture.release()
        cv2.destroyAllWindows()

        # 如果保存视频，释放视频写入器资源
        if video_save_path != "":
            out.release()

#         capture = cv2.VideoCapture(video_path)
#         if video_save_path!="":
#             fourcc  = cv2.VideoWriter_fourcc(*'XVID')
#             size    = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
#             out     = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)

#         ref, frame = capture.read()
#         if not ref:
#             raise ValueError("未能正确读取摄像头（视频），请注意是否正确安装摄像头（是否正确填写视频路径）。")

#         fps = 0.0
#         while(True):
#             t1 = time.time()
#             # 读取某一帧
#             ref, frame = capture.read()
#             if not ref:
#                 break
#             # 格式转变，BGRtoRGB
#             frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
#             # 转变成Image
#             frame = Image.fromarray(np.uint8(frame))
#             # 进行检测
#             frame = np.array(yolo.detect_image(frame))
#             # RGBtoBGR满足opencv显示格式
#             frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
            
#             fps  = ( fps + (1./(time.time()-t1)) ) / 2
#             print("fps= %.2f"%(fps))
#             frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
            
#             cv2.imshow("video",frame)
#             c= cv2.waitKey(1) & 0xff 
#             if video_save_path!="":
#                 out.write(frame)

#             if c==27:
#                 capture.release()
#                 break

#         print("Video Detection Done!")
#         capture.release()
#         if video_save_path!="":
#             print("Save processed video to the path :" + video_save_path)
#             out.release()
#         cv2.destroyAllWindows()

        
    elif mode == "fps":
        img = Image.open(fps_image_path)
        tact_time = yolo.get_FPS(img, test_interval)
        print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')

    elif mode == "dir_predict":
        import os

        from tqdm import tqdm

        img_names = os.listdir(dir_origin_path)
        for img_name in tqdm(img_names):
            if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
                image_path  = os.path.join(dir_origin_path, img_name)
                image       = Image.open(image_path)
                r_image     = yolo.detect_image(image)
                if not os.path.exists(dir_save_path):
                    os.makedirs(dir_save_path)
                r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)

    elif mode == "heatmap":
        while True:
            img = input('Input image filename:')
            try:
                image = Image.open(img)
            except:
                print('Open Error! Try again!')
                continue
            else:
                yolo.detect_heatmap(image, heatmap_save_path)
                
    elif mode == "export_onnx":
        yolo.convert_to_onnx(simplify, onnx_save_path)
        
    else:
        raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.")


initialize network with normal type
Fusing layers... 
logs_CA_SEA_backbone/best_epoch_weights.pth model, and classes loaded.
Configurations:
----------------------------------------------------------------------
|                     keys |                                   values|
----------------------------------------------------------------------
|               model_path | logs_CA_SEA_backbone/best_epoch_weights.pth|
|             classes_path |              model_data/fire_classes.txt|
|              input_shape |                               [640, 640]|
|                      phi |                                        s|
|               confidence |                                      0.5|
|                  nms_iou |                                      0.3|
|          letterbox_image |                                     True|
|                     cuda |                                    False|
----------------------------------------------------------------------
b'fi

b'fire 0.64' 770 96 865 240
b'fire 0.59' 714 1696 831 1848
b'fire 0.57' 769 1231 808 1274
fps= 3.19
b'fire 0.64' 770 96 865 240
b'fire 0.57' 769 1232 808 1274
b'fire 0.57' 713 1696 831 1849
fps= 3.18
b'fire 0.65' 770 96 865 240
b'fire 0.57' 769 1232 807 1274
b'fire 0.55' 715 1697 831 1850
b'fire 0.52' 680 972 757 1143
fps= 3.17
b'fire 0.65' 770 96 865 240
b'fire 0.58' 769 1232 808 1274
b'fire 0.55' 715 1697 831 1851
b'fire 0.51' 679 972 758 1148
fps= 3.09
b'fire 0.64' 769 96 865 240
b'fire 0.58' 769 1232 808 1274
b'fire 0.55' 714 1697 831 1851
b'fire 0.51' 679 972 758 1149
fps= 3.13
b'fire 0.64' 769 96 865 240
b'fire 0.58' 713 1697 831 1851
b'fire 0.58' 768 1232 807 1274
b'fire 0.52' 677 973 758 1151
fps= 3.12
b'fire 0.64' 769 96 865 240
b'fire 0.59' 713 1697 831 1851
b'fire 0.58' 769 1233 807 1274
b'fire 0.52' 677 973 758 1151
fps= 3.09
b'fire 0.64' 769 96 865 240
b'fire 0.58' 713 1698 830 1851
b'fire 0.58' 769 1232 807 1274
b'fire 0.52' 676 973 758 1151
fps= 3.11
b'fire 0.64' 769 96 

b'fire 0.65' 764 93 866 239
b'fire 0.61' 432 1261 627 1488
b'fire 0.58' 733 1700 831 1847
b'fire 0.52' 765 1229 808 1274
b'smoke 0.53' 10 5 837 1905
fps= 3.06
b'fire 0.64' 763 92 866 239
b'fire 0.64' 436 1264 626 1489
b'fire 0.58' 747 1702 832 1847
b'fire 0.52' 764 1229 808 1274
fps= 3.09
b'fire 0.65' 438 1262 625 1488
b'fire 0.65' 764 92 866 239
b'fire 0.59' 738 1701 829 1845
b'fire 0.52' 763 1229 808 1274
fps= 3.15
b'fire 0.66' 445 1261 626 1490
b'fire 0.65' 765 92 866 239
b'fire 0.59' 748 1702 832 1847
b'fire 0.54' 763 1229 808 1273
fps= 3.11
b'fire 0.65' 765 92 866 239
b'fire 0.64' 440 1251 626 1489
b'fire 0.60' 750 1701 832 1846
b'fire 0.54' 763 1229 808 1273
fps= 3.07
b'fire 0.65' 765 92 866 239
b'fire 0.63' 441 1252 626 1489
b'fire 0.59' 750 1701 832 1846
b'fire 0.54' 763 1229 808 1273
fps= 3.10
b'fire 0.67' 439 1247 625 1489
b'fire 0.65' 765 92 867 239
b'fire 0.58' 750 1701 832 1846
b'fire 0.53' 763 1229 808 1273
fps= 3.05
b'fire 0.66' 765 92 867 239
b'fire 0.60' 433 1240 625 1

b'fire 0.70' 680 977 754 1120
b'fire 0.52' 771 95 867 238
b'fire 0.51' 750 1229 807 1277
fps= 3.04
b'fire 0.68' 680 977 754 1120
b'fire 0.53' 771 95 867 238
fps= 3.08
b'fire 0.66' 679 977 754 1120
b'fire 0.52' 771 95 867 237
fps= 3.24
b'fire 0.66' 679 977 754 1120
b'fire 0.52' 771 95 868 237
fps= 3.19
b'fire 0.66' 679 977 754 1120
b'fire 0.52' 771 95 867 237
fps= 3.14
b'fire 0.61' 682 976 755 1120
b'fire 0.54' 772 95 867 237
fps= 3.25
b'fire 0.62' 681 976 756 1119
b'fire 0.53' 770 94 867 236
fps= 3.29
b'fire 0.62' 680 976 756 1120
b'fire 0.54' 770 94 867 236
fps= 3.35
b'fire 0.63' 680 976 756 1120
b'fire 0.54' 770 94 867 236
fps= 3.35
b'fire 0.63' 680 976 756 1119
b'fire 0.57' 769 93 867 236
fps= 3.17
b'fire 0.63' 680 976 755 1119
b'fire 0.58' 769 94 867 236
b'fire 0.51' 749 1701 829 1842
fps= 3.21
b'fire 0.62' 680 976 755 1119
b'fire 0.58' 769 94 867 236
b'fire 0.51' 749 1701 829 1842
fps= 3.20
b'fire 0.60' 771 95 867 236
b'fire 0.58' 681 976 756 1119
fps= 3.08
b'fire 0.59' 770 95 867

# 