## Step3: Read and process the Inference record, then visualize
process

Input: the inference TFrecord, detection score threshold

Output: the groups of anomaly frames(both filtered and not filtered), stored in .csv files 

visualize

input: the original folder of images(same as step1)

output: labeled images grouped in destination folder


## Imports

In [1]:
from tkinter import *
import tkinter.filedialog
import tkinter.messagebox
import subprocess,time
import threading

import tensorflow as tf
import json
import numpy as np
from google.protobuf.json_format import MessageToJson
from utils import visualization_utils as vis_util
from utils import label_map_util
from PIL import Image
import base64,os,csv,shutil

## Variables

Here you can set the default values to save time

remember to replace the 'not selected' strings in GUI at the bottom

In [2]:
detection_tfrecord=''
processed_destination=''
processed_prefix=''
image_dir=''
group_choice=''
process_done=False
NUM_CLASSES = 1
threshold = 0.9

PATH_TO_LABELS = os.path.join('data', 'anomaly_label_map.pbtxt')
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)

label_list =[]
positive_list=[]
filtered_list=[]
positive_flags=[]
filtered_flags=[]
group_list =[]
highest_scores = []

## Functions

In [3]:
def display(string):
     Gui.t.insert(END,string +'\n')

def askdestination():
    global processed_destination
    processed_destination = tkinter.filedialog.askdirectory()
    Gui.entry_3.configure(text=processed_destination)
    
def askimagedir():
    global image_dir
    image_dir = tkinter.filedialog.askdirectory()
    Gui.entry_5.configure(text=image_dir)

def askinput():
    global detection_tfrecord
    detection_tfrecord = tkinter.filedialog.askopenfilename()
    Gui.entry_1.configure(text=detection_tfrecord)

