In [1]:
import sys
import urllib.request
from PyQt5.QtWidgets import *
from PyQt5.QtCore import *
from PyQt5.QtGui import *
from PyQt5 import uic, QtGui, QtCore
import pyqtgraph as pg
from datetime import datetime
import time

import tensorflow as tf
import utils.deeplabv3_ as build_model
import utils.draw_predict as dp
import utils.call_data as cd

import cv2
import numpy as np

model_path_root = './model/113'

class_num = 2
c_area_thr = 400
ratio_thr = 0.4
diff_thr = 350

font = cv2.FONT_HERSHEY_SIMPLEX
fontScale = 0.5

c_radi = 10
red = (0, 0, 255)
black = (0, 0, 0)

# tensorflow
tf.reset_default_graph()
X = tf.placeholder(tf.float32, shape=[None, None, None, 3])
model = build_model.deeplabv3(X, class_num, True)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint(model_path_root))

#UI파일 연결
#단, UI파일은 Python 코드 파일과 같은 디렉토리에 위치해야한다.
form_class = uic.loadUiType("dongledongle.ui")[0]

def cvtPixmap(frame, img_size):
    frame = cv2.resize(frame, img_size)
    height, width, channel = frame.shape
    bytesPerLine = 3 * width
    qImg = QImage(frame.data, 
                  width, 
                  height, 
                  bytesPerLine, 
                  QImage.Format_RGB888).rgbSwapped()
    qpixmap = QtGui.QPixmap.fromImage(qImg)
    
    return qpixmap

class TimeAxisItem(pg.AxisItem):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.setLabel(text='Time(초)', units=None)
        self.enableAutoSIPrefix(False)

    def tickStrings(self, values, scale, spacing):
        """ override 하여, tick 옆에 써지는 문자를 원하는대로 수정함.
            values --> x축 값들   ; 숫자로 이루어진 Itarable data --> ex) List[int]
        """
        # print("--tickStrings valuse ==>", values)
        return [time.strftime("%H:%M:%S", time.localtime(local_time)) for local_time in values]

#화면을 띄우는데 사용되는 Class 선언
class WindowClass(QMainWindow, form_class) :
    def __init__(self) :
        super().__init__()
        self.setupUi(self)

        # video
        self.show_video = False
#         self.cap = cv2.VideoCapture('testavi_ori.avi')
        self.cap = cv2.VideoCapture(1)
        self.img_blue = cv2.imread('./blue.jpg')
        self.img_red = cv2.imread('./red.jpg')
        self.blink_display = self.img_blue
        
        # 버튼에 기능을 연결하는 코드
        self.pButton_start.clicked.connect(self.pButtonStartFunction)
        self.pButton_finish.clicked.connect(self.pButtonFinishFunction)
        self.pButton_cali.clicked.connect(self.pButtonCaliFunction)
        self.pButton_calitest.clicked.connect(self.pButtonCaliTestFunction)
        self.pButton_ctstop.clicked.connect(self.pButtonCTStopFunction)
        self.pButton_gamestart.clicked.connect(self.pButtonGameStartFunction)
        self.pButton_gamestop.clicked.connect(self.pButtonGameStartFunction)
        
        # graph
        self.graphWidget = pg.PlotWidget(
            title = '동공 크기',
            labels = {'left' : 'Size (픽셀)'},
            axisItems = {'bottom' : TimeAxisItem(orientation='bottom')}
        )
        self.verticalLayout.addWidget(self.graphWidget)
        self.graphWidget.setYRange(0, 5000, padding=0)
        self.graphWidget.showGrid(x=True, y=True)
        
        self.pdi = self.graphWidget.plot(pen='y')
        self.plotData = {'x' : [], 'y' : []}
        
        # Tab Widget
#         self.currentTabName = self.tabWidget.currentIndex()
        
        # calibration
        self.cali_left = [0, 0]
        self.cali_right = [0, 0]
        self.cali_center = [0, 0]
        self.cali_frame = np.zeros((450, 1300, 3), np.uint8)
        height, width, channel = self.cali_frame.shape
        self.cali_frame_center_x = int(width/2)
        self.cali_frame_center_y = int(height/2)
        self.cali_frame_left_x = int(self.cali_frame_center_x / 2) - int(int(self.cali_frame_center_x / 2) / 2)
        self.cali_frame_right_x = self.cali_frame_center_x + int(self.cali_frame_center_x / 2) + int(int(self.cali_frame_center_x / 2)/ 2)
        self.color_list = [[(self.cali_frame_center_x, self.cali_frame_center_y), black, red, black],
                          [(self.cali_frame_left_x, self.cali_frame_center_y), red, black, black],
                          [(self.cali_frame_right_x, self.cali_frame_center_y), black, black, red]]
        self.cali_test = False
        
        # game
        self.game_start = False
    
    def pButtonGameStartFunction(self):
        self.game_start = True
        
        ## 동공 중심 값 가져오는 코드
#             f_n = 0
#             pre_segment = 0
#             while(self.cap.isOpened()):
#                 ret, frame = self.cap.read()
            
