<a href="https://colab.research.google.com/github/cosmo3769/s3d-mil-nce/blob/main/notebooks/s3d_mil_nce.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
print(f'Installing Weights and Biases')
!pip install -qq --upgrade wandb

Installing Weights and Biases
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.7/6.7 MB[0m [31m17.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m277.3/277.3 kB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os
import cv2
import math
import wandb
import numpy as np
import tensorflow_hub as hub
import tensorflow.compat.v2 as tf

from IPython import display

In [None]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
# Load the model once from TF-Hub.
hub_handle = 'https://tfhub.dev/deepmind/mil-nce/s3d/1'
hub_model = hub.load(hub_handle)

def generate_embeddings(model, input_frames, input_words):
  """Generate embeddings from the model from video frames and input words."""
  # Input_frames must be normalized in [0, 1] and of the shape Batch x T x H x W x 3
  vision_output = model.signatures['video'](tf.constant(tf.cast(input_frames, dtype=tf.float32)))
  text_output = model.signatures['text'](tf.constant(input_words))
  return vision_output['video_embedding'], text_output['text_embedding']

In [None]:
def crop_center_square(frame):
  y, x = frame.shape[0:2]
  min_dim = min(y, x)
  start_x = (x // 2) - (min_dim // 2)
  start_y = (y // 2) - (min_dim // 2)
  return frame[start_y:start_y+min_dim,start_x:start_x+min_dim]

In [None]:
def load_video(video_url, max_frames=32, resize=(224, 224)):
  path = tf.keras.utils.get_file(os.path.basename(video_url)[-128:], video_url)
  cap = cv2.VideoCapture(path)
  frames = []
  try:
    while True:
      ret, frame = cap.read()
      if not ret:
        break
      frame = crop_center_square(frame)
      frame = cv2.resize(frame, resize)
      frame = frame[:, :, [2, 1, 0]]
      frames.append(frame)

      if len(frames) == max_frames:
        break
  finally:
    cap.release()
  frames = np.array(frames)
  if len(frames) < max_frames:
    n_repeat = int(math.ceil(max_frames / float(len(frames))))
    frames = frames.repeat(n_repeat, axis=0)
  frames = frames[:max_frames]
  frames = frames / 255.0
  return frames

In [None]:
def display_video(urls):
    html = '<table>'
    html += '<tr><th>Video 1</th><th>Video 2</th><th>Video 3</th></tr><tr>'
    for url in urls:
        html += '<td>'
        html += '<img src="{}" height="224">'.format(url)
        html += '</td>'
    html += '</tr></table>'
    return html

In [None]:
def display_query_and_results_video(query, urls, scores):
  """Display a text query and the top result videos and scores."""
  sorted_ix = np.argsort(-scores)
  html = ''
  html += '<h2>Input query: <i>{}</i> </h2><div>'.format(query)
  html += 'Results: <div>'
  html += '<table>'
  html += '<tr><th>Rank #1, Score:{:.2f}</th>'.format(scores[sorted_ix[0]])
  html += '<th>Rank #2, Score:{:.2f}</th>'.format(scores[sorted_ix[1]])
  html += '<th>Rank #3, Score:{:.2f}</th></tr><tr>'.format(scores[sorted_ix[2]])
  for i, idx in enumerate(sorted_ix):
    url = urls[sorted_ix[i]];
    html += '<td>'
    html += '<img src="{}" height="224">'.format(url)
    html += '</td>'
  html += '</tr></table>'
  return html

In [None]:
video_1_url = 'https://github.com/cosmo3769/s3d-mil-nce/blob/main/gif_dir/dancing-cat.gif?raw=true'
video_2_url = 'https://github.com/cosmo3769/s3d-mil-nce/blob/main/gif_dir/sunset.gif?raw=true'
video_3_url = 'https://github.com/cosmo3769/s3d-mil-nce/blob/main/gif_dir/cycle.gif?raw=true'
all_videos_urls = [video_1_url, video_2_url, video_3_url]

video_1 = load_video(video_1_url)
video_2 = load_video(video_2_url)
video_3 = load_video(video_3_url)
all_videos = [video_1, video_2, video_3]

query_1 = 'Dancing'
query_2 = 'Sunset'
query_3 = 'Cycling'
all_queries = [query_1, query_2, query_3]

In [None]:
display.HTML(display_video(all_videos_urls))

Video 1,Video 2,Video 3
,,


In [None]:
wandb.init(
    entity="cosmo3769",
    project="s3d-mil-nce",
    name="display_gif_table_6"
)
wandb.log({"display_video": wandb.Html(display_video(all_videos_urls))})
wandb.finish()

VBox(children=(Label(value='0.011 MB of 0.011 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

In [None]:
# Prepare video inputs.
videos_np = np.stack(all_videos, axis=0)

# Prepare text input.
words_np = np.array(all_queries)

# Generate the video and text embeddings.
video_embd, text_embd = generate_embeddings(hub_model, videos_np, words_np)

# Scores between video and text is computed by dot products.
all_scores = np.dot(text_embd, tf.transpose(video_embd))

In [None]:
# Display results.
html = ''
for i, words in enumerate(words_np):
  html += display_query_and_results_video(words, all_videos_urls, all_scores[i, :])
  html += '<br>'
display.HTML(html)

"Rank #1, Score:3.38","Rank #2, Score:1.19","Rank #3, Score:-0.98"
,,

"Rank #1, Score:6.09","Rank #2, Score:3.98","Rank #3, Score:-1.34"
,,

"Rank #1, Score:7.59","Rank #2, Score:1.05","Rank #3, Score:0.36"
,,


In [None]:
wandb.init(
    entity="cosmo3769",
    project="s3d-mil-nce",
    name="display_gif_table_7"
)
wandb.log({"display_query_and_results_video": wandb.Html(html)})
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mcosmo3769[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='0.012 MB of 0.012 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))