def process():
    display('Start Processing...')
    
    #indentify positive and filtered frames
    global label_list
    global positive_list
    global filtered_list
    global positive_flags
    global filtered_flags
    global highest_scores
    for example in tf.python_io.tf_record_iterator(detection_tfrecord):
        ##reading the result of this image 
        result = tf.train.Example.FromString(example);
        json_str = MessageToJson(result)
        dic = json.loads(json_str)['features']['feature']
        scores = dic['image/detection/score']['floatList']['value']
        label = dic['label']['bytesList']['value']
        label = base64.b64decode(label[0]).decode()
        ymax = dic['image/detection/bbox/ymax']['floatList']['value']
        ymin = dic['image/detection/bbox/ymin']['floatList']['value']
        xmax = dic['image/detection/bbox/xmax']['floatList']['value']
        xmin = dic['image/detection/bbox/xmin']['floatList']['value']
        label_list.append(label)
        positive_flags.append(0)
        filtered_flags.append(0)
        
        highest_score = max(scores)
        highest_scores.append(highest_score)
        if highest_score > threshold:
            #unfiltered
            positive_list.append(label)
            positive_flags[-1]=1
            #filtering
            for i in range(len(scores)):
                if scores[i]>threshold:
                    width = xmax[i]-xmin[i];
                    height = ymax[i]-ymin[i];
                    ratio = width/height
                    centroid_x = (xmax[i]+xmin[i])/2
                    centroid_y = (ymax[i]+ymin[i])/2
                    #rules
                    if width<0.5 and width>0.1 and height<0.5 and height>0.1 and ratio>0.8 and ratio<2 and centroid_y>0.3 and centroid_x>0.3:
                        filtered_list.append(label)
                        filtered_flags[-1]=1
                        break;
    line_positive = 'Number of positive frames: '+str(len(positive_list))
    line_filtered = 'Number of filtered frames: '+str(len(filtered_list))
    display(line_positive)
    display(line_filtered)
    
    #Group up the frames
    flag_lists=[positive_flags,filtered_flags]
    result_lists=[[],[]]
    for idx,flags in enumerate(flag_lists):
        status = 0   #0 means normal, 1 means anomaly
        zero_count = 0
        one_count = 0
        start_index = 0
        end_index = 0
        buffer = 3   #back to normal if 3 consecutive frames are normal
        minimum = 3  #group filtered out if status abnormal in only less than 3 frames
        for i in range(len(flags)):
            if status ==0 and flags[i] == 1:  #start counting
                status =1
                zero_count = 0
                one_count = 1
                start_index = i
            elif status==1 and flags[i]==0:  #check exit or not
                zero_count+=1
                if status==1 and zero_count>=buffer: #buffer reached
                    status = 0
                    end_index = i-buffer
                    length = (end_index-start_index+1) 
                    if length >= minimum: #longer than minimum, so not filtered
                        density=one_count/(end_index-start_index+1)
                        score_sum = sum(highest_scores[start_index:end_index])
                        result_lists[idx].append([label_list[start_index],label_list[end_index],one_count,density,score_sum])
            elif status==1 and flags[i]==1:  #continue counting
                one_count+=1
                zero_count=0
    
    #output results to .csv files
    display('Exporting processed results...')
    global positive_groups
    global filtered_groups
    positive_groups = result_lists[0]
    filtered_groups = result_lists[1]
    #rank by ratio of anomaly in the group, sum of highest scores, and number of anomaly frames
    ranked_positive_groups = sorted(positive_groups,key = lambda x: [float(x[3]),float(x[4]),int(x[2])],reverse=True)
    ranked_filtered_groups = sorted(filtered_groups,key = lambda x: [float(x[3]),float(x[4]),int(x[2])],reverse=True)
    positive_csv = os.path.join(processed_destination,processed_prefix+'_positive.csv')
    filtered_csv = os.path.join(processed_destination,processed_prefix+'_filtered.csv')
    positive_csv_ranked = os.path.join(processed_destination,processed_prefix+'_positive_ranked.csv')
    filtered_csv_ranked = os.path.join(processed_destination,processed_prefix+'_filtered_ranked.csv')
    fieldnames=['start frame','end frame','number of anomaly frames','anomaly frames ratio','sum of highest scores']
    with open(positive_csv, "w") as output:
        writer = csv.writer(output, lineterminator='\n')
        writer.writerow(fieldnames)
        writer.writerows(positive_groups)
    with open(filtered_csv, "w") as output:
        writer = csv.writer(output, lineterminator='\n')
        writer.writerow(fieldnames)
        writer.writerows(filtered_groups)
    with open(positive_csv_ranked, "w") as output:
        writer = csv.writer(output, lineterminator='\n')
        writer.writerow(fieldnames)
        writer.writerows(ranked_positive_groups)
    with open(filtered_csv_ranked, "w") as output:
        writer = csv.writer(output, lineterminator='\n')
        writer.writerow(fieldnames)
        writer.writerows(ranked_filtered_groups)
    global process_done
    process_done=True
    display('All Processed results exported to '+processed_destination)
    
    
def process_button():
    Gui.t.delete(1.0,END)
    global processed_prefix
    global threshold
    processed_prefix=Gui.entry_4.get()
    threshold=float(Gui.entry_2.get())
    if detection_tfrecord==''or processed_destination=='':
        tkinter.messagebox.showwarning('Alert','Please select source and destination directory.')
    elif processed_prefix=='':
        tkinter.messagebox.showwarning('Alert','Please type in prefix for outputs.')
    else:
        threading.Thread(target=process).start()

def load_image_into_numpy_array(image):
    (im_width, im_height) = image.size
    return np.array(image.getdata()).reshape(
      (im_height, im_width, 3)).astype(np.uint8)

def visualize (image_folder,label,scores,ymax,ymin,xmax,xmin,full_img_path=''):
    image_path = image_folder+'/'+label
    image = Image.open(image_path)
    image_np = load_image_into_numpy_array(image)
    bbox = []
    for i in range(len(scores)):
        bbox.append([ymin[i],xmin[i],ymax[i],xmax[i]])
    vis_util.visualize_boxes_and_labels_on_image_array(
        image_np,
        np.asarray(bbox),
        np.ones(len(scores),dtype=int),
        np.asarray(scores),
        category_index,
        use_normalized_coordinates=True,
        min_score_thresh=threshold,
        line_thickness=3)
    im = Image.fromarray(image_np) 
    im.save(full_img_path+'/'+label)        
        
