In [None]:
!pip install llama-index qdrant_client pyMuPDF tools frontend git+https://github.com/openai/CLIP.git easyocr

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Patch
import io
from PIL import Image, ImageDraw
import numpy as np
import csv
import pandas as pd

from torchvision import transforms

from transformers import AutoModelForObjectDetection
import torch
import openai
import os
import fitz

device = "cuda" if torch.cuda.is_available() else "cpu"

OPENAI_API_TOKEN = "sk-"
openai.api_key = OPENAI_API_TOKEN

In [None]:
# download a PDF with images

!wget --user-agent "Mozilla" "https://archive.org/download/NSAsecurityPosters1950s60s/NSAsecurityPosters_1950s-60s.pdf" -O "nsa_posters.pdf"

In [None]:
# Convert each page of the PDF to an image for multimodal indexing.

pdf_file = "nsa_posters.pdf"

# Split the base name and extension
output_directory_path, _ = os.path.splitext(pdf_file)

if not os.path.exists(output_directory_path):
    os.makedirs(output_directory_path)

# Open the PDF file
pdf_document = fitz.open(pdf_file)

# Iterate through each page and convert to an image
for page_number in range(pdf_document.page_count):
    # Get the page
    page = pdf_document[page_number]

    # Convert the page to an image
    pix = page.get_pixmap()

    # Create a Pillow Image object from the pixmap
    image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)

    # Save the image
    image.save(f"./{output_directory_path}/page_{page_number + 1}.png")
   
    if page_number >= 31:
        break

# Close the PDF file
pdf_document.close()

In [None]:
# Summarize images with gpt4-v
from llama_index.multi_modal_llms.openai import OpenAIMultiModal
from llama_index import SimpleDirectoryReader

# Create a dictionary of image metadata including filename, path, and text summary
image_metadata_dict = {}

openai_mm_llm = OpenAIMultiModal(
    model="gpt-4-vision-preview", api_key=OPENAI_API_TOKEN, max_new_tokens=1500
)

image_prompt = """
    Please summarize what you see in this picture. Ignore the header and footer. Just describe the actual picture.
"""

image_documents = SimpleDirectoryReader(output_directory_path).load_data()
print(len(image_documents))

img_count = 0
# Loop through each image file and summarize.
for image_file in image_documents:
  
  image_path = image_file.metadata['file_path']
  image_filename = image_file.metadata['file_name']
  print("summarizing ", image_filename) 
  # Summarize the image
  image_summary = openai_mm_llm.complete(
     prompt=image_prompt,
     image_documents=[image_file],
  )

  # Create a dictionary to store the image metadata.
  image_metadata_dict[image_filename] = {
    "filename": image_filename,
    "img_path": image_path,
    "img_summary": image_summary,
  }
    
  #print(image_filename, ":", image_summary)     
  img_count += 1
  if img_count >= 32:
    break

print(len(image_metadata_dict))
#print(image_metadata_dict[image_filename]["img_summary"])

In [None]:
# Plot an image and print its summary
from PIL import Image
import matplotlib.pyplot as plt
import os

# plot one image and its corresponding summary
img_path = image_metadata_dict[image_filename]["img_path"]
image = Image.open(img_path).convert("RGB")

plt.figure(figsize=(16, 9))
plt.imshow(image)

print(image_metadata_dict[image_filename]["img_summary"])

In [None]:
# Plot all the images as thumbnails
from PIL import Image
import matplotlib.pyplot as plt
import os

# plot multiple images in a grid
def plot_images(image_metadata_dict):
    images = []
    images_shown = 0
    for image_filename in image_metadata_dict:
        img_path = image_metadata_dict[image_filename]["img_path"]
        if os.path.isfile(img_path):
            # open the image file and convert it to RGB colorspace.
            filename = image_metadata_dict[image_filename]["filename"]
            image = Image.open(img_path).convert("RGB")

            # plot the image in a subplot of an 4x8 grid, also disables the tick labels on the axes to make the plot cleaner.
            plt.subplot(4, 8, len(images) + 1)
            plt.imshow(image)
            plt.xticks([])
            plt.yticks([])

            images.append(filename)
            images_shown += 1
            if images_shown >= 64:
                break

    plt.tight_layout()

plot_images(image_metadata_dict)

In [None]:
# save metadata as text Documents for embedding
from llama_index.schema import Document

text_docs = [
    Document(
        #text=str(image_metadata_dict[image_file.metadata['file_name']["img_summary"]]),
        #text=str(image_metadata_dict[image_file.metadata['file_name']]["img_summary"]),
        metadata={"image_path": image_file.metadata['file_path'], "image_name": image_file.metadata['file_name']},
    )
    for image_file in image_documents
]

for file_path in text_docs:
    file_name = file_path.metadata['image_name']
    file_path.text=str(image_metadata_dict[file_name]["img_summary"])
    #print(file_path.text)
   
print(text_docs)

In [None]:
# index the text embeddings
import qdrant_client
from llama_index import (
    ServiceContext,
    SimpleDirectoryReader,
)
from llama_index.vector_stores.qdrant import QdrantVectorStore
from llama_index import VectorStoreIndex, StorageContext
from llama_index.llms import OpenAI

# create QdrantClient
text_client = qdrant_client.QdrantClient(url="http://localhost:6333")

# load text documents from the data directory
#text_documents = SimpleDirectoryReader("./data").load_data()

# define llm and embedding model in ServiceContext
service_context = ServiceContext.from_defaults(
    # llm=llm_zephyr,
    llm=OpenAI(model="gpt-3.5-turbo", temperature=0.1),
    embed_model="local:BAAI/bge-base-en-v1.5"
)

# create QdrantVectorStore using QdrantClient and the collection name
text_vector_store = QdrantVectorStore(
    client=text_client, collection_name="nsa_posters_text"
)

