# TwelveLabs / Amazon OpenSearch Demonstration

### Prerequisites

1. Establish a free TwelveLabs account and obtain an [API Key](https://playground.twelvelabs.io/dashboard/api-keys).
2. Create an [Amazon OpenSearch Serverless Collection](https://us-east-1.console.aws.amazon.com/aos/home?region=us-east-1#opensearch/collections) and note the OpenSearch endpoint, ignoring the `https://` prefix.

### Sample Video Content

Download free videos from any number of sites, including [Pexels](https://www.pexels.com/videos/). Videos must meet the TwelveLabs [requirements](https://docs.twelvelabs.io/docs/get-started/quickstart/create-embeddings#prerequisites):

- Video resolution: Must be at least 360x360 and must not exceed 3840x2160.
- Aspect ratio: Must be one of 1:1, 4:3, 4:5, 5:4, 16:9, or 9:16.
- I suggest starting with Pexel's small SD size, 640 x 360 pixels, (16:9) format videos for speed and cost.
- I used 25 videos as a minimum to obtain reasonable search results.

### Workflow Diagram

![Architecture](./twelve_labs_bedrock.png)


In [None]:
# Install packages using pip
%pip install pip -Uq
%pip install twelvelabs boto3 opensearch-py -Uq
%pip install matplotlib Pillow scikit-learn plotly nbformat -Uq

#### Restart Kernel

If first time installing, restart your Jupyter Notebook's kernel before continuing.


In [None]:
# Test that the Twelve Labs package is installed
%pip show twelvelabs

In [None]:
# Import the required libraries
import json
import os

from twelvelabs import TwelveLabs
from twelvelabs.models import Video
from twelvelabs.exceptions import NotFoundError

## TwelveLabs API Key and AWS Credentials

Set TwelveLabs API Key and AWS Credentials as environment variables.


In [None]:
# *** Make sure to update these variables before running the code ***
%env AWS_REGION=<Your AWS Region>
%env AWS_ACCESS_KEY_ID=<Your AWS Access Key ID>
%env AWS_SECRET_ACCESS_KEY=<Your AWS Secret Access Key>
%env AWS_SESSION_TOKEN=<Your AWS Session Token>
%env TL_API_KEY=<Your TL API Key>
%env OPENSEARCH_ENDPOINT=<Your OpenSearch Endpoint>

In [None]:
# Set the API key for TwelveLabs from environment variable
TL_API_KEY = os.getenv("TL_API_KEY").replace("'", "").replace('"', "")

# Set the TwelveLabs Index ID
TL_INDEX_NAME = "pexels_sample_index"

## Create TwelveLabs Index


In [None]:
tl_client = TwelveLabs(api_key=TL_API_KEY)

In [None]:
def create_index(index_name: str) -> str:
    """Create a new index for embeddings if it doesn't already exist.

    Args:
        index_name (str): The name of the index to create.

    Returns:
        str: The ID of the created index.
    """
    # Check if the index already exists
    index_list = tl_client.index.list(
        name=index_name,
        sort_option="asc",
        page_limit=1,
    )

    # If the index exists, return its ID
    if index_list:
        for index in index_list:
            print(f"Index '{index.name}' already exists.")
            return index.id

    # If the index does not exist, create a new one
    print(f"Creating index '{index_name}'...")
    models = [
        {"name": "marengo2.7", "options": ["visual", "audio"]},
        {"name": "pegasus1.2", "options": ["visual", "audio"]},
    ]

    created_index = tl_client.index.create(
        name=index_name, models=models, addons=["thumbnail"]
    )

    return created_index.id


tl_index_id = create_index(TL_INDEX_NAME)
print(f"New index ID: {tl_index_id}")

## Upload Videos to Index


In [None]:
def upload_video(tl_index_id: str, video_path: str) -> None:
    """Upload a video to the TwelveLabs index.

    Args:
        tl_index_id (str): The ID of the TwelveLabs index.
        video_path (str): The path to the video file to upload.

    Returns:
        None
    """
    try:
        task = tl_client.task.create(index_id=tl_index_id, file=video_path)
        print(f"Task id={task.id}")
        print(f"Video '{video_path}' uploaded successfully!")
    except Exception as ex:
        print(f"Failed to upload video '{video_path}': {ex}")


video_directory = "videos/pexels"
if not os.path.exists(video_directory):
    print(f"Video directory '{video_directory}' does not exist. Creating it.")
    os.makedirs(video_directory)

for video in os.listdir(video_directory):
    if video.endswith(".mp4"):
        video_path = os.path.join(video_directory, video)
        upload_video(tl_index_id, video_path)

## Retrieve Embeddings and Analyses from Index


### Bulk Retrieve Embeddings from Index


In [None]:
def save_embeddings_to_json(video: Video, output_path: str) -> None:
    """Save the embedding task details to a JSON file if it doesn't already exist.

    Args:
        video (Video): The video object containing embedding details.
        output_path (str): The path where the JSON file will be saved.

    Returns:
        None
    """
    # Serialize the video object to a dictionary
    video_data = video.model_dump_json()
    video_data = json.loads(video_data)
    video_data["video_id"] = video.id

    # Determine the filename using the input filename from the task metadata
    input_filename = video_data["video_id"]
    output_filename = f"{output_path}/{input_filename}_embeddings.json"
    if os.path.exists(output_filename):
        print(f"Embeddings already exist for video ID {video.id}. Skipping...")
        return

    # Write the dictionary to a JSON file
    with open(output_filename, "w") as json_file:
        json.dump(video_data, json_file, indent=4)
    print(f"Embeddings saved to {output_filename}")


def get_videos_from_index(index_id: str, page_limit: int = 25) -> list:
    """Retrieve video IDs from the specified index.

    Args:
        index_id (str): The ID of the index to query.
        page_limit (int): The maximum number of results to return.

    Returns:
        list: A list of video IDs retrieved from the index.
    """
    result = tl_client.search.query(
        index_id=index_id,
        query_text="*",
        options=["visual"],
        page_limit=page_limit,
    )

    print(f"Total count of videos in index {index_id}: {result.pool.total_count}")
    if result.pool.total_count == 0:
        raise NotFoundError(f"No videos found in index {index_id}")
    print(result)
    video_ids = [item.video_id for item in result.data]
    return video_ids


# Retrieve the video IDs from the index
video_ids = get_videos_from_index(tl_index_id)

# Retrieve the video embeddings from the index and save to JSON
for video_id in video_ids:
    video = tl_client.index.video.retrieve(
        index_id=tl_index_id, id=video_id, embedding_option=["visual-text"]
    )

    output_directory = "output/pexels"
    if not os.path.exists(output_directory):
        print(f"Output directory '{output_directory}' does not exist. Creating it.")
        os.makedirs(output_directory)

    print(f"Processing video ID: {video.id}")
    save_embeddings_to_json(video, output_directory)

### Bulk Create Analyses from Videos in Index


In [None]:
def summarize_video(index_id: str, video_id: str, output_path: str) -> None:
    """Summarize a video and save the analysis to a JSON file if it doesn't already exist.

    Args:
        index_id (str): The ID of the index where the video is stored.
        video_id (str): The ID of the video to summarize.
        output_path (str): The path where the JSON file will be saved.

    Returns:
        None
    """
    # Check if the analysis already exists
    filename = f"{output_path}/{video_id}_analysis.json"
    if os.path.exists(filename):
        print(f"Analysis already exists for video ID {video_id}. Skipping...")
        return
    print(f"Analyzing video ID: {video_id}")

    # Get the video summary
    res_summary = tl_client.summarize(
        video_id=video_id,
        prompt="Summarize the video in a concise manner.",
        temperature=0.4,
        type="summary",
    )

    # Get the chapters of the video
    res_chapters = tl_client.summarize(
        video_id=video_id,
        prompt="List the chapters of the video.",
        temperature=0.4,
        type="chapter",
    )

    # Get the highlights of the video
    res_highlights = tl_client.summarize(
        video_id=video_id,
        prompt="List the highlights of the video.",
        temperature=0.4,
        type="highlight",
    )

    # Get open-ended text analysis of the video
    res_analyze = tl_client.analyze(
        video_id=video_id,
        prompt="Describe what is happening in the video.",
        temperature=0.4,
    )

    # Get the gist of the video
    res_gist = tl_client.gist(video_id=video_id, types=["title", "topic", "hashtag"])

    # Combined responses
    analyses = {}

    analyses.update(
        {
            "gist": res_gist.model_dump(),
            "video_id": video_id,
            "index_id": index_id,
            "summary": res_summary.summary,
            "analysis": res_analyze.data,
            "chapters": res_chapters.chapters.model_dump(),
            "highlights": res_highlights.highlights.model_dump(),
        }
    )

    # save to file
    with open(filename, "w") as f:
        f.write(json.dumps(analyses))


# Retrieve the video IDs from the index
video_ids = get_videos_from_index(tl_index_id)

# Retrieve the video embeddings from the index and save to JSON
for video_id in video_ids:
    video = tl_client.index.video.retrieve(
        index_id=tl_index_id, id=video_id, embedding_option=["visual-text"]
    )
    summarize_video(tl_index_id, video.id, output_directory)

### Merge Embeddings and Analyses


In [None]:
def extract_video_ids(output_path: str) -> list:
    """Extract video IDs from analysis filenames in the specified directory.

    Args:
        output_path (str): Directory containing the analysis JSON files

    Returns:
        list: List of extracted video IDs
    """
    video_ids = []

    # Check if the output directory exists
    if not os.path.exists(output_path):
        print(f"Directory {output_path} doesn't exist")
        return video_ids

    for filename in os.listdir(output_path):
        # Check if it's an analysis file
        if filename.endswith("_analysis.json"):
            # Extract the ID part from the filename
            # The ID is everything before "_analysis.json"
            video_id = filename.split("_analysis.json")[0]
            video_ids.append(video_id)

    return video_ids


# Extract video IDs from the analysis files
video_ids = extract_video_ids(output_directory)
print(f"Found {len(video_ids)} video IDs: {video_ids}")

In [None]:
def combine_segments_to_documents(
    output_path: str, document_path: str, video_ids: list
) -> None:
    """Combine embeddings and analyses into single documents and save them to a local directory.

    Args:
        output_path (str): Directory containing the analysis and embeddings JSON files
        document_path (str): Directory to save the combined document files
        video_ids (list): List of video IDs to process

    Returns:
        None
    """
    for video_id in video_ids:
        # Open corresponding analyses and embeddings documents and combined
        with open(f"{output_path}/{video_id}_embeddings.json", "r") as f:
            embeddings = json.load(f)

        with open(f"{output_path}/{video_id}_analysis.json", "r") as f:
            analyses = json.load(f)

        # Combine the two documents
        document = {}
        document.update(analyses)
        document.update(embeddings)

        # Remove unneeded keys
        document["gist"].pop("id", None)
        document["gist"].pop("usage", None)

        # Segments of video
        segments = document["embedding"]["video_embedding"]["segments"]

        # Write documents to local directory for each segment
        filename = f"{document_path}/{document['video_id']}_document.json"
        document.pop("embedding", None)
        document["segments"] = segments
        for segment in document["segments"]:
            segment["segment_embedding"] = segment["embeddings_float"].copy()
            segment.pop("embeddings_float", None)

        with open(filename, "w") as f:
            f.write(json.dumps(document, indent=4))


document_directory = "documents/pexels"
if not os.path.exists(document_directory):
    print(f"Document directory '{document_directory}' does not exist. Creating it.")
    os.makedirs(document_directory)

combine_segments_to_documents(output_directory, document_directory, video_ids)

## Amazon OpenSearch Serverless


### Instantiate OpenSearch Client


In [None]:
import boto3

from opensearchpy import (
    AWSV4SignerAuth,
    NotFoundError,
    OpenSearch,
    RequestsHttpConnection,
)

In [None]:
# Amazon OpenSearch configuration
aws_access_key_id = os.getenv("AWS_ACCESS_KEY_ID")
aws_secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY")
aws_session_token = os.getenv("AWS_SESSION_TOKEN")

aws_region = os.getenv("AWS_REGION", "us-east-1")
aoss_host = os.getenv("OPENSEARCH_ENDPOINT")
aoss_index = os.getenv("OPENSEARCH_INDEX", "video-search-nested")

In [None]:
# Create OpenSearch client
# https://opensearch.org/docs/latest/clients/python-low-level/#connecting-to-amazon-opensearch-serverless

service = "aoss"
credentials = boto3.Session(
    aws_access_key_id=aws_access_key_id,
    aws_secret_access_key=aws_secret_access_key,
    aws_session_token=aws_session_token,
    region_name=aws_region,
).get_credentials()
auth = AWSV4SignerAuth(credentials, aws_region, service)

aoss_client = OpenSearch(
    hosts=[{"host": aoss_host, "port": 443}],
    http_auth=auth,
    use_ssl=True,
    verify_certs=True,
    connection_class=RequestsHttpConnection,
    pool_maxsize=20,
)

aoss_client

### Create New OpenSearch Vector Index


In [None]:
# Create new nested field search index (multiple vector fields)
# https://docs.opensearch.org/docs/latest/vector-search/specialized-operations/nested-search-knn/

try:
    response = aoss_client.indices.delete(index=aoss_index)
except NotFoundError as ex:
    print(f"Index {aoss_index} not found, skipping deletion.")
except Exception as ex:
    print(f"Error deleting index: {ex}")

index_body = {
    "settings": {
        "index": {
            "knn": True,
            "number_of_shards": 2,
        }
    },
    "mappings": {
        "properties": {
            "segments": {
                "type": "nested",
                "properties": {
                    "segment_embedding": {
                        "type": "knn_vector",
                        "dimension": 1024,
                        "method": {
                            "engine": "faiss",
                            "name": "hnsw",
                            "space_type": "l2",
                        },
                    }
                },
            }
        }
    },
}

try:
    response = aoss_client.indices.create(index=aoss_index, body=index_body)
    print(json.dumps(response, indent=4))
except Exception as ex:
    print(ex)

In [None]:
try:
    response = aoss_client.indices.get(index=aoss_index)
    print(json.dumps(response, indent=4))
except NotFoundError as ex:
    print(f"Index not found: {ex}")
except Exception as ex:
    print(ex.error)

### Bulk Index OpenSearch Documents


In [None]:
# Bulk indexing documents from JSON files in the local directory
def load_and_index_documents(document_path: str) -> None:
    """Load documents from JSON files in the specified directory and index them in OpenSearch.

    Args:
        document_path (str): Directory containing the document JSON files

    Returns:
        None
    """
    payload = ""
    put_command = f'{{ "create": {{ "_index": "{aoss_index}" }} }}\n'

    for file in os.listdir(document_path):
        if file.endswith("_document.json"):
            with open(os.path.join(document_path, file), "r") as f:
                tmp = json.load(f)
                payload += f"{put_command}{json.dumps(tmp)}\n"

    try:
        response = aoss_client.bulk(
            index=aoss_index,
            body=payload,
        )
        print(json.dumps(response, indent=4))
        row_count = int(len(payload.splitlines()) / 2)
        return row_count
    except Exception as ex:
        print(f"Error indexing documents: {ex}")
        return 0


row_count = load_and_index_documents(document_directory)
print(f"Total rows to index: {row_count}")

In [None]:
from time import sleep

# Wait for indexing to complete and refresh
response = aoss_client.count(index=aoss_index)
while response["count"] != row_count:
    response = aoss_client.count(index=aoss_index)
    print(f"Current indexed documents: {response['count']}")
    sleep(10)
print(f"Indexing completed. Total indexed documents: {response['count']}")

## Query the Amazon OpenSearch Index


### Convert Query to Embedding


In [None]:
def get_embedding_from_query(query: str) -> list:
    """Convert a text query to an embedding using TwelveLabs.

    Args:
        query (str): The text query to convert.

    Returns:
        list: The embedding vector.
    """
    res = tl_client.embed.create(
        model_name="Marengo-retrieval-2.7",
        text_truncate="start",
        text=query,
    )

    if res.text_embedding is not None and res.text_embedding.segments is not None:
        return res.text_embedding.segments[0].embeddings_float
    else:
        raise ValueError("Failed to retrieve embedding from the response.")


query = "bustling street scene from a low-angle perspective"
query_embedding = get_embedding_from_query(query)
print(f"Embedding: {query_embedding[:5]}...")  # Print first 5 elements for brevity

### Semantic Search


In [None]:
# Reference: https://docs.opensearch.org/docs/latest/vector-search/filter-search-knn/efficient-knn-filtering/#faiss-k-nn-filter-implementation


def semantic_search(aoss_index: str, embedding: list) -> dict:
    """Query the OpenSearch index using a text embedding.

    Args:
        aoss_index (str): The ID of the Amazon OpenSearch index.
        embedding (list): The embedding vector to use for the query.

    Returns:
        dict: The search response from OpenSearch.
    """
    query = {
        "query": {
            "nested": {
                "path": "segments",
                "query": {
                    "knn": {
                        "segments.segment_embedding": {
                            "vector": embedding,
                            "k": 6,
                        }
                    }
                },
            }
        },
        "size": 6,
        "_source": {"excludes": ["segments.segment_embedding"]},
    }

    try:
        search_results = aoss_client.search(body=query, index=aoss_index)
        return search_results
    except Exception as ex:
        print(f"Error querying index: {ex}")
        return {}


# Query the index with the embedding
search_results_1 = semantic_search(aoss_index, query_embedding)

for hit in search_results_1["hits"]["hits"]:
    print(f"Video ID: {hit['_source'].get('video_id', 'N/A')}")
    print(f"Title: {hit['_source'].get('title', 'N/A')}")
    print(f"Score: {hit['_score']}")
    print(f"Duration: {hit['_source']['system_metadata']['duration']:.2f} seconds")
    print("\r")

### Semantic Search with Filters


In [None]:
# Reference: https://docs.opensearch.org/docs/latest/vector-search/filter-search-knn/efficient-knn-filtering/#step-3-search-your-data-with-a-filter


def semantic_search_with_filter(aoss_index: str, embedding: list) -> dict:
    """Query the OpenSearch index using a text embedding.

    Args:
        aoss_index (str): The ID of the Amazon OpenSearch index.
        embedding (list): The embedding vector to use for the query.

    Returns:
        dict: The search response from OpenSearch.
    """
    query = {
        "query": {
            "nested": {
                "path": "segments",
                "query": {
                    "knn": {
                        "segments.segment_embedding": {
                            "vector": embedding,
                            "k": 5,
                            "filter": {
                                "bool": {
                                    "must": [
                                        {
                                            "range": {
                                                "system_metadata.duration": {
                                                    "gte": 20,
                                                    "lte": 60,
                                                }
                                            }
                                        },
                                    ]
                                }
                            },
                        }
                    }
                },
            }
        },
        "size": 5,
        "_source": {"excludes": ["segments.segment_embedding"]},
    }

    try:
        search_results = aoss_client.search(body=query, index=aoss_index)
        return search_results
    except Exception as ex:
        print(f"Error querying index: {ex}")
        return {}


# Query the index with the embedding
search_results_2 = semantic_search_with_filter(aoss_index, query_embedding)

for hit in search_results_2["hits"]["hits"]:
    print(f"Video ID: {hit['_source'].get('video_id', 'N/A')}")
    print(f"Title: {hit['_source'].get('title', 'N/A')}")
    print(f"Score: {hit['_score']}")
    print(f"Duration: {hit['_source']['system_metadata']['duration']:.2f} seconds")
    print("\r")

### Radial Search


In [None]:
# Reference: https://docs.opensearch.org/docs/latest/vector-search/specialized-operations/radial-search-knn/


def radial_search(aoss_index: str, embedding: list) -> dict:
    """Query the OpenSearch index using a text embedding.

    Args:
        aoss_index (str): The ID of the Amazon OpenSearch index.
        embedding (list): The embedding vector to use for the query.

    Returns:
        dict: The search response from OpenSearch.
    """
    query = {
        "query": {
            "nested": {
                "path": "segments",
                "query": {
                    "knn": {
                        "segments.segment_embedding": {
                            "vector": embedding,
                            "max_distance": 2,
                        }
                    }
                },
            }
        },
        "size": 6,
        "_source": {"excludes": ["segments.segment_embedding"]},
    }

    try:
        search_results = aoss_client.search(body=query, index=aoss_index)
        return search_results
    except Exception as ex:
        print(f"Error querying index: {ex}")
        return {}


# Query the index with the embedding
search_results_3 = semantic_search(aoss_index, query_embedding)

for hit in search_results_3["hits"]["hits"]:
    print(f"Video ID: {hit['_source'].get('video_id', 'N/A')}")
    print(f"Title: {hit['_source'].get('title', 'N/A')}")
    print(f"Score: {hit['_score']}")
    print(f"Duration: {hit['_source']['system_metadata']['duration']:.2f} seconds")
    print("\r")

### Displaying Previews of Search Results


In [None]:
from matplotlib import pyplot as plt
from PIL import Image
from urllib import request
import io


def load_image_from_url(url):
    """Load an image from a URL.

    Args:
        url (str): The URL of the image to load.

    Returns:
        PIL.Image.Image: The loaded image.
    """
    try:
        with request.urlopen(url) as response:
            image_data = response.read()
            image = Image.open(io.BytesIO(image_data))
            return image
    except Exception as e:
        print(f"Error loading video thumbnail from URL: {e}")
        return None


index = 1
rows = 3
columns = 3

fig = plt.figure(figsize=(10, 7))

for hit in search_results_1["hits"]["hits"]:
    fig.set_dpi(300)
    fig.add_subplot(rows, columns, index)
    image_url = hit["_source"]["hls"]["thumbnail_urls"][0]
    image = load_image_from_url(image_url)
    plt.axis("off")
    plt.imshow(image)
    plt.title(
        f'Video: {hit["_source"]["system_metadata"]["filename"]}\nScore: {hit["_score"]}',
        fontdict=dict(family="Arial", size=8),
        color="black",
    )
    index += 1

### 2D/3D Visualizations Using PCA


In [None]:
def semantic_search_pca(aoss_index: str, embedding: list) -> dict:
    """Query the OpenSearch index using a text embedding.

    Args:
        aoss_index (str): The ID of the Amazon OpenSearch index.
        embedding (list): The embedding vector to use for the query.

    Returns:
        dict: The search response from OpenSearch.
    """
    query = {
        "query": {
            "nested": {
                "path": "segments",
                "query": {
                    "knn": {
                        "segments.segment_embedding": {
                            "vector": embedding,
                            "k": 9,
                        },
                    },
                },
            }
        },
        "size": 9,
    }

    try:
        search_results = aoss_client.search(body=query, index=aoss_index)
        return search_results
    except Exception as ex:
        print(f"Error querying index: {ex}")
        return {}


# Query the index with the embedding
search_results_4 = semantic_search_pca(aoss_index, query_embedding)

embeddings = []
video_names = []

for hit in search_results_4["hits"]["hits"]:
    embeddings.append(hit["_source"]["segments"][0]["segment_embedding"])
    video_names.append(hit["_source"]["system_metadata"]["filename"])
embeddings.append(query_embedding)
video_names.append("User query")

#### 2D Visualization Using PCA


In [None]:
from sklearn.decomposition import PCA

# Reduce dimensions from 1,024 to 2 using PCA for visualization
pca = PCA(n_components=2)
vis_dims_2d = pca.fit_transform(embeddings)
print(f"Reduced dimensions shape (2d): {vis_dims_2d.shape}")

In [None]:
import plotly.graph_objs as go
import numpy as np

fig = go.Figure()

# Search results
for i, video_name in enumerate(video_names[0:-1]):
    x = np.array([vis_dims_2d[i, 0]])
    y = np.array([vis_dims_2d[i, 1]])

    fig.add_trace(
        go.Scatter(
            x=x,
            y=y,
            mode="markers",
            marker=dict(
                size=15,
                colorscale="Viridis",
                opacity=1.0,
                symbol="circle",
            ),
            name=video_names[i],
            # text=video_names[i],
            # textposition="bottom left",
        )
    )

# User query
x = np.array([vis_dims_2d[-1, 0]])
y = np.array([vis_dims_2d[-1, 1]])

fig.add_trace(
    go.Scatter(
        x=x,
        y=y,
        mode="text+markers",
        marker=dict(
            size=15,
            color="black",
            colorscale="Viridis",
            opacity=1.0,
            symbol="square",
        ),
        name=video_names[-1],
        text=video_names[-1],
        textposition="bottom left",
        showlegend=False,
    )
)

fig.update_layout(
    autosize=True,
    font=dict(size=12, color="black", family="Arial, sans-serif"),
    title="2D Scatter Plot of Search Results using PCA",
    margin=dict(l=30, r=30, b=30, t=60, pad=10),
    xaxis=dict(title="x"),
    yaxis=dict(title="y"),
    legend=dict(title="   Search Results"),
)
fig.show()

#### 3D Visualization Using PCA


In [None]:
# Reduce dimensions from 1,024 to 3 using PCA for visualization
pca = PCA(n_components=3)
vis_dims_3d = pca.fit_transform(embeddings)

print(f"Reduced dimensions shape (3d): {vis_dims_3d.shape}")

In [None]:
fig = go.Figure()

# Results
for i, video_name in enumerate(video_names[0:-1]):
    x = np.array([vis_dims_3d[i, 0]])
    y = np.array([vis_dims_3d[i, 1]])
    z = np.array([vis_dims_3d[i, 2]])

    fig.add_trace(
        go.Scatter3d(
            x=x,
            y=y,
            z=z,
            mode="markers",
            marker=dict(size=7, colorscale="Viridis", opacity=1.0, symbol="circle"),
            name=video_name,
            text=video_name,
            textposition="top center",
        )
    )

# User query
x = np.array([vis_dims_3d[-1, 0]])
y = np.array([vis_dims_3d[-1, 1]])
z = np.array([vis_dims_3d[-1, 2]])

fig.add_trace(
    go.Scatter3d(
        x=x,
        y=y,
        z=z,
        mode="markers",
        marker=dict(
            size=7, color="black", colorscale="Viridis", opacity=1.0, symbol="square"
        ),
        name="video_names[-1]",
        text=video_names[-1],
        textposition="bottom left",
        showlegend=False,
    )
)

fig.update_layout(
    autosize=True,
    font=dict(size=12, color="black", family="Arial, sans-serif"),
    title="3D Scatter Plot of Search Results using PCA",
    margin=dict(l=30, r=30, b=20, t=50, pad=10),
    scene=dict(
        xaxis=dict(title="z"),
        yaxis=dict(title="x"),
        zaxis=dict(title="y"),
    ),
    legend=dict(
        title="   Search Results",
    ),
)
fig.show()

#### Animate the 3D Visualization


In [None]:
fig = go.Figure()

# Results
for i, video_name in enumerate(video_names[0:-1]):
    x = np.array([vis_dims_3d[i, 0]])
    y = np.array([vis_dims_3d[i, 1]])
    z = np.array([vis_dims_3d[i, 2]])

    fig.add_trace(
        go.Scatter3d(
            x=x,
            y=y,
            z=z,
            mode="markers",
            marker=dict(size=7, colorscale="Viridis", opacity=1.0, symbol="circle"),
            name=video_name,
            text=video_name,
            textposition="top center",
        )
    )

# User query
x = np.array([vis_dims_3d[-1, 0]])
y = np.array([vis_dims_3d[-1, 1]])
z = np.array([vis_dims_3d[-1, 2]])

fig.add_trace(
    go.Scatter3d(
        x=x,
        y=y,
        z=z,
        mode="markers",
        marker=dict(
            size=7, color="black", colorscale="Viridis", opacity=1.0, symbol="square"
        ),
        name=video_names[-1],
        text=video_names[-1],
        textposition="top center",
        showlegend=False,
    )
)

x_eye = -1.25
y_eye = 1.5
z_eye = 0.5

fig.update_layout(
    autosize=True,
    font=dict(size=12, color="black", family="Arial, sans-serif"),
    title="3D Scatter Plot of Search Results using PCA",
    margin=dict(l=30, r=30, b=30, t=40, pad=10),
    scene=dict(
        xaxis=dict(title="z"),
        yaxis=dict(title="x"),
        zaxis=dict(title="y"),
    ),
    scene_camera_eye=dict(x=x_eye, y=y_eye, z=z_eye),
    updatemenus=[
        dict(
            type="buttons",
            showactive=True,
            y=0.9,
            x=0.9,
            xanchor="left",
            yanchor="bottom",
            pad=dict(t=45, r=10),
            buttons=[
                dict(
                    label="Play",
                    method="animate",
                    args=[
                        None,
                        dict(
                            frame=dict(duration=15, redraw=True),
                            transition=dict(duration=1),
                            fromcurrent=True,
                            mode="immediate",
                        ),
                    ],
                )
            ],
        )
    ],
    legend=dict(
        title="   Search Results",
    ),
)


def rotate_z(x, y, z, theta):
    w = x + 1j * y
    return np.real(np.exp(1j * theta) * w), np.imag(np.exp(1j * theta) * w), z


frames = []
for t in np.arange(0, 10, 0.01):
    xe, ye, ze = rotate_z(x_eye, y_eye, z_eye, -t)
    frames.append(go.Frame(layout=dict(scene_camera_eye=dict(x=xe, y=ye, z=ze))))
fig.frames = frames

fig.show()