[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/earthengine-community/blob/master/experimental/scienceai_ee_dataset_explorer/scienceai_ee_dataset_explorer_embeddings_generation_v0.ipynb)

In [None]:
#@title Copyright 2024 The Earth Engine Community Authors { display-mode: "form" }
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Earth Engine Dataset Explorer - Embeddings Generation

## Overview
This notebook, built by the Science AI team in Google Research, is intended to supplement our main [Earth Engine Dataset Explorer](https://github.com/google/earthengine-community/blob/master/experimental/scienceai_ee_dataset_explorer/scienceai_ee_dataset_explorer_v0.ipynb) by demonstrating how the dataset summaries and [embeddings](https://developers.google.com/machine-learning/crash-course/embeddings) were generated for use in the main EE Dataset Explorer notebook.
For more details on the project as a whole, see the main notebook, or the [README](https://github.com/google/earthengine-community/tree/master/experimental/scienceai_ee_dataset_explorer).

The notebook uses:

 - The [Gemini 1.5 Pro language model](https://blog.google/technology/ai/gemini-1-5/) to create concise summaries of Earth Engine dataset descriptions.
 - [Google Text Embedding API](https://cloud.google.com/natural-language/docs/embedding-overview) to generate vector representations of these summaries. These embeddings are then uploaded to a Google Cloud Storage bucket for use in downstream applications, such as the [Earth Engine Dataset Explorer](https://github.com/google/earthengine-community/blob/master/experimental/scienceai_ee_dataset_explorer/scienceai_ee_dataset_explorer_v0.ipynb).


## Setup Details and Billing

You will need:

- A Google cloud project with the Earth Engine API enabled. ([Details](https://developers.google.com/earth-engine/cloud/earthengine_cloud_project_setup)).
- A Gemini API key. ([Details](https://ai.google.dev/gemini-api/docs/api-key)).
- (Optionally) A predefined Google Cloud Storage (GCS) bucket. ([Details](https://cloud.google.com/storage/docs/buckets)).

Each of the above can be stored in the [colab "Secrets" panel](https://medium.com/@parthdasawant/how-to-use-secrets-in-google-colab-450c38e3ec75). Add the following strings as secrets:

 - Use `GOOGLE_PROJECT_ID` for the Cloud project id.
 - Use `GOOGLE_API_KEY` for the Gemini API key
 - Use `DESTINATION_BUCKET` for the GCS bucket where you want to upload embeddings.

## Caveats

 - This is an early prototype, bugs and unexpected behavior are likely. Code improvements and refactors to follow.

 - Currently the notebook uses Langchain for some of the dataset summarization "glue", but this will likely change in a future version.

 - The very lightweight use of the TextEmbedding API from VertexAI requires billing to be enabled in your Cloud project. It should be an extremely minimal expense. ([Details](https://cloud.google.com/vertex-ai/generative-ai/pricing)).

 - For assistance, please email scienceai_ee_dataset_explorer@googlegroups.com.

In [None]:
#@title Install Python Libraries

%%capture
!pip install google_cloud_aiplatform langchain-community langchain_google_genai langchain iso8601

In [None]:
#@title Imports
# Standard library imports
import dataclasses
import datetime
import json
import logging
import os
import re
import time
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence

# Third-party imports
import iso8601
import pandas as pd
import tenacity
import tqdm
import vertexai
from IPython.display import HTML, display, clear_output
from google.api_core import exceptions as google_exceptions
from google.cloud import storage
from google.colab import userdata
from langchain.chains.summarize import load_summarize_chain
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.language_models.base import BaseLanguageModel
from langchain_google_genai import ChatGoogleGenerativeAI
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type, wait_fixed
from vertexai.preview.language_models import TextEmbeddingModel

# Specific exception imports
from google.api_core.exceptions import ResourceExhausted

In [None]:
#@title Setup
project_name = userdata.get('GOOGLE_PROJECT_ID')
vertex_ai_zone = "us-central1"

storage_client = storage.Client(project=project_name)
from google.colab import auth
auth.authenticate_user()
vertexai.init(project=project_name, location=vertex_ai_zone)

# Define classes for working with the Earth Engine data catalog

These will soon be broken up into their own files.

In [None]:
#@title Helper methods
def matches_interval(
    collection_interval: tuple[datetime.datetime, datetime.datetime],
    query_interval: tuple[datetime.datetime, datetime.datetime],
):
  """Checks if the collection's datetime interval matches the query datetime interval.

  Args:
    collection_interval: Temporal interval of the collection.
    query_interval: a tuple with the query interval start and end

  Returns:
    True if the datetime interval matches
  """
  start_query, end_query = query_interval
  start_collection, end_collection = collection_interval
  if end_collection is None:
    # End date should always be set in STAC JSON files, but just in case...
    end_collection = datetime.datetime.now(tz=datetime.UTC)
  return end_query > start_collection and start_query <= end_collection



def matches_datetime(
    collection_interval: tuple[datetime.datetime, Optional[datetime.datetime]],
    query_datetime: datetime.datetime,
):
  """Checks if the collection's datetime interval matches the query datetime.

  Args:
    collection_interval: Temporal interval of the collection.
    query_datetime: a datetime coming from a query

  Returns:
    True if the datetime interval matches
  """
  if collection_interval[1] is None:
    # End date should always be set in STAC JSON files, but just in case...
    end_date = datetime.datetime.now(tz=datetime.UTC)
  else:
    end_date = collection_interval[1]
  return collection_interval[0] <= query_datetime <= end_date

In [None]:
# @title class BBox()
@dataclasses.dataclass
class BBox:
  """Class representing a lat/lon bounding box."""
  west: float
  south: float
  east: float
  north: float

  def is_global(self) -> bool:
    return (
        self.west == -180 and self.south == -90 and
        self.east == 180 and self.north == 90)

  @classmethod
  def from_list(cls, bbox_list: list[float]):
    """Constructs a BBox from a list of four numbers [west,south,east,north]."""
    if bbox_list[0] > bbox_list[2]:
      raise ValueError(
          'The smaller (west) coordinate must be listed first in a bounding box'
          f' corner list. Found {bbox_list}'
      )
    if bbox_list[1] > bbox_list[3]:
      raise ValueError(
          'The smaller (south) coordinate must be listed first in a bounding'
          f' box corner list. Found {bbox_list}'
      )
    return cls(bbox_list[0], bbox_list[1], bbox_list[2], bbox_list[3])

  def to_list(self) -> list[float]:
    return [self.west, self.south, self.east, self.north]

  def intersects(self, query_bbox) -> bool:
    """Checks if this bbox intersects with the query bbox.

    Doesn't handle bboxes extending past the antimeridaian.

    Args:
      query_bbox: Bounding box from the query.

    Returns:
      True if the two bounding boxes intersect
    """
    return (
        query_bbox.west < self.east
        and query_bbox.east > self.west
        and query_bbox.south < self.north
        and query_bbox.north > self.south
    )

In [None]:
# @title class Collection()
class Collection:
  """A simple wrapper for a STAC Collection.."""
  stac_json: dict[str, Any]

  def __init__(self, stac_json: dict[str, Any]):
    self.stac_json = stac_json
    if stac_json.get('gee:status') == 'deprecated':
      # Set the STAC 'deprecated' field that we don't set in the jsonnet files
      stac_json['deprecated'] = True

  def __getitem__(self, item: str) -> Any:
    return self.stac_json[item]

  def get(self, item: str, default: Optional[Any] = None) -> Optional[Any]:
    """Matches dict's get by returning None if there is no item."""
    return self.stac_json.get(item, default)

  def public_id(self) -> str:
    return self['id']

  def sanitize_id(self) -> str:
    return self['id'].replace('/', '_')

  def get_dataset_type(self) -> str:
    """Could be Image, ImageCollection, FeatureCollection, Feature."""
    return self['gee:type']

  def is_deprecated(self) -> bool:
    """Returns True for collections that are deprecated or have a successor."""
    if self.get('deprecated', False):
      logging.info('Skipping deprecated collection: %s', self.public_id())
      return True

  def datetime_interval(
      self,
  ) -> Iterable[tuple[datetime.datetime, Optional[datetime.datetime]]]:
    """Returns datetime objects representing temporal extents."""
    for stac_interval in self.stac_json['extent']['temporal']['interval']:
      if not stac_interval[0]:
        raise ValueError(
            'Expected a non-empty temporal interval start for '
            + self.public_id()
        )
      start_date = iso8601.parse_date(stac_interval[0])
      if stac_interval[1] is not None:
        end_date = iso8601.parse_date(stac_interval[1])
      else:
        end_date = None
      yield (start_date, end_date)

  def start(self) -> datetime.datetime:
    return list(self.datetime_interval())[0][0]

  def start_str(self) -> datetime.datetime:
    if not self.start():
      return ''
    return self.start().strftime("%Y-%m-%d")

  def end(self) -> Optional[datetime.datetime]:
    return list(self.datetime_interval())[0][1]

  def end_str(self) -> Optional[datetime.datetime]:
    if not self.end():
      return ''
    return self.end().strftime("%Y-%m-%d")

  def bbox_list(self) -> Sequence[BBox]:
    if 'extent' not in self.stac_json:
      # Assume global if nothing listed.
      return (BBox(-180, -90, 180, 90),)
    return tuple([
        BBox.from_list(x)
        for x in self.stac_json['extent']['spatial']['bbox']
    ])

  def bands(self) -> List[Dict]:
    summaries = self.stac_json.get('summaries')
    if not summaries:
      return []
    return summaries.get('eo:bands', [])

  def spatial_resolution_m(self) -> float:
    summaries = self.stac_json.get('summaries')
    if not summaries:
      return -1
    if summaries.get('gsd'):
      return summaries.get('gsd')[0]

    # Hacky fallback for cases where the stac does not follow convention.
    gsd_lst = re.findall(r'"gsd": (\d+)', json.dumps(self.stac_json))

    if len(gsd_lst) > 0:
      return float(gsd_lst[0])

    return -1


  def temporal_resolution_str(self) -> str:
    interval_dict = self.stac_json.get('gee:interval')
    if not interval_dict:
      return ""
    return f"{interval_dict['interval']} {interval_dict['unit']}"


  def set_js_code(self, code: str):
    if not code:
      return ''
    js_code = self.stac_json.get('code').get('js_code')
    self.stac_json['code'] = {'js_code': ''}

  def image_preview_url(self):
    for link in self.stac_json['links']:
      if 'rel' in link and link['rel'] == 'preview' and link['type'] == 'image/png':
        return link['href']
    raise ValueError(f"No preview image found for {id}")


  def catalog_url(self):
    links = self.stac_json['links']
    for link in links:
      if 'rel' in link and link['rel'] == 'catalog':
        return link['href']

      # Ideally there would be a 'catalog' link but sometimes there isn't.
      base_url = "https://developers.google.com/earth-engine/datasets/catalog/"
      if link['href'].startswith(base_url):
        return link['href'].split('#')[0]

    logging.warning(f"No catalog link found for {self.public_id()}")
    return ""

In [None]:
# @title class CollectionList()
class CollectionList(Sequence[Collection]):
  """List of stac.Collections; can be filtered to return a smaller sublist."""

  _collections = Sequence[Collection]

  def __init__(self, collections: Sequence[Collection]):
    self._collections = tuple(collections)

  def __iter__(self):
    return iter(self._collections)

  def __getitem__(self, index):
    return self._collections[index]

  def __len__(self):
    return len(self._collections)

  def __eq__(self, other: object) -> bool:
    if isinstance(other, CollectionList):
      return self._collections == other._collections
    return False

  def __hash__(self) -> int:
    return hash(self._collections)

  def filter_by_ids(self, ids: Iterable[str]):
    """Returns a sublist with only the collections matching the given ids."""
    return self.__class__(
        [c for c in self._collections if c.public_id() in ids]
    )

  def filter_by_datetime(
      self,
      query_datetime: datetime.datetime,
  ):
    """Returns a sublist with the time interval matching the given time."""
    result = []
    for collection in self._collections:
      for datetime_interval in collection.datetime_interval():
        if matches_datetime(datetime_interval, query_datetime):
          result.append(collection)
          break
    return self.__class__(result)

  def filter_by_interval(
      self,
      query_interval: tuple[datetime.datetime, datetime.datetime],
  ):
    """Returns a sublist with the time interval matching the given interval."""
    result = []
    for collection in self._collections:
      for datetime_interval in collection.datetime_interval():
        if matches_interval(datetime_interval, query_interval):
          result.append(collection)
          break
    return self.__class__(result)

  def filter_by_bounding_box_list(
      self, query_bbox: BBox):
    """Returns a sublist with the bbox matching the given bbox."""
    result = []
    for collection in self._collections:
      for collection_bbox in collection.bbox_list():
        if collection_bbox.intersects(query_bbox):
          result.append(collection)
          break
    return self.__class__(result)

  def filter_by_bounding_box(
      self, query_bbox: BBox):
    """Returns a sublist with the bbox matching the given bbox."""
    result = []
    for collection in self._collections:
      for collection_bbox in collection.bbox_list():
        if collection_bbox.intersects(query_bbox):
          result.append(collection)
          break
    return self.__class__(result)


  def start_str(self) -> datetime.datetime:
      return self.start().strftime("%Y-%m-%d")


  def sort_by_spatial_resolution(self, reverse=False):
        """
        Sorts the collections based on their spatial resolution.
        Collections with spatial_resolution_m() == -1 are pushed to the end.

        Args:
            reverse: If True, sort in descending order (highest resolution first).
                            If False (default), sort in ascending order (lowest resolution first).

        Returns:
            A new CollectionList instance with sorted collections.
        """
        def sort_key(collection):
            resolution = collection.spatial_resolution_m()
            if resolution == -1:
                return float('inf') if not reverse else float('-inf')
            return resolution

        sorted_collections = sorted(
            self._collections,
            key=sort_key,
            reverse=reverse
        )
        return self.__class__(sorted_collections)


  def limit(self, n: int):
    """
    Returns a new CollectionList containing the first n entries.

    Args:
        n: The number of entries to include in the new list.

    Returns:
        A new CollectionList instance with at most n collections.
    """
    return self.__class__(self._collections[:n])


  def to_df(self):
    """Converts a collection list to a dataframe with a select set of fields."""

    rows = []
    for col in self._collections:
      # Remove text in parens in dataset name.
      short_title = re.sub(r'\([^)]*\)', '', col.get('title')).strip()

      row = {
          'id': col.public_id(),
          'name': short_title,
          'temp_res': col.temporal_resolution_str(),
          'spatial_res_m': col.spatial_resolution_m(),
          'earliest': col.start_str(),
          'latest': col.end_str(),
          'url': col.catalog_url()
      }
      rows.append(row)
    return pd.DataFrame(rows)

In [None]:
#@title class Catalog()
class Catalog:
  """Class containing all collections in the EE STAC catalog."""

  collections: CollectionList

  def __init__(self, storage_client: storage.Client):
    self.collections = CollectionList(self._load_collections(storage_client))

  def get_collection(self, id: str) -> Collection:
    """Returns the collection with the given id."""
    col = self.collections.filter_by_ids([id])
    if len(col) == 0:
      raise ValueError(f'No collection with id {id}')
    return col[0]


  @tenacity.retry(
    stop=tenacity.stop_after_attempt(5),
    wait=tenacity.wait_exponential(multiplier=1, min=4, max=10),
    retry=tenacity.retry_if_exception_type((
        google_exceptions.GoogleAPICallError,
        google_exceptions.RetryError,
        ConnectionError
    )),
    before_sleep=lambda retry_state: print(
        f"Error occurred: {str(retry_state.outcome.exception())}\n"
        f"Retrying in {retry_state.next_action.sleep} seconds... "
        f"(Attempt {retry_state.attempt_number}/3)"
    )
  )
  def _read_file(self, file_blob: storage.blob.Blob) -> Collection:
    """Reads the contents of a file from the specified bucket."""
    file_contents = file_blob.download_as_string().decode()
    return Collection(json.loads(file_contents))

  def _read_files(
      self, file_blobs: list[storage.blob.Blob]
  ) -> list[Collection]:
    """Processes files in parallel."""
    collections = []
    with futures.ThreadPoolExecutor(max_workers=10) as executor:
      file_futures = [
          executor.submit(self._read_file, file_blob)
          for file_blob in file_blobs
      ]
      for future in file_futures:
        collections.append(future.result())
    return collections

  def _load_collections(
      self, storage_client: storage.Client
  ) -> Sequence[Collection]:
    """Loads all EE STAC JSON files from GCS, with datetimes as objects."""
    bucket = storage_client.get_bucket('earthengine-stac')
    files = [
        x
        for x in bucket.list_blobs(prefix='catalog/')
        if x.name.endswith('.json')
        and not x.name.endswith('/catalog.json')
        and not x.name.endswith('/units.json')
    ]
    logging.warning('Found %d files, loading...', len(files))
    collections = self._read_files(files)

    code_samples_dict = self._load_all_code_samples(storage_client)

    res = []
    for c in collections:
      if c.is_deprecated():
        continue
      c.stac_json['code'] = code_samples_dict.get(c.sanitize_id())
      res.append(c)
    logging.warning(
        'Loaded %d collections (skipping deprecated ones)', len(res)
    )
    # Returning a tuple for immutability.
    return tuple(res)

  def _load_all_code_samples(self, storage_client: storage.Client):
    """Loads js + py example scripts from GCS into dict keyed by dataset ID."""

    # Get json file from GCS bucket
    # 'gs://earthengine-catalog/catalog/example_scripts.json'
    bucket = storage_client.get_bucket('earthengine-catalog')
    blob= bucket.blob('catalog/example_scripts.json')
    file_contents = blob.download_as_string().decode()
    data = json.loads(file_contents)

    # Flatten json to get a map from ID (using '_' rather than '/') to code
    # sample.
    all_datasets_by_provider = data[0]['contents']
    code_samples_dict = {}
    for provider in all_datasets_by_provider:
      for dataset in provider['contents']:
        js_code = dataset['code']

        code_samples_dict[dataset['name']] = {
            'js_code': js_code}

    return code_samples_dict

## Test catalog/collection functions


In [None]:
catalog = Catalog(storage_client)



In [None]:
col_list = catalog.collections.filter_by_ids(['CGIAR/SRTM90_V4', 'CIESIN/GPWv411/GPW_Land_Area'])
col_list
df = col_list.to_df()
HTML(df.to_html(render_links=True, escape=False))

Unnamed: 0,id,name,temp_res,spatial_res_m,earliest,latest,url
0,CGIAR/SRTM90_V4,SRTM Digital Elevation Data Version 4,,90.0,2000-02-11,2000-02-22,https://developers.google.com/earth-engine/datasets/catalog/CGIAR_SRTM90_V4
1,CIESIN/GPWv411/GPW_Land_Area,GPWv411: Land Area,,927.67,2000-01-01,2020-01-01,https://developers.google.com/earth-engine/datasets/catalog/CIESIN_GPWv411_GPW_Land_Area


# Generate Dataset summaries and embeddings

In [None]:
# @title Source code for dataset summarization and embedding modules

@retry(stop=stop_after_attempt(5), wait=wait_fixed(1))
def summarize_text(text: str, llm: BaseLanguageModel) -> str:
    """Summarize a given text using a language model.

    This function splits the input text into chunks, then uses a map-reduce
    summarization chain to generate a summary.

    Args:
        text: The text to be summarized.
        llm: The language model to use for summarization.

    Returns:
        str: The summarized text.

    Raises:
        Exception: If summarization fails after 5 attempts.
    """
    # Remove newlines in description
    text = re.sub('\n\s*', ' ', text)

    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=200,
        length_function=len,
    )

    docs = text_splitter.create_documents([text])
    chain = load_summarize_chain(llm, chain_type="map_reduce")
    return chain.run(docs)


def summarize_collection(collection: 'Collection', llm: BaseLanguageModel) -> Dict[str, str]:
    """Summarize the dataset description and band information for a data collection.

    Args:
        collection: The collection object containing dataset information.
        llm: The language model to use for summarization.

    Returns:
        A dictionary containing the collection's ID, name, and summarized description.
    """
    summarized_description = summarize_text(collection.get('description'), llm)

    # Adding text about individual bands improves search performance.
    band_descriptions = ""
    for band in collection.bands():
        band_descriptions += f'"{band["name"]}" represents {band["description"]}\n'
        if 'gee:classes' in band:
            band_descriptions += "    Classes:\n"
            for band_class in band['gee:classes']:
                band_descriptions += f'    {band_class["description"]}\n'

    summarized_description = summarized_description + "\n\n" + band_descriptions

    return {
        'id': collection.public_id(),
        'name': collection.get('title'),
        'summary': summarized_description
    }


def summarize_ee_catalog(catalog: 'Catalog', llm: BaseLanguageModel, output_path: Optional[str] = None) -> pd.DataFrame:
    """Generate summaries of all dataset descriptions in an Earth Engine data catalog.

    This function processes all collections in the catalog concurrently,
    summarizing each collection's description and band information.

    Args:
        catalog: The Earth Engine data catalog to summarize.
        llm: The language model to use for summarization.
        output_path: If provided, the path to save the output DataFrame as a JSON file.

    Returns:
        A DataFrame containing the summarized information for all collections.

    Note:
        This function uses a ThreadPoolExecutor with a maximum of 8 workers to handle
        potential throttling issues with the language model API.
    """
    summarize_collection_partial = partial(summarize_collection, llm=llm)

    with ThreadPoolExecutor(max_workers=8) as executor:
        results = list(tqdm.tqdm(
            executor.map(summarize_collection_partial, catalog.collections),
            total=len(catalog.collections)))
    return results
    df = pd.DataFrame(results)

    if output_path:
        with open(output_path, 'w') as f:
            f.write(df.to_json(orient='records', lines=True))
    return df



def get_embeddings_wrapper(texts: List[str], model: TextEmbeddingModel):
  # VertexAI allows you to send batches of 5 embeddings requests at once.
  BATCH_SIZE = 5
  embs = []
  for i in tqdm.tqdm(range(0, len(texts), BATCH_SIZE)):
      time.sleep(1)  # to avoid the quota error
      result = model.get_embeddings(texts[i : i + BATCH_SIZE])
      embs = embs + [e.values for e in result]
  return embs


def add_embeddings_to_df(
    df: pd.DataFrame, col_to_embed: str,  model: TextEmbeddingModel) -> pd.DataFrame:
    get_embeddings_partial = partial(get_embeddings_wrapper, model=model)
    df = df.assign(embedding=get_embeddings_partial(list(df[col_to_embed])))
    return df





In [None]:
import google
from google.cloud import storage

# @title Initialize Language and Text embedding models plus output destinations


# We use Gemini 1.5 pro to summarize the original dataset descriptions
gemini_llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro", google_api_key=userdata.get('GOOGLE_API_KEY'))

# We use a VertexAI model for embedding the dataset summaries to eventually be
# loaded into a Vectorstore.
embedding_model = TextEmbeddingModel.from_pretrained("google/text-embedding-004")

# We write the output to disk to reduce the risk of needing to rerun.
CATALOG_SUMMARIES_PATH = 'ee_catalog_summaries.jsonl'
EMBEDDINGS_LOCAL_PATH = 'ee_catalog_embeddings.jsonl'

# Eventually we upload embeddings and summaries to GCS.
GCP_PROJECT = userdata.get('GOOGLE_PROJECT_ID')
DESTINATION_BUCKET = userdata.get('DESTINATION_BUCKET')
EMBEDDINGS_GCS_PATH = 'ee_catalog_embeddings.jsonl'

In [None]:
# @title Load the entire EE Pubic data catalog from GCS:
catalog = Catalog(storage_client)



In [None]:
#@title Use an LLM to generate per-collection dataset summaries.
# This tends to take around 10-15 minutes.

summary_json_list = summarize_ee_catalog(catalog, gemini_llm)

# Write to a file so we minimize the need to repeat this time consuming step.
with open(CATALOG_SUMMARIES_PATH, 'w') as f:
  for entry in summary_json_list:
    json.dump(entry, f)
    f.write('\n')

  return chain.run(docs)
100%|██████████| 812/812 [09:39<00:00,  1.40it/s]


In [None]:
# @title View summary results
catalog_summary_df = pd.read_json(CATALOG_SUMMARIES_PATH, lines=True)
catalog_summary_df.head()

Unnamed: 0,id,name,summary
0,AAFC/ACI,Canada AAFC Annual Crop Inventory,Agriculture and Agri-Food Canada annually maps...
1,ACA/reef_habitat/v2_0,Allen Coral Atlas (ACA) - Geomorphic Zonation ...,"The Allen Coral Atlas, a global, high-resoluti..."
2,AHN/AHN2_05M_INT,"AHN Netherlands 0.5m DEM, Interpolated",The AHN DEM is a high-resolution (0.5m) digita...
3,AHN/AHN2_05M_NON,"AHN Netherlands 0.5m DEM, Non-Interpolated",The AHN DEM is a detailed (0.5m resolution) el...
4,AHN/AHN2_05M_RUW,"AHN Netherlands 0.5m DEM, Raw Samples","The AHN DEM, created from 2007-2012 LIDAR data..."


In [None]:
# @title Calculate embeddings for each dataset summary
# This takes around 3-5 minutes due to the text embedding model's rate limits.

embedding_df = add_embeddings_to_df(catalog_summary_df, 'summary', embedding_model)

# First store locally, just in case something happens to the Colab runtime.
with open(EMBEDDINGS_LOCAL_PATH, 'w') as f:
  f.write(embedding_df.to_json(orient='records', lines=True))

# Make sure we can read the embeddings that were written to file.
embeddings_df = pd.read_json(EMBEDDINGS_LOCAL_PATH, lines=True)
embedding_df.head()

100%|██████████| 163/163 [03:21<00:00,  1.23s/it]


Unnamed: 0,id,name,summary,embedding
0,AAFC/ACI,Canada AAFC Annual Crop Inventory,Agriculture and Agri-Food Canada annually maps...,"[-0.03112766332924366, 0.022871049121022224, -..."
1,ACA/reef_habitat/v2_0,Allen Coral Atlas (ACA) - Geomorphic Zonation ...,"The Allen Coral Atlas, a global, high-resoluti...","[0.006329342722892761, 0.056551311165094376, -..."
2,AHN/AHN2_05M_INT,"AHN Netherlands 0.5m DEM, Interpolated",The AHN DEM is a high-resolution (0.5m) digita...,"[0.0030822136905044317, -0.06489657610654831, ..."
3,AHN/AHN2_05M_NON,"AHN Netherlands 0.5m DEM, Non-Interpolated",The AHN DEM is a detailed (0.5m resolution) el...,"[-0.014630626887083054, -0.07648028433322906, ..."
4,AHN/AHN2_05M_RUW,"AHN Netherlands 0.5m DEM, Raw Samples","The AHN DEM, created from 2007-2012 LIDAR data...","[-0.008305594325065613, -0.07478459924459457, ..."


In [None]:
#@title Upload embeddings and summaries to GCS
storage_client = storage.Client(project=GCP_PROJECT)

bucket = google.cloud.storage.bucket.Bucket(
    storage_client, name=DESTINATION_BUCKET, user_project=GCP_PROJECT)
blob = bucket.blob(EMBEDDINGS_GCS_PATH)
blob.upload_from_filename(EMBEDDINGS_LOCAL_PATH)

In [None]:
# @title Make sure we can load the new file from GCS

EMBEDDINGS_CLOUD_PATH = 'gs://science-ai-ee-catalog-index/catalog_embeddings.jsonl'
EMBEDDINGS_LOCAL_PATH = 'catalog_embeddings.jsonl'


parts = EMBEDDINGS_CLOUD_PATH.split('/')
bucket_name = parts[2]
blob_path = '/'.join(parts[3:])
bucket = storage_client.get_bucket(bucket_name)
blob = bucket.blob(blob_path)
blob.download_to_filename(EMBEDDINGS_LOCAL_PATH)

embeddings_df = pd.read_json(EMBEDDINGS_LOCAL_PATH, lines=True)
embeddings_df

Unnamed: 0,id,name,summary,embedding
0,AAFC/ACI,Canada AAFC Annual Crop Inventory,Agriculture and Agri-Food Canada annually maps...,"[-0.0297800228, 0.017559804000000002, -0.02272..."
1,ACA/reef_habitat/v2_0,Allen Coral Atlas (ACA) - Geomorphic Zonation ...,"The Allen Coral Atlas is a global, high-resolu...","[-0.00014006280000000002, 0.0595749207, -0.044..."
2,AHN/AHN2_05M_INT,"AHN Netherlands 0.5m DEM, Interpolated",The AHN DEM is a high-resolution (0.5m) model ...,"[0.0002864186, -0.0866151974, -0.0573001616, 0..."
3,AHN/AHN2_05M_NON,"AHN Netherlands 0.5m DEM, Non-Interpolated","The AHN DEM, created from 2007-2012 LIDAR data...","[-0.015314440200000001, -0.07613136620000001, ..."
4,AHN/AHN2_05M_RUW,"AHN Netherlands 0.5m DEM, Raw Samples","The AHN DEM, a high-resolution (0.5m) elevatio...","[-0.0136169484, -0.1123650596, -0.0680219904, ..."
...,...,...,...,...
819,projects/planet-nicfi/assets/basemaps/americas,NICFI Satellite Data Program Basemaps for Trop...,**Concise Summary:**\n\nThe Norway's Internati...,"[-0.0014470087, 0.028179975200000002, -0.02454..."
820,projects/planet-nicfi/assets/basemaps/asia,NICFI Satellite Data Program Basemaps for Trop...,This collection provides high-resolution satel...,"[0.0229713377, 0.0144163994, 0.019438674700000..."
821,projects/sat-io/open-datasets/GLOBathy/GLOBath...,GLOBathy Global lakes bathymetry dataset,GLOBathy offers detailed depth maps of over 1....,"[0.009630191100000001, 0.049831219, -0.0843064..."
822,projects/sat-io/open-datasets/ORNL/LANDSCAN_GL...,LandScan Population Data Global 1km,"LandScan, a high-resolution global population ...","[0.0689298883, 0.0050361394, -0.0488965176, -0..."
