Copyright 2021 Google LLC.
SPDX-License-Identifier: Apache-2.0

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# **Socratic Models: MSR-VTT Video-to-Text Retrieval**

Socratic Models (SMs) is a framework that composes multiple pre-existing foundation models (e.g., large language models, visual language models, audio-language models) to provide results for new multimodal tasks, without any model finetuning.

This colab runs SMs for zero-shot video-to-text retrieval on the [MSR-VTT](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/06/cvpr16.msr-vtt.tmei_-1.pdf) Full and 1k-A test sets. Specifically, this augments [Portillo-Quintero et al. 2021](https://arxiv.org/pdf/2102.12443.pdf) with audio information by using an ALM for speech-to-text, summarizing the transcriptions with a causal LM (e.g., GPT-3), and re-ranking CLIP (VLM) matching scores against captions with a masked LM (e.g., RoBERTa) on the summaries.

This is a reference implementation of one task demonstrated in the work: [Socratic Models: Composing Zero-Shot Multimodal Reasoning with Language](https://socraticmodels.github.io/)

**Disclaimer:** this colab uses CLIP and GPT-3 as foundation models, and may be subject to unwanted biases. This code should be used with caution (and checked for correctness) in downstream applications.

### **Quick Start:**

**Step 1.** Register for an [OpenAI API key](https://openai.com/blog/openai-api/) to use GPT-3 (there's a free trial) and enter it below

**Step 2.** Menu > Change runtime type > Hardware accelerator > "GPU"

**Step 3.** Menu > Runtime > Run all

In [None]:
openai_api_key = "your-api-key"

## **Setup**
This installs a few dependencies: PyTorch, CLIP, GPT-3.

In [None]:
!pip install -U --no-cache-dir gdown --pre
!pip install -U sentence-transformers
!pip install openai ftfy
!nvidia-smi  # Show GPU info.

Thu May  5 00:59:04 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.46       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    37W / 300W |  10927MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import json
import os

import numpy as np
import openai
import pandas as pd
import pickle
from sentence_transformers import SentenceTransformer
from sentence_transformers import util as st_utils
import torch

openai.api_key = openai_api_key

In [None]:
# From: https://github.com/Deferf/CLIP_Video_Representation
if not os.path.exists('MSRVTT_test_dict_CLIP_text.pt'):
  !gdown 1-3tpfZzo1_D18WdrioQzc-iogEl-KSnA -O "MSRVTT_test_dict_CLIP_text.pt"
if not os.path.exists('MSRVTT_test_dict_CLIP_visual.pt'):
  !gdown 1Gp3_I_OvcKwjOQmn334-T4wfwQk29TCp -O "MSRVTT_test_dict_CLIP_visual.pt"
if not os.path.exists('test_videodatainfo.json'):
  !gdown 1BzTt1Bf-XJSUXxBfJVxLL3mYWLZ6odsw -O "test_videodatainfo.json"
if not os.path.exists('JS_test_dict_CLIP_text.pt'):
  !gdown --id 15mvFQxrWLNvBvFg4_9rr_Kqyzsy9dudj -O "JS_test_dict_CLIP_text.pt"

# Load generated video transcriptions from Google cloud speed-to-text API.
if not os.path.exists('video_id_to_gcloud_transcription_full.json'):
  !gdown 1LTmvtf9zzw61O7D8YUqdS2mbql76nO6E -O "video_id_to_gcloud_transcription_full.json"

# Load generated summaries from LM (comment this out to generate your own with GPT-3).
if not os.path.exists('msr_full_summaries.pkl'):
  !gdown 1ESXkRv3-3Kz1jZTNtkIhBXME6k1Jr9SW -O "msr_full_summaries.pkl"

In [None]:
# Import helper functions from Portillo-Quintero et al. 2021
!git clone https://github.com/Deferf/Experiments
%cd Experiments
from metrics import rank_at_k_precomputed,stack_encoded_dict,generate_sim_tensor,tensor_video_to_text_sim,tensor_text_to_video_metrics,normalize_matrix,pad_dict,list_recall
%cd "/content"

fatal: destination path 'Experiments' already exists and is not an empty directory.
/content/Experiments
/content


##### Load RoBERTa (masked LM)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
roberta_model = SentenceTransformer('stsb-roberta-large').to(device)

##### Wrap GPT-3 (causal LM)

In [None]:
gpt_version = "text-davinci-002"
def prompt_llm(prompt, max_tokens=64, temperature=0, stop=None):
  response = openai.Completion.create(engine=gpt_version, prompt=prompt, max_tokens=max_tokens, temperature=temperature, stop=stop)
  return response["choices"][0]["text"].strip()

## **Evaluate on MSR-Full**

In [None]:
# Load raw text captions from MSR-Full.
with open('test_videodatainfo.json', 'r') as j:
  msr_full_info = json.loads(j.read())
msr_full_vid_id_to_captions = {}
for info in msr_full_info['sentences']:
  if info['video_id'] not in msr_full_vid_id_to_captions:
    msr_full_vid_id_to_captions[info['video_id']] = []
  msr_full_vid_id_to_captions[info['video_id']].append(info['caption'])

In [None]:
# Reproduce original results with original eval code.
msr_full_vid_id_to_clip_vid_feats = torch.load("/content/MSRVTT_test_dict_CLIP_visual.pt", map_location="cpu")
msr_full_vid_ids_to_clip_text_feats = torch.load("/content/MSRVTT_test_dict_CLIP_text.pt", map_location="cpu")
msr_full_vid_ids = list(msr_full_vid_ids_to_clip_text_feats.keys())
msr_full_sim_tensor = generate_sim_tensor(msr_full_vid_ids_to_clip_text_feats, msr_full_vid_id_to_clip_vid_feats, msr_full_vid_ids)
msr_full_vid_text_sim = tensor_video_to_text_sim(msr_full_sim_tensor)
msr_full_metrics_vtt = rank_at_k_precomputed(msr_full_vid_text_sim)
print(msr_full_metrics_vtt)

{'R@1': 40.301002502441406, 'R@5': 69.7324447631836, 'R@10': 79.19732666015625, 'Median_Rank': 2.0, 'Mean_Rank': 13.206688963210702, 'Std_Rank': 43.19973161311378}


In [None]:
# Transcription results from gCloud API.
with open('video_id_to_gcloud_transcription_full.json', 'r') as j:
  msr_full_vid_id_to_transcript = json.loads(j.read())
 
# Sort video IDs by transcription length.
num_transcripts = 0
transcript_lengths = []
for i in msr_full_vid_ids:
  if msr_full_vid_id_to_transcript[i] is None:
    transcript_lengths.append(0)
  else:
    num_transcripts += 1
    transcript_lengths.append(len(msr_full_vid_id_to_transcript[i]))
msr_full_sorted_vid_ids = [msr_full_vid_ids[i] for i in np.argsort(transcript_lengths)[::-1]]

In [None]:
# Summarize transcriptions with LLM.
if os.path.exists('msr_full_summaries.pkl'):
  msr_full_vid_id_to_summary = pickle.load(open('msr_full_summaries.pkl', 'rb'))
else:

  # Zero-shot LLM: summarize transcriptions.
  msr_full_vid_id_to_summary = {}
  for vid_id in msr_full_sorted_vid_ids:
    transcript = msr_full_vid_id_to_transcript[vid_id]
    print('Video ID:', vid_id)
    print('Transcript:', transcript)
  
    if transcript is not None:
      transcript = transcript.strip()
      prompt = 'I am an intelligent video captioning bot.'
      prompt += f'\nI hear a person saying: "{transcript}".'
      prompt += f"\nQ: What's a short video caption for this video? A: In this video,"
      print('Prompt:', prompt)
      summary = prompt_llm(prompt, temperature=0, stop='.')
      print('Summary:', summary)
      msr_full_vid_id_to_summary[vid_id] = summary
  
    pickle.dump(msr_full_vid_id_to_summary, open(f'msr_full_summaries.pkl', 'wb'))

In [None]:
# Compute RoBERTa features for all captions.
msr_full_vid_id_to_roberta_feats = {}
for vid_id in msr_full_sorted_vid_ids:
  msr_full_vid_id_to_roberta_feats[vid_id] = roberta_model.encode(msr_full_vid_id_to_captions[vid_id], convert_to_tensor=True, device=device)

In [None]:
topk = 100  # Pre-rank with top-100 from Portillo.
combine_clip_roberta = True  # Combine CLIP (text-video) x RoBERTa (text-text) scores?
portillo_vid_id_to_topk_vid_ids = {}
socratic_vid_id_to_topk_vid_ids = {}
msr_full_all_clip_text_feats = torch.cat([msr_full_vid_ids_to_clip_text_feats[i] for i in msr_full_sorted_vid_ids], dim=0).cpu().numpy()
for vid_id in msr_full_sorted_vid_ids:
 
  # Get Portillo top-K captions.
  vid_feats = msr_full_vid_id_to_clip_vid_feats[vid_id]  # CLIP features for all frames of the video
  vid_feat = normalize_matrix(torch.mean(vid_feats, dim = 0, keepdim = True)).cpu().numpy()
  clip_scores = msr_full_all_clip_text_feats @ vid_feat.T
  clip_scores = clip_scores.squeeze()
  clip_scores = clip_scores.reshape(-1, 20)
  clip_scores = np.max(clip_scores, axis=1)
  sorted_idx = np.argsort(clip_scores).squeeze()[::-1]
  portillo_topk_vid_ids = [msr_full_sorted_vid_ids[i] for i in sorted_idx[:topk]]
  portillo_vid_id_to_topk_vid_ids[vid_id] = portillo_topk_vid_ids

  # If no LLM summary, default to Portillo ranking.
  socratic_vid_id_to_topk_vid_ids[vid_id] = portillo_topk_vid_ids
  if vid_id not in msr_full_vid_id_to_summary:
    continue

  # Get RoBERTa scores between LLM summary and captions.
  summary = msr_full_vid_id_to_summary[vid_id]
  summary_feat = roberta_model.encode([summary], convert_to_tensor=True, device=device)
  caption_feats = torch.cat([msr_full_vid_id_to_roberta_feats[i] for i in portillo_topk_vid_ids], dim=0)
  roberta_scores = st_utils.pytorch_cos_sim(caption_feats, summary_feat).detach().cpu().numpy().squeeze()
  roberta_scores = roberta_scores.reshape(-1, 20)
  roberta_scores = np.max(roberta_scores, axis=1)

  # Re-rank top-K with RoBERTa scores.
  sort_idx = np.argsort(roberta_scores, kind='stable').squeeze()[::-1]
  socratic_vid_id_to_topk_vid_ids[vid_id] = [portillo_topk_vid_ids[i] for i in sort_idx]

  # Combine CLIP (text-video) x RoBERTa (text-text) scores.
  if combine_clip_roberta:
    clip_scores = np.sort(clip_scores, kind='stable').squeeze()[::-1][:topk]
    scores = clip_scores * roberta_scores
    sort_idx = np.argsort(scores, kind='stable').squeeze()[::-1]
    socratic_vid_id_to_topk_vid_ids[vid_id] = [portillo_topk_vid_ids[i] for i in sort_idx]  # Override ranking from only LLM

In [None]:
# Return R@1, R@5, R@10.
def get_recall(vid_ids, socratic_subset, k=[1, 5, 10]):
 recall = []
 rank = []
 for vid_id in vid_ids:
   sorted_vid_ids = portillo_vid_id_to_topk_vid_ids[vid_id]
   if vid_id in socratic_subset:
     sorted_vid_ids = socratic_vid_id_to_topk_vid_ids[vid_id]
   recall.append([(vid_id in sorted_vid_ids[:i]) for i in k])
   rank.append(sorted_vid_ids.index(vid_id) + 1 if vid_id in sorted_vid_ids else len(sorted_vid_ids))
 mdr = np.median(rank)
 return np.mean(np.float32(recall) * 100, axis=0), mdr
 
subset_size = 1007  # Subset of long transcripts.
 
# Portillo only.
recall, mdr = get_recall(msr_full_sorted_vid_ids, msr_full_sorted_vid_ids[:0])
print(f'R@1: {recall[0]:.1f}\tR@5: {recall[1]:.1f}\tR@10: {recall[2]:.1f}\tMdR: {mdr}')
 
# Socratic + Portillo.
recall, mdr = get_recall(msr_full_sorted_vid_ids, msr_full_sorted_vid_ids[:subset_size])
print(f'R@1: {recall[0]:.1f}\tR@5: {recall[1]:.1f}\tR@10: {recall[2]:.1f}\tMdR: {mdr}')
 
# Portillo only on long transcripts.
recall, mdr = get_recall(msr_full_sorted_vid_ids[:subset_size], msr_full_sorted_vid_ids[:0])
print(f'R@1: {recall[0]:.1f}\tR@5: {recall[1]:.1f}\tR@10: {recall[2]:.1f}\tMdR: {mdr}')
 
# Socratic + Portillo on long transcripts.
recall, mdr = get_recall(msr_full_sorted_vid_ids[:subset_size], msr_full_sorted_vid_ids[:subset_size])
print(f'R@1: {recall[0]:.1f}\tR@5: {recall[1]:.1f}\tR@10: {recall[2]:.1f}\tMdR: {mdr}')

R@1: 40.2	R@5: 69.7	R@10: 79.2	MdR: 2.0
R@1: 44.7	R@5: 71.2	R@10: 80.0	MdR: 2.0
R@1: 41.5	R@5: 69.6	R@10: 77.4	MdR: 2.0
R@1: 54.9	R@5: 74.0	R@10: 79.9	MdR: 1.0
