In [1]:
import numpy as np
import os
import sys
import scipy
import cv2
import gc

#解析使用
import xml
from xml.etree import ElementTree as ET

from glob import glob

import keras.backend as K
from keras.applications import VGG19
from keras.models import Model
from keras.utils import to_categorical

import imageio
from skimage import transform

from matplotlib import pyplot as plt
%matplotlib inline

from sklearn.svm import SVC #类别分类使用
from sklearn.linear_model import Ridge #bounding-box回归
from sklearn.externals import joblib

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
import tensorflow as tf

from tensorflow.contrib import slim

from ImageNet_classes import class_names #验证alexnet使用

In [3]:
TRAIN_DATA_PATH = '../../../tensorflow2/dataset/VOCtrainval_11-May-2012/JPEGImages/'
TEST_DATA_PATH = '../../../tensorflow2/dataset/VOC2012test/JPEGImages/'

TRAIN_XML_PATH = '../../../tensorflow2/dataset/VOCtrainval_11-May-2012/Annotations/'
TEST_XML_PATH = '../../../tensorflow2/dataset/VOC2012test/Annotations/'

CLASSES_NUM = 20

STR = [
    'person',
    'bird','cat','cow','dog','horse','sheep',
    'aeroplane','bicycle','boat','bus','car','motorbike','train',
    'bottle','chair','diningtable','pottedplant','sofa','tvmonitor'
]

LABEL2STR = {idx:value for idx , value in enumerate(STR)}
STR2LABEL = {value:key for key,value in LABEL2STR.items()}
#STR2LABEL = {value:idx for idx , value in enumerate(STR)}

STR2LABEL['none'] = 'none' #先不使用part部分 只进行naive目标检测

#目标检测相关
IoU_THRESHOLD = 0.5

#SVM相关
SVM_IoU_THRESHOLD = 0.3

#NMS相关
NMS_IoU_THRESHOLD = 0.3 #or ~0.5

#bbox回归
BBOX_REGRESS_IoU_THRESHOLD = 0.6

In [4]:
xml_file_names_train = glob(TRAIN_XML_PATH + '*') #所有的xml文件 完整路径

#从xml文件中读出图片相关的信息

def xml_parse(xml_file):
    '''
    return filename , shape , name_boxes , crop_boxes
    xml文件中的shape格式为 (width height 3)
    '''
    xml_file = xml.dom.minidom.parse(xml_file)
    xml_file_docu_ele = xml_file.documentElement

    filename_list = xml_file_docu_ele.getElementsByTagName('filename')
    
    #filename_list可能有多个filename的 所以要索引0(此数据集中filename只有一个)
    filename = filename_list[0].childNodes[0].data #filename_list.firstChild.data

    #图像的尺寸信息
    size_list = xml_file_docu_ele.getElementsByTagName('size')

    for size in size_list:
        width_list = size.getElementsByTagName('width')
        width = int(width_list[0].childNodes[0].data)

        height_list = size.getElementsByTagName('height')
        height = int(height_list[0].childNodes[0].data)

        channel_list = size.getElementsByTagName('depth')
        channel = int(channel_list[0].childNodes[0].data)

    #一个文件中有多个object
    object_list = xml_file_docu_ele.getElementsByTagName('object')

    #多个object与多个object对应的详细信息
    name_boxes = [] #一个元素就是一个object
    crop_boxes = []

    for objects in object_list:
        #一次循环处理一个object信息
        #一个xml文件（即一个图像中）有多个object

        #name
        name_list = objects.getElementsByTagName('name')

        name_box = name_list[0].childNodes[0].data

        #bounding box points
        bndbox = objects.getElementsByTagName('bndbox')

        x1_list = bndbox[0].getElementsByTagName('xmin')
        x1 = int( round( float(x1_list[0].childNodes[0].data) ) )

        y1_list = bndbox[0].getElementsByTagName('ymin')
        y1 = int(round(float( y1_list[0].childNodes[0].data )))

        x2_list = bndbox[0].getElementsByTagName('xmax')
        x2 = int(round(float( x2_list[0].childNodes[0].data )))

        y2_list = bndbox[0].getElementsByTagName('ymax')
        y2 = int(round(float( y2_list[0].childNodes[0].data )))

        crop_box = [x1,x2,y1,y2]

        name_boxes.append(name_box)
        crop_boxes.append(crop_box)

    #crop_box:[x1 x2 y1 y2]
    return filename , name_boxes , np.array(crop_boxes) #filename调试使用


