Skip to content

Commit

Permalink
Fix to dist_fn when using dtw. Also added option to normalize embeddi…
Browse files Browse the repository at this point in the history
…ngs.

PiperOrigin-RevId: 265731973
  • Loading branch information
debidatta authored and Copybara-Service committed Aug 27, 2019
1 parent 58d4e1e commit 75537ae
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions tcc/visualize_alignment.py
Expand Up @@ -35,13 +35,18 @@

gfile = tf.io.gfile

flags.DEFINE_string('video_path', '/tmp/aligned.mp4', 'Path to aligned video.')
flags.DEFINE_string('embs_path', '/tmp/embeddings.npy', 'Path to '
EPSILON = 1e-7

flags.DEFINE_string('video_path', None, 'Path to aligned video.')
flags.DEFINE_string('embs_path', None, 'Path to '
'embeddings. Can be regex.')
flags.DEFINE_boolean('use_dtw', False, 'Use dynamic time warping.')
flags.DEFINE_integer('reference_video', 0, 'Reference video.')
flags.DEFINE_integer('switch_video', 10, 'Reference video.')
flags.DEFINE_integer('candidate_video', None, 'Target video.')
flags.DEFINE_boolean(
'normalize_embeddings', False, 'If True, L2 normalizes the embeddings '
'before aligning.')
flags.DEFINE_boolean(
'grid_mode', True, 'If False, switches to dynamically '
'jumping between videos.')
Expand All @@ -54,7 +59,7 @@


def dist_fn(x, y):
dist = -1.0 * np.matmul(x, y.T)
dist = np.sum((x-y)**2)
return dist


Expand Down Expand Up @@ -200,7 +205,10 @@ def visualize():
query_dict = np.load(file_obj, allow_pickle=True).item()

for j in range(len(query_dict['embs'])):
embs.append(query_dict['embs'][j])
curr_embs = query_dict['embs'][j]
if FLAGS.normalize_embeddings:
curr_embs = [x/(np.linalg.norm(x) + EPSILON) for x in curr_embs]
embs.append(curr_embs)
frames.append(query_dict['frames'][j])

if FLAGS.grid_mode:
Expand Down

0 comments on commit 75537ae

Please sign in to comment.