#                 if ret:
#                     frame = frame[60:420]
#                     _, center, segment_info = self.extractPupil(frame, f_n, pre_segment)
                
#                 f_n += 1
#                 pre_segment = len(segment_info[0])
                
#                 if f_n > 3:
#                     break
        
    def pButtonGameStopFunction(self):
        self.game_start = False

    def pButtonCaliTestFunction(self):
        self.cali_test = True
        
        f_n = 0
        pre_segment = 0
        while(self.cap.isOpened()):
            ret, frame = self.cap.read()
            
            if self.tabWidget.currentIndex() == 0 or self.tabWidget.currentIndex() == 2:
                self.cali_frame[:] = (255, 255, 255)
                qpixmap = cvtPixmap(self.cali_frame, (1300, 450))
                self.label_cali.setPixmap(qpixmap)
                break

            if ret:
                self.cali_frame[:] = (255, 255, 255)
                frame = frame[60:420]
                _, center, segment_info = self.extractPupil(frame, f_n, pre_segment)

                f_n += 1
                pre_segment = len(segment_info[0])
                
                cali_center = self.getCaliCenter(center)

                self.cali_frame = cv2.circle(self.cali_frame, cali_center, 10, (255, 0, 0), -1)

                qpixmap = cvtPixmap(self.cali_frame, (1300, 450))

                self.label_cali.setPixmap(qpixmap)

                loop = QtCore.QEventLoop()
                QtCore.QTimer.singleShot(25, loop.quit) #25 ms
                loop.exec_()

                if self.cali_test == False:
                    break
            
    def pButtonCTStopFunction(self):
        self.cali_test = False
        self.cali_frame[:] = (255, 255, 255)
        qpixmap = cvtPixmap(self.cali_frame, (1300, 450))
        self.label_cali.setPixmap(qpixmap)
        
    def pButtonCaliFunction(self):
        self.cali_left = [0, 0]
        self.cali_right = [0, 0]
        self.cali_center = [0, 0]
        
        for cali_cnt in range(3):
            for radi_cnt in range(90):
                self.cali_frame[:] = (255, 255, 255)

                self.cali_frame = cv2.line(self.cali_frame, (self.cali_frame_center_x, self.cali_frame_center_y), (self.cali_frame_left_x, self.cali_frame_center_y), black, 3)
                self.cali_frame = cv2.line(self.cali_frame, (self.cali_frame_left_x, self.cali_frame_center_y), (self.cali_frame_right_x,  self.cali_frame_center_y), black, 3)

                self.cali_frame = cv2.circle(self.cali_frame, (self.cali_frame_left_x, self.cali_frame_center_y), c_radi, self.color_list[cali_cnt][1], -1) # left
                self.cali_frame = cv2.circle(self.cali_frame, (self.cali_frame_center_x, self.cali_frame_center_y), c_radi, self.color_list[cali_cnt][2], -1) # center
                self.cali_frame = cv2.circle(self.cali_frame, (self.cali_frame_right_x, self.cali_frame_center_y), c_radi, self.color_list[cali_cnt][3], -1) # right

                self.cali_frame = cv2.circle(self.cali_frame, self.color_list[cali_cnt][0], 100 - radi_cnt, (0, 0, 255), 3) # right
                
                qpixmap = cvtPixmap(self.cali_frame, (1300, 450))
                
                self.label_cali.setPixmap(qpixmap)
                
                cv2.waitKey(10)
                
                loop = QtCore.QEventLoop()
                QtCore.QTimer.singleShot(25, loop.quit) #25 ms
                loop.exec_()
            
            f_n = 0
            pre_segment = 0
            while(self.cap.isOpened()):
                ret, frame = self.cap.read()
            
                if ret:
                    frame = frame[60:420]
                    _, center, segment_info = self.extractPupil(frame, f_n, pre_segment)
                
                f_n += 1
                pre_segment = len(segment_info[0])
                
                if cali_cnt == 0:
                    self.cali_center = center
                elif cali_cnt == 1:
                    self.cali_left = center
                elif cali_cnt == 2:
                    self.cali_right = center
                
                if f_n > 3:
                    break
                    
        print(self.cali_left, ' ', self.cali_center, ' ', self.cali_right)
        
    def pButtonStartFunction(self):
        self.show_video = True
        new_time_data = int(time.time())
        
        f_n = 0
        pre_segment = 0
        while(self.cap.isOpened()):
            ret, frame = self.cap.read()
            
            if self.tabWidget.currentIndex() == 1 or self.tabWidget.currentIndex() == 2:
                self.show_video = False
                self.pButtonFinishFunction()
            
            if self.show_video == False:
                break
            
            if ret:
                frame = frame[60:420]
                s_frame = frame.copy()
                pupil_info, center, segment_info = self.extractPupil(frame, f_n, pre_segment)

                if pupil_info[0] != 0:
                    cv2.ellipse(frame, pupil_info[1], (0, 0, 255), 2)
                    cv2.circle(frame, center, 4, (255, 0, 0), -1)
                    self.blink_display = self.img_blue
                else:
                    self.blink_display = self.img_red
                    