In [5]:
#xml_parse(xml_file_names_train[897])

In [6]:
class Image(object):
    '''
    图片的真实信息
    '''
    def __init__(self):
        self.img_file_names_train = glob(TRAIN_DATA_PATH+'*') #训练全路径信息
                
    def load(self , img_path_name = None):
        if not img_path_name:
            img_path_name = np.random.choice(self.img_file_names_train) #随机选择一张图片
            #img_path_idx = np.random.randint(0 , high = len(self.img_file_names_train)) #随机索引

        img_arr = cv2.imread(img_path_name) #BGR height*width*chanel
        
        xml_file_name = TRAIN_XML_PATH + img_path_name[-15:-4] +  '.xml'
        
        _ , name_boxes , crop_boxes = xml_parse(xml_file_name)
        
        labels = [] #存储与bndbox对应的 label信息

        for i in range(len(crop_boxes)): #多个object 
            labels.append(STR2LABEL.get(name_boxes[i] , 'none'))
        
        return img_arr , labels , crop_boxes
    

In [7]:
class Img_generator(object):
    def __init__(self):
        self.img_loader = Image()

    #计算bbox面积
    def bbox_area(self , bbox):
        w = bbox[1] - bbox[0]
        h = bbox[3] - bbox[2]
        
        return w*h
    
    #计算交并比
    def IoU(self , bbox_a , bbox_b):
        xmin_a = bbox_a[0]
        xmax_a = bbox_a[1]
        ymin_a = bbox_a[2]
        ymax_a = bbox_a[3]
        
        xmin_b = bbox_b[0]
        xmax_b = bbox_b[1]
        ymin_b = bbox_b[2]
        ymax_b = bbox_b[3]
        
        if   xmin_a < xmax_b <= xmax_a and (ymin_a < ymax_b <= ymax_a or ymin_a <= ymin_b < ymax_a):
            flag = True
        elif xmin_a <= xmin_b < xmax_a and (ymin_a < ymax_b <= ymax_a or ymin_a <= ymin_b < ymax_a):
            flag = True
        elif xmin_b < xmax_a <= xmax_b and (ymin_b < ymax_a <= ymax_b or ymin_b <= ymin_a < ymax_b):
            flag = True
        elif xmin_b <= xmin_a < xmax_b and (ymin_b < ymax_a <= ymax_b or ymin_b <= ymin_a < ymax_b):
            flag = True
        else:
            flag = False
        
        if flag:
            x_sorted_list = sorted([xmin_a, xmax_a, xmin_b, xmax_b])
            y_sorted_list = sorted([ymin_a, ymax_a, ymin_b, ymax_b])
            
            x_intersect_w = x_sorted_list[2] - x_sorted_list[1] #0 1 2 3
            y_intersect_h = y_sorted_list[2] - y_sorted_list[1] #0 1 2 3
            
            area_inter = x_intersect_w * y_intersect_h #计算重合面积
            
            union_area = self.bbox_area(bbox_a) + self.bbox_area(bbox_b) - area_inter
            
            return area_inter/union_area
        else:
            return 0.0
    
    
    def map2new(self , img_shape , ground_truth_coord):
        '''
        坐标系映射至448*448坐标系中
        '''
        original_height = img_shape[0]
        original_width = img_shape[1]
        
        ground_truth_coord[: , :2] = np.array( ground_truth_coord[: , :2] * (448/original_width) , dtype=int) #x1 x2
        ground_truth_coord[: , 2:] = np.array( ground_truth_coord[: , 2:] * (448/original_height) , dtype=int) #y1 y2
        
        return ground_truth_coord
        
    
    def get_train_proposal(self , labels , ground_truth_coord):
        #get_train_proposal为关键函数
        
        def _center(gt):
            '''
            gt的中心坐标
            '''
            x = int( ( gt[0] + gt[1] ) / 2 )
            y = int( ( gt[2] + gt[3] ) / 2 )
        
            return [x,y]
        
        def _is_in_grid(gt_center , grid):
            '''
            判断gt的中心是否在此grid cell中
            '''
            if (grid[0] <= gt_center[0] <= grid[1]) and (grid[2] <= gt_center[1] <= grid[3]):
                return True
            else:
                return False
        
        def _target_gt(gt , grid):
            '''
            将gt变为target需要的格式
            '''
            center = _center(gt)
            #gt中心坐标在对应的grid中的偏移
            target_x = ( center[0] - grid[0] ) / 64
            target_y = ( center[1] - grid[2] ) / 64
            
            target_w = ( gt[1]-gt[0] ) / 448
            target_h = ( gt[3]-gt[2] ) / 448
            
            return [target_x , target_y , target_w , target_h]
            
        
        '''下面操作在448*448坐标系中进行'''
        
        #S*S grid cells
        x_slice = [ [64*i , 64*(i+1)] for i in [0,1,2,3,4,5,6] ]
        y_slice = [ [64*i , 64*(i+1)] for i in [0,1,2,3,4,5,6] ]
        
        grid = np.zeros(shape=[7 , 7] , dtype=list)
        
        for x_idx , x in enumerate(x_slice):
            for y_idx , y in enumerate(y_slice):
                grid[x_idx][y_idx] = y + x
        
        #训练样本中的y
        #[confidence x y w h]*2 + cls_score
        target = np.zeros(shape=[7 , 7 , 30] , dtype=float)
        
        for i in range(7):
            for j in range(7):
                for idx , gt in enumerate(ground_truth_coord):
                    
                    if _is_in_grid( _center(gt) , grid[i][j]):
                        '''
                        此gt的中心位于此grid cell中
                        '''
                        iou = self.IoU(gt , grid[i][j])
                        
                        #如果出现一个grid cell有多个gt对应 则只保留iou最高的前两个
                        if target[i][j][0] > target[i][j][5]:
                            if iou > target[i][j][5]:
                                target[i][j][5] = iou
                                target[i][j][6 : 10] = _target_gt(gt , grid[i][j])
                                target[i][j][ labels[idx] + 10 ] = 1.0
                        elif target[i][j][0] <= target[i][j][5]:
                            if iou > target[i][j][0]:
                                target[i][j][0] = iou
                                target[i][j][1 : 5] = _target_gt(gt , grid[i][j])
                                target[i][j][ labels[idx] + 10 ] = 1.0
                        #并入上面
                        #else:
                        #    #先执行此处
                        #    if iou > target[i][j][0]:
                        #        #只要大于 随意选一个位置即可
                        #        target[i][j][0] = iou
                        #        target[i][j][1 : 5] = _target_gt(gt)
                        #        target[i][j][ labels[idx] + 10 ] = 1.0
                
                #here
                #处理完一个grid cell
                #如果[i , j] grid只有一个gt与之对应 则将其翻倍（因为yolo v1一个grid对应两个bounding box）
                #只有一个gt与之对应 则只会出现在target[i][j][0 1 2 3 4]处
                if (target[i][j][0] != 0.0) and (target[i][j][5] == 0.0):
                    target[i][j][5:10] = target[i][j][0:5]
                    #cls_score是一样的 同一个位置
                                                  
        return np.array(target)
    
    
    def load(self , img_path_name):
        '''
        img_path_name:绝对路径
        '''
        
        #图片数据 ground truth具体数据 ground truth对应label ground truth坐标信息 图片文件名
        img_arr , labels , ground_truth_coord = self.img_loader.load(img_path_name)
        
        ground_truth_coord = self.map2new(img_arr.shape , ground_truth_coord) #将ground_truch坐标从原坐标系映射至448*448坐标系中
        
        img_arr = cv2.resize(img_arr , (448 , 448))
        img_arr = img_arr / 127.5 - 1 #对下面的get_train_proposal没有影响
        
        target = self.get_train_proposal(labels , ground_truth_coord)
        
        '''
        resize 并 归一化像素值
        img_arr 为 BGR形式
        '''
        
        '''[R G B] [123.68 116.779 103.939]
        减去每个通道的像素平均值 归一化'''
        #img_arr[:,:,0] = img_arr[:,:,0] - 103.939
        #img_arr[:,:,1] = img_arr[:,:,1] - 116.779
        #img_arr[:,:,2] = img_arr[:,:,2] - 123.680
        
        #'''增加一维 batch维'''
        return np.expand_dims(img_arr , axis=0) , target
    
    
    def get_test_proposal(self , img_arr):
        '''
        return:rois
        proposals_coord
        '''
        
        h = img_arr.shape[0]
        w = img_arr.shape[1]
        
        def bbox_trans(_rect):
            rect = [-1,-1,-1,-1]
            
            rect[0] = int(_rect[0]*600 / w)
            rect[1] = int(_rect[1]*600 / w)
            rect[2] = int(_rect[2]*1000 / h)
            rect[3] = int(_rect[3]*1000 / h)
        
            return rect
        
        anchors = [] #x1 x2 y1 y2 计算iou使用
        
        feature_map_height = 61
        feature_map_width = 36
        
        scales = [128 , 256 , 512]
        ratios = [[1,2] , [1,1] , [2,1]] #用scale除以即可 [height_ratio width_ratio]
        
        '''
        x_0 y_0 为 feature map中的坐标
        
        x_0_coord y_0_coord 为原图中的坐标（中点坐标）
        
        跨越边界的anchor 进行截断
        '''
        
        for x_0 in range(61): #height
            for y_0 in range(36): #width
                
                x_0_coord = x_0 * 16
                y_0_coord = y_0 * 16
                
                for scale in scales:
                    for ratio in ratios:
                        scale_height = int(scale / ratio[0])
                        scale_width = int(scale / ratio[1])
                    
                        x_1_coord = int(x_0_coord - scale_width/2)
                        y_1_coord = int(y_0_coord - scale_height/2)

                        if x_1_coord < 0:
                            x_1_coord = 0
                            
                        if y_1_coord < 0:
                            y_1_coord = 0

                        x_2_coord = int(x_0_coord + scale_width/2)
                        y_2_coord = int(y_0_coord + scale_height/2)

                        if x_2_coord > 600:
                            x_2_coord = 600
                            
                        if y_2_coord > 1000:
                            y_2_coord = 1000
                        
                        anchors.append( [x_1_coord , x_2_coord , y_1_coord , y_2_coord] )

        return np.array(anchors)
    
    def load_test(self , img_path_name):
        img_arr = cv2.imread(img_path_name)
        
        anchors = self.get_test_proposal(img_arr)
        
        img_arr = cv2.resize(img_arr , (600 , 1000))
        img_arr = img_arr / 127.5 - 1.0
        
        return np.expand_dims(img_arr , axis=0) , anchors

