In [1]:
import boto3
import tabula
import faiss
import os
import json
import base64
import pymupdf
import numpy as np
from tqdm import tqdm
import logging
from botocore.exceptions import ClientError
from langchain.text_splitter import RecursiveCharacterTextSplitter
from IPython import display


logger = logging.getLogger(__name__)
logger.setLevel(logging.ERROR)


### Data Loading

In [2]:
filename = "attention_paper.pdf"
filepath = "data/" + filename


### Data Extraction

In [3]:
from utils.utils import pdf2imgs

doc = pymupdf.open(filepath)
num_pages = len(doc)

# Define the directories to store the extracted text, images and page images from each page
image_save_dir = "data/processed_images"
text_save_dir = "data/processed_text"
table_save_dir = "data/processed_tables"
page_images_save_dir = "data/processed_page_images"

# Chunk the text for effective retrieval
chunk_size = 700
overlap=200

# Process chunks with LangChain's RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=chunk_size,
    chunk_overlap=overlap,
    length_function=len,
)

items = []
# Process all pages of the PDF
for page_num in tqdm(range(num_pages), desc="Processing PDF pages"):
    page = doc[page_num]
    text = page.get_text()

    # Step 1: Extract tables using Tabula
    try:
        tables = tabula.read_pdf(filepath, pages=page_num + 1, multiple_tables=True)
        if tables:
            for table_idx, table in enumerate(tables):
                # Convert the table DataFrame to a string format (Markdown-style for clarity)
                table_text = "\n".join([" | ".join(map(str, row)) for row in table.values])

                # Save the table text as a file
                table_file_name = f"{table_save_dir}/{os.path.basename(filepath)}_table_{page_num}_{table_idx}.txt"
                os.makedirs(table_save_dir, exist_ok=True)
                with open(table_file_name, 'w') as f:
                    f.write(table_text)

                # Add table information to items (format as a string for embeddings)
                table_item = {
                    "page": page_num,
                    "type": "table",
                    "text": table_text,  # Use the formatted table text here
                    "path": table_file_name
                }
                items.append(table_item)

                # Optionally remove table text from the page's text to avoid duplication
                text = text.replace(table_text, "")
    except Exception as e:
        print(f"Error extracting tables from page {page_num}: {str(e)}")
            
    # Process chunks with overlap
    chunks = text_splitter.split_text(text)
    # chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size-overlap)]
    
    # Generate an item to add to items
    for i,chunk in enumerate(chunks):
        text_file_name = f"{text_save_dir}/{filename}_text_{page_num}_{i}.txt"
        # If the text folder doesn't exist, create one
        os.makedirs(text_save_dir, exist_ok=True)
        with open(text_file_name, 'w') as f:
            f.write(chunk)
        
        item={}
        item["page"] = page_num
        item["type"] = "text"
        item["text"] = chunk
        item["path"] = text_file_name
        items.append(item)
    
    
    # Get all the images in the current page
    images = page.get_images()
    for idx, image in enumerate(images):        
        # Extract the image data
        xref = image[0]
        pix = pymupdf.Pixmap(doc, xref)
        pix.tobytes("png")
        # Create the image_name that includes the image path
        image_name = f"{image_save_dir}/{filename}_image_{page_num}_{idx}_{xref}.png"
        # If the image folder doesn't exist, create one
        os.makedirs(image_save_dir, exist_ok=True)
        # Save the image
        pix.save(image_name)
        
        # Produce base64 string
        with open(image_name, 'rb') as f:
            image = base64.b64encode(f.read()).decode('utf8')
        
        item={}
        item["page"] = page_num
        item["type"] = "image"
        item["path"] = image_name
        item["image"] = image
        items.append(item)

# Save pdf pages as images
page_images_save_dir = pdf2imgs(filepath, page_images_save_dir)

for page_num in range(num_pages):
    page_path = os.path.join(page_images_save_dir,  f"page_{page_num:03d}.png")
    
    # Produce base64 string
    with open(image_name, 'rb') as f:
        page_image = base64.b64encode(f.read()).decode('utf8')
    
    item = {}
    item["page"] = page_num
    item["type"] = "page"
    item["path"] = page_path
    item["image"] = page_image
    items.append(item)