def image():
    #prepare and check dir
    Gui.t.delete(1.0,END)
    global group_choice
    group_choice = Gui.var.get()
    if group_choice == 'filtered_groups':
        groups = filtered_groups.copy()
    elif group_choice == 'positive_groups':
        groups = positive_groups.copy()
    path = processed_destination+'/'+processed_prefix+'_'+group_choice
    if os.path.exists(path):
        shutil.rmtree(path)
        time.sleep(1)
    os.mkdir(path)
    
    #double sort, read through the inference record again
    idx=-1;
    need_update=True
    for example in tf.python_io.tf_record_iterator(detection_tfrecord):
        idx=idx+1
        if need_update:
            #find next group
            first = label_list.index(groups[0][0])
            last = label_list.index(groups[0][1])
            sub_folder = label_list[first][0:-4]
            full_img_path = path+'/'+sub_folder
            #full_img_path = os.path.normpath(full_img_path)
            os.mkdir(full_img_path)
            need_update = False
        result = tf.train.Example.FromString(example)
        json_str = MessageToJson(result)
        dic = json.loads(json_str)['features']['feature']
        if idx in range(first,last+1):
            #find the bounding boxes and then visualize
            label=label_list[idx]
            scores = dic['image/detection/score']['floatList']['value']
            ymax = dic['image/detection/bbox/ymax']['floatList']['value']
            ymin = dic['image/detection/bbox/ymin']['floatList']['value']
            xmax = dic['image/detection/bbox/xmax']['floatList']['value']
            xmin = dic['image/detection/bbox/xmin']['floatList']['value']
            visualize(image_dir,label,scores,ymax,ymin,xmax,xmin,full_img_path)
        if idx==last+1:
            #this group is all visualized
            del groups[0]
            need_update=True
            display('visualized: '+ sub_folder)
            if len(groups)==0:
                break
    display('All groups visualized!')

def image_button():
    if not process_done:
        tkinter.messagebox.showwarning('Alert','Detections are not processed yet.')
    elif image_dir=='':
        tkinter.messagebox.showwarning('Alert','Please select source directory of images')
    else:
        threading.Thread(target=image).start()

## User Interface

In [4]:
class Gui:
    root = Tk()
    root.minsize(480,320)
    root.title('Process detections')

    label_1=Label(root,text = 'Detection TFrecord: ')
    label_2=Label(root,text = 'Score Threshold: ')
    label_3=Label(root,text = 'Output Destination: ')
    label_4=Label(root,text = 'Output Prefix: ')
    label_5=Label(root,text = 'Image Directory: ')
    entry_1=Label(root,text = 'not selected',bg='White')
    entry_2=Entry(root,bg = 'White')
    entry_2.insert(END, '0.9')
    entry_3=Label(root,text = 'not selected',bg='White')
    entry_4=Entry(root,bg = 'White')
    entry_5=Label(root,text = 'not selected',bg='White')
    var = StringVar(root)
    var.set('filtered_groups') # default value
    entry_6= OptionMenu(root, var, 'filtered_groups','positive_groups')
    
    button_1=Button(root,text='Browse',command= askinput)
    button_3=Button(root,text='Browse',command= askdestination)
    button_4=Button(root,text='Process',command= process_button)
    button_5=Button(root,text='Browse',command= askimagedir)
    button_6=Button(root,text='Export Images',command=image_button)
    label_1.grid(row=0,sticky=E)
    label_2.grid(row=1,sticky=E)
    label_3.grid(row=2,sticky=E)
    label_4.grid(row=3,sticky=E)
    label_5.grid(row=4,sticky=E)
    entry_1.grid(row=0,column=1)
    entry_2.grid(row=1,column=1)
    entry_3.grid(row=2,column=1)
    entry_4.grid(row=3,column=1)
    entry_5.grid(row=4,column=1)
    entry_6.grid(row=5,column=1)
    button_1.grid(row=0,column=2)
    button_3.grid(row=2,column=2)
    button_4.grid(row=3,column=2)
    button_5.grid(row=4,column=2)
    button_6.grid(row=5,column=2)
    t = Text(root,bg='White')
    t.grid(row=6,column=0,columnspan=3)

GUI = Gui.root
GUI.mainloop()