<a href="https://colab.research.google.com/github/lestermartin/starburst-dataframes-exploration/blob/main/StarburstPythonOptions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Python options when using Trino & Starburst

This notebook is focused on showing the primary options available to Python programmers when using Starburst.

1.   [Python UDFs](https://trino.io/docs/current/udf/python.html)
2.   [Python client](https://github.com/trinodb/trino-python-client)
3.   [PyStarburst](https://docs.starburst.io/clients/python/pystarburst.html) (ONLY available with Starburst)
4.   [Ibis](https://ibis-project.org/)



## Trino Python user-defined functions

TODO - ramble about how this is avail with Trino, but for Starburst it is only on SEP and even there, as a public preview (not on Galaxy)

DON'T TRY TO RUN HERE... GIVE EXAMPLES OF SQL UDF AND PORTS TO PYTHON UDF (and tell them to run in Starburst UI (or Trino CLI))

In [None]:
import getpass

# grab credentials from the notebook user to be used when making a connection
my_host = input("Host name")
my_username = input("User name")
my_password = getpass.getpass("Password")

## Trino Python client

Python client for Trino as described in its [Github repo](https://github.com/trinodb/trino-python-client).

In [None]:
# install Trino Python client

%pip install trino

In [None]:
# boiler-plate code for setup

from trino.dbapi import connect
from trino.auth import BasicAuthentication

# sanity check
print('\n Make sure the phrase ** CONNECTION IS GOOD ** displays \n')


# build the connection object with the hostname & creds inputed earlier
conn = connect(
    host=my_host,
    port="443",
    user=my_username,
    auth=BasicAuthentication(my_username, my_password),
    http_scheme="https",
    catalog="system",
    schema="runtime",
)
cur = conn.cursor()
cur.execute("SELECT '** CONNECTION IS GOOD **'")
rows = cur.fetchall()
print(rows)

### Select a full table

First example showing a simple SQL statement to fetch the entire contents of a single table.

NOTE: Using Pandas to make the output look prettier in the notebook.

In [None]:
import pandas as pd

cur = conn.cursor()
cur.execute("SELECT * FROM tpch.tiny.nation")
rows = cur.fetchall()

col_name = [desc[0] for desc in cur.description]
pandasDF = pd.DataFrame(rows, columns=col_name)

pandasDF

### Running multiple SQL statements (not great)

This examples shows running two different SQL statements and holding onto both of them as a local variable (using Pandas again to help with additional functionality).

Then, multiple [Pandas API](https://pandas.pydata.org/docs/reference/index.html) calls are executed to join and sort the data to be reviewed.

Using a local DataFrame API like this forces the client to run multiple queries and to retrieve the results (no matter how large) into the local machine's memory.

The PyStarburst and Ibis solutions would allow this all to be lazy executed and ultimately only running on SQL statement into the Trino engine.

In [None]:
cur.execute("""
         SELECT name AS c_name, acctbal, nationkey AS c_nationkey
           FROM tpch.tiny.customer
          WHERE acctbal > 9900.0
""")
rows = cur.fetchall()
col_name = [desc[0] for desc in cur.description]
custPandasDF = pd.DataFrame(rows, columns=col_name)

cur.execute("""
         SELECT name AS n_name, nationkey AS n_nationkey
           FROM tpch.tiny.nation
""")
rows = cur.fetchall()
col_name = [desc[0] for desc in cur.description]
nationPandasDF = pd.DataFrame(rows, columns=col_name)


joinedPDF = custPandasDF.merge(nationPandasDF, left_on='c_nationkey', right_on='n_nationkey')

sortedPDF = joinedPDF.sort_values(by='acctbal', ascending=False)

cleanedUpPDF = sortedPDF.drop(columns=['c_nationkey', 'n_nationkey'])

cleanedUpPDF

### Running a single SQL (much more efficient)

This examples shows NOT using a local DataFrame API (such as Pandas) and forces the user to be most efficient by sticking with straight SQL. This will guarantee only one statement runs.

In [None]:
cur.execute("""
         SELECT c.name, c.acctbal, n.name
           FROM tpch.tiny.customer c
           JOIN tpch.tiny.nation n
             ON c.nationkey = n.nationkey
          WHERE c.acctbal > 9900.0
          ORDER BY c.acctbal DESC
""")
rows = cur.fetchall()
col_name = [desc[0] for desc in cur.description]
pandasDF = pd.DataFrame(rows, columns=col_name)

pandasDF

## PyStarburst

Python DataFrame API modeled after PySpark as described in the [PyStarburst documentation]().

NOTE: Only works with Starburst; not supported on open-source Trino.

In [None]:
# install PyStarburst

%pip install pystarburst

In [None]:
# boiler-plate code for setup

import trino

from pystarburst import Session
from pystarburst import functions as F
from pystarburst.functions import *
from pystarburst.window import Window as W

# PyStarburst setup
session_properties = {
    "host":my_host,
    "port": 443,
    "http_scheme": "https",
    "auth": trino.auth.BasicAuthentication(my_username, my_password)
}
session = Session.builder.configs(session_properties).create()

# validate PyStarburst working
print('\n Make sure the phrase ** CONNECTION IS GOOD ** displays \n')
session.sql("select '** CONNECTION IS GOOD **' as conn_check").collect()

### Walk before running

bla bla bla

Note: This code was originally published in  [pystarburst (the dataframe api)](https://lestermartin.blog/2023/09/12/pystarburst-the-dataframe-api/).

#### Select a full table

bla bla bla

In [None]:
custDF = session.table("tpch.tiny.customer")
custDF.show()

#### Use projection

bla bla bla

In [None]:
projectedDF = custDF.select(custDF.name, custDF.acctbal, custDF.nationkey)
projectedDF.show()

#### Filter the rows

bla bla bla

In [None]:
filteredDF = projectedDF.filter(projectedDF.acctbal > 9900.0)
filteredDF.show(100)

#### Select a second table

skdjf skdfj sdkfj

In [None]:
nationDF = session.table("tpch.tiny.nation") \
                  .drop("regionkey", "comment") \
                  .rename("name", "nation_name") \
                  .rename("nationkey", "n_nationkey")
nationDF.show()

#### Join the tables

sdkjf ksdjf ksdjf

In [None]:
joinedDF = filteredDF.join(nationDF, filteredDF.nationkey == nationDF.n_nationkey)
joinedDF.show()

#### Project the joined result

sdkfj skdjf ksdjf

In [None]:
projectedJoinDF = joinedDF.drop("nationkey").drop("n_nationkey")
projectedJoinDF.show()

#### Apply a sort

skdjf ksdjf ksdjf

In [None]:
orderedDF = projectedJoinDF.sort(col("acctbal"), ascending=False)
orderedDF.show()

#### Put it all together

sdkjf skdfj ksdjf

In [None]:
nationDF = session.table("tpch.tiny.nation") \
            .drop("regionkey", "comment") \
            .rename("name", "nation_name") \
            .rename("nationkey", "n_nationkey")

apiSQL = session.table("tpch.tiny.customer") \
            .select("name", "acctbal", "nationkey") \
            .filter(col("acctbal") > 9900.0) \
            .join(nationDF, col("nationkey") == nationDF.n_nationkey) \
            .drop("nationkey").drop("n_nationkey") \
            .sort(col("acctbal"), ascending=False)
apiSQL.show()

#### Or... just run some SQL

skfj sdkfj skdfj

In [None]:
sqlDF = session.sql("SELECT c.name, c.acctbal, n.name "\
                    "  FROM tpch.tiny.customer c "\
                    "  JOIN tpch.tiny.nation n "\
                    "    ON c.nationkey = n.nationkey "\
                    " WHERE c.acctbal > 9900.0 "\
                    " ORDER BY c.acctbal DESC ")
sqlDF.show()

### Richer examples

sdkjf ksdjf


#### Example 1 (joining 3 tables)

Description: Aggregate total customer acctbal by region name

Tables: tpch.tiny.customer, tpch.tiny.nation, tpch.tiny.region

In [None]:
nation = session.table(f"tpch.tiny.nation")
region = session.table(f"tpch.tiny.region")
customer = session.table(f"tpch.tiny.customer")

nr = nation.join(region, nation.regionkey == region.regionkey) \
           .select(nation.nationkey.alias("nationkey"),
                   region.name.alias("region_name"))

result = customer.join(nr, customer.nationkey == nr.nationkey) \
                 .groupBy("region_name") \
                 .agg(("acctbal", "sum")) \
                 .rename("sum(acctbal)", "total_acctbal") \
                 .sort(col("total_acctbal").desc())

result.show()

#### Example 2 (windowing example)

For each nation, get top N customers by acctbal

Tables: tpch.tiny.customer, tpch.tiny.nation

In [None]:
from pystarburst.window import Window

customer = session.table(f"tpch.tiny.customer")
nation = session.table(f"tpch.tiny.nation") \
            .drop("regionkey", "comment") \
            .rename("name", "nation_name") \
            .rename("nationkey", "n_nationkey")

filtered = customer.select("custkey", "name", "acctbal", "nationkey") \
            .filter(col("acctbal") > 8000.0)
joined = filtered.join(nation, col("nationkey") == nation.n_nationkey) \
            .drop("nationkey", "n_nationkey")

w = Window.partitionBy("nation_name") \
          .orderBy(col("acctbal") \
          .desc())

ranked = joined.select("*", row_number().over(w).alias("rn"))
top_x = ranked.filter(col("rn") <= 1) \
              .sort(col("acctbal") \
              .desc(), col("nation_name"))
top_x.show(25)

## Ibis

Python DataFrame API as described at the [Ibis project website](https://ibis-project.org/).

NOTE: Ibis can run against many different SQL engines, not just Trino.

In [None]:
# install Ibis

%pip install trino
%pip install 'ibis-framework[trino]'
%pip install pystarburst

In [None]:
# boiler-plate code for setup

import os
import ibis
from trino.auth import BasicAuthentication

ibis.options.interactive = True

user = my_username
trino_auth_obj = BasicAuthentication(my_username, my_password)
host = my_host
port = "443"
http_scheme = "https"
catalog = "tpch"
schema = "tiny"

con = ibis.trino.connect(
    user=user, auth=trino_auth_obj, host=host, port=port, http_scheme=http_scheme, database=catalog, schema=schema
)

print('\n Make sure the phrase ** CONNECTION IS GOOD ** displays \n')
con.sql("select '** CONNECTION IS GOOD **' as conn_check")

### Walk before running

bla bla bla

Note: This code was originally published in  [ibis & trino (dataframe api part deux)](https://lestermartin.blog/2023/10/27/ibis-trino-dataframe-api-part-deux/).  

#### Select a full table

ksdjf skdjf skdjfkds

In [None]:
custDF = con.table("customer")
custDF[0:10]

#### Use projection

skdjf skdjfk sdkfj

In [None]:
projectedDF = custDF.select("name", "acctbal", "nationkey")
projectedDF[0:10]

#### Filter the rows

skldjf skdfj skdfj

In [None]:
filteredDF = projectedDF.filter(projectedDF["acctbal"] > 9900.0)
filteredDF[0:100]

#### Select a second table

skdj skdjf skdjf

In [None]:
# Grab new table, drop 2 cols, and rename 2 others
nationDF = con.table("nation") \
            .drop("regionkey", "comment") \
            .rename(
                dict(
                    nation_name="name",
                    n_nationkey="nationkey"
                )
            )
nationDF[0:10]

#### Join the tables

dkvfjk sdkfjfd

In [None]:
joinedDF = filteredDF.join(nationDF,
    filteredDF.nationkey == nationDF.n_nationkey)
joinedDF[0:10]

#### Project the joined result

skdjf skdjf ksdjf

In [None]:
projectedJoinDF = joinedDF.drop("nationkey", "n_nationkey")
projectedJoinDF[0:10]

#### Apply a sort

skdjf skdjf skdjf

In [None]:
orderedDF = projectedJoinDF.order_by([ibis.desc("acctbal")])
orderedDF[0:10]

#### Put it all together

skdjf ksdjfksjdf

In [None]:
nationDF = con.table("nation") \
            .drop("regionkey", "comment") \
            .rename(
                dict(
                    nation_name="name",
                    n_nationkey="nationkey"
                )
            )

custDF = con.table("customer") \
            .select("name", "acctbal", "nationkey") \
            .filter(projectedDF["acctbal"] > 9900.0)

apiSQL = custDF.join(nationDF,
    custDF.nationkey == nationDF.n_nationkey) \
            .drop("nationkey", "n_nationkey") \
            .order_by([ibis.desc("acctbal")])

apiSQL[0:10]

#### Or... just run some SQL

skfj sdkfj skdfj

In [None]:
sqlDF = con.sql("""
         SELECT c.name, c.acctbal, n.name AS nation_name
           FROM tpch.tiny.customer c
           JOIN tpch.tiny.nation n
             ON c.nationkey = n.nationkey
          WHERE c.acctbal > 9900.0
          ORDER BY c.acctbal DESC
""")
sqlDF[0:10]

### Richer examples

skdjf skdjfk sdkfj

#### Example 1 (joining 3 tables)

Description: Aggregate total customer acctbal by region name

Tables: tpch.tiny.customer, tpch.tiny.nation, tpch.tiny.region

In [None]:
nation = con.table("nation", database=f"tpch.tiny")
region = con.table("region", database=f"tpch.tiny")
customer = con.table("customer", database=f"tpch.tiny")

# nation >< region, n.regionkey = r.regionkey
nr = nation.join(region, nation["regionkey"] == region["regionkey"]).select(
    nation["nationkey"], region["name"].name("region_name")
)

# customer >< (nation, region) nr, c.nationkey = nr.nationkey
agg = (
    customer.join(nr, customer["nationkey"] == nr["nationkey"])
            .group_by("region_name")
            .aggregate(total_acctbal=customer["acctbal"].sum())
            .order_by(ibis.desc("total_acctbal"))
)

agg[0:10]

#### Example 2 (windowing example)

For each nation, get top N customers by acctbal

Tables: tpch.tiny.customer, tpch.tiny.nation

In [None]:
expr = con.sql("""
        SELECT
          c.custkey,
          c.name,
          c.acctbal,
          n.name AS nation_name,
          row_number() OVER (
            PARTITION BY
              n.name
            ORDER BY
              c.acctbal DESC
          ) AS rn
        FROM
          tpch.tiny.customer c
          JOIN tpch.tiny.nation n ON c.nationkey = n.nationkey
        WHERE
          c.acctbal > 8000.0
""")
top_x = expr.filter(expr.rn == 10).order_by([ibis.desc("acctbal"), "nation_name"])

top_x[0:10]
