In [6]:
import leafmap.maplibregl as leafmap
import numpy as np
from matplotlib import cm
import pandas as pd
import ibis
from ibis import _

import duckdb
from langchain_anthropic import ChatAnthropic
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import ConfigurableField
from langchain_core.tools import tool
from langchain.agents import create_tool_calling_agent, AgentExecutor

In [2]:
#%pip install langchain_anthropic langchain_ollama langchain_openai

In [7]:
h3_parquet = "/home/rstudio/huggingface/spaces/boettiger-lab/gbif/gbif_ca.geoparquet"
con = ibis.duckdb.connect("duck.db")
gbif = con.read_parquet(h3_parquet, "gbif")

In [4]:
x = gbif.sql("SELECT * FROM gbif LIMIT 1;")

In [5]:
#con.list_tables()
#con.table("gbif").schema()

In [3]:
@tool
def sql_query(sql):
    """Execute the SQL query 'sql' on the current database"""

    return gbif.sql(sql).limit(100).execute()

In [4]:
def get_h3point_df(_df, zoom: float) -> pd.DataFrame:
    column = "h" + str(zoom)
    df = (_df
            .rename(hex = column)
            .group_by(_.hex)
            .agg(n = _.count())
            .mutate(v = _.n.log())
            .mutate(normalized_values = _.v / _.v.max())
       #     .mutate(wkt =  h3_cell_to_boundary_wkt(_.hex))
            .to_pandas()
            )
    rgb = cm.viridis(df.normalized_values) 
    rgb_array = np.round( rgb * 255 ).astype(int).clip(0,255).tolist()
    df['rgb'] = rgb_array
#    df['viridis_hex'] = rgb_to_hex(rgb) 
    return df


import pydeck as pdk
def hex_layer(m, df: pd.DataFrame, v_scale = 1):
    layer = pdk.Layer(
        "H3HexagonLayer",
        df,
        get_hexagon="hex",
        get_fill_color="rgb",
        extruded=True,
        get_elevation="normalized_values",
        elevation_scale= v_scale,
        elevation_range = [0,1]
    )
 
    return m.add_deck_layers([layer])

@tool
def maplibre_plot(sql, zoom = 4, vertical_exaggeration=0):
    """plot the result of a sql_query() at the specified zoom"""
    df = gbif.sql(sql)
    df = get_h3point_df(df, zoom = zoom)
    return hex_layer(m, df, vertical_exaggeration)
    

In [8]:
#m = leafmap.Map(style="positron",  center=(-121.4, 37.50), zoom=5,)
#df = gbif.sql("SELECT * FROM gbif;")
#df = get_h3point_df(df, zoom = 6)
#hex_layer(m, df, 1)
#m

In [5]:
llm = ChatOllama(
    model="llama3-groq-tool-use:70b",
    temperature=0,
)

schema = con.table("gbif").schema()
prompt = ChatPromptTemplate.from_messages([
    ("system", '''
    you're a helpful assistant. You are an expert in duckdb SQL. 
    Your job is to construct a valid duckdb SQL to filter the table in the way the user describes by calling the sql_query() tool I provide. 
    The sql_query() tool will return a pandas data frame, 'df'.  If that data.frame includes a column called "hex", then I want you to call the 
    tool maplibre_plot(query, zoom, vertical_exaggeration) on the data frame returned by the sql_query().  
    If the user specifies a 'zoom' in their query, please use that value for 'zoom' in calling maplibre_plot, otherwise, omit the zoom argument to use the default.  

    You can call the sql_query() tool to help you explore the data better before you try and answer the question.  For instance, you may want to use the
    tool to call sql_query("select * from gbif limit 1;") to examine the table schema before formulating your query.
    
    If not specified, the table name is 'gbif'. Include the table name in the SQL query, it is not a parameter.

    The schema of the table is {schema}
    
    '''
    ), 
    ("human", "{input}"), 
    ("placeholder", "{agent_scratchpad}"),
])

tools = [sql_query, maplibre_plot]
agent = create_tool_calling_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

In [19]:
m = leafmap.Map(style="positron",  center=(-121.4, 37.50), zoom=5,)
out = agent_executor.invoke({"input": "filter the gbif table to show only bird species. visualize results at zoom 10 and set the vertical exaggeration to 5000.", "schema": schema})
m



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `sql_query` with `{'sql': "SELECT * FROM gbif WHERE class = 'Aves'"}`


[0m[36;1m[1;3mr0 := DatabaseTable: gbif
  gbifid                           string
  datasetkey                       string
  occurrenceid                     string
  kingdom                          string
  phylum                           string
  class                            string
  order                            string
  family                           string
  genus                            string
  species                          string
  infraspecificepithet             string
  taxonrank                        string
  scientificname                   string
  verbatimscientificname           string
  verbatimscientificnameauthorship string
  countrycode                      string
  locality                         string
  stateprovince                    string
  occurrencestatus                 string
  individualcou

Map(height='600px', map_options={'bearing': 0, 'center': (-121.4, 37.5), 'pitch': 0, 'style': 'https://basemap…