Create json file of inferences, skipping frames when inference time is greater than frame rate

In [None]:
import custom_funcs as cf
import matplotlib.pyplot as plt
import os, torch, cv2, shutil, json, csv
import numpy as np
import pandas as pd
from utils.metrics import bbox_iou
import utils.general as gen
from matplotlib.lines import Line2D

# our goal here is to pars validated_results.json into a bea
dhufish_path ='/home/field/jennaDir/ccom_yolov5'
vm_path = '/home/jennaehnot/Desktop/ccom_yolov5'

yolo_dir = dhufish_path
weights_path= os.path.join(yolo_dir, 'model_testing/model4_3_best.pt')
model='model4_3'

# path to annotated videos like shown above
annotated_dir = '/home/field/jennaDir/dive_tools/annotated_vids'
buoy_vids = ['mooringBall_115cmH', 'lobsterPot_140cmH','redBall_150cmH']
plot_titles = ['White Mooring Ball', 'Lobster Pot', 'Red Mooring Ball'] # for plot titles while saving
plot_save_name = '_conf_by_dist_by_imgsz.png' # will be added on to buoy_vid name

#img sizes to run inference at
fps = 10
fpms = fps/1000 # frames per ms
mspf = 1/fpms #ms per frame
img_sz = [640, 960,1280, 1600, 1920]
images = 'images'
labels = 'labels'

inference = model + '_inference'

for video in buoy_vids:
    inf_dir = os.path.join(annotated_dir, video, inference)

    for sz in img_sz:

        inf_json = os.path.join(inf_dir,  'imgsz' + str(sz), 'inference_results.json')
        val_json = os.path.join(inf_dir, 'imgsz' + str(sz), 'validated_results.json' )  
        
        with open(inf_json, 'r') as jsonfile:
            inf_data = json.load(jsonfile)
        with open(val_json, 'r') as jsonfile:
            val_data = json.load(jsonfile)
        keys =[]
        for key in inf_data:
            keys.append(key)

        last_frame = len(inf_data)
        idx = 0
        decimated = {}

        while idx < last_frame:
            frame = keys[idx]
            decimated[frame] = inf_data[frame]
            inf = cf.Inference(inf_data[frame])
            tot_t = inf.times[3]

            if tot_t >= mspf: # if total inf time in ms is longer than it takes to p record one frame
                frames_missed = np.ceil(tot_t * fpms) # round up number of frames
                idx += int(frames_missed)
            else:
                idx += 1

        print(video + str(sz) + ' completed')
        save_json_path = os.path.join(inf_dir,  'imgsz' + str(sz) , 'decimated_results.json')
        with open(save_json_path,'w') as json_file:
            json.dump(decimated,json_file)


Now validate decimated detections and plot!

In [None]:
for q in range(0,len(buoy_vids)):
    video = buoy_vids[q]
    inf_dir = os.path.join(annotated_dir, video, inference)
    csvpath = os.path.join(annotated_dir, video, 'frames_dist_corr.csv')
    plot_save_path = os.path.join(inf_dir, video + plot_save_name)
    frames, dists = cf.frames2distances(csvpath) # get the interpolated dist for each frame num
    figure, axis = plt.subplots(len(img_sz), 1, figsize=(12,10), sharex=True, sharey=True)

    for i in range(0,len(img_sz)):
        # load decimated json
        dec_json_path = os.path.join(inf_dir,  'imgsz' + str(img_sz[i]) , 'decimated_results.json')

        with open(dec_json_path,'r') as json_file:
            dec_data = json.load(json_file)

        # load val json
        validated_path = os.path.join(inf_dir, 'imgsz' + str(img_sz[i]), 'validated_results.json' )
        with open(validated_path, 'r') as jsonfile:
                val_data = json.load(jsonfile) 
                
        # check if val result exists for each decimated result, plot for each frame
        conf = np.empty((0,len(frames)))
        times = np.empty((0,len(frames)))
        clss = np.empty((0,len(frames)))

        for j in frames:
            #make file name
            img_name = 'frame_' + str(f"{j:05d}") + '.png'
            # check if inf exists in dec
            try:
                # if either of these don't exist, Keyerror will be thrown
                det = dec_data[img_name]
                val_det = val_data[img_name]['Detections']
                c = val_det[0][4]
                conf = np.append(conf, c)

                id = val_det[0][5]
                clss = np.append(clss, id)

                t = val_data[img_name]['Time Stats']
                total_t = t[3]
                times = np.append(times,total_t)

            except KeyError:
                #print(f"{img_name} did not have a detection in it")
                conf = np.append(conf, np.nan)
                times = np.append(times, np.nan)
                clss = np.append(clss, -1)

        color_map = {0: 'limegreen', 1: 'magenta', 2: 'royalblue', -1:'white'}
        colors = np.array([color_map[val] for val in clss])

        avg_t = np.nanmean(times)
        axis[i].scatter(dists,conf,s=2, c=colors)
        axis[i].grid(True, which='both', linestyle='--', color='gray', linewidth=0.5,alpha=0.5)
        axis[i].invert_xaxis()
        axis[i].set_ylim(0,1.0)
        axis[i].set_xlim(100,4)
        axis[i].set_yticks([0, 0.25, 0.5, 0.75, 1])
        axis[i].set_xticks(np.linspace(100,5,20))
        axis[i].set_title(f"Input Size = {img_sz[i]}," + r'  $\bar{t} = $'+ f"{avg_t:.2f} ms", fontsize = 10) 
        #axis[i].text(0.01, 0.95, f'Avg. t = {avg_t:.2f} ms', transform= axis[i].transAxes, fontsize=12, verticalalignment='top', horizontalalignment='left')

    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='limegreen', markersize=10, label='0: navBuoy'),
        Line2D([0], [0], marker='o', color='w', markerfacecolor='magenta', markersize=10, label='1: buoy')]
        
    #     ,
    #     Line2D([0], [0], marker='o', color='w', markerfacecolor='royalblue', markersize=10, label='2: fishingBuoy')
    # ]

    # Add the legend to the plot
    title = plot_titles[q]
    figure.suptitle(f'Detection of {title} at Different Compression Sizes', fontsize=18, y=0.96)
    figure.text(0.5, 0.03, 'Buoy Distance from Camera', ha='center', va='center', fontsize=18)
    figure.text(0.03, 0.5, 'Confidence', ha='center', va='center', rotation='vertical', fontsize=18)
    figure.legend(handles=legend_elements, loc='center', ncol=3, bbox_to_anchor=(0.5, 0.91))
    plt.tight_layout(rect=[0.04, 0.04, 0.95, 0.95])  
    
    figure.savefig(plot_save_path, dpi=400)