# class Img_generator

In [8]:
class Dataset(object):
    def __init__(self):
        self.img_generator = Img_generator()
        
        self.img_loader = Image()
        
        self.img_file_names_train = glob(TRAIN_DATA_PATH + '*')
        self.img_file_names_test = glob(TEST_DATA_PATH + '*')
    
    def get_batch(self):
        path = np.random.choice(self.img_file_names_train)
        
        x , target = self.img_generator.load(path)
    
        return x , target
    
    def get_batch_test(self , path):
        
        if not path:
            #未指定path 从测试目录中随机选一张图片测试
            path = np.random.choice(self.img_file_names_test)
        
        x , target = self.img_generator.load_test(path)
        
        return x , target
    
    
    def target2coord(self , bbox_pred , img_arr , anchors):
        img_height = img_arr.shape[0]
        img_width = img_arr.shape[1]
        
        def to(rect):
            x1 = rect[0]
            x2 = rect[1]
            y1 = rect[2]
            y2 = rect[3]
            
            w = x2-x1
            h = y2-y1
            
            x_c = (x1+x2)//2
            y_c = (y1+y2)//2
            
            return x_c , y_c , w , h
        
        def ot(target):
            x_c = target[0]
            y_c = target[1]
            w = target[2]
            h = target[3]
            
            x1 = 0.5*(2*x_c-w)
            y1 = 0.5*(2*y_c-h)
            x2 = x1+w
            y2 = y1+h
            
            x1=int(round(x1))
            y1=int(round(y1))
            x2=int(round(x2))
            y2=int(round(y2))
            
            if x1<0:
                x1 = 0
            if x2>img_width:
                x2 = img_width
            if y1<0:
                y1 = 0
            if y2>img_height:
                y2 = img_height
                            
            return [x1 , x2 , y1 , y2]
        
        def target2rect(target_hat , P_box):
            t_x = target_hat[0]
            t_y = target_hat[1]
            t_w = target_hat[2]
            t_h = target_hat[3]
            
            P_x , P_y , P_w , P_h = to(P_box) #将P框转换为 中点坐标 宽 高 形式
            
            G_x_hat = P_w*t_x+P_x
            G_y_hat = P_h*t_y+P_y
            G_w_hat = P_w*np.exp(t_w)
            G_h_hat = P_h*np.exp(t_h)
            
            return ot([G_x_hat , G_y_hat , G_w_hat , G_h_hat]) #ot还需要转化为(x1,x2,y1,y2)形式
        
        bbox_coord_pred = []
        
        for i in range(len(bbox_pred)):
            bbox_coord_pred.append( target2rect(bbox_pred[i] , anchors[i]) )
                
        return bbox_coord_pred