Processing PDF pages:  27%|██▋       | 3/11 [00:02<00:05,  1.52it/s]Oct 02, 2024 6:02:07 PM org.apache.pdfbox.pdmodel.font.PDSimpleFont toUnicode
Processing PDF pages:  55%|█████▍    | 6/11 [00:02<00:01,  3.32it/s]Oct 02, 2024 6:02:07 PM org.apache.pdfbox.pdmodel.font.PDSimpleFont toUnicode
Processing PDF pages:  64%|██████▎   | 7/11 [00:03<00:01,  3.52it/s]Oct 02, 2024 6:02:08 PM org.apache.pdfbox.pdmodel.font.PDSimpleFont toUnicode
Processing PDF pages:  73%|███████▎  | 8/11 [00:03<00:00,  4.23it/s]Oct 02, 2024 6:02:08 PM org.apache.pdfbox.pdmodel.font.PDSimpleFont toUnicode
Processing PDF pages: 100%|██████████| 11/11 [00:03<00:00,  3.18it/s]


### Generating function

In [4]:
# Embedding Generation Code
def generate_multimodal_embeddings(prompt=None, image=None, output_embedding_length=384):
    """
    Invoke the Amazon Titan Multimodal Embeddings model using AWS Bedrock runtime.

    Args:
        prompt (str): The text prompt to provide to the model.
        image (str): A base64-encoded image data.
    Returns:
        str: The model's response embedding.
    """
    if not prompt and not image:
        raise ValueError("Please provide either a text prompt, base64 image, or both as input")
    
    # Initialize the Amazon Bedrock runtime client
    client = boto3.client(service_name="bedrock-runtime")
    model_id = "amazon.titan-embed-image-v1"
    
    body = {"embeddingConfig": {"outputEmbeddingLength": output_embedding_length}}
    
    if prompt:
        body["inputText"] = prompt
    if image:
        body["inputImage"] = image

    try:
        response = client.invoke_model(
            modelId=model_id,
            body=json.dumps(body),
            accept="application/json",
            contentType="application/json"
        )

        # Process and return the response
        result = json.loads(response.get("body").read())
        return result.get("embedding")

    except ClientError as err:
        print(f"Couldn't invoke Titan embedding model. Error: {err.response['Error']['Message']}")
        return None




### Generating embeddings

In [5]:
# Set embedding vector dimension
embedding_vector_dimension = 384

# Generate embeddings for all items
for item in tqdm(items, desc="Generating embeddings"):
    if item['type'] == 'text' or item['type'] == 'table':
        # For text or table, use the formatted text representation
        item['embedding'] = generate_multimodal_embeddings(prompt=item['text'], output_embedding_length=embedding_vector_dimension)
    else:
        # For images, use the base64-encoded image data
        item['embedding'] = generate_multimodal_embeddings(image=item['image'], output_embedding_length=embedding_vector_dimension)

Generating embeddings: 100%|██████████| 85/85 [00:23<00:00,  3.61it/s]


In [6]:
all_embeddings = np.array([item['embedding'] for item in items])

In [7]:
all_embeddings.shape


(85, 384)

In [8]:
# Create FAISS Index
index = faiss.IndexFlatL2(embedding_vector_dimension)
index.reset() # Clear any pre-existing index
index.add(np.array(all_embeddings, dtype=np.float32))

In [9]:
def invoke_claude_3_multimodal(prompt, images, image_types):
    """
    Invoke the Claude-3 multimodal model from Anthropic using AWS Bedrock runtime.

    Args:
        prompt (str): The text prompt to provide to the model.
        images (list): A list of base64-encoded image data.
        image_types (list): A list of MIME types corresponding to the images.

    Returns:
        str: The model's response text.

    Raises:
        ValueError: If an invalid model name is provided.
    """
    # Initialize the Amazon Bedrock runtime client
    client = boto3.client(service_name="bedrock-runtime")
    model_id = "anthropic.claude-3-sonnet-20240229-v1:0"

    # Prepare the multimodal prompt message
    message_content = []

    # Add each image to the message content
    for image, img_type in zip(images, image_types):
        message_content.append({
            "type": "image",
            "source": {
                "type": "base64",
                "media_type": img_type,
                "data": image,
            },
        })
    message_content.append({"type": "text", "text": prompt})

    request_body = {
        "anthropic_version": "bedrock-2023-05-31",
        "max_tokens": 2048,
        "temperature": 0.2,
        "top_p": 1.0,
        "top_k": 250,
        "messages": [
            {
                "role": "user",
                "content": message_content,
            }
        ],
    }

    try:
        response = client.invoke_model(
            modelId=model_id,
            body=json.dumps(request_body),
        )

        # Process and return the response
        result = json.loads(response.get("body").read())
        return result['content'][0]['text']

    except ClientError as err:
        logger.error(
            "Couldn't invoke Claude 3 %s model. Here's why: %s: %s",
            model_id.split('.')[-1].capitalize(),
            err.response["Error"]["Code"],
            err.response["Error"]["Message"],
        )
        raise

