-
Notifications
You must be signed in to change notification settings - Fork 8
/
run_gazenet.py
76 lines (66 loc) · 2.98 KB
/
run_gazenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import argparse
import sys
from pathlib import Path
import cv2
import torch
# Import local files and utils
root_dir = Path.cwd()
sys.path.append(str(root_dir))
import src.cam_utils as cam_utils
import src.data_utils as data_utils
'''
This file is used to run the Gaze Net. It takes images from your webcam, feeds them through the model
and outputs your current gaze location on the screen.
'''
parser = argparse.ArgumentParser(description='Gazenet Runner')
parser.add_argument('--model', type=str, default=None,
help='Model to run[default: None]')
parser.add_argument('-src', '--source', dest='video_source', type=int,
default=0, help='Device index of the camera.')
parser.add_argument('-wd', '--width', dest='width', type=int,
default=128, help='Width of the frames in the video stream.')
parser.add_argument('-ht', '--height', dest='height', type=int,
default=96, help='Height of the frames in the video stream.')
parser.add_argument('--window_name', type=str, default='GazeNet',
help='Name of window for when running [default: GazeNet]')
args = parser.parse_args()
def gaze_inference(image_np, model):
# Convert input image from numpy
input_image = data_utils.ndimage_to_variable(image_np,
imsize=(args.height, args.width),
use_gpu=True)
# Inference
# Clamp output to between 0 and 1 (outside this range doesn't make sense)
gaze_output = model(input_image).clamp(0, 1)
gaze_list = gaze_output.cpu().data.numpy().tolist()[0]
# Visualization of the results of a detection.
canvas = cam_utils.screen_plot(gaze_list, image=image_np, window_name=args.window_name)
return canvas
if __name__ == '__main__':
# Print out parameters
print('Gazenet Model Runner. Parameters:')
for attr, value in args.__dict__.items():
print('%s : %s' % (attr.upper(), value))
# Load Pytorch model from saved models directory
model_path = str(Path.cwd() / 'src' / 'models' / args.model)
print('Loading model from %s' % model_path)
model = torch.load(model_path)
# Start up webcam stream and fps tracker
video_capture = cam_utils.WebcamVideoStream(src=args.video_source,
width=args.width,
height=args.height).start()
fps = cam_utils.FPS().start()
while True: # fps._numFrames < 120
frame = video_capture.read()
output_frame = gaze_inference(frame, model)
cv2.imshow(args.window_name, output_frame)
fps.update()
if cv2.waitKey(1) & 0xFF == ord('q'):
break
# Print out fps tracker summary
fps.stop()
print('[INFO] elapsed time (total): {:.2f}'.format(fps.elapsed()))
print('[INFO] approx. FPS: {:.2f}'.format(fps.fps()))
# Clean up camera streams, cv2 windows, etc
video_capture.stop()
cv2.destroyAllWindows()