<a href="https://colab.research.google.com/github/nateraw/openai-vision-api-for-videos/blob/main/demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Trying OpenAI's Vision API on Videos

OpenAI just launched their [Vision API](https://platform.openai.com/docs/guides/vision) - a multimodal LLM that understands language and images that acts as a helpful assistant.

The debut model,  `gpt-4-vision-preview`, has a whopping [128k context window](https://platform.openai.com/docs/models/gpt-4-and-gpt-4-turbo) 🤯. This means that in one pass of the model, it can process up to 128,000 tokens.

In this notebook, we'll put that context to use by passing frames from videos sequentially into it, and see what it can do.

# If you find this notebook helpful...

Please consider supporting me by:

- giving [the repo](https://github.com/nateraw/openai-vision-api-for-videos) a ⭐️
- Following me on [GitHub](https://github.com/nateraw) and/or [Twitter](https://x.com/_nateraw). ❤️

Share issues/feature requests [here](https://github.com/nateraw/openai-vision-api-for-videos/issues).

<details>
  <summary><strong>⚠️ Note About Costs (Click to expand)</strong></summary>
  <p>Pricing is done by token, just like with language models. However, Vision tokens are counted differently. The amount of tokens a single image is billed for depends on a wide variety of factors. This is outside of the scope of this notebook, but I encourage you to read <a href="https://platform.openai.com/docs/guides/vision/calculating-costs">this rather complicated explanation</a> for more info so you (hopefully 😅) understand how you'll be billed for the Vision API.</p>

  <p>You should be mindful of your spending, and make sure you have a spend limit set to avoid any bank-draining whoopsies.</p>

</details>

# Install dependencies

In [None]:
%%capture
! pip install openai yt-dlp av replicate

# Setup Env

<!-- TODO maybe use this instead of form -->
On the left panel, select the 🔑 icon to add secrets:

- `OPENAI_API_KEY`: Key from OpenAI, which you can find [here](https://platform.openai.com/account/api-keys) when logged in.
- `REPLICATE_API_TOKEN`: (only required for transcription) A key from your account on Replicate, which can be found [here](https://replicate.com/account/api-tokens) if logged in.

<img src="https://huggingface.co/datasets/nateraw/documentation-images/resolve/main/Screen%20Shot%202023-11-07%20at%202.38.30%20AM.png?download=true">

In [None]:
import os
from google.colab import userdata

os.environ["OPENAI_API_KEY"] = userdata.get("OPENAI_API_KEY")
if not os.environ.get("OPENAI_API_KEY"):
    raise ValueError("OPENAI_API_KEY not set")

os.environ["REPLICATE_API_TOKEN"] = userdata.get("REPLICATE_API_TOKEN")

# Download YouTube video to analyze

Note the format of the URL is:

`https://www.youtube.com/watch?v=<YouTube-ID>`

If you use a different format, you might have a bad time.

In [None]:
import subprocess
import re
from pathlib import Path
from datetime import datetime

from IPython.display import YouTubeVideo

youtube_url = "https://www.youtube.com/watch?v=S60GxA9JpLk"  # @param {type:"string"}
data_dir = Path("./videos")
data_dir.mkdir(exist_ok=True)

# check for existing matches based on the YTID
ytid = re.search(r'v=(.*)', youtube_url).group(1)
matches = list(data_dir.glob(f'*_{ytid}_*'))
if matches:
    filepath = Path(matches[-1])
    print(f"Found existing file {filepath}")
else:
    datetime_prefix = datetime.now().strftime("%Y%m%d%H%M%S")
    output_template = str(data_dir / f'{datetime_prefix}_%(id)s_%(title)s.%(ext)s')

    command = [
        'yt-dlp',
        '-f', 'best[ext=mp4]',
        '-o', output_template,
        youtube_url
    ]

    process = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
    filepath = next(data_dir.glob(f'{datetime_prefix}*.mp4'))
    print(f"Downloaded video path: {filepath}")


YouTubeVideo(ytid)

## Define Utilities for Extracting Frames / Hitting the API

#### ⚠️ Note
You could write this with WAY less code with the `decord` library if you just want to load frames.

I'm using `av` because it handles audio + video together quite well, which will be helpful when I extend the ideas in this notebook in the future. Some of this code is modified from [pytorchvideo](https://github.com/facebookresearch/pytorchvideo).

In [None]:
# @title
%matplotlib inline

import av
import gc
from typing import Tuple, List
import math
import numpy as np
from fractions import Fraction
import subprocess
from pathlib import Path
import matplotlib.pyplot as plt
import cv2
import numpy as np
import math
from typing import Tuple, Union
import base64
from PIL import Image
from io import BytesIO

import openai


def display_frames_as_grid(frames, rows=5, cols=5):
    # Adjust the number of rows to ensure there are no completely empty rows
    total_frames = frames.shape[0]
    while rows > 1 and total_frames <= (rows - 1) * cols:
        rows -= 1

    # Initialize a figure with the adjusted number of rows
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))

    # If there's only one row or column, axes might not be a 2D array
    if rows == 1 or cols == 1:
        axes = np.array(axes).reshape(-1)

    # Flatten axes array for easy indexing
    axes = axes.flatten()

    # Loop through the grid and add frames with indices
    for i in range(rows * cols):
        ax = axes[i]
        if i < total_frames:  # Check if the current frame exists
            ax.imshow(frames[i])
            ax.text(0.05, 0.95, f'Frame {i}', color='white', weight='bold',
                    transform=ax.transAxes, ha="left", va="top", fontsize=8)
        ax.axis('off')  # Hide axes

    plt.tight_layout()
    plt.show()


def _pyav_decode_stream(
    container: av.container.input.InputContainer,
    start_sec: float,
    end_sec: float,
    stream: av.video.stream.VideoStream,
    perform_seek: bool = True,
) -> Tuple[List, float]:
    """
    Decode the video with PyAV decoder.
    Args:
        container (container): PyAV container.
        start_sec (float): the starting second to fetch the frames.
        end_sec (float): the ending second of the decoded frames.
        stream (stream): PyAV stream.
    Returns:
        result (np.ndarray): np array of decoded frames.
    """

    start_pts = math.ceil(start_sec / stream.time_base)
    end_pts = math.ceil(end_sec / stream.time_base)
    # NOTE:
    # Don't want to seek if iterating through a video due to slow-downs. I
    # believe this is some PyAV bug where seeking after a certain point causes
    # major slow-downs
    if perform_seek:
        # Seeking in the stream is imprecise. Thus, seek to an earlier pts by a
        # margin pts.
        margin = 1024
        seek_offset = max(start_pts - margin, 0)
        container.seek(int(seek_offset), any_frame=False, backward=True, stream=stream)

    frames = []
    stream_info = {'video': 0} if stream.type == 'video' else {'audio': 0}
    for frame in container.decode(**stream_info):
        if frame.pts >= start_pts and frame.pts < end_pts:
            frames.append(frame)
        elif frame.pts >= end_pts:
            break

    if stream.type == 'audio':
        return np.concatenate([x.to_ndarray() for x in frames], 1)
    else:
        return np.stack([x.to_ndarray(format='rgb24') for x in frames])


class EncodedVideo(object):
    """
    wrapper around the _pyav_decode_stream fn that keeps the container open while we decode clips.

    Inspired by PyTorchVideo's EncodedVideo, but without the torch dependency.
    """

    def __init__(self, container_or_path: Union[str, av.container.input.InputContainer]):
        if isinstance(container_or_path, (str, Path)):
            self.container = av.open(str(container_or_path))
        else:
            self.container = container_or_path

    def get_clip(self, start_sec, end_sec, decode_audio=True, perform_seek=True):
        # Handle video stream
        video_stream = self.container.streams.video[0]
        info = dict(video_fps=video_stream.average_rate)
        video_arr = _pyav_decode_stream(self.container, start_sec, end_sec, video_stream, perform_seek)

        # Handle audio stream if desired
        audio_arr = None
        if decode_audio:
            audio_stream = self.container.streams.audio[0]
            info['audio_sample_rate'] = audio_stream.rate
            audio_arr = _pyav_decode_stream(self.container, start_sec, end_sec, audio_stream, perform_seek)

        return video_arr, audio_arr, info

    def close(self):
        self.container.close()
        gc.collect()

    def __del__(self):
        self.close()


# translated to np from pytorchvideo
def uniform_temporal_subsample(x: np.ndarray, num_samples: int, temporal_dim: int = -3) -> np.ndarray:
    """
    Uniformly subsamples num_samples indices from the temporal dimension of the video.
    When num_samples is larger than the size of temporal dimension of the video, it
    will sample frames based on nearest neighbor interpolation.
    Args:
        x (np.ndarray): A video tensor with dimension larger than one with numpy
            array type includes int, long, float, complex, etc.
        num_samples (int): The number of equispaced samples to be selected
        temporal_dim (int): dimension of temporal to perform temporal subsample.
    Returns:
        An x-like array with subsampled temporal dimension.
    """
    # Adjust the temporal_dim to work with numpy's axis definition
    if temporal_dim < 0:
        temporal_dim += x.ndim

    t = x.shape[temporal_dim]
    assert num_samples > 0 and t > 0
    # Sample by nearest neighbor interpolation if num_samples > t.
    indices = np.linspace(0, t - 1, num_samples)
    indices = np.clip(indices, 0, t - 1).astype(int)
    # Use advanced integer indexing to select the frames
    return np.take(x, indices, axis=temporal_dim)


# translated to numpy from pytorchvideo
def _interpolate_opencv(
    x: np.ndarray, size: Tuple[int, int], interpolation: str
) -> np.ndarray:
    """
    Down/up samples the input numpy array x to the given size with given interpolation
    mode.
    Args:
        x (np.ndarray): the input array to be down/up sampled.
        size (Tuple[int, int]): expected output spatial size.
        interpolation: mode to perform interpolation, options include `nearest`,
            `linear`, `bilinear`, `bicubic`.
    """
    _opencv_np_interpolation_map = {
        "nearest": cv2.INTER_NEAREST,
        "linear": cv2.INTER_LINEAR,
        "bilinear": cv2.INTER_LINEAR,
        "bicubic": cv2.INTER_CUBIC,
    }
    assert interpolation in _opencv_np_interpolation_map, "Invalid interpolation mode."
    new_h, new_w = size
    # Transpose to shape (H, W, C, T)
    x = x.transpose(2, 3, 0, 1)
    resized_array_list = [
        cv2.resize(
            x[:, :, :, t],
            (new_w, new_h),
            interpolation=_opencv_np_interpolation_map[interpolation],
        )
        for t in range(x.shape[3])
    ]
    # Stack on the last dimension and then transpose back to (C, T, H, W)
    img_array = np.stack(resized_array_list, axis=-1)
    return img_array.transpose(2, 3, 0, 1)


# translated to np from pytorchvideo
def short_side_scale(
    x: np.ndarray,
    size: int,
    interpolation: str = "bilinear"
) -> np.ndarray:
    """
    Determines the shorter spatial dim of the input (i.e. width or height) and scales
    it to the given size. To maintain aspect ratio, the longer side is then scaled
    accordingly.
    Args:
        x (np.ndarray): An array of shape (C, T, H, W).
        size (int): The size the shorter side is scaled to.
        interpolation (str): Algorithm used for upsampling,
            options: 'nearest' | 'linear' | 'bilinear' | 'bicubic'
    Returns:
        A NumPy array with scaled spatial dims.
    """
    assert len(x.shape) == 4, "Input must be a 4D array."
    c, t, h, w = x.shape
    if w < h:
        new_h = int(math.floor((float(h) / w) * size))
        new_w = size
    else:
        new_h = size
        new_w = int(math.floor((float(w) / h) * size))

    return _interpolate_opencv(x, size=(new_h, new_w), interpolation=interpolation)


def transform(
    video_arr,
    num_frames: int,
    short_side_size: int,
):
    # T, H, W, C -> C, T, H, W
    # uint8 [0, 255] -> float32 [0.0, 1.0]
    video_arr = np.transpose(video_arr, (3, 0, 1, 2)).astype(np.float32) / 255.

    # Sample T down to num_frames
    video_arr = uniform_temporal_subsample(video_arr, num_frames)

    # Scale short size of video to short_side_size
    video_arr = short_side_scale(video_arr, short_side_size)

    # Rescale back to [0, 255] and convert to uint8
    video_arr *= 255
    video_arr = np.transpose(video_arr.astype(np.uint8), (1, 2, 3, 0))
    gc.collect()
    return video_arr

# Function to encode a numpy array to base64
def encode_numpy_image(np_img):
    """Takes in a np img array and returns a base64 encoded string"""
    pil_img = Image.fromarray(np_img.astype('uint8'), 'RGB')
    buffer = BytesIO()
    pil_img.save(buffer, format='JPEG')
    img_str = buffer.getvalue()
    return base64.b64encode(img_str).decode('utf-8')


def call_vision_api(frames, prompt: str, max_tokens=200, resize=768, **completion_kwargs):
    """
    frames should be np array with shape (T, H, W, C), dtype np.uint8
    prompt should be a string that will be used as prompt alongside the frames.
    max_tokens should be an int
    """
    client = openai.Client()

    # Encode all frames to base64
    base64_frames = [encode_numpy_image(frame) for frame in frames]

    # Prepare the content for the prompt
    content = [
        prompt,
        *map(lambda x: {"image": x, "resize": resize}, base64_frames)
    ]

    # Prepare the prompt message
    prompt_messages = [
        {
            "role": "user",
            "content": content,
        },
    ]

    # Set up the parameters for the API call
    params = {
        "model": "gpt-4-vision-preview",
        "messages": prompt_messages,
        "max_tokens": max_tokens,
        **completion_kwargs,
    }

    # Make the API call
    result = client.chat.completions.create(**params)

    # Return the generated description
    return result.choices[0].message.content

# Run it!

In [None]:
start_sec = 55  # @param {type:"integer"}
duration = 30   # @param {type:"integer"}
end_sec = start_sec + duration
num_frames_to_sample = 15  # @param {type:"integer"}
short_side_size = 512  # @param {type:"integer"}
max_tokens = 512  # @param {type:"integer"}
prompt = "These are frames from a video. Generate a description/summary that explains what is happening based on what you can see.\nYour response should **STRICTLY and ONLY** be the description, no titles, headers, or anything else." # @param {type: "string",description:"leav blank"}
show_frame_grid = True  # @param {type:"boolean"}

video_arr, audio_arr, info = EncodedVideo(filepath).get_clip(start_sec, end_sec)
print(f"Original video shape: {video_arr.shape}")

video_arr = transform(video_arr, num_frames_to_sample, short_side_size)
print(f"Final video frames shape: {video_arr.shape}")

print('-' * 80)
print()

if show_frame_grid:
    display_frames_as_grid(video_arr, rows=30, cols=5)

out = call_vision_api(video_arr, prompt, max_tokens=max_tokens)
out

# Incorporating audio transcriptions

Let's not forget we have audio we can use too!

In [None]:
from IPython.display import Audio

Audio(audio_arr, rate=info["audio_sample_rate"])

Here we'll try using [WhisperX](https://github.com/m-bain/whisperX) to grab the text transcription of the video clip and see if including that as context in our prompt improves the video summary.

We'll use this model via [Replicate's API](https://replicate.com/zeke/zisper) thanks to [@zeke](https://github.com/zeke). ❤️

If you want to run the model directly, see the underlying [code for the API here](https://github.com/zeke/zisper)! 👀

In [None]:
import os
from google.colab import userdata

###############################################################################
# Token should be set as REPLICATE_API_TOKEN in the Secrets section of Colab
# Click the 🔑 icon on the left sidebar to add it/check if you already added it
###############################################################################

replicate_api_token = userdata.get("REPLICATE_API_TOKEN")
os.environ["REPLICATE_API_TOKEN"] = replicate_api_token

Note - you might run into a [cold boot](https://replicate.com/docs/how-does-replicate-work#cold-boots) if the model hasn't been used in a while. Once it boots up, it'll respond much faster in subsequent runs.

In [None]:
import json
import time
import io
from scipy.io.wavfile import write
import replicate


client = replicate.Client(replicate_api_token)

model = client.models.get("zeke/zisper")
version = model.latest_version

audio_file_to_transcribe = "audio.wav"
write(audio_file_to_transcribe, info["audio_sample_rate"], audio_arr.T)

prediction = client.predictions.create(
    version=version,
    input=dict(
        audio=open(audio_file_to_transcribe, 'rb'),
        debug=False,
        batch_size=2,
        only_text=False,
        align_output=False
    )
)


while prediction.completed_at is None:
    prediction.reload()
    prediction.status
    time.sleep(1)

# Warning, haven't tried on responses with multiple chunks, may be wrong join logic here
transcription = [x['text'] for x in json.loads(prediction.output)]
transcription = "".join(transcription).strip()
transcription

# Hit the Vision API with the images AND transcription from the video

In [None]:
prompt_with_transcriptions = f"""\
These are frames from a video. Generate a description/summary that explains what is happening based on what you can see, as well as the video's audio transcription below.

Transcription: {transcription}

Your response should **STRICTLY and ONLY** be the summary, no titles, headers, or anything else.\
"""

out = call_vision_api(
    video_arr,
    prompt_with_transcriptions,
    max_tokens=max_tokens
)
out

# Conclusion

There's a LOT more you can do with this API. I'll list some ideas below you can go try for yourself :)

I'll try to keep updating this notebook as I find time. Feel free to drop issues/feature requests on GitHub [here](https://github.com/nateraw/openai-vision-api-for-videos/issues).

#### Some ideas:

- Set up a live webcam and have the model answer questions about what you're doing
  - 🔥 implemented [here](https://twitter.com/skalskip92/status/1721694286440468849) by [Piotr Skalski](https://github.com/SkalskiP)
- Embed summaries of clips from videos and do video search (on one or a LOT of videos)
- Use the API to extract good frames from videos with a characteristic style, then use them to fine-tune a image generative model, such as SDXL.
- Use workflow similar to this one to write executive summaries of Ted Talks/Podcasts, or perhaps write review notes for lectures.

# If you found this notebook helpful...

Please consider supporting me by:

- giving [the repo](https://github.com/nateraw/openai-vision-api-for-videos) a ⭐️
- Following me on [GitHub](https://github.com/nateraw) and/or [Twitter](https://x.com/_nateraw). ❤️