In [10]:
def generate_rag_response(prompt, matched_items):
    
    # Create context
    text_context = ""
    image_context = []
    
    for item in matched_items:
        if item['type'] == 'text':
            text_context += str(item["page"]) + ". " + item['text'] + "\n"
        else:
            image_context.append(item['image'])
    
    # Only 5 images are supported by Claude3 models
    if len(image_context) > 5:
        image_context = image_context[:5]
    
    final_prompt = f"""You are a helpful assistant for question answering.
    The text context is relevant information retrieved.
    The provided image(s) are relevant information retrieved.
    
    <context>
    {text_context}
    </context>
    
    Answer the following question using the relevant context and images.
    
    <question>
    {prompt}
    </question>
    
    Answer:"""
    
    return invoke_claude_3_multimodal(final_prompt, image_context, ['image/png' for _ in image_context])
    

In [11]:
query = "How is the scaled-dot-product attention is calculated?"

query_embedding = generate_multimodal_embeddings(prompt=query,output_embedding_length=embedding_vector_dimension)
distances, result = index.search(np.array(query_embedding, dtype=np.float32).reshape(1,-1), k=5)

In [12]:
result.flatten()

array([21, 19, 20, 23, 40])

In [13]:
matched_items = [items[index] for index in result.flatten()]

In [14]:
response = generate_rag_response(query, matched_items)

In [15]:
display.Markdown(response)

According to the context, the scaled dot-product attention is calculated as follows:

Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V

Where:

- Q is the matrix of queries 
- K is the matrix of keys
- V is the matrix of values
- d_k is the dimension of the keys

The key steps are:

1. Take the dot product of the query matrix Q with the transpose of the key matrix K^T. This gives a score showing the similarity between each query and key.

2. Scale the scores by dividing by sqrt(d_k), where d_k is the dimension of the keys. This helps prevent the softmax function from being pushed into regions with extremely small gradients when d_k is large.

3. Apply the softmax function to the scaled scores to obtain the attention weights.

4. Multiply the attention weights with the value matrix V to get the weighted sum of the values, which are the attended outputs.

The scaling factor 1/sqrt(d_k) is a key difference from the standard dot-product attention. It helps counteract the effect of large dot products when the key dimension d_k is high, which can make the softmax saturate and have tiny gradients.

In [16]:
query = "What is the BLEU score of the model in English to German translation (EN-DE)?"

query_embedding = generate_multimodal_embeddings(prompt=query,output_embedding_length=embedding_vector_dimension)
distances, result = index.search(np.array(query_embedding, dtype=np.float32).reshape(1,-1), k=5)
result.flatten()
matched_items = [items[index] for index in result.flatten()]
response = generate_rag_response(query, matched_items)
display.Markdown(response)

According to Table 2 in the provided context, the Transformer (big) model achieves a BLEU score of 28.4 on the WMT 2014 English-to-German (EN-DE) translation task.

<div style="background-color:#f0f8ff; padding: 15px; border-radius: 10px; border-left: 6px solid #4682B4;">
  <p>Nice. We have seen a few example questions and answers. Let's try asking more questions. Some example questions are given below:</p>
  <ul style="text-align: left;">
    <li>"How long were the base and big models trained?"</li>
    <li>"Which optimizer was used when training the models?"</li>
    <li>"What is position-wise feed-forward neural network mentioned in the paper?"</li>
    <li>"What is the BLEU score of the model in English to French translation (EN-FR)?"</li>
    <li>"What is the BLEU score of the model in English to German translation (EN-DE)?"</li>
  </ul>
</div>