# create StorageContext object using the QdrantVectorStore
storage_context = StorageContext.from_defaults(vector_store=text_vector_store)

# create VectorStoreIndex using the text documents and StorageContext
nsa_posters_text_index = VectorStoreIndex.from_documents(
    text_docs,
    storage_context=storage_context,
    service_context=service_context
)

# define text query engine
text_query_engine = nsa_posters_text_index.as_query_engine()

In [None]:
# Query the text embeddings

MAX_TOKENS = 50
retriever_engine = nsa_posters_text_index.as_retriever(
    similarity_top_k=3,
)
# retrieve more information from the GPT4V response
retrieval_results = retriever_engine.retrieve("Are any of these posters holiday themed? If yes, please explain.")

from llama_index.response.notebook_utils import display_source_node

retrieved_image = []
for res_node in retrieval_results:
    display_source_node(res_node, source_length=1000)

In [None]:
# Load and initialize the CLIP image embedding model
import torch
import clip
import numpy as np

# load the CLIP model with the name ViT-B/32
model, preprocess = clip.load("ViT-B/32")

# the resolution of the input images expected by the model
input_resolution = model.visual.input_resolution

# the maximum length of the input text
context_length = model.context_length

# the size of the vocabulary used by the model
vocab_size = model.vocab_size

# print the information about the model to the console
print(
    "Model parameters:",
    f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}",
)
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

In [None]:
# Generate CLIP embeddings for the Posters

# use either CUDA (GPU) or CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

# img_emb_dict stores image embeddings for each image
img_emb_dict = {}
with torch.no_grad():

    # iterate over the image metadata dictionary and extracts image embeddings for each image
    for image_filename in image_metadata_dict:
        img_file_path = image_metadata_dict[image_filename]["img_path"]
        if os.path.isfile(img_file_path):
            image = (
                # preprocess the image using the CLIP model's preprocess function
                # unsqueeze the image tensor to add a batch dimension
                # move the image tensor to the device specified in line 1
                preprocess(Image.open(img_file_path)).unsqueeze(0).to(device)
            )

            # extract image features using the CLIP model's encode_image function
            image_features = model.encode_image(image)

            # store the image features in the image embedding dictionary
            img_emb_dict[image_filename] = image_features

print(img_emb_dict)

In [None]:
# Index the image embeddings

from llama_index.schema import ImageDocument

# create a list of ImageDocument objects, one for each image in the dataset
img_documents = []
for image_filename in image_metadata_dict:
    # the img_emb_dict dictionary contains the image embeddings
    if image_filename in img_emb_dict:
        filename = image_metadata_dict[image_filename]["filename"]
        filepath = image_metadata_dict[image_filename]["img_path"]
        print(filepath)

        # create an ImageDocument for each image
        newImgDoc = ImageDocument(
            text=filename, metadata={"filepath": filepath}
        )

        # set image embedding on the ImageDocument
        newImgDoc.embedding = img_emb_dict[image_filename].tolist()[0]
        img_documents.append(newImgDoc)

# create QdrantVectorStore, with collection name "birds_img"
image_vector_store = QdrantVectorStore(
    client=text_client, collection_name="nsa_posters_images"
)

# define storage context
storage_context = StorageContext.from_defaults(vector_store=image_vector_store)

# define image index
image_index = VectorStoreIndex.from_documents(
    img_documents,
    storage_context=storage_context
)

In [None]:
# Define image query functions

from llama_index.vector_stores import VectorStoreQuery

# take a text query and return the most similar image
def retrieve_results_from_image_index(query):

    # tokenize the text query and convert it to a tensor
    text = clip.tokenize(query).to(device)

    # embed the tensor using the CLIP model
    query_embedding = model.encode_text(text).tolist()[0]

    # Create an index query
    image_vector_store_query = VectorStoreQuery(
        query_embedding=query_embedding,
        similarity_top_k=1,
        mode="default",
    )

    # Query the image vector store
    image_retrieval_results = image_vector_store.query(
        image_vector_store_query
    )
    return image_retrieval_results

# Plot the list of retrievals
def plot_image_retrieve_results(image_retrieval_results):
    plt.figure(figsize=(16, 5))

    img_cnt = 0
    # for each retrieval, plot the image and its score.
    for returned_image, score in zip(
        image_retrieval_results.nodes, image_retrieval_results.similarities
    ):
        img_name = returned_image.text
        img_path = returned_image.metadata["filepath"]
        image = Image.open(img_path).convert("RGB")

        plt.subplot(2, 3, img_cnt + 1)
        plt.title("{:.4f}".format(score))

        plt.imshow(image)
        plt.xticks([])
        plt.yticks([])
        img_cnt += 1

# define image_query function
def image_query(query):
    image_retrieval_results = retrieve_results_from_image_index(query)
    plot_image_retrieve_results(image_retrieval_results)

In [None]:
query = "Are there any holiday themed posters? Explain your reasoning. What about this poster makes it a holiday poster?"
# generate Image retrieval results
image_query(query)

# generate Text retrieval results
text_retrieval_results = text_query_engine.query(query)
print(str(text_retrieval_results))

In [None]:
query = "Are there any sports themed posters? Please explain what about the poster makes is sports themed. \
And what do you think the poster is trying to convey?"
# generate Image retrieval results
image_query(query)

# generate Text retrieval results
text_retrieval_results = text_query_engine.query(query)
print(str(text_retrieval_results))

In [None]:
query = "The 1970s were wild times. Everyone loved disco dancing and psychadelic drugs. \
Do any of these posters remind you of the 1970s?"
# generate Image retrieval results
image_query(query)

# generate Text retrieval results
text_retrieval_results = text_query_engine.query(query)
print(str(text_retrieval_results))