In [61]:
import lancedb
from utils import load_json_file
from model_components.bridgetower_embeddings import (
    BridgeTowerEmbeddings
)

from model_components.multimodal_lancedb import MultimodalLanceDB
# from model_components.client import PredictionGuardClient
from model_components.lvlm import LVLM
from PIL import Image
from langchain_core.runnables import (
    RunnableParallel,
    RunnablePassthrough,
    RunnableLambda
)


In [62]:
# for prepopulated data
# TBL_NAME = "demo_tbl"

# initializeing the vector store that we have previouslly constructed

# declare host file
LANCEDB_HOST_FILE = "./shared_data/.lancedb"

# declare table name 
TBL_NAME = "test_tbl"

In [63]:
# now lets initialize the bridge tower emmbedding model
embedder = BridgeTowerEmbeddings()


In [64]:
# now lets create the retriveval
# Creating a LanceDB vector store
vectorstore = MultimodalLanceDB(
    uri=LANCEDB_HOST_FILE,
    embedding = embedder,
    table_name = TBL_NAME
)

# creating a retriver for the vector store
# with search type="similarity" and search_kwargs={"k":1} 
retriever_module = vectorstore.as_retriever(
    search_type='similarity',
    search_kwargs={"k":1}
)



In [67]:
# Invoke Retrival with USer query

# invoke the retriever for a query
query = "what do the astronauts feel about their work?"
print(retriever_module)
retrieved_video_segments = retriever_module.invoke(query)

# get the first retrieved video segment
retrieved_video_segment = retrieved_video_segments[0]

tags=['MultimodalLanceDB', 'BridgeTowerEmbeddings'] vectorstore=<model_components.multimodal_lancedb.MultimodalLanceDB object at 0x000001C9A58F21F0> search_kwargs={'k': 1}


ValueError: No api_key provided or in environment. Please provide the api_key as client = PredictionGuard(api_key=<your_api_key>) or as PREDICTIONGUARD_API_KEY in your environment.

In [32]:
# get all metadata of the retrieved video segment
retrieved_metadata = retrieved_video_segment.metadata['metadata']

# get the extracted frame
frame_path = retrieved_metadata['extracted_frame_path']

# get the corresponding transcript
transcript = retrieved_metadata['transcript']

# get the path to video when the frame was extracted
timestamp = retrieved_metadata['mid_time_ms']

# display
print(f"Transcript:\n{transcript}\n")
print(f"Path to extracted frame: {frame_path}")
print(f"Path to video: {video_path}")
print(f"Timestamp in ms when the frame was extracted: {timestamp}")
display(Image.open(frame_path))

NameError: name 'retrieved_video_segment' is not defined

If LVLM doesn't strictly require a client to be passed, and you want to handle inference directly without an external API, you can simply omit the client initialization. In this case, modify the LVLM initialization to not expect a client, or pass None if it's optional.

In [20]:
# LVLM Inference Module
# Initialize Client and LVL< for Inference 

# initialize a client as PredictioGuardClient
# client = PredictionGuardClient()
client = None

# initialize LVLM with the given client
lvlm_inference_module = LVLM(client= client)


In [21]:
# Invoke LVLM Inference with User Query

# This new query is the augumentation of the previous query
# with the transcript retrived above 
augumented_query_template = (
    "The transcript associated with the image is '{transcript}'."
    "{previous_query}"
)

augumented_query = augumented_query_template.format(
    transcript=transcript,
    previous_query = query,
)

print(f"Augmented query is:\n {augumented_query}")

NameError: name 'transcript' is not defined

In [None]:
# we use the augmented query and the retrieved path-to-image
# as the input to LVLM inference module

inpute = {'prompt':augumented_query, 'image':frame_path}
response = lvlm_inference_module.invoke(input)

# display the response
print('LVLM Response:')
print(response)

## Prompt processing module


In [None]:
def prompt_processing(input):
    # get the retrieved results and user's query
    retrieved_results = input['retrieved_results']
    user_query = input['user_query']

    # get the first retrieved result by default
    retrieved_result = retrieved_results[0]
    prompt_template = (
        "The transcript associated with the image is '{transcript}'."
        "{user_query}"
    )

    # get all metadata of the retrieved video segment
    retrieved_metadata = retrieved_result.metadata['metadata']

    # get coressponding transcript
    transcript_metadata = retrieved_metadata['transcript']
    # get the extracted frame
    frame_path = retrieved_metadata['extracted_frame_path']

    return {
        'prompt': prompt_template.format(
            transcript=transcript,
            user_query = user_query 

        ), 
        'image': frame_path
    }