#                 print(center[0], center[1])
                    
                f_n += 1
                pre_segment = len(segment_info[0])
                
                self.textEdit_size.clear()
                self.textEdit_size.setText(str(pupil_info[0]))

                def get_data():
                    new_time_data = int(time.time())
                    self.update_plot(new_time_data, pupil_info[0])
#                     self.update_plot(f_n, pupil[0])

                if len(self.plotData['y']) > 100:
                    del self.plotData['y'][0]
                    del self.plotData['x'][0]
                    
                mytimer = QTimer()
                mytimer.start(30)
                mytimer.timeout.connect(get_data)
                
                s_frame[segment_info[0], segment_info[1]] = (0, 0, 255)
                qpixmap = cvtPixmap(frame, (640, 480))
#                 qpixmap2 = cvtPixmap(s_frame, (160, 120))
                qpixmap2 = cvtPixmap(s_frame, (160, 120))
                qpixmap_blink = cvtPixmap(self.blink_display, (25, 25))
                
                self.label_1.setPixmap(qpixmap)
                self.label_2.setPixmap(qpixmap2)
                self.label_blinkCnt.setPixmap(qpixmap_blink)
                
                loop = QtCore.QEventLoop()
                QtCore.QTimer.singleShot(25, loop.quit) #25 ms
                loop.exec_()
            else:
                break
        
    def update_plot(self, new_time_data: int, pupil_size):
        self.plotData['y'].append(pupil_size)
        self.plotData['x'].append(new_time_data)

        self.graphWidget.setXRange(new_time_data - 10, new_time_data + 1, padding=0)   # 항상 x축 시간을 최근 범위만 보여줌.
        
        self.pdi.setData(self.plotData['x'], self.plotData['y'])
        
    def pButtonFinishFunction(self):
        self.show_video = False
        self.plotData = {'x' : [], 'y' : []}
        
        frame = np.zeros((1, 1, 3))
        qpixmap = cvtPixmap(frame, (640, 480))
        qpixmap2 = cvtPixmap(frame, (160, 120))
        self.label_1.setPixmap(qpixmap)
        self.label_2.setPixmap(qpixmap2)
        
        self.textEdit_size.clear()
        
    def extractPupil(self, frame, f_n, pre_segment):
#         frame = cv2.resize(frame, (640, 480))

        in_x = np.expand_dims(frame, axis = 0)
        predict = sess.run([model], feed_dict = {X : in_x})
        segmentation = np.argmax(predict[0][0], 2)
        segmentation = np.where(segmentation == 1, 255, 0)
        segmentation = segmentation.astype(np.uint8)
        y, x = np.where(segmentation[:, :] == 255)

        pupil_cand_list = []
        if f_n == 0:
            pre_segment = len(y)
        else:
            diff = len(y) - pre_segment

            if abs(diff) < diff_thr:
                _, contours, hierachy = cv2.findContours(segmentation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

                for cnt in contours:
                    c_area = cv2.contourArea(cnt)

                    if c_area > c_area_thr:
                        _, _, w, h = cv2.boundingRect(cnt)

                        if (float(w) / h) > ratio_thr:
                            ellipse = cv2.fitEllipse(cnt)

                            pupil_cand_list.append([c_area, ellipse, cnt])

        if len(pupil_cand_list) != 0:
            pupil = max(pupil_cand_list)
            M = cv2.moments(pupil[2])
            cx = int(M['m10']/M['m00'])
            cy = int(M['m01']/M['m00'])
        else:
            pupil = [0, 0, (0, 0)]
            cx = 0
            cy = 0

        return pupil, (cx, cy), (y, x)
    
    def getCaliCenter(self, center):
        norm_x = (center[0] - self.cali_right[0]) / (self.cali_left[0] - self.cali_right[0])
        convert_x = int(norm_x * 974)
        
        if convert_x > 650:
            convert_x = 650 - (convert_x - 650)
        else:
            convert_x = 1300 - convert_x

        convert_x -= 200
            
        if convert_x < 0:
            convert_x = 1
        if convert_x > 1300:
            convert_x = 1299
            
        return (convert_x, 255)

app = QApplication(sys.argv)
myWindow = WindowClass()
myWindow.show()
app.exec_()

W1205 02:25:56.950892 149228 deprecation_wrapper.py:119] From C:\Users\th_k9\Desktop\AI Competition\utils\deeplabv3_.py:154: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.

W1205 02:25:56.952858 149228 deprecation_wrapper.py:119] From C:\Users\th_k9\Desktop\AI Competition\utils\deeplabv3_.py:155: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.

W1205 02:25:56.953867 149228 deprecation.py:506] From C:\Users\th_k9\AppData\Local\Continuum\anaconda3\envs\Kimtae\lib\site-packages\tensorflow\python\ops\init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W1205 02:25:56.971807 149228 deprecation_wrapper.py:119] From C:\Users\th_k9\Desktop\AI Competition\utils\deeplabv3_.py:168: The name tf.nn.max_

0