In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Context Caching with the Gemini API


## Overview

### Gemini

Gemini is a family of generative AI models developed by Google DeepMind that is designed for multimodal use cases.

### Context Caching

The Gemini API provides the context caching feature for developers to store frequently used input tokens in a dedicated cache and reference them for subsequent requests, eliminating the need to repeatedly pass the same set of tokens to a model. This feature can help reduce the number of tokens sent to the model, thereby lowering the cost of requests that contain repeat content with high input token counts.

### Objectives

In this tutorial, you learn how to use the Gemini API context caching feature in Vertex AI.

You will complete the following tasks:
- Create a context cache
- Retrieve and use a context cache
- Use context caching in Chat
- Update the expire time of a context cache
- Delete a context cache


## Get started

### Install Vertex AI SDK and other required packages


In [1]:
%pip install --upgrade --user --quiet google-cloud-aiplatform
%pip install --upgrade --user --quiet PyPDF2

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m232.6/232.6 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[?25h

### Restart runtime

To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.

The restart might take a minute or longer. After it's restarted, continue to the next step.

In [2]:
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

{'status': 'ok', 'restart': True}

<div class="alert alert-block alert-warning">
<b>⚠️ The kernel is going to restart. Wait until it's finished before continuing to the next step. ⚠️</b>
</div>


### Authenticate your notebook environment (Colab only)

If you're running this notebook on Google Colab, run the cell below to authenticate your environment.

In [1]:
import sys

if "google.colab" in sys.modules:
    from google.colab import auth

    auth.authenticate_user()

### Set Google Cloud project information and initialize Vertex AI SDK

To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).

Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment).

In [2]:
PROJECT_ID = ""  # @param {type:"string"}
LOCATION = "us-central1"  # @param {type:"string"}

import vertexai

vertexai.init(project=PROJECT_ID, location=LOCATION)

## Code Examples

### Import libraries

In [22]:
import IPython.display
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

from vertexai.generative_models import (
    GenerationConfig,
    GenerativeModel,
    HarmBlockThreshold,
    HarmCategory,
    Part,
)
import pandas as pd
from rich.markdown import Markdown as rich_Markdown
from rich import print as rich_print
import pickle
from IPython.display import display, Markdown, HTML
import logging
import nest_asyncio
import warnings
import datetime

import vertexai
from vertexai.generative_models import Part
from vertexai.preview import caching
from vertexai.preview.generative_models import GenerativeModel

### Create a context cache

**Note**: Context caching is only available for stable models with fixed versions (for example, `gemini-1.5-pro-001`). You must include the version postfix (for example, the `-001` in `gemini-1.5-pro-001`).

