**Earth Engine Dataset Retrieval and Visualizer Agent**

Author: Renee Johnston (reneejohnston@google.com), Eliot Cowan (eliotc@google.com)

Inspired by: Simon Ilyushchenko's (simonf@google.com) [EE Genie](https://github.com/google/earthengine-community/blob/master/experimental/ee_genie.ipynb)

**Annoyance**

Due to problems with Javascript/Python interaction, the agent has to stop running after it moves or pans geemap. When this happens, the agent icon will change to 🙏. Just hit enter in the chat box when you see this to continue analysis.

**Installation**

To use it, you need two things:
1. Earth Engine access
2. Generative AI API key

You need a Google Cloud Project to associate your requests with. [Use these instructions](https://developers.google.com/earth-engine/cloud/earthengine_cloud_project_setup) and set project_id to your Google Cloud Project ID in the cell below when prompted.

Next you need to get a Generative AI API key [here](https://aistudio.google.com/app/prompts/new_chat). Be aware that you might need to pay
for use of the Generative AI API.

To save this key in the notebook, click on the key icon in Colab on the left-hand side and add your key as a secret with the name GOOGLE_API_KEY. Make sure the value has no newlines.

Finally, run the first cell. You only need to do it once. Earth Engine client will ask you authenticate.

To use the dataset search agent, run every cell, then scroll to the app at the bottom. If the app is stuck waiting for an LLM, you can reset the app: click on the `Runtime / Interrupt Execution menu item`, then hit run all.

# Imports

In [None]:
#@title Install Python Libraries

%%capture
!pip install google_cloud_aiplatform langchain-community langchain_google_genai chromadb langchain

In [None]:
import contextlib
import dateutil
import io
import json
import math
import os
import re
import requests
import shutil
import sys
import threading
import traceback
import time

import ee
import geemap
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
from langchain.embeddings.base import Embeddings
from langchain_google_genai import ChatGoogleGenerativeAI
import numpy as np
import pandas as pd
import PIL
import vertexai
from vertexai.preview.language_models import TextEmbeddingModel

import google.ai.generativelanguage as glm
import google.api_core
from google.cloud import storage
from google.colab import userdata
import google.generativeai as genai

# Initialization

In [None]:
# @title Please update the following fields with your cloud project.
project_name = "your_gcp_project" #@param {type:"string"}
vertex_ai_zone = "us-central1" #@param {type:"string"}
#@markdown If you created a new .jsonl file of the Earth Engine catalog using the accompanying notebook, set the local or gcs path to that file here.
ee_catalog_jsonl_path = "gs://science-ai-ee-catalog-index/catalog_summaries.jsonl" #@param {type:"string"}
genai.configure(api_key=userdata.get('GOOGLE_API_KEY'))


ee.Authenticate()
ee.Initialize(project=project_name)
storage_client = storage.Client(project=project_name)
bucket = storage_client.get_bucket('earthengine-stac')
vertexai.init(project=project_name, location=vertex_ai_zone)

# download pkl path
if ee_catalog_jsonl_path.startswith('gs://'):
  ee_client_bucket = google.cloud.storage.bucket.Bucket(storage_client, name=ee_catalog_jsonl_path.split('/')[2],user_project=project_name)
  blob = ee_client_bucket.blob('/'.join(ee_catalog_jsonl_path.split('/')[3:]))
  blob.download_to_filename('catalog_summaries.jsonl')
else:
  shutil.copyfile(ee_catalog_jsonl_path, 'catalog_summaries.jsonl')
# Score to aim for (on the 0-1 scale). The exact meaning of what "score" means
# is left to the LLM.
target_score = 0.8

# Count of analysis rounds
iteration = 1

Map = geemap.Map()
Map.add("layer_manager")

analysis_model = None
map_dirty = False

image_model = genai.GenerativeModel('gemini-1.5-pro-exp-0801')


# UI widget definitions

In [None]:
command_input = widgets.Text(
    value = 'what datasets have information about active fires in California during the 2020 fire season?',
    description='❓',
    layout=widgets.Layout(width='100%', height='50px')
)

command_output = widgets.Label(
    value='Last command will be here',
)

status_label = widgets.Textarea(
    value='LLM response will be here',
    layout=widgets.Layout(width='50%', height='100px')
)

widget_height = "400px"
debug_output = widgets.Output(layout={
    'border': '1px solid black',
    'height': widget_height,
    'overflow': 'scroll',
    'width': '500px',
    'padding': '5px'
})
with debug_output:
  print('DEBUG COLUMN\n')

chat_output = widgets.Output(layout={
    'border': '1px solid black',
    'height': '600px',
    'overflow': 'scroll',
    'width': '300px'})

with chat_output:
  print('CHAT COLUMN\n')


dataset_widget = widgets.Output(layout={'overflow': 'scroll',
                                        'width': '600px'})

# Function to update the DataFrame display
def update_dataset_widget(df):
    with dataset_widget:
        # Clear previous output
        dataset_widget.clear_output()
        # Display the DataFrame
        html = df.to_html(escape=False, index=False, table_id='dataframe-table')  # add an ID for styling
        styled_html = f"""
<style>
    #dataframe-table {{
        width: 100%;
        table-layout: auto;
        max-width: 100%;
        overflow-x: auto;
    }}
    #dataframe-table th:nth-child(2), #dataframe-table td:nth-child(2) {{
        max-width: 50px; /* Adjust as needed */
        white-space: normal; /* Allows text wrapping */
        word-wrap: break-word; /* Breaks long words */
    }}
    #dataframe-table a {{ /* Adds styling for hrefs */
            color: blue;
            text-decoration: underline;
        }}
</style>
{html}
"""
        display(HTML(styled_html))

update_dataset_widget(pd.DataFrame([['Welcome','to','Science'],['AI\'s','Earth','Engine'],['dataset','search','agent']],columns=['Title', 'id', 'description']))

In [None]:
# Code Viewer Widget
import pygments
from pygments.lexers import PythonLexer
from pygments.formatters import HtmlFormatter

# Refresh interval in seconds for more responsiveness
refresh_interval = 3

# We define the widgets early because some functions will write to the debug
# and/or chat panels.
code_editor_file = "code_editor_file.py"
os.system(f'rm {code_editor_file}; echo "# Welcome to the Science AI Dataset Search Agent Code Editor.\n# You can view the code being executed by the agent here" >> {code_editor_file}')

# Function to read the file content
def read_file(file_path):
    with open(file_path, 'r') as file:
        return file.read()

# Function to write the file content
def write_file(file_path, content):
    with open(file_path, 'w') as file:
        file.write(content)

# Function to highlight and display code
def highlight_code(code):
    formatter = HtmlFormatter(style='default', full=True, linenos=True)
    return pygments.highlight(code, PythonLexer(), formatter)

# Define a global flag for stopping the thread
stop_event = threading.Event()

# Function to periodically refresh the code editor
def refresh_code_editor():
    while not stop_event.is_set():
        time.sleep(refresh_interval)  # Wait for a specified interval
        code = read_file(code_editor_file)
        code_editor.value = highlight_code(code)  # Update editor with highlighted code

# Create a widget to display highlighted code
code_editor = widgets.HTML(
    value=highlight_code(read_file(code_editor_file)),
    layout=widgets.Layout(width='300px', height='300px')
)

# Ensure that the previous thread is cleaned up
def cleanup_thread():
    if 'code_editor_refresh_thread' in globals():
        stop_event.set()
        code_editor_refresh_thread.join()
        del globals()['code_editor_refresh_thread']
        stop_event.clear()

# Cleanup any existing thread before starting a new one. This is a colab
# notebook so without a cleanup, running the cell twice would cause multiple
# competeting threads.
cleanup_thread()

# Start a new background thread to refresh the code editor
code_editor_refresh_thread = threading.Thread(target=refresh_code_editor, daemon=True)
code_editor_refresh_thread.start()

# Simple functions that LLM will call

In [None]:
import tempfile
import runpy
from bs4 import BeautifulSoup

def set_center(x: float, y: float, zoom: int) -> str:
    """Sets the map center to the given coordinates and zoom level and
    returns instructions on what to do next."""
    with debug_output:
     print(f"SET_CENTER({x}, {y}, {zoom})\n")
    Map.set_center(x, y)
    Map.zoom = zoom
    global map_dirty
    map_dirty = True
    return (
      'Do not call any more functions in this request to let geemap bounds '
      'update. Wait for user input.')

def add_image_layer(image_id: str) -> str:
    """Adds to the map center an ee.Image with the given id
    and returns status message (success or failure)."""
    Map.clear()
    command_output.value = f"add_image_layer('{image_id}')"
    Map.addLayer(ee.Image(image_id))
    return 'success'

def _get_dataset_code_sample(catalog_url: str):
  """Fetches the Javascript and/or Python code sample on a EE catalog page."""
  response = requests.get(catalog_url)
  code_sample = ""
  # Check if the request was successful
  if response.status_code == 200:
      html_content = response.content
      soup = BeautifulSoup(html_content, 'lxml')

      # Extract JavaScript and Python Earth Engine code
      codes = {}
      for pre in soup.find_all('pre'):
          # Get the language of the code
          lang = pre.get('class', [])
          if 'lang-javascript' in lang:
              codes['JavaScript'] = pre.get_text(strip=True)
          elif 'lang-python' in lang:
              codes['Python'] = pre.get_text(strip=True)

      # Print the extracted code
      if not codes:
          return "No code snippets found."
      else:
          for lang, code in codes.items():
              code_sample += "--- " + lang + " Code ---\n"
              code_sample += code + "\n"
          return code_sample
  else:
      raise RuntimeError(f"Failed to retrieve the page. Status code: {response.status_code}")

def get_dataset_description(dataset_id: str) -> str:
  """Fetches JSON STAC description for the given Earth Engine dataset id.
  This function can be used to access all available metadata about an Earth
  Engine dataset, including examples of visualization parameters."""
  with debug_output:
    print(f'LOOKING UP {dataset_id}\n')
  parent = dataset_id.split('/')[0]

  # Get the blob (file)
  path = os.path.join('catalog', parent, dataset_id.replace('/', '_')) + '.json'
  blob = bucket.blob(path)

  if not blob.exists():
    return 'dataset file not found: ' + path

  file_contents = blob.download_as_string().decode()

  # Parse the JSON data
  entry = json.loads(file_contents)
  for link in entry['providers']:
    if 'name' in link and link['name'] == 'Google Earth Engine':
      catalog_url = link["url"]
      break
  else:
    raise ValueError(f"No catalog link found for {id}")
  code_sample = _get_dataset_code_sample(catalog_url)
  entry["code_sample"] = code_sample
  return json.dumps(entry)

def _get_image(image_url: str) -> bytes:
  """Fetches from Earth Engine the content of the given URL as bytes."""
  response = requests.get(image_url)

  if response.status_code == 200:
    return response.content
  else:
    error_message = f'Error downloading image: {response}'
    try:
      error_details = (
          json.loads(response.content.decode()).get('error', {}).get('message')
      )
      if error_details:
        error_message += f' - {error_details}'
    except json.JSONDecodeError:
      pass
    with debug_output:
      print(error_message)
    raise ValueError("URL %s causes %s" % (image_url, error_message))

def show_layer(python_code: str) -> str:
    """Execute the given Earth Engine Python client code and add the result to
    the map. Returns the status message (success or error message)."""
    Map.layers = Map.layers[:2]
    while '\\"' in python_code:
      python_code = python_code.replace('\\"', '"')
    command_output.value = f"show_layer('{python_code}')"
    with debug_output:
      print(f'IMAGE:\n {python_code}\n')
    try:
      locals = {}
      exec(f"import ee; im = {python_code}", {}, locals)
      with open(code_editor_file, 'w', encoding='utf-8') as f:
        f.write(python_code.replace(").", ")\n."))
      Map.addLayer(locals['im'])
    except Exception as e:
      with debug_output:
        print(f"ERROR: {e}"  )
      return str(e)
    return 'success'

def inner_monologue(thoughts: str) -> str:
  """Sends the current thinking of the LLM model to the user so that they are
  aware of what the model is thinking between function calls."""
  with debug_output:
    print(f'THOUGHTS:\n {thoughts}\n')
  return 'success'

# Functions for textual analysis of images

In [None]:
def _lat_lon_to_tile(lon, lat, zoom_level):
    # Convert latitude and longitude to Mercator coordinates
    x_merc = (lon + 180) / 360
    y_merc = (1 - math.log(math.tan(math.radians(lat)) + 1 / math.cos(math.radians(lat))) / math.pi) / 2

    # Calculate number of tiles
    n = 2 ** zoom_level

    # Convert to tile coordinates
    X = int(x_merc * n)
    Y = int(y_merc * n)

    return X, Y

def analyze_image(additional_instructions:str='') -> str:
    """Returns GenAI image analysis describing the current map image.
    Optional additional instructions might be passed to target the analysis
    more precisely.
    """
    global map_dirty
    if map_dirty:
        print('MAP DIRTY')
        return 'Map is not ready. Stop further processing and ask for user input'

    try:
      return _analyze_image(additional_instructions)
    except ValueError as e:
      return str(e)

def _analyze_image(additional_instructions:str='') -> str:
    bounds = Map.bounds
    s, w = bounds[0]
    n, e = bounds[1]
    zoom = int(Map.zoom)

    min_tile_x, max_tile_y = _lat_lon_to_tile(w, s, zoom)
    max_tile_x, min_tile_y = _lat_lon_to_tile(e, n, zoom)
    min_tile_x = max(0, min_tile_x)
    max_tile_x = min(2**zoom-1, max_tile_x)
    min_tile_y = max(0, min_tile_y)
    max_tile_y = min(2**zoom-1, max_tile_y)

    with debug_output:
      if additional_instructions:
        print(f"RUNNING IMAGE ANALYSIS: {additional_instructions}...\n")
      else:
        print("RUNNING IMAGE ANALYSIS...\n")

    layers = list(Map.ee_layer_dict.values())
    if not layers:
      return 'No data layers loaded'
    url_template = layers[-1]['ee_layer'].url
    tile_width = 256
    tile_height = 256
    image_width = (max_tile_x - min_tile_x + 1) * tile_width
    image_height = (max_tile_y - min_tile_y + 1) * tile_height

    # Create a new blank image
    image = PIL.Image.new("RGB", (image_width, image_height))

    for y in range(min_tile_y, max_tile_y + 1):
      for x in range(min_tile_x, max_tile_x + 1):
        tile_url = str.format(url_template, x=x, y=y, z=zoom)
        #print(tile_url)
        tile_img = PIL.Image.open(io.BytesIO(_get_image(tile_url)))

        offset_x = (x - min_tile_x) * tile_width
        offset_y = (y - min_tile_y) * tile_height
        image.paste(tile_img, (offset_x, offset_y))

    width, height = image.size
    num_bands = len(image.getbands())
    image_array = np.array(image)
    image_min = np.min(image_array)
    image_max = np.max(image_array)

    # Skip an LLM call when we can simply tell that something is wrong.
    # (Also, LLMs might hallucinate on uniform images.)
    if image_min == image_max:
      return (
          f'The image tile has a single uniform color with value '
          f'{image_min}.'
      )

    query = """You are an objective, precise overhead imagery analyst.
Describe what the provided map tile depicts in terms of:

1. The colors, textures, and patterns visible in the image.
2. The spatial distribution, shape, and extent of distinct features or regions.
3. Any notable contrasts, boundaries, or gradients between different areas.

Avoid making assumptions about the specific geographic location, time period,
or cause of the observed features. Focus solely on the literal contents of the
image itself. Clearly indicate which features look natural, which look human-made,
and which look like image artifacts. (Eg, a completely straight blue line
is unlikely to be a river.)

If the image is ambiguous or unclear, state so directly. Do not speculate or
hypothesize beyond what is directly visible.

Do not address a lack of text or captions in the image. These are provided elsewhere.

If the image is of mostly the same color (white, gray, or black) with little
contrast, just report that and do not describe the features.

Use clear, concise language. Avoid subjective interpretations or analogies.
Organize your response into structured paragraphs.
"""
    if additional_instructions:
      query += additional_instructions
    req = {
        'parts': [
            {
                'text': query

            },
            {'inline_data': image},
        ]
    }
    image_response = image_model.generate_content(req)
    try:
      with debug_output:
        print(f'ANALYSIS RESULT: {image_response.text}\n')
      return image_response.text
    except ValueError as e:
      with debug_output:
        print(f'UNEXPECTED IMAGE RESPONSE: {e}')
        print(image_response)
      breakpoint()

# Load EE Dataset Embeddings
This will take a few minutes the first time. The second time the cell is run it will be only ~10s.

In [None]:
from langchain.embeddings.base import Embeddings
from langchain.indexes import VectorstoreIndexCreator
from langchain.indexes.vectorstore import VectorstoreIndexCreator
from langchain.schema import Document
from langchain_google_genai import ChatGoogleGenerativeAI

llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro", google_api_key=userdata.get('GOOGLE_API_KEY'))

def rate_limit(max_per_minute):
  period = 60 / max_per_minute
  while True:
    before = time.time()
    yield
    after = time.time()
    elapsed = after - before
    sleep_time = max(0, period - elapsed)
    if sleep_time > 0:
      print(f'Sleeping {sleep_time:.1f} seconds')
      time.sleep(sleep_time)

class VertexEmbeddings(Embeddings):

  def __init__(self, model, *, requests_per_minute=15):
    self.model = model
    self.requests_per_minute = requests_per_minute

  def embed_documents(self, texts):
    limiter = rate_limit(self.requests_per_minute)
    results = []
    docs = list(texts)

    while docs:
      # Working in batches of 2 because the API apparently won't let
      # us send more than 2 documents per request to get embeddings.
      head, docs = docs[:2], docs[2:]
      chunk = self.model.get_embeddings(head)
      results.extend(chunk)
      next(limiter)

    return [r.values for r in results]

  def embed_query(self, text):
    single_result = self.embed_documents([text])
    return single_result[0]

jsonl_path = f'catalog_summaries.jsonl'
embedding_model = TextEmbeddingModel.from_pretrained("google/text-embedding-004")
embedding = VertexEmbeddings(embedding_model, requests_per_minute=600)

def load_docs_from_jsonl(file_path):
  docs = []
  with open(file_path, 'r') as jsonl_file:
    for line in jsonl_file:
      data = json.loads(line)
      obj = Document(**data)
      docs.append(obj)
  return docs

# Create the embeddings only if the embeddings aren't already defined.
if not ('index' in globals() or 'index' in locals()):
  if os.path.isfile(jsonl_path):
    print(f"Found existing jsonl index at ${jsonl_path}")
    documents = load_docs_from_jsonl(jsonl_path)
    # load the dates as datetime objects
    for i in range(len(documents)):
      documents[i].metadata['temporal'] = [dateutil.parser.parse(date_str) if date_str is not None else None for date_str in documents[i].metadata['temporal']]
    index = VectorstoreIndexCreator(embedding=embedding).from_documents(documents)
  else:
    raise ValueError(f"No catalog found at {jsonl_path}")


## Dataset filtering functions

In [None]:
import shapely
import pandas as pd
import datetime
from typing import Optional, Union

def _temporal_bounds_intersect(tbox1, tbox2):
    start1, end1 = tbox1
    start2, end2 = tbox2

    # Convert dates to datetimes
    start_date_to_start_datetime = lambda start_date: datetime.datetime(start_date.year, start_date.month, start_date.day, 0, 0, 0) if isinstance(start_date, datetime.date) else start_date
    end_date_to_end_datetime = lambda end_date: start_date_to_start_datetime(end_date) + datetime.timedelta(days=1) if isinstance(end_date, datetime.date) else end_date

    start1 = start_date_to_start_datetime(start1).timestamp() if start1 is not None else float('-inf')
    start2 = start_date_to_start_datetime(start2).timestamp() if start2 is not None else float('-inf')
    end1 = end_date_to_end_datetime(end1).timestamp() if end1 is not None else float('inf')
    end2 = end_date_to_end_datetime(end2).timestamp() if end2 is not None else float('inf')

    return not (start1 > end2 or start2 > end1)

def _spatial_bounds_intersect(sbox1, sbox2):
  return shapely.box(*sbox1).intersects(shapely.box(*sbox2))

def _raise_error(e):
  """Utility to support error raising in python expressions"""
  raise e

float_to_int_without_truncation = lambda x: int(x) if x == int(x) else _raise_error(ValueError(f"{x} is not an integer"))

from io import BytesIO
from PIL import Image
import base64


def _image_to_html(img_url, width=100, height=100):
    # Download image
    response = requests.get(img_url)
    img = Image.open(BytesIO(response.content))

    # Convert image to base64
    buffer = BytesIO()
    img.save(buffer, format='PNG')
    img_str = base64.b64encode(buffer.getvalue()).decode('ascii')

    # Create HTML image tag
    html_img = f'<img src="data:image/png;base64,{img_str}" width="{width}" height="{height}">'
    return html_img

def _build_dataset_widget(id: str):
  description = get_dataset_description(id)
  print(description)
  entry = json.loads(description)
  # get image url
  for link in entry['links']:
    if 'rel' in link and link['rel'] == 'preview':
      assert link['type'] == 'image/png'
      image_url = link['href']
      break
  else:
    raise ValueError(f"No preview image found for {id}")
  # get EE catalog URL
  for link in entry['providers']:
    if 'name' in link and link['name'] == 'Google Earth Engine':
      catalog_url = link["url"]
      break
  else:
    raise ValueError(f"No catalog link found for {id}")
  image = _image_to_html(image_url)
  title = entry['title']
  temporal = str(entry['extent']['temporal']['interval'])
  keywords = ", ".join(entry['keywords'])
  id_table_entry = f'<a href="{catalog_url}" target="_blank">{id}</a>'
  return pd.Series({'Image': image, "ID": id_table_entry, "Title": title, "Temporal Span": temporal, "Keywords": keywords})

def build_dataset_widget(ids: list[str]):
  return pd.DataFrame([_build_dataset_widget(id) for id in ids])

def assert_condition(condition, exp):
  assert condition, f"Assertion failed, {condition}"
  return exp

def find_dataset(query: str, results: int = 4, threshold: float = 0.7, spatial: Optional[list[float]]=None, temporal: Optional[list[Optional[list[int]]]]=None) -> pd.DataFrame:
  """
  Retrieve relevant dataset from the Earth Engine data catalog.

  query: str. The kind of data being searched for. ie 'population' or 'Cheese Futures'.
  results: int. The number of datasets to return. 4 is recommended.
  threshold: float. The maximum dot product between the query and catalog
    embeddings. Recommended 0.7.
  spatial: Optional[list[float]]. The spatial bounding box for the query, in the
    format [lon1, lat1, lon2, lon2]. If None then no spatial filter is appled.
  temporal: Optional[list[Optional[list[int]]]]. If provided, temporal
    constraints are provided as a list of two int lists following the structure
    [[year, month, day], [year, month, day]]. A none can be used to set no
    start or end date. For example [None, [2022,12,31]] will return all datasets
    that have data before 2022-12-31.
  """
  results = float_to_int_without_truncation(results)
  if temporal is not None:
    temporal = [datetime.date(*list(map(float_to_int_without_truncation, temporal[0]))) if temporal[0] is not None else None,
                datetime.date(*list(map(float_to_int_without_truncation, temporal[1]))) if temporal[1] is not None else None]
    if temporal[0] is not None and temporal[1] is not None and temporal[0] > temporal[1]:
      raise ValueError(f"Temporal bounds must be in chronological order: {temporal}")

  # Get the relevant dataset
  similar_datasets = index.vectorstore.similarity_search_with_score(query, llm=llm, k=len(documents))
  # Filter by relevance threshold
  filtered = list(filter(lambda doc: doc[1] <= threshold, similar_datasets))
  # Filter by time
  filtered = list(filter(lambda doc: _temporal_bounds_intersect(doc[0].metadata['temporal'], temporal) if temporal is not None else True, filtered))
  # Remove community catalog entries
  filtered = list(filter(lambda doc: not doc[0].metadata['id'].startswith("projects/"), filtered))
  # Filter by space
  filtered = list(filter(lambda doc: assert_condition(len(doc[0].metadata['spatial']) == 1, _spatial_bounds_intersect(doc[0].metadata['spatial'][0], spatial)) if spatial is not None else True, filtered))
  # Select the top k results
  filtered = filtered[:min(results, len(filtered))]
  # Format output
  outputs = [[dataset[0].metadata['title'], dataset[0].metadata['id'], dataset[0].page_content] for dataset in filtered]
  update_dataset_widget(build_dataset_widget([output[1] for output in outputs]))
  return outputs

# Function for scoring how well image analysis corresponds to the user query.

In [None]:
# Note that we ask for the score outside of the main agent chat to keep
# the scoring more objective.

scoring_system_prompt = """
After looking at the user query and the map tile analysis, start
your answer with a number between 0 and 1 indicating how relevant
the image is as an answer to the query. (0=irrelevant, 1=perfect answer)

Make sure you have enough justification to definitively declare the analysis
relevant - it's better to give a false negative than a false positive. However,
the image analysis identifies specific matching landmarks (eg, the
the outlines of Manhattan island for a request to show NYC), believe it.

Do not assume  too much (eg, that the presence of green doesn't by itself mean the
image shows forest); attempt to find multiple (at least three) independent
lines of evidence before declaring victory and cite all these lines of evidence
in your response.

Be very, very skeptical - look for specific features that match only the query
and nothing else (eg, if the query looks for a river, a completely straight blue
line is unlikely to be a river). Think about what size the features are based on
the zoom level and whether this size matches the feature size expected from
first principles.

If there is ambiguity or uncertainty, express it in your analysis and
lower the score accordingly. If the image analysis is inconclusive, try zooming
out to make sure you are looking at the right spot. Do not reduce the score if
the analysis does not mention visualization parameters - they are just given for
your reference. The image might show an area slightly larger than requested -
this is okay, do not reduce the score on this account.
"""

def score_response(query: str, visualization_parameters: str, analysis: str) -> str:
    """Returns how well the given analysis describes a map tile returned for
    the given query. The analysis starts with a number between 0 and 1.

    Arguments:
      query: user-specified query
      visualization_parameters: description of the bands used and visualization
        parameters applied to the map tile
      analysis: the textual description of the map tile
    """
    with debug_output:
      print(f"VIZ PARAMS: {visualization_parameters}\n")
    question = (
        f"""For user query {query} please score the following analysis:
       {analysis}. The answer must start with a number between 0 and 1.""")
    if visualization_parameters:
      question += (
          f"""Do not assume that common bands or visualization
          parameters should have been used, as the visualization used the
          following parameters: {visualization_parameters}""")

    result = analysis_model.ask(question)
    global iteration
    with debug_output:
      print(f'SCORE #{iteration}:\n {result}\n')
    iteration += 1
    return result

# Main prompt for the agent

In [None]:
system_prompt = f"""
The client is running in a Python notebook with a geemap Map displayed. The
client also has a code editor that's initialized and authenticated with
earthengine. When given a task or a question. Start by making a plan of how you
will approach solving it. Important: Only use the map if the user stated that
they want to see or visualize data.

When composing Python code for the map, do not use getMapId - just return the
single-line layer definition like 'ee.Image("USGS/SRTMGL1_003")' that we will
pass to Map.addLayer(). Do not escape quotation marks in Python code.

Be sure to use Python, not Javascript, syntax for keyword parameters in
Python code (that is, "function(arg=value)") Using the provided functions,
respond to the user command following below (or respond why it's not possible).
If you get an Earth Engine error, attempt to fix it and then try again.

Before you choose a dataset, think about what kind of dataset would be most
suitable for the query. When using the map also think about what zoom level
would be suitable for the query, keeping in mind that for high-resolution
image collections higher zoom levels are better to speed up tile loading. You
can search for available datasets using the find_dataset functions. For example
to find datasets that have information about burn scars between 2020 and 2022,
you can use find_dataset('Burn Scars', temporal=[[2020,1,1], [2022,12,31]]).

Once you have chosen a dataset, read its description using the provided function
to see what spatial and temporal range it covers, what bands it has, as well as
to find the recommended visualization parameters. Explain using the inner
monlogue function why you chose a specific dataset, zoom level and map location.

Use Landsat Collection 2, not Landsat Collection 1 ids. If you are getting
repeated errors when filtering by a time range, read the dataset description
to confirm that the dataset has data for the selected range.

When visualizing on the map, prefer mosaicing image collections using the
mosaic() function, don't get individual images from collections via 'first()'.
Choose a tile size and zoom level that will ensure the tile has enough pixels
in it to avoid graininess, but not so many that processing becomes very
expensive. Do not use wide date ranges with collections that have many images,
but remember that Landsat and Sentinel-2 have revisit period of several days.
Do not use sample locations - try to come up with actual locations that are
relevant to the request.

IF YOU ARE ASKED TO PROVIDE INFORMATION ABOUT A DATASET, DO NOT MAP THE
DATASET UNLESS THE USER REQUESTS THAT YOU MAP THE DATASET.

Important: after using the set_center() function, just say that you have called
this function and wait for the user to hit enter, after which you should
continue answering the original request. This will make sure the map is updated
on the client side.

Once the map is updated and the user told you to proceed, call the analyze_image
function() to describe the image for the same location that will be shown in
geemap. If you pass additional instructions to analyze_image(), do not disclose
what the image is supposed to be to discourage hallucinations - you can only tell
the analysis function to pay attention to specific areas (eg, center or top left)
or shapes (eg, a line at the bottom) in the image. You can also tell the analysis
function about the chosen bands, color palette and min/max visualization
parameters, if any, to help it interpret the colors correctly. If the image
turns out to be uniform in color with no features,
use min/max visualization parameters to enhance contrast.

Frequently call the inner_monologue() functions to tell the user about your
current thought process. This is a good time to reflect if you have been running
into repeated errors of the same kind, and if so, to try a different approach.

When you are done, call the score_response() function to evaluate the analysis.
You can also tell the scoring function about the chosen bands, color palette
and min/max visualization parameters, if any. If the analysis score is below
{target_score},
keep trying to find and show a better image. You might have to change the dataset,
map location, zoom level, date range, bands, or other parameters - think about
what went wrong in the previous attempt and make the change that's most likely
to improve the score.
"""

# Class for LLM chat with function calling

In [None]:
import google.generativeai as genai

gemini_tools=[
        set_center,
        show_layer,
        analyze_image,
        inner_monologue,
        get_dataset_description,
        score_response,
        find_dataset,
]

class Gemini():
  """Gemini LLM."""

  def __init__(self, system_prompt, tools=None):
    if not tools:
      tools = []
    self._text_model = genai.GenerativeModel(
      model_name='gemini-1.5-pro-latest',
      tools=tools
    )

    initial_messages = glm.Content(
        role='model',
        parts=[glm.Part(text=system_prompt)])
    self._chat_proxy = self._text_model.start_chat(
        history=initial_messages, enable_automatic_function_calling=True)

  def ask(self, question, temperature=0):
    while True:
      condition = ''
      try:
        sleep_duration = 10
        response = self._text_model.generate_content(question + condition)
        return response.text
      except genai.types.generation_types.StopCandidateException as e:
          if glm.Candidate.FinishReason.RECITATION == e.args[0].finish_reason:
            condition = (
                'Previous attempt returned a RECITATION error. '
                'Rephrase the answer to avoid it.')
          with chat_output:
            command_input.description = '🆁'
          time.sleep(1)
          with chat_output:
            command_input.description = '🤔'
          continue
      except (
          google.api_core.exceptions.TooManyRequests,
          google.api_core.exceptions.DeadlineExceeded
      ):
        with debug_output:
          command_input.description = '💤'
        time.sleep(sleep_duration)
        continue
      except ValueError as e:
        with debug_output:
          print(f'Response {response} led to error: {e}')
        breakpoint()
        i = 1

  def chat(self, question: str, temperature=0) -> str:
    """Adds a question to the ongoing chat session."""
    # Always delay a bit to reduce the chance for rate-limiting errors.
    time.sleep(1)
    condition = ''
    sleep_duration = 10
    while True:
      response = ''
      try:
        response = self._chat_proxy.send_message(
            question + condition,
            generation_config={
                'temperature': temperature,
                # Use a generous but limited output size to encourage in-depth
                # replies.
                'max_output_tokens': 5000,
            }
        )
        if not response.parts:
          raise ValueError(
              'Cannot get analysis with reason'
              f' {response.candidates[0].finish_reason.name}, terminating'
          )
      except genai.types.generation_types.StopCandidateException as e:
          if glm.Candidate.FinishReason.RECITATION == e.args[0].finish_reason:
            condition = (
                'Previous attempt returned a RECITATION error. '
                'Rephrase the answer to avoid it.')
          with chat_output:
            command_input.description = '🆁'
          time.sleep(1)
          with chat_output:
            command_input.description = '🤔'
          continue
      except (
            google.api_core.exceptions.TooManyRequests,
            google.api_core.exceptions.DeadlineExceeded
        ):
          with debug_output:
            command_input.description = '💤'
          time.sleep(10)
          continue
      try:
        return response.text
      except ValueError as e:
       with debug_output:
        print(f'Response {response} led to the error {e}')

model = Gemini(system_prompt, gemini_tools)
analysis_model = Gemini(scoring_system_prompt)

# UI functions

In [None]:
def set_cursor_waiting():
    js_code = """
    document.querySelector('body').style.cursor = 'wait';
    """
    display(HTML(f"<script>{js_code}</script>"))

def set_cursor_default():
    js_code = """
    document.querySelector('body').style.cursor = 'default';
    """
    display(HTML(f"<script>{js_code}</script>"))

def on_submit(widget):
    global map_dirty
    map_dirty = False
    command_input.description = '❓'
    command = widget.value
    if not command:
      command = 'go on'
    with chat_output:
      print('> ' + command + '\n')
    if command != 'go on':
      with debug_output:
        print('> ' + command + '\n')
    widget.value = ''
    set_cursor_waiting()
    command_input.description = '🤔'
    response = model.chat(command, temperature=0)
    if map_dirty:
      command_input.description = '🙏'
    else:
      command_input.description = '❓'
    set_cursor_default()
    response = response.strip()
    if not response:
      response = ''
    with chat_output:
        print(response + '\n')
    command_input.value = ''

command_input.on_submit(on_submit)

# UI layout

In [None]:
def build_frontend(show_debug_column=False):
  # Arrange the chat history and input in a vertical box
  chat_ui = widgets.VBox([chat_output], layout=widgets.Layout(width='400px'))

  chat_output.layout = widgets.Layout(width='400px')  # Fixed width for the left control
  Map.layout = widgets.Layout(width='600px')

  # labels
  code_viewer_label = widgets.Label("Code Generated by Agent")
  dataset_label = widgets.Label("Datasets Discovered by Agent")
  map_label = widgets.Label("Map")


  top_row = widgets.HBox([chat_output, code_editor, code_viewer_label])
  bottom_row = widgets.HBox([Map, dataset_widget])

  ui_widgets = [top_row, command_input, bottom_row]
  ui_widgets = ui_widgets + [debug_output] if show_debug_column else ui_widgets
  return widgets.VBox(ui_widgets)

# Run

In [None]:
# Display the layout
ui = build_frontend(show_debug_column = False)
display(ui)
# print('❓ = waiting for user input')
# print('🙏 = waiting for user to hit enter after calling set_center()')
# print('🤔 = thinking')
# print('💤 = sleeping due to retries')
# print('🆁 = Gemini recitation error')