아래와 같이 class를 선언하고 실행할 시 tracking되는 이미지를 확인해볼 수 있습니다.

# import

In [None]:
from absl import flags
import sys
FLAGS = flags.FLAGS
FLAGS(sys.argv[:1])

import time # 프레임 당 시간 계산하기위해서 필요함
import numpy as np
import cv2
import matplotlib.pyplot as plt

import tensorflow as tf
from yolov3_tf2.models import YoloV3
from yolov3_tf2.dataset import transform_images
from yolov3_tf2.utils import convert_boxes

from deep_sort import preprocessing # NMS
from deep_sort import nn_matching
from deep_sort.detection import Detection
from deep_sort.tracker import Tracker
from tools import generate_detections as gdet # feature generation

# detection 모델 선언

In [None]:
class_names = [c.strip() for c in open('./data/labels/coco.names').readlines()] # [car, person, ... ]
yolo = YoloV3(classes=len(class_names))
yolo.load_weights('./weights/yolov3.tf')

# frame 단위로 video capture 

In [None]:
vid = cv2.VideoCapture('./data/video/test.mp4')

codec = cv2.VideoWriter_fourcc(*'XVID')
vid_fps = int(vid.get(cv2.CAP_PROP_FPS)) # CAP_PROP_FPS는 float을 반환해서 int로 바꿔줌 
vid_width, vid_height = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH)), int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT))
out = cv2.VideoWriter('./data/video/results.avi', codec, vid_fps, (vid_width, vid_height))

# Tracker

In [7]:
max_cosine_distance = 0.5 # 0.5보다 클 경우 유사하다는 의미
nn_budget = None
nms_max_overlap = 0.8

# class 선언
model_filename = 'model_data/mars-small128.pb'
encoder = gdet.create_box_encoder(model_filename, batch_size=1)
metric = nn_matching.NearestNeighborDistanceMetric('cosine', max_cosine_distance, nn_budget)
tracker = Tracker(metric)

# frame단위로 실행
while True:
    _, img = vid.read() # img : ndarray (height, width,channel) 한장씩 받아옴
    if img is None:
        print('Completed')
        break

    img_in = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_in = tf.expand_dims(img_in, 0) # detector에 넣어주기 위해서 (1,height, width, channel)로 만들어줌 >> tf.Tensor로 바뀜
    img_in = transform_images(img_in, 416)  # 1. image resize >> 416으로, 2. 픽셀값 255로 나눠줌. 

    t1 = time.time()
    
    '''
    numpy array로 받음
    bounding box 갯수 : 100개
    boxes, 3D shape (1, 100, 4) ; 4 : l,t,r,b
    scores, 2D shape (1, 100) : confidence score
    classes, 2D shape(1, 100) : box에 있는 물체의 class 번호
    nums, 1D shape(1,) : 감지된 물체의 총 수
    '''
    boxes, scores, classes, nums = yolo.predict(img_in)
    
    classes = classes[0]
    names = []
    for i in range(len(classes)):
        names.append(class_names[int(classes[i])])
    names = np.array(names)
    converted_boxes = convert_boxes(img, boxes[0]) # box shape : x_min, y_min, w, h로 바뀜
    features = encoder(img, converted_boxes)

    detections = [Detection(bbox, score, class_name, feature) for bbox, score, class_name, feature in
                  zip(converted_boxes, scores[0], names, features)]

    boxs = np.array([d.tlwh for d in detections]) # left, top, width, height
    scores = np.array([d.confidence for d in detections])
    classes = np.array([d.class_name for d in detections])
    indices = preprocessing.non_max_suppression(boxs, classes, nms_max_overlap, scores)
    detections = [detections[i] for i in indices]

    tracker.predict() # 칼만 필터로 예측
    tracker.update(detections)

    cmap = plt.get_cmap('tab20b') # 숫자를 색상에 매핑
    colors = [cmap(i)[:3] for i in np.linspace(0,1,20)] # 20개 색상 생성

    current_count = int(0)

    for track in tracker.tracks:
        if not track.is_confirmed() or track.time_since_update > 1:
            continue

        bbox = track.to_tlbr() # cv2 출력에 사용됨 min_x, min_y, max_x, max_y
        class_name = track.get_class()
        color = colors[int(track.track_id) % len(colors)]
        color = [i*255 for i in color]

        cv2.rectangle(img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]),int(bbox[3])), color, 2) # lt, rb
        cv2.rectangle(img, (int(bbox[0]), int(bbox[1]-30)), (int(bbox[0])+(len(class_name)
                    +len(str(track.track_id)))*17,int(bbox[1])), color, -1) # id box
        cv2.putText(img, class_name+"-"+str(track.track_id), (int(bbox[0]), int(bbox[1]-10)), 0, 0.75,
                    (255,255,255), 2)
    
    fps  = 1./(time.time()-t1)
    cv2.putText(img, "FPS: {:.2f}".format(fps), (0,30), 0, 1, (0,0,255), 2)
    cv2.namedWindow("output", cv2.WINDOW_NORMAL)
    cv2.resizeWindow('output', 1024, 768)
    cv2.imshow('output', img)
    out.write(img)

    if cv2.waitKey(1) == ord('q'): # 휴식
        break
vid.release()
out.release()
cv2.destroyAllWindows()




KeyboardInterrupt: 