In [9]:
class Display(object):
    def __init__(self):
        pass
    
    def display(self , img_arr , labels , bbox , name):    
        for i in range(len(labels)):
            
            x1 = bbox[i][0]
            x2 = bbox[i][1]
            y1 = bbox[i][2]
            y2 = bbox[i][3]
            
            img_arr = cv2.rectangle(img_arr , (x1 , y1) , (x2 , y2) , (255,255,255))
            
            img_arr = cv2.putText(img_arr , labels[i] , org=(x1 , y1+10) , fontFace = cv2.FONT_HERSHEY_PLAIN , fontScale=1 , color = (255,255,255), thickness = 1)
        
        #plt.imshow(meta_img) #图像查看
        
        plt.imsave(arr=img_arr[: , : ,[2,1,0]] , fname = 'result/%s.jpg' % name) #保存图像
        

In [10]:
#refer:https://blog.csdn.net/two_vv/article/details/76769860
#alexnet原始模型以及预训练参数导入
class AlexNet_model(object):
    def __init__(self , is_training=True):
        
        self.x = tf.placeholder(dtype=tf.float32 , shape=[1 , 448 , 448 , 3])        
        
        self.build(is_training) #构建网络产生输出
        
        if is_training:
            self.target = tf.placeholder(dtype=tf.float32 , shape=[7 , 7 , 30])
            self.loss()

    def build(self , is_training):
        #arch from paper
        def _conv(_input , num_outputs , kernel_size , stride=1):
            return slim.conv2d(_input , num_outputs=num_outputs , kernel_size=kernel_size , stride=stride , activation_fn=tf.nn.leaky_relu ,
                             weights_initializer=tf.initializers.truncated_normal(stddev=0.01) ,
                             biases_initializer=tf.initializers.constant(0.0))
        
        def _max_pool(_input , kernel_size=2 , stride=2):
            return slim.max_pool2d(_input , kernel_size=kernel_size , stride=stride)
        
        def _conv_module_a(_input):
            _output = _conv(_input , 256 , 1)
            return _conv(_output , 512 , 3)
        
        def _conv_module_b(_input):
            _output = _conv(_input , 512 , 1)
            return _conv(_output , 1024 , 3)
        
        output = _conv(self.x , 64 , 7 , 2)
        output = _max_pool(output)
        
        output = _conv(output , 192 , 3)              
        output = _max_pool(output)
        
        output = _conv(output , 128 , 1)
        output = _conv(output , 256 , 3)
        output = _conv(output , 256 , 1)
        output = _conv(output , 512 , 3)
        output = _max_pool(output)
        
        #4 times
        output = _conv_module_a(output)
        output = _conv_module_a(output)
        output = _conv_module_a(output)
        output = _conv_module_a(output)
        
        output = _conv(output , 512 , 1)
        output = _conv(output , 1024 , 3)
        output = _max_pool(output)
        
        #twice
        output = _conv_module_b(output)
        output = _conv_module_b(output)

        output = _conv(output , 1024 , 3)
        output = _conv(output , 1024 , 3 , 2)
        
        output = _conv(output , 1024 , 3)
        output = _conv(output , 1024 , 3)
        
        output = slim.flatten(output)
        
        #paper中为4096 pc性能达不到
        output = slim.fully_connected(inputs=output , num_outputs=1024 , activation_fn=tf.nn.leaky_relu ,
                                     weights_initializer=tf.initializers.truncated_normal(stddev=0.01),
                                     biases_initializer=tf.initializers.constant(0.0))
        
        #引入dropout
        output = slim.dropout(output , keep_prob=0.5 , is_training=is_training)
        
        #tf.identity 使用线性激活函数 nan错误 使用leaky relu也会出错 换成relu
        output = slim.fully_connected(inputs=output , num_outputs=7*7*30 , activation_fn=tf.nn.relu  ,
                                     weights_initializer=tf.initializers.truncated_normal(stddev=0.01),
                                     biases_initializer=tf.initializers.constant(0.0))
        
        self.output = tf.reshape(output , shape=[7 , 7 , 30]) #丢弃掉batch维 没用

        
    def loss(self):
        lambda_coord = 5.0
        lambda_noobj = 0.5
        
        #[:,:,0]
        _mask = tf.cast( tf.greater( tf.slice(self.target,begin=[0,0,0],size=[7,7,1]) , np.zeros(shape=[7,7,1] , dtype=float) ) , dtype=tf.float32 )
        mask = tf.tile(_mask , multiples=[1 , 1 , 2]) #7*7*2
        #[:,:,[1,2]]
        loss_coord = tf.reduce_sum( tf.square( ( tf.slice(self.output,begin=[0,0,1],size=[7,7,2]) - tf.slice(self.target,begin=[0,0,1],size=[7,7,2]) ) * mask ) )
        #[:,:,[3,4]]
        loss_coord += tf.reduce_sum( tf.square( (tf.sqrt( tf.slice(self.output,begin=[0,0,3],size=[7,7,2]) ) - tf.sqrt( tf.slice(self.target,begin=[0,0,3],size=[7,7,2]))) * mask ) )
        
        loss_coord = lambda_coord * loss_coord
        
        #[:,:,:[0,5]]
        loss_iou = tf.reduce_sum( tf.square( ( tf.slice(self.output,begin=[0,0,0],size=[7,7,1]) - tf.slice(self.target,begin=[0,0,0],size=[7,7,1])) * mask ) )
        loss_iou += tf.reduce_sum( tf.square( ( tf.slice(self.output,begin=[0,0,5],size=[7,7,1]) - tf.slice(self.target,begin=[0,0,5],size=[7,7,1])) * mask ) )
        #[:,:,:[0,5]]
        loss_iou += lambda_noobj * (tf.reduce_sum( tf.square( ( tf.slice(self.output,begin=[0,0,0],size=[7,7,1]) - tf.slice(self.target,begin=[0,0,0],size=[7,7,1]) ) * (1.0-mask) ) ) +\
                                    tf.reduce_sum( tf.square( ( tf.slice(self.output,begin=[0,0,5],size=[7,7,1]) - tf.slice(self.target,begin=[0,0,5],size=[7,7,1]) ) * (1.0-mask) ) ))
        
        mask = tf.tile(_mask , multiples=[1,1,20]) #7*7*20 (因为20 classes)
        #[:,:,10:]
        loss_cls = tf.reduce_sum( tf.square( ( tf.slice(self.output,begin=[0,0,10],size=[7,7,20]) - tf.slice(self.target,begin=[0,0,10],size=[7,7,20]) ) * mask ) )
        
        self.total_loss = loss_coord + loss_iou + loss_cls