For more information, see [Available Gemini stable model versions](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versioning#stable-versions-available).


In [4]:
MODEL_ID = "gemini-1.5-pro-001"  # @param {type:"string"}

Context caching is particularly well suited to scenarios where a substantial initial context is referenced repeatedly by shorter requests.

- Cached content can be any of the MIME types supported by Gemini multimodal models. For example, you can cache a large amount of text, audio, or video. **Note**: The minimum size of a context cache is 32,769 tokens.
- The default expiration time of a context cache is 60 minutes. You can specify a different expiration time using the `ttl` (time to live) or the `expire_time` property.

This example shows how to create a context cache using two large research papers stored in a Cloud Storage bucket, and set the `ttl` to 60 minutes.

- Paper 1: [Gemini: A Family of Highly Capable Multimodal Models](https://arxiv.org/abs/2312.11805)
- Paper 2: [Gemini 1.5: Unlocking multimodal understanding across millions of tokens of context](https://arxiv.org/abs/2403.05530)


In [5]:
system_instruction = """
You are an expert researcher who has years of experience in conducting systematic literature surveys and meta-analyses of different topics.
You pride yourself on incredible accuracy and attention to detail. You always stick to the facts in the sources provided, and never make up new facts.
Now look at the research paper below, and answer the following questions in 1-2 sentences.
"""

contents = [
    Part.from_uri(
        "gs://cloud-samples-data/generative-ai/pdf/2312.11805v3.pdf",
        mime_type="application/pdf",
    ),
    Part.from_uri(
        "gs://cloud-samples-data/generative-ai/pdf/2403.05530.pdf",
        mime_type="application/pdf",
    ),
]

cached_content = caching.CachedContent.create(
    model_name=MODEL_ID,
    system_instruction=system_instruction,
    contents=contents,
    ttl=datetime.timedelta(minutes=60),
)

You can access the properties of the cached content as example below. You can use its `name` or `resource_name` to reference the contents of the context cache.

**Note**: The `name` of the context cache is also referred to as cache ID.

In [6]:
print(cached_content.name)
print(cached_content.resource_name)
print(cached_content.model_name)
print(cached_content.create_time)
print(cached_content.expire_time)

8050219518296850432
projects/761584077845/locations/us-central1/cachedContents/8050219518296850432
projects/lavi-llm-experiment/locations/us-central1/publishers/google/models/gemini-1.5-pro-001
2024-08-28 13:06:30.706204+00:00
2024-08-28 14:06:30.594654+00:00


### Retrieve and use a context cache

You can use the property `name` or `resource_name` to reference the contents of the context cache. For example:
```
new_cached_content = caching.CachedContent(cached_content_name=cached_content.name)
```

To use the context cache, you construct a `GenerativeModel` with the context cache.

In [7]:
model = GenerativeModel.from_cached_content(cached_content=cached_content)

Then you can query the model with a prompt, and the cached content will be used as a prefix to the prompt.

In [8]:
response = model.generate_content(
    "What is the research goal shared by these research papers?"
)

print(response.text)

The research goal of these papers is to develop a family of large language models called Gemini, which will be highly capable in understanding and responding to a variety of inputs, including text, images, audio, and video. They aim to build a model that can reason across these different modalities, enabling it to excel in a wide range of tasks. 



### Use context caching in Chat

You can use the context cache in a multi-turn chat session.


In [9]:
chat = model.start_chat()

In [10]:
prompt = """
How do the approaches to responsible AI development and mitigation strategies in Gemini 1.5 evolve from those in Gemini 1.0?
"""

response = chat.send_message(prompt)

print(response.text)

Gemini 1.5 continues the same structured approach to responsible AI development as Gemini 1.0 with additions focused on long-context understanding, such as new image-to-text safety data and consideration of the potential for longer inputs to negatively affect model safety. The mitigation strategies are largely the same, primarily using SFT and RLHF, with the most substantial new update being the incorporation of image-to-text SFT data. 



In [11]:
prompt = """
Given the advancements presented in Gemini 1.5, what are the key future research directions identified in both papers
for further improving multimodal AI models?
"""

response = chat.send_message(prompt)

print(response.text)

Both papers highlight the need for more robust evaluations of LLMs, particularly for long-context understanding, and on tasks demanding high-level reasoning like causal understanding, logical deduction, and counterfactual reasoning. 
Additionally, there is a call for research into mitigating "hallucinations" generated by LLMs, improving translation quality (especially for low-resource languages), and  developing new ways to measure bias and stereotyping beyond simple, binary notions of harm. 



You can use `print(chat.history)` to print out the chat session history.

### Update the expiration time of a context cache


The default expiration time of a context cache is 60 minutes. To update the expiration time, update one of the following properties:

`ttl` - The number of seconds and nanoseconds that the cache lives after it's created or after the `ttl` is updated before it expires. When you set the `ttl`, the cache `expire_time` is updated.

`expire_time` - A Timestamp that specifies the absolute date and time when the context cache expires.

In [12]:
cached_content.update(ttl=datetime.timedelta(hours=1))

cached_content.refresh()

print(cached_content.expire_time)

2024-08-26 14:59:05.250714+00:00


### Delete a context cache

You can remove content from the cache using the delete operation.

In [13]:
cached_content.delete()

INFO:google.cloud.aiplatform.base:Deleting CachedContent : projects/761584077845/locations/us-central1/cachedContents/83246224362176512


## Basic RAG without Context Caching

In [46]:
file_1 = "gs://genai-asset/20230426-alphabet-10q.pdf"
file_2 = "gs://genai-asset/goog-10-k-2023.pdf"
file_3 = "gs://genai-asset/goog-10-q-q2-2023-4.pdf"
file_4 = "gs://genai-asset/goog-10-q-q3-2023.pdf"
file_5 = "gs://genai-asset/goog-10-k-q4-2022.pdf"

In [13]:
import pandas as pd
from google.cloud import storage
from io import BytesIO
import PyPDF2

def extract_pdf_metadata(file_uris):
    """
    Extracts text from PDF files stored in Google Cloud Storage and returns a Pandas DataFrame.

    Args:
        file_uris: A list of GCS URIs pointing to PDF files.

    Returns:
        A Pandas DataFrame containing metadata for each page of each PDF file.
    """

    all_metadata = []
    storage_client = storage.Client()

    for file_uri in file_uris:
        try:
            # Extract file name from URI
            file_name = file_uri.split('/')[-1]

            # Read file from GCS as bytes
            bucket_name = file_uri.split('/')[2]
            blob_name = '/'.join(file_uri.split('/')[3:])
            bucket = storage_client.bucket(bucket_name)
            blob = bucket.blob(blob_name)
            pdf_bytes = blob.download_as_bytes()

            # Extract text from PDF
            pdf_reader = PyPDF2.PdfReader(BytesIO(pdf_bytes))
            for page_number in range(len(pdf_reader.pages)):
                page = pdf_reader.pages[page_number]
                text = page.extract_text()

                # Append metadata to list
                all_metadata.append({
                    'file_name': file_name,
                    'page_number': page_number + 1,
                    'text': text
                })

        except Exception as e:
            print(f"Error processing {file_uri}: {e}")

    return pd.DataFrame(all_metadata)

In [14]:
file_uris = [file_1, file_2, file_3, file_4, file_5]
df = extract_pdf_metadata(file_uris)

In [15]:
df.head()

Unnamed: 0,file_name,page_number,text
0,20230426-alphabet-10q.pdf,1,UNITED STATES\nSECURITIES AND EXCHANGE COMMISS...
1,20230426-alphabet-10q.pdf,2,Alphabet Inc.\nForm 10-Q\nFor the Quarterly Pe...
2,20230426-alphabet-10q.pdf,3,Note About Forward-Looking Statements\nThis Qu...
3,20230426-alphabet-10q.pdf,4,"•the expected timing, amount, and effect of Al..."
4,20230426-alphabet-10q.pdf,5,PART I. FINANCIAL INFORMATION\nITEM 1. FINANCI...


In [16]:
df.shape

(249, 3)

In [27]:
from typing import List, Optional

from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel


def embed_text(
    texts: List[str] = ["banana muffins? ", "banana bread? banana muffins?"],
    task: str = "RETRIEVAL_DOCUMENT",
    model_name: str = "text-embedding-004",
    dimensionality: Optional[int] = 768,
) -> List[List[float]]:
    """Embeds texts with a pre-trained, foundational model."""
    model = TextEmbeddingModel.from_pretrained(model_name)
    inputs = [TextEmbeddingInput(text, task) for text in texts]
    kwargs = dict(output_dimensionality=dimensionality) if dimensionality else {}
    embeddings = model.get_embeddings(inputs, **kwargs)
    return [embedding.values for embedding in embeddings][0]


In [28]:
%%time

df['embeddings'] = df['text'].apply(lambda x: embed_text([x]))

CPU times: user 9.45 s, sys: 1.22 s, total: 10.7 s
Wall time: 3min 48s


In [29]:
df.head()

Unnamed: 0,file_name,page_number,text,embeddings
0,20230426-alphabet-10q.pdf,1,UNITED STATES\nSECURITIES AND EXCHANGE COMMISS...,"[0.04671001061797142, 0.01689017191529274, -0...."
1,20230426-alphabet-10q.pdf,2,Alphabet Inc.\nForm 10-Q\nFor the Quarterly Pe...,"[0.050522807985544205, 0.01621415838599205, -0..."
2,20230426-alphabet-10q.pdf,3,Note About Forward-Looking Statements\nThis Qu...,"[0.03999736160039902, -0.015815766528248787, -..."
3,20230426-alphabet-10q.pdf,4,"•the expected timing, amount, and effect of Al...","[0.05650733411312103, 0.007662168703973293, -0..."
4,20230426-alphabet-10q.pdf,5,PART I. FINANCIAL INFORMATION\nITEM 1. FINANCI...,"[0.042750339955091476, 0.015332790091633797, -..."


In [39]:
from vertexai.language_models import TextEmbeddingInput, TextEmbeddingModel
from typing import List
import numpy as np


def get_gemini_response(model, generation_config=None,
                        safety_settings=None,
                        uri_path=None,mime_type=None, prompt=None):
  if not generation_config:
    generation_config = {
      "max_output_tokens": 8192,
      "temperature": 1,
      "top_p": 0.95,
    }

  if not safety_settings:
    safety_settings = {
        HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
        HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
        HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
        HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
    }

  responses = model.generate_content(prompt,
      generation_config=generation_config,
      safety_settings=safety_settings,
      stream=True,
  )
  final_response = []
  for response in responses:
    try:
      final_response.append(response.text)
    except ValueError:
      # print("Something is blocked...")
      final_response.append("blocked")

  return "".join(final_response)


def get_cosine_score(
    dataframe: pd.DataFrame, column_name: str, input_text_embd: np.ndarray
) -> float:
    """
    Calculates the cosine similarity between the user query embedding and the dataframe embedding for a specific column.

    Args:
        dataframe: The pandas DataFrame containing the data to compare against.
        column_name: The name of the column containing the embeddings to compare with.
        input_text_embd: The NumPy array representing the user query embedding.

    Returns:
        The cosine similarity score (rounded to two decimal places) between the user query embedding and the dataframe embedding.
    """
    if dataframe[column_name]:
      text_cosine_score = round(np.dot(dataframe[column_name], input_text_embd), 2)
      return text_cosine_score
    else:
      return 0

def get_answer(question,vector_db, model, top_n=5):
  query_embedding = embed_text([question])
  #Find score
  cosine_scores = vector_db.apply(
              lambda x: get_cosine_score(x, 'embeddings', query_embedding),
              axis=1,
          )
  # print(len(cosine_scores))
  # Remove same image comparison score when user image is matched exactly with metadata image
  # cosine_scores = cosine_scores[cosine_scores < 1.00000000]
  # Get top N cosine scores and their indices
  top_n_cosine_scores = cosine_scores.nlargest(top_n).index.tolist()
  top_n_cosine_values = cosine_scores.nlargest(top_n).values.tolist()

  citations = vector_db.iloc[top_n_cosine_scores].copy()
  # citations['score'] = top_n_cosine_scores
  citations.loc[:, 'score'] = top_n_cosine_values
  citations = citations[['text','score']]

  # # print(citations)
  # gemini_content = get_gemini_content_list(query, vector_db, top_n_cosine_scores)
  context = "\n".join(citations['text'].tolist())
  prompt = f""" Task: Answer the question based on the provided context.

Question: {question}

Context: {context}

Answer:
"""
  response  =  get_gemini_response(model=model, prompt=prompt)

  return (response, context, citations.to_dict('records'), prompt)

In [40]:
question = """How does Alphabet's organizational structure aim to balance innovation with financial stability,
 particularly in the context of emerging technology investments?"""

In [41]:
%%time

response, context, citations, prompt = get_answer(question,
                                                            df,
                                                            model)

CPU times: user 339 ms, sys: 39.5 ms, total: 379 ms
Wall time: 32 s


In [43]:
rich_Markdown(response)

## Basic RAG with Context Caching

In [56]:
file_1 = "gs://genai-asset/20230426-alphabet-10q.pdf"
file_2 = "gs://genai-asset/goog-10-k-2023.pdf"
file_3 = "gs://genai-asset/goog-10-q-q2-2023-4.pdf"
file_4 = "gs://genai-asset/goog-10-q-q3-2023.pdf"
file_5 = "gs://genai-asset/goog-10-k-q4-2022.pdf"

In [None]:
#The minimum token count to start caching is 32768.

In [None]:
system_instructions = """You are a helpful and informative AI assistant.
You have been provided with a 10-K filing from a publicly traded company. Your task is to analyze the document and answer user questions about the company's business and financial performance.

Here are some guidelines to follow:

* **Focus on accuracy and evidence:** Your answers should be grounded in the information presented within the 10-K document. Cite specific sections or page numbers whenever possible to support your claims.
* **Maintain neutrality and objectivity:** Avoid expressing personal opinions or making subjective judgments about the company or its prospects. Present the facts as they are stated in the filing.
* **Acknowledge limitations:** If a question cannot be answered definitively from the 10-K, explain the limitations of the information and suggest possible alternative sources or data points that could be helpful.
"""

In [59]:
%%time

contents = [
    Part.from_uri(
        file_1,
        mime_type="application/pdf",
    ),
    Part.from_uri(
        file_2,
        mime_type="application/pdf",
    ),
]

cached_content_file1_2 = caching.CachedContent.create(
    model_name=MODEL_ID,
    system_instruction=system_instructions,
    contents=contents,
    ttl=datetime.timedelta(minutes=60),
)

CPU times: user 156 ms, sys: 34.5 ms, total: 190 ms
Wall time: 20 s


In [57]:
%%time

contents = [
    Part.from_uri(
        file_3,
        mime_type="application/pdf",
    ),
    Part.from_uri(
        file_4,
        mime_type="application/pdf",
    ),
    Part.from_uri(
        file_5,
        mime_type="application/pdf",
    )
]

cached_content_file3_4_5 = caching.CachedContent.create(
    model_name=MODEL_ID,
    system_instruction=system_instructions,
    contents=contents,
    ttl=datetime.timedelta(minutes=60),
)

CPU times: user 999 ms, sys: 174 ms, total: 1.17 s
Wall time: 2min 29s


In [58]:
cache_mapping = {
    "20230426-alphabet-10q.pdf": cached_content_file1_2,
    "goog-10-k-2023.pdf": cached_content_file1_2,
    "goog-10-q-q2-2023-4.pdf": cached_content_file3_4_5,
    "goog-10-q-q3-2023.pdf": cached_content_file3_4_5,
    "goog-10-k-q4-2022.pdf": cached_content_file3_4_5
}

In [87]:
def get_answer_with_context_cache(question,vector_db, model, top_n=5):
  query_embedding = embed_text([question])
  #Find score
  cosine_scores = vector_db.apply(
              lambda x: get_cosine_score(x, 'embeddings', query_embedding),
              axis=1,
          )
  # print(len(cosine_scores))
  # Remove same image comparison score when user image is matched exactly with metadata image
  # cosine_scores = cosine_scores[cosine_scores < 1.00000000]
  # Get top N cosine scores and their indices
  top_n_cosine_scores = cosine_scores.nlargest(top_n).index.tolist()
  top_n_cosine_values = cosine_scores.nlargest(top_n).values.tolist()

  citations = vector_db.iloc[top_n_cosine_scores].copy()
  # citations['score'] = top_n_cosine_scores
  citations.loc[:, 'score'] = top_n_cosine_values
  citations = citations[['file_name','text','score']]

  # get the top file_name
  file_name = citations['file_name'].value_counts().index[0]

  # load the model with pre-defined cached
  cached_content = cache_mapping[file_name]
  model = GenerativeModel.from_cached_content(cached_content=cached_content)

  # call the model
  response = model.generate_content(
    question
)

  return citations, response.text

In [88]:
question = """How does Alphabet's organizational structure aim to balance innovation with financial stability,
 particularly in the context of emerging technology investments?"""

In [93]:
%%time

citations, response = get_answer_with_context_cache(question,df,
                      model, top_n=3)

In [90]:
citations

Unnamed: 0,file_name,text,score
50,goog-10-k-2023.pdf,"•the expected timing, amount, and effect of Al...",0.75
54,goog-10-k-2023.pdf,•Collaboration Tools: Google Workspace and Du...,0.74
70,goog-10-k-2023.pdf,•liability for activities of the acquired comp...,0.74


In [92]:
rich_Markdown(response)