<a href="https://colab.research.google.com/github/lestermartin/starburst-dataframes-exploration/blob/main/StarburstPythonOptions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Python options when using Trino & Starburst

This notebook is focused on showing the primary options available to Python programmers when using Starburst.

1.   [Python UDFs](https://trino.io/docs/current/udf/python.html)
2.   [Python client](https://github.com/trinodb/trino-python-client)
3.   [PyStarburst](https://docs.starburst.io/clients/python/pystarburst.html) (ONLY available with Starburst)
4.   [Ibis](https://ibis-project.org/)



## Trino Python user-defined functions

TODO - ramble about how this is avail with Trino, but for Starburst it is only on SEP and even there, as a public preview (not on Galaxy)

DON'T TRY TO RUN HERE... GIVE EXAMPLES OF SQL UDF AND PORTS TO PYTHON UDF (and tell them to run in Starburst UI (or Trino CLI))

In [3]:
import getpass

# grab credentials from the notebook user to be used when making a connection
my_host = input("Host name")
my_username = input("User name")
my_password = getpass.getpass("Password")



Host namelester-free-cluster.trino.galaxy.starburst.io
User namelester.martin@starburstdata.com/accountadmin
Password··········


## Trino Python client

TODO - RAMLBE...


In [22]:
%pip install trino



In [3]:
from trino.dbapi import connect
from trino.auth import BasicAuthentication

# sanity check
print('\n Make sure the phrase ** CONNECTION IS GOOD ** displays \n')


# build the connection object with the hostname & creds inputed earlier
conn = connect(
    host=my_host,
    port="443",
    user=my_username,
    auth=BasicAuthentication(my_username, my_password),
    http_scheme="https",
    catalog="system",
    schema="runtime",
)
cur = conn.cursor()
cur.execute("SELECT '** CONNECTION IS GOOD **'")
rows = cur.fetchall()
print(rows)


 Make sure the phrase ** CONNECTION IS GOOD ** displays 

[['** CONNECTION IS GOOD **']]


In [4]:
import pandas as pd


cur = conn.cursor()
cur.execute("SELECT * FROM tpch.tiny.nation")
rows = cur.fetchall()
col_name = [desc[0] for desc in cur.description]
pdf = pd.DataFrame(rows, columns=col_name)

pdf




Unnamed: 0,nationkey,name,regionkey,comment
0,0,ALGERIA,0,haggle. carefully final deposits detect slyly...
1,1,ARGENTINA,1,al foxes promise slyly according to the regula...
2,2,BRAZIL,1,y alongside of the pending deposits. carefully...
3,3,CANADA,1,"eas hang ironic, silent packages. slyly regula..."
4,4,EGYPT,4,y above the carefully unusual theodolites. fin...
5,5,ETHIOPIA,0,ven packages wake quickly. regu
6,6,FRANCE,3,"refully final requests. regular, ironi"
7,7,GERMANY,3,"l platelets. regular accounts x-ray: unusual, ..."
8,8,INDIA,2,ss excuses cajole slyly across the packages. d...
9,9,INDONESIA,2,slyly express asymptotes. regular deposits ha...


## PyStarburst

These next two cells are to install PyStarburst and to execute some
boiler-plate code for setup.

In [5]:
pip install pystarburst



In [28]:
import trino

from pystarburst import Session
from pystarburst import functions as F
from pystarburst.functions import *
from pystarburst.window import Window as W

# PyStarburst setup
session_properties = {
    "host":my_host,
    "port": 443,
    "http_scheme": "https",
    "auth": trino.auth.BasicAuthentication(my_username, my_password)
}
session = Session.builder.configs(session_properties).create()

# validate PyStarburst working
print('\n Make sure the phrase ** CONNECTION IS GOOD ** displays \n')
session.sql("select '** CONNECTION IS GOOD **' as conn_check").collect()


 Make sure the phrase ** CONNECTION IS GOOD ** displays 



[Row(conn_check='** CONNECTION IS GOOD **')]

### Walk before running

bla bla bla

Note: This code was originally published in  [pystarburst (the dataframe api)](https://lestermartin.blog/2023/09/12/pystarburst-the-dataframe-api/).

#### Select a full table

bla bla bla

In [7]:
custDF = session.table("tpch.tiny.customer")
custDF.show()

-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|"custkey"  |"name"              |"address"                             |"nationkey"  |"phone"          |"acctbal"  |"mktsegment"  |"comment"                                           |
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|751        |Customer#000000751  |e OSrreG6sx7l1t3wAg8u11DWk D 9        |0            |10-658-550-2257  |2130.98    |FURNITURE     |ges sleep furiously bold deposits. furiously re...  |
|752        |Customer#000000752  |KtdEacPUecPdPLt99kwZrnH9oIxUxpw       |8            |18-924-993-6038  |8363.66    |MACHINERY     |mong the ironic, final waters. regular deposits...  |
|753        |Customer#000000753  |9k2PLlDRbMq4oSvW5Hh7Ak5iRDH         

#### Use projection

bla bla bla

In [8]:
projectedDF = custDF.select(custDF.name, custDF.acctbal, custDF.nationkey)
projectedDF.show()

------------------------------------------------
|"name"              |"acctbal"  |"nationkey"  |
------------------------------------------------
|Customer#000000751  |2130.98    |0            |
|Customer#000000752  |8363.66    |8            |
|Customer#000000753  |8114.44    |17           |
|Customer#000000754  |-566.86    |0            |
|Customer#000000755  |7631.94    |16           |
|Customer#000000756  |8116.99    |14           |
|Customer#000000757  |9334.82    |3            |
|Customer#000000758  |6352.14    |17           |
|Customer#000000759  |3477.59    |1            |
|Customer#000000760  |2883.24    |2            |
------------------------------------------------



#### Filter the rows

bla bla bla

In [9]:
filteredDF = projectedDF.filter(projectedDF.acctbal > 9900.0)
filteredDF.show(100)

------------------------------------------------
|"name"              |"acctbal"  |"nationkey"  |
------------------------------------------------
|Customer#000000043  |9904.28    |19           |
|Customer#000000045  |9983.38    |9            |
|Customer#000000140  |9963.15    |4            |
|Customer#000000200  |9967.6     |16           |
|Customer#000000213  |9987.71    |24           |
|Customer#000000381  |9931.71    |5            |
|Customer#000001106  |9977.62    |21           |
------------------------------------------------



#### Select a second table

skdjf skdfj sdkfj

In [10]:
nationDF = session.table("tpch.tiny.nation") \
                  .drop("regionkey", "comment") \
                  .rename("name", "nation_name") \
                  .rename("nationkey", "n_nationkey")
nationDF.show()

---------------------------------
|"n_nationkey"  |"nation_name"  |
---------------------------------
|0              |ALGERIA        |
|1              |ARGENTINA      |
|2              |BRAZIL         |
|3              |CANADA         |
|4              |EGYPT          |
|5              |ETHIOPIA       |
|6              |FRANCE         |
|7              |GERMANY        |
|8              |INDIA          |
|9              |INDONESIA      |
---------------------------------



#### Join the tables

sdkjf ksdjf ksdjf

In [11]:
joinedDF = filteredDF.join(nationDF, filteredDF.nationkey == nationDF.n_nationkey)
joinedDF.show()

--------------------------------------------------------------------------------
|"name"              |"acctbal"  |"nationkey"  |"n_nationkey"  |"nation_name"  |
--------------------------------------------------------------------------------
|Customer#000000140  |9963.15    |4            |4              |EGYPT          |
|Customer#000000381  |9931.71    |5            |5              |ETHIOPIA       |
|Customer#000000045  |9983.38    |9            |9              |INDONESIA      |
|Customer#000000200  |9967.6     |16           |16             |MOZAMBIQUE     |
|Customer#000000043  |9904.28    |19           |19             |ROMANIA        |
|Customer#000001106  |9977.62    |21           |21             |VIETNAM        |
|Customer#000000213  |9987.71    |24           |24             |UNITED STATES  |
--------------------------------------------------------------------------------



#### Project the joined result

sdkfj skdjf ksdjf

In [12]:
projectedJoinDF = joinedDF.drop("nationkey").drop("n_nationkey")
projectedJoinDF.show()

--------------------------------------------------
|"name"              |"acctbal"  |"nation_name"  |
--------------------------------------------------
|Customer#000000140  |9963.15    |EGYPT          |
|Customer#000000381  |9931.71    |ETHIOPIA       |
|Customer#000000045  |9983.38    |INDONESIA      |
|Customer#000000200  |9967.6     |MOZAMBIQUE     |
|Customer#000000043  |9904.28    |ROMANIA        |
|Customer#000001106  |9977.62    |VIETNAM        |
|Customer#000000213  |9987.71    |UNITED STATES  |
--------------------------------------------------



#### Apply a sort

skdjf ksdjf ksdjf

In [13]:
orderedDF = projectedJoinDF.sort(col("acctbal"), ascending=False)
orderedDF.show()

--------------------------------------------------
|"name"              |"acctbal"  |"nation_name"  |
--------------------------------------------------
|Customer#000000213  |9987.71    |UNITED STATES  |
|Customer#000000045  |9983.38    |INDONESIA      |
|Customer#000001106  |9977.62    |VIETNAM        |
|Customer#000000200  |9967.6     |MOZAMBIQUE     |
|Customer#000000140  |9963.15    |EGYPT          |
|Customer#000000381  |9931.71    |ETHIOPIA       |
|Customer#000000043  |9904.28    |ROMANIA        |
--------------------------------------------------



#### Put it all together

sdkjf skdfj ksdjf

In [15]:
nationDF = session.table("tpch.tiny.nation") \
            .drop("regionkey", "comment") \
            .rename("name", "nation_name") \
            .rename("nationkey", "n_nationkey")

apiSQL = session.table("tpch.tiny.customer") \
            .select("name", "acctbal", "nationkey") \
            .filter(col("acctbal") > 9900.0) \
            .join(nationDF, col("nationkey") == nationDF.n_nationkey) \
            .drop("nationkey").drop("n_nationkey") \
            .sort(col("acctbal"), ascending=False)
apiSQL.show()

--------------------------------------------------
|"name"              |"acctbal"  |"nation_name"  |
--------------------------------------------------
|Customer#000000213  |9987.71    |UNITED STATES  |
|Customer#000000045  |9983.38    |INDONESIA      |
|Customer#000001106  |9977.62    |VIETNAM        |
|Customer#000000200  |9967.6     |MOZAMBIQUE     |
|Customer#000000140  |9963.15    |EGYPT          |
|Customer#000000381  |9931.71    |ETHIOPIA       |
|Customer#000000043  |9904.28    |ROMANIA        |
--------------------------------------------------



#### Or... just run some SQL

skfj sdkfj skdfj

In [29]:
sqlDF = session.sql("SELECT c.name, c.acctbal, n.name "\
                    "  FROM tpch.tiny.customer c "\
                    "  JOIN tpch.tiny.nation n "\
                    "    ON c.nationkey = n.nationkey "\
                    " WHERE c.acctbal > 9900.0 "\
                    " ORDER BY c.acctbal DESC ")
sqlDF.show()

--------------------------------------------------
|"name"              |"acctbal"  |"name"         |
--------------------------------------------------
|Customer#000000213  |9987.71    |UNITED STATES  |
|Customer#000000045  |9983.38    |INDONESIA      |
|Customer#000001106  |9977.62    |VIETNAM        |
|Customer#000000200  |9967.6     |MOZAMBIQUE     |
|Customer#000000140  |9963.15    |EGYPT          |
|Customer#000000381  |9931.71    |ETHIOPIA       |
|Customer#000000043  |9904.28    |ROMANIA        |
--------------------------------------------------



### Richer examples

sdkjf ksdjf


#### Example 1 (joining 3 tables)

Description: Aggregate total customer acctbal by region name

Tables: tpch.tiny.customer, tpch.tiny.nation, tpch.tiny.region

In [56]:
nation = session.table(f"tpch.tiny.nation")
region = session.table(f"tpch.tiny.region")
customer = session.table(f"tpch.tiny.customer")

nr = nation.join(region, nation.regionkey == region.regionkey) \
           .select(nation.nationkey.alias("nationkey"),
                   region.name.alias("region_name"))

result = customer.join(nr, customer.nationkey == nr.nationkey) \
                 .groupBy("region_name") \
                 .agg(("acctbal", "sum")) \
                 .rename("sum(acctbal)", "total_acctbal") \
                 .sort(col("total_acctbal").desc())

result.show()

--------------------------------------
|"region_name"  |"total_acctbal"     |
--------------------------------------
|ASIA           |1499764.889999999   |
|MIDDLE EAST    |1437184.9000000008  |
|AFRICA         |1374136.539999999   |
|AMERICA        |1264568.9199999995  |
|EUROPE         |1106210.3400000012  |
--------------------------------------



#### Example 2 (windowing example)

For each nation, get top N customers by acctbal

Tables: tpch.tiny.customer, tpch.tiny.nation

In [20]:
from pystarburst.window import Window

customer = session.table(f"tpch.tiny.customer")
nation = session.table(f"tpch.tiny.nation") \
            .drop("regionkey", "comment") \
            .rename("name", "nation_name") \
            .rename("nationkey", "n_nationkey")

filtered = customer.select("custkey", "name", "acctbal", "nationkey") \
            .filter(col("acctbal") > 8000.0)
joined = filtered.join(nation, col("nationkey") == nation.n_nationkey) \
            .drop("nationkey", "n_nationkey")

w = Window.partitionBy("nation_name") \
          .orderBy(col("acctbal") \
          .desc())

ranked = joined.select("*", row_number().over(w).alias("rn"))
top_x = ranked.filter(col("rn") <= 1) \
              .sort(col("acctbal") \
              .desc(), col("nation_name"))
top_x.show(25)

----------------------------------------------------------------------
|"custkey"  |"name"              |"acctbal"  |"nation_name"   |"rn"  |
----------------------------------------------------------------------
|213        |Customer#000000213  |9987.71    |UNITED STATES   |1     |
|45         |Customer#000000045  |9983.38    |INDONESIA       |1     |
|1106       |Customer#000001106  |9977.62    |VIETNAM         |1     |
|200        |Customer#000000200  |9967.6     |MOZAMBIQUE      |1     |
|140        |Customer#000000140  |9963.15    |EGYPT           |1     |
|381        |Customer#000000381  |9931.71    |ETHIOPIA        |1     |
|43         |Customer#000000043  |9904.28    |ROMANIA         |1     |
|100        |Customer#000000100  |9889.89    |SAUDI ARABIA    |1     |
|780        |Customer#000000780  |9874.12    |INDIA           |1     |
|518        |Customer#000000518  |9871.66    |PERU            |1     |
|197        |Customer#000000197  |9860.22    |ARGENTINA       |1     |
|219  

## Ibis

These next two cells are to install Ibis and to execute some
boiler-plate code for setup.

In [None]:
%pip install trino
%pip install 'ibis-framework[trino]'
%pip install pystarburst

Collecting pyarrow<18,>=10.0.1 (from ibis-framework[trino])
  Downloading pyarrow-17.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting pyarrow-hotfix<1,>=0.4 (from ibis-framework[trino])
  Downloading pyarrow_hotfix-0.7-py3-none-any.whl.metadata (3.6 kB)
Downloading pyarrow-17.0.0-cp312-cp312-manylinux_2_28_x86_64.whl (39.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m39.9/39.9 MB[0m [31m41.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pyarrow_hotfix-0.7-py3-none-any.whl (7.9 kB)
Installing collected packages: pyarrow-hotfix, pyarrow
  Attempting uninstall: pyarrow
    Found existing installation: pyarrow 18.1.0
    Uninstalling pyarrow-18.1.0:
      Successfully uninstalled pyarrow-18.1.0
Successfully installed pyarrow-17.0.0 pyarrow-hotfix-0.7


^C


In [4]:
import os
import ibis
from trino.auth import BasicAuthentication

ibis.options.interactive = True

user = my_username
trino_auth_obj = BasicAuthentication(my_username, my_password)
host = my_host
port = "443"
http_scheme = "https"
catalog = "tpch"
schema = "tiny"

con = ibis.trino.connect(
    user=user, auth=trino_auth_obj, host=host, port=port, http_scheme=http_scheme, database=catalog, schema=schema
)

print('\n Make sure the phrase ** CONNECTION IS GOOD ** displays \n')
con.sql("select '** CONNECTION IS GOOD **' as conn_check")


 Make sure the phrase ** CONNECTION IS GOOD ** displays 



### Walk before running

bla bla bla

Note: This code was originally published in  [ibis & trino (dataframe api part deux)](https://lestermartin.blog/2023/10/27/ibis-trino-dataframe-api-part-deux/).  

#### Select a full table

ksdjf skdjf skdjfkds

In [8]:
custDF = con.table("customer")
custDF[0:10]

#### Use projection

skdjf skdjfk sdkfj

In [11]:
projectedDF = custDF.select("name", "acctbal", "nationkey")
projectedDF[0:10]

#### Filter the rows

skldjf skdfj skdfj

In [13]:
filteredDF = projectedDF.filter(projectedDF["acctbal"] > 9900.0)
filteredDF[0:100]

#### Select a second table

skdj skdjf skdjf

In [15]:
# Grab new table, drop 2 cols, and rename 2 others
nationDF = con.table("nation") \
            .drop("regionkey", "comment") \
            .rename(
                dict(
                    nation_name="name",
                    n_nationkey="nationkey"
                )
            )
nationDF[0:10]

#### Join the tables

dkvfjk sdkfjfd

In [17]:
joinedDF = filteredDF.join(nationDF,
    filteredDF.nationkey == nationDF.n_nationkey)
joinedDF[0:10]

#### Project the joined result

skdjf skdjf ksdjf

In [19]:
projectedJoinDF = joinedDF.drop("nationkey", "n_nationkey")
projectedJoinDF[0:10]

#### Apply a sort

skdjf skdjf skdjf

In [21]:
orderedDF = projectedJoinDF.order_by([ibis.desc("acctbal")])
orderedDF[0:10]

#### Put it all together

skdjf ksdjfksjdf

In [22]:
nationDF = con.table("nation") \
            .drop("regionkey", "comment") \
            .rename(
                dict(
                    nation_name="name",
                    n_nationkey="nationkey"
                )
            )

custDF = con.table("customer") \
            .select("name", "acctbal", "nationkey") \
            .filter(projectedDF["acctbal"] > 9900.0)

apiSQL = custDF.join(nationDF,
    custDF.nationkey == nationDF.n_nationkey) \
            .drop("nationkey", "n_nationkey") \
            .order_by([ibis.desc("acctbal")])

apiSQL[0:10]

#### Or... just run some SQL

skfj sdkfj skdfj

In [55]:
sqlDF = con.sql("""
         SELECT c.name, c.acctbal, n.name AS nation_name
           FROM tpch.tiny.customer c
           JOIN tpch.tiny.nation n
             ON c.nationkey = n.nationkey
          WHERE c.acctbal > 9900.0
          ORDER BY c.acctbal DESC
""")
sqlDF[0:10]

### Richer examples

skdjf skdjfk sdkfj

#### Example 1 (joining 3 tables)

Description: Aggregate total customer acctbal by region name

Tables: tpch.tiny.customer, tpch.tiny.nation, tpch.tiny.region

In [57]:
nation = con.table("nation", database=f"tpch.tiny")
region = con.table("region", database=f"tpch.tiny")
customer = con.table("customer", database=f"tpch.tiny")

# nation >< region, n.regionkey = r.regionkey
nr = nation.join(region, nation["regionkey"] == region["regionkey"]).select(
    nation["nationkey"], region["name"].name("region_name")
)

# customer >< (nation, region) nr, c.nationkey = nr.nationkey
agg = (
    customer.join(nr, customer["nationkey"] == nr["nationkey"])
            .group_by("region_name")
            .aggregate(total_acctbal=customer["acctbal"].sum())
            .order_by(ibis.desc("total_acctbal"))
)

agg[0:10]

#### Example 2 (windowing example)

For each nation, get top N customers by acctbal

Tables: tpch.tiny.customer, tpch.tiny.nation

In [59]:
expr = con.sql("""
        SELECT
          c.custkey,
          c.name,
          c.acctbal,
          n.name AS nation_name,
          row_number() OVER (
            PARTITION BY
              n.name
            ORDER BY
              c.acctbal DESC
          ) AS rn
        FROM
          tpch.tiny.customer c
          JOIN tpch.tiny.nation n ON c.nationkey = n.nationkey
        WHERE
          c.acctbal > 8000.0
""")
top_x = expr.filter(expr.rn == 10).order_by([ibis.desc("acctbal"), "nation_name"])

top_x[0:10]
