Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modified code for Tensorflow 2.4 and scipy>1.0.0 #1192

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 45 additions & 9 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,45 @@
tensorflow==1.7
scipy
scikit-learn
opencv-python
h5py
matplotlib
Pillow
requests
psutil
absl-py==0.11.0
astunparse==1.6.3
cachetools==4.2.1
certifi==2020.12.5
chardet==4.0.0
cycler==0.10.0
flatbuffers==1.12
gast==0.3.3
google-auth==1.24.0
google-auth-oauthlib==0.4.2
google-pasta==0.2.0
grpcio==1.32.0
h5py==2.10.0
idna==2.10
joblib==1.0.0
Keras-Preprocessing==1.1.2
kiwisolver==1.3.1
Markdown==3.3.3
matplotlib==3.3.4
numpy==1.19.5
oauthlib==3.1.0
opencv-python==4.5.1.48
opt-einsum==3.3.0
Pillow==8.1.0
protobuf==3.14.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pyparsing==2.4.7
python-dateutil==2.8.1
requests==2.25.1
requests-oauthlib==1.3.0
rsa==4.7
scikit-learn==0.24.1
scipy==1.6.0
six==1.15.0
tensorboard==2.4.1
tensorboard-plugin-wit==1.8.0
tensorflow==2.4.1
tensorflow-estimator==2.4.0
termcolor==1.1.0
threadpoolctl==2.1.0
typing-extensions==3.7.4.3
urllib3==1.26.3
Werkzeug==1.0.1
wrapt==1.12.1
12 changes: 7 additions & 5 deletions src/align/align_dataset_mtcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import align.detect_face
import random
from time import sleep
from PIL import Image

def main(args):
sleep(random.random())
Expand All @@ -49,8 +50,8 @@ def main(args):
print('Creating networks and loading parameters')

with tf.Graph().as_default():
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_memory_fraction)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, log_device_placement=False))
gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=args.gpu_memory_fraction)
sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options, log_device_placement=False))
with sess.as_default():
pnet, rnet, onet = align.detect_face.create_mtcnn(sess, None)