# initialize prompt processing module
# as a langchain RunnableLambda of function prompt_processing
prompt_processing_module = RunnableLambda(prompt_processing)

### Invoke Prompt Processing Module with query and the retrieved results above


In [None]:
# We use the user query and the retrived results above
input_to_lvlm = prompt_processing_module.invoke(
    {
        'retrieved_results': retrieved_video_segments,
        'user_query': query
})

# display output of prompt processing module
# which is the input to LVLM Inference module
print(input_to_lvlm)

#### Define Multimodal RAG System as a Chain in LangChain¶
We are going to make use of the followings from Langchain:

The RunnableParallel primitive is essentially a dict whose values are runnables (or things that can be coerced to runnables, like functions). It runs all of its values in parallel, and each value is called with the overall input of the RunnableParallel. The final return value is a dict with the results of each value under its appropriate key.
The RunnablePassthrough on its own allows you to pass inputs unchanged. This typically is used in conjuction with RunnableParallel to pass data through to a new key in the map.
The RunnableLambda converts a python function into a Runnable. Wrapping a function in a RunnableLambda makes the function usable within either a sync or async context.

In [None]:
# combine all the modules into a chain
# to create multimodal RAG system

mm_rag_chain = (
    RunnableParallel({
        "retrieved_results": retriever_module,
        "user_query": RunnablePassthrough()
    })
    | prompt_processing_module
    | lvlm_inference_module
)

In [None]:
# invoke the Multimodal RAG System with a query
query1 = "What do the astronauts feel about their work?"
final_text_response = mm_rag_chain.invoke(query1)

# display
print(f"USER Query: {query1}")
print(f"MM-RAG Response: {final_text_response}")



In [None]:
# lets try another query

query2 = "What is the name of the astronauts?"
final_text_response2 = mm_rag_chain.invoke(query2)
# display

print(f"USER Query: {query2}")
print(f"MM-RAG Response: {final_text_response2}")

In [None]:
# Multimodal RAG system showing image/frame retrived 
mm_rag_chain_with_retrieved_image = (
    RunnableParallel({
        "retreived_results": retriever_module,
        "user_query": RunnablePassthrough()
    })
    | prompt_processing_module
    | RunnableParallel({
        'final_text_output': lvlm_inference_module,
        'input_to_lvlm': RunnablePassthrough()
    })
)

In [None]:
# lets try again with the query2
response3 = mm_rag_chain_with_retrieved_image.invoke(query2)
# display
print("Type of output of mm_Rag_chain_with_retrieved _image is :")
print(type(response3))
print(f"Keys of the dict are {response3.keys()}")

In [None]:
# We now extract final text response and path to extracted frame
final_text_response3 = response3['final_text_output']
path_to_extracted_frame = response3['input_To_lvlm']['image']

# display 
print(f"USER Query: {query2}")
print(f"MM-RAG Response: {final_text_response3}")
print("Retrieved frame:")
display(Image.open(path_to_extracted_frame))


In [None]:
# lets try again with another query

query4 = "an asronaut's spacewalk"
response4 = mm_rag_chain_with_retrieved_image.invoke(query4)


In [None]:
# extract results
final_text_response4 = response4['final_text_output']
path_to_extracted_frame4 = response4['final_text_output']
path_to_extracted_frame4 = response4['input_to_lvlm']['image']

# display
print(f"USER Query: {query4}")
print()
print(f"MM-RAG Response: {final_text_response4}")
print()
print("Retrieved frame:")
display(Image.open(path_to_extracted_frame4))

In [None]:
# We would like an astronaut's spacewalk with the earth view behind
query5 = (
    "Describe the image of an astronaut's spacewalk"
    "with an amazing view of the earth from space behind"
)

response5 = mm_rag_chain_with_retrieved_image.invoke(query5)

# extract results
final_text_response5 = response5['final_text_output']
path_to_extracted_frame5 = response5['input_to_lvlm']['image']

# display
print(f"User Query:  {query5}")
print()
print(f"MM-RAG Response: {final_text_response5}")
print()
print("Retrieved Frame:")
display(Image.open(path_to_extracted_frame5))


In [None]:
# slightly change the query5

query6 = (
    "An astronaut's spacewalk with "
    "an amazing view of the earth from space behind"
)
response6 = mm_rag_chain_with_retrieved_image.invoke(query6)
# extract results
final_text_response6 = response6['final_text_output']
path_to_extracted_frame6 = response6['input_to_lvlm']['image']
# display
print(f"USER Query: {query6}")
print()
print(f"MM-RAG Response: {final_text_response6}")
print()
print("Retrieved Frame:")
display(Image.open(path_to_extracted_frame6))