-
Notifications
You must be signed in to change notification settings - Fork 1
/
demo.py
88 lines (68 loc) · 2.35 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# from data import DataSet
import os
import glob
import numpy as np
from extractor import Extractor
from keras.models import load_model
import sys
from subprocess import call
import shutil
import matplotlib.pyplot as plt
if len(sys.argv) == 1:
print("No args! exiting")
exit()
# shape that lstm expects is (40, 2048)
seq_length = 40
# load trained lstm custom model
inception = Extractor() # load pre-trained inception model
saved_model = 'data/checkpoints/lstm-features.008-0.105.hdf5'
model = load_model(saved_model)
fname_ext = os.path.basename(sys.argv[1])
fname = fname_ext.split('.')[0]
call(["ffmpeg", "-i", sys.argv[1], "-r", str(seq_length), os.path.join('data/extracted_frames', fname + '-%04d.jpg')])
# data = DataSet(seq_length=40, class_limit=2)
frames = sorted(glob.glob(os.path.join('data/extracted_frames', fname + '*jpg')))
# make sure number of frames is a multiple of seq_length
nframes = len(frames)-(len(frames)%seq_length)
# remove extra frames
frames = frames[:nframes]
x = [i for i in range(0, nframes//seq_length)]
# x = np.linspace(0, nframes//40, nframes//40)
y_violent = []
y_non_violent = []
for i in range(0, nframes, seq_length):
sequence = []
for frame in frames[i:i+seq_length]:
features = inception.extract(frame)
sequence.append(features)
prediction = model.predict(np.expand_dims(sequence, axis=0))
# prediction[0][0] is non violent score
# prediction[0][1] is violent score
# print(prediction)
y_violent.append(prediction[0][1])
y_non_violent.append(prediction[0][0])
# np.save('data/saved_sequence/' + fname, sequence)
print(x)
print(y_violent)
print(y_non_violent)
# plt.step(x, y_violent, label='violence score')
# plt.step(x, y_non_violent, label='non-violent score')
plt.plot(x, y_violent, 'r', label='violence-score')
plt.plot(x, y_non_violent, 'r', label='non-violence-score')
plt.xlabel('time(s)')
plt.ylabel('violence')
plt.title('Violence in video')
plt.ylim(0, 1)
plt.legend()
plt.show()
# clean up by deleting frames captured
shutil.rmtree('data/extracted_frames')
os.makedirs('data/extracted_frames')
# print(prediction)
# fig1, ax1 = plt.subplots()
# labels = ['Non violent score', 'Violent score']
# explode = [0, 0.1]
# ax1.pie(prediction.tolist()[0], explode=explode, labels=labels, autopct='%1.1f%%',
# shadow=True, startangle=90)
# ax1.axis('equal')
# plt.show()