Expand Down Expand Up @@ -80,7 +81,7 @@ def main(args):
print(image_path)
if not os.path.exists(output_filename):
try:
img = misc.imread(image_path)
img = np.array(Image.open(image_path))
except (IOError, ValueError, IndexError) as e:
errorMessage = '{}: {}'.format(image_path, e)
print(errorMessage)
Expand Down Expand Up @@ -121,14 +122,15 @@ def main(args):
bb[2] = np.minimum(det[2]+args.margin/2, img_size[1])
bb[3] = np.minimum(det[3]+args.margin/2, img_size[0])
cropped = img[bb[1]:bb[3],bb[0]:bb[2],:]
scaled = misc.imresize(cropped, (args.image_size, args.image_size), interp='bilinear')
cropped = Image.fromarray(np.uint8(cropped))
scaled = cropped.resize((args.image_size, args.image_size), Image.ANTIALIAS)
nrof_successfully_aligned += 1
filename_base, file_extension = os.path.splitext(output_filename)
if args.detect_multiple_faces:
output_filename_n = "{}_{}{}".format(filename_base, i, file_extension)
else:
output_filename_n = "{}{}".format(filename_base, file_extension)
misc.imsave(output_filename_n, scaled)
scaled.save(output_filename_n)
text_file.write('%s %d %d %d %d\n' % (output_filename_n, bb[0], bb[1], bb[2], bb[3]))
else:
print('Unable to align "%s"' % image_path)
Expand Down
32 changes: 16 additions & 16 deletions src/align/detect_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ def load(self, data_path, session, ignore_missing=False):
session: The current TensorFlow session
ignore_missing: If true, serialized weights for missing layers are ignored.
"""
data_dict = np.load(data_path, encoding='latin1').item() #pylint: disable=no-member
data_dict = np.load(data_path, encoding='latin1', allow_pickle=True).item() #pylint: disable=no-member

for op_name in data_dict:
with tf.variable_scope(op_name, reuse=True):
with tf.compat.v1.variable_scope(op_name, reuse=True):
for param_name, data in iteritems(data_dict[op_name]):
try:
var = tf.get_variable(param_name)
var = tf.compat.v1.get_variable(param_name)
session.run(var.assign(data))
except ValueError:
if not ignore_missing:
Expand Down Expand Up @@ -122,7 +122,7 @@ def get_unique_name(self, prefix):

def make_var(self, name, shape):
"""Creates a new TensorFlow variable."""
return tf.get_variable(name, shape, trainable=self.trainable)
return tf.compat.v1.get_variable(name, shape, trainable=self.trainable)

def validate_padding(self, padding):
"""Verifies that the padding is one of the supported ones."""
Expand Down Expand Up @@ -150,7 +150,7 @@ def conv(self,
assert c_o % group == 0
# Convolution for a given input and kernel
convolve = lambda i, k: tf.nn.conv2d(i, k, [1, s_h, s_w, 1], padding=padding)
with tf.variable_scope(name) as scope:
with tf.compat.v1.variable_scope(name) as scope:
kernel = self.make_var('weights', shape=[k_h, k_w, c_i // group, c_o])
# This is the common-case. Convolve the input without any further complications.
output = convolve(inp, kernel)
Expand All @@ -165,7 +165,7 @@ def conv(self,

@layer
def prelu(self, inp, name):
with tf.variable_scope(name):
with tf.compat.v1.variable_scope(name):
i = int(inp.get_shape()[-1])
alpha = self.make_var('alpha', shape=(i,))
output = tf.nn.relu(inp) + tf.multiply(alpha, -tf.nn.relu(-inp))
Expand All @@ -182,7 +182,7 @@ def max_pool(self, inp, k_h, k_w, s_h, s_w, name, padding='SAME'):

@layer
def fc(self, inp, num_out, name, relu=True):
with tf.variable_scope(name):
with tf.compat.v1.variable_scope(name):
input_shape = inp.get_shape()
if input_shape.ndims == 4:
# The input is spatial. Vectorize it first.
Expand All @@ -191,10 +191,10 @@ def fc(self, inp, num_out, name, relu=True):
dim *= int(d)
feed_in = tf.reshape(inp, [-1, dim])
else:
feed_in, dim = (inp, input_shape[-1].value)
feed_in, dim = (inp, input_shape.as_list()[-1])
weights = self.make_var('weights', shape=[dim, num_out])
biases = self.make_var('biases', [num_out])
op = tf.nn.relu_layer if relu else tf.nn.xw_plus_b
op = tf.nn.relu_layer if relu else tf.compat.v1.nn.xw_plus_b
fc = op(feed_in, weights, biases, name=name)
return fc

Expand All @@ -210,7 +210,7 @@ def softmax(self, target, axis, name=None):
max_axis = tf.reduce_max(target, axis, keepdims=True)
target_exp = tf.exp(target-max_axis)
normalize = tf.reduce_sum(target_exp, axis, keepdims=True)
softmax = tf.div(target_exp, normalize, name)
softmax = tf.compat.v1.div(target_exp, normalize, name)
return softmax

class PNet(Network):
Expand Down Expand Up @@ -277,16 +277,16 @@ def create_mtcnn(sess, model_path):
if not model_path:
model_path,_ = os.path.split(os.path.realpath(__file__))

with tf.variable_scope('pnet'):
data = tf.placeholder(tf.float32, (None,None,None,3), 'input')
with tf.compat.v1.variable_scope('pnet'):
data = tf.compat.v1.placeholder(tf.float32, (None,None,None,3), 'input')
pnet = PNet({'data':data})
pnet.load(os.path.join(model_path, 'det1.npy'), sess)
with tf.variable_scope('rnet'):
data = tf.placeholder(tf.float32, (None,24,24,3), 'input')
with tf.compat.v1.variable_scope('rnet'):
data = tf.compat.v1.placeholder(tf.float32, (None,24,24,3), 'input')
rnet = RNet({'data':data})
rnet.load(os.path.join(model_path, 'det2.npy'), sess)
with tf.variable_scope('onet'):
data = tf.placeholder(tf.float32, (None,48,48,3), 'input')
with tf.compat.v1.variable_scope('onet'):
data = tf.compat.v1.placeholder(tf.float32, (None,48,48,3), 'input')
onet = ONet({'data':data})
onet.load(os.path.join(model_path, 'det3.npy'), sess)

Expand Down
8 changes: 4 additions & 4 deletions src/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main(args):

with tf.Graph().as_default():

with tf.Session() as sess:
with tf.compat.v1.Session() as sess:

np.random.seed(seed=args.seed)

Expand Down Expand Up @@ -69,9 +69,9 @@ def main(args):
facenet.load_model(args.model)

# Get input and output tensors
images_placeholder = tf.get_default_graph().get_tensor_by_name("input:0")
embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
phase_train_placeholder = tf.get_default_graph().get_tensor_by_name("phase_train:0")
images_placeholder = tf.compat.v1.get_default_graph().get_tensor_by_name("input:0")
embeddings = tf.compat.v1.get_default_graph().get_tensor_by_name("embeddings:0")
phase_train_placeholder = tf.compat.v1.get_default_graph().get_tensor_by_name("phase_train:0")
embedding_size = embeddings.get_shape()[1]

# Run forward pass to calculate embeddings
Expand Down
5 changes: 3 additions & 2 deletions src/facenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from tensorflow.python.platform import gfile
import math
from six import iteritems
from PIL import Image

def triplet_loss(anchor, positive, negative, alpha):
"""Calculate the triplet loss according to the FaceNet paper
Expand Down Expand Up @@ -244,7 +245,7 @@ def load_data(image_paths, do_random_crop, do_random_flip, image_size, do_prewhi
nrof_samples = len(image_paths)
images = np.zeros((nrof_samples, image_size, image_size, 3))
for i in range(nrof_samples):
img = misc.imread(image_paths[i])
img = np.array(Image.open(image_paths[i]))
if img.ndim == 2:
img = to_rgb(img)
if do_prewhiten:
Expand Down Expand Up @@ -368,7 +369,7 @@ def load_model(model, input_map=None):
if (os.path.isfile(model_exp)):
print('Model filename: %s' % model_exp)
with gfile.FastGFile(model_exp,'rb') as f:
graph_def = tf.GraphDef()
graph_def = tf.compat.v1.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, input_map=input_map, name='')
else:
Expand Down