In [12]:
#refer:https://blog.csdn.net/two_vv/article/details/76769860

class YOLO_V1(object):
    '''
    完整模型
    '''
    
    def __init__(self , is_training = True):      
        self.dataset = Dataset()
        self.display = Display()
        
        self.img_generator = Img_generator()
        
        self.filewriter_path = 'save/logs' #模型可视化
        self.checkpoint_path = 'save/model/' #模型持久化
                              
        self.model = AlexNet_model(is_training)
        
        self.sess = tf.Session()
        
        if is_training:
            '''训练参数'''
            self.epoch = 100000
            
            self.global_step = tf.Variable(initial_value=0 , trainable=False)
            
            self.learning_rate = tf.train.exponential_decay(learning_rate=0.0001 , global_step=self.global_step,
                                                            decay_steps=900 , decay_rate=0.8 , staircase=True)
            
            self.optimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(self.model.total_loss , global_step=self.global_step)
        
            '''引入滑动平均'''
            self.ema = tf.train.ExponentialMovingAverage(decay=0.9) #滑动平均
            self.average_op = self.ema.apply(tf.trainable_variables()) #给所有的可训练变量应用滑动平均
            
            with tf.control_dependencies([self.optimizer]):
                self.train_op = tf.group(self.average_op)
            
        self.sess.run(tf.global_variables_initializer())
        
        if is_training:
            tf.summary.scalar('total_loss' , self.model.total_loss)
            self.merged_summary = tf.summary.merge_all() #merge all summaries in the default graph
            self.writer = tf.summary.FileWriter(self.filewriter_path , self.sess.graph) #可视化
            
        self.saver = tf.train.Saver(max_to_keep=2) #max_to_keep 最大保存5次模型  之后继续保存则会覆盖前面的模型
        
    def train(self):
        
        if os.path.exists(self.checkpoint_path+'checkpoint'):
            self.saver.restore(self.sess , tf.train.latest_checkpoint(self.checkpoint_path))
        else:
            self.sess.run(tf.global_variables_initializer()) 
        
        for i in range(100000):
            x , target = self.dataset.get_batch()
            
            self.sess.run(self.train_op , feed_dict={self.model.x : x , self.model.target : target} )

            if i % 10 == 0:
                self.saver.save(self.sess , self.checkpoint_path + 'model.ckpt' , global_step = i)
                
                total_loss , summary = self.sess.run([self.model.total_loss , self.merged_summary] , feed_dict={self.model.x : x , self.model.target : target})
                
                self.writer.add_summary(summary , global_step = i)
                
                print(i , total_loss)
            
        self.writer.close() #event to disk and close the file

    def predict(self , path=None , scores_threshold = 0.1 , nms_iou_threshold = 0.7):
        if os.path.exists(self.checkpoint_path + 'checkpoint'):
            self.saver.restore(self.sess , tf.train.latest_checkpoint(self.checkpoint_path) )
            
            return self._predict(path , scores_threshold , nms_iou_threshold)
            
        else:
            print('no model!!!')
            return 
    
    def _predict(self , path , scores_threshold , nms_iou_threshold):
        x , rois , img_arr , proposals_coord = self.dataset.get_batch_test(path)
        
        cls_pred , bbox_pred = self.sess.run([self.model.cls_pred , self.model.bbox_pred] , feed_dict={self.x : x , self.rois : rois})
        
        #转换为原始图片中的坐标
        bbox_coord_pred = self.dataset.target2coord(bbox_pred , img_arr , proposals_coord)
        
        '''
        由target到原始坐标 在进行nms
        '''
        
        scores_pred_f = [] #符合条件的概率值
        bbox_coord_pred_f = [] #符合条件的框子坐标
        
        labels_pred_f = [] #label名字
        
        for i in range(len(cls_pred)):
            if np.argmax(cls_pred[i]) != 0 and (np.max(cls_pred[i]) > scores_threshold):
                scores_pred_f.append(np.max(cls_pred[i]))
                
                bbox_coord_pred_f.append(bbox_coord_pred[i])
                
                labels_pred_f.append(LABEL2STR[np.argmax(cls_pred[i])])
        
        scores_pred_f = np.array(scores_pred_f)
        bbox_coord_pred_f = np.array(bbox_coord_pred_f)
        labels_pred_f = np.array(labels_pred_f)
        
        #降序scores
        sort_idx = np.argsort(- np.array(scores_pred_f) )
        
        scores_pred_f = scores_pred_f[sort_idx]
        bbox_coord_pred_f = bbox_coord_pred_f[sort_idx]
        labels_pred_f = labels_pred_f[sort_idx]
                
        final_idx = self._nms(scores_pred_f , bbox_coord_pred_f , nms_iou_threshold)
                
        #scores_pred_f = scores_pred_f[final_idx] #用不上
        bbox_coord_pred_f = bbox_coord_pred_f[final_idx]
        labels_pred_f = labels_pred_f[final_idx]
                
        # 绘制并保存
        self.display.display(img_arr , labels_pred_f , bbox_coord_pred_f , 'first')
        
        return cls_pred , bbox_coord_pred , labels_pred_f , bbox_coord_pred_f
        
        
    def _nms(self , probability_hat , rects_hat , nms_iou_threshold):
        idx = []
        
        length = len(probability_hat)
        lost_flag = [1]*length #标记丢弃的框 0表示丢弃
        
        max_score_idx = 0 #记录当前最大score的idx
        
        while max_score_idx < length:
            max_score_rect = rects_hat[max_score_idx]
            
            for i in range(max_score_idx+1 , length):
                if lost_flag[i] == 1 and self.img_generator.IoU( max_score_rect , rects_hat[i] ) > nms_iou_threshold: #大于阈值 丢弃
                    lost_flag[i] = 0

            max_score_idx_bak = max_score_idx #后续使用
            
            #让max_score_idx指向下一个没被丢弃的最大值
            for i in range(max_score_idx+1 , length):
                if lost_flag[i] == 1:
                    max_score_idx = i
                    break
            
            #说明max_score_idx没有移动过 即后续的都被丢弃了 终止循环
            if max_score_idx == max_score_idx_bak:
                break
        
        for i in range(length):
            if lost_flag[i] == 1:
                idx.append(i)
                
        return idx

In [13]:
test = YOLO_V1()

In [None]:
test.train()

0 28.909016
10 5.7129765
20 2.821789
30 2.9340255
40 2.0163012
50 5.452017
60 2.9741642
70 14.488641
80 1.3236729


In [14]:
test.sess.close()

In [41]:
tf.reset_default_graph()