# Spark LLM Assistant

## Initialization

In [1]:
from langchain.chat_models import ChatOpenAI
from spark_llm import SparkLLMAssistant

llm = ChatOpenAI(model_name='gpt-4') # using gpt-4 can achieve better results
assistant=SparkLLMAssistant(llm=llm, verbose=True)
assistant.activate() # active partial functions for Spark DataFrame

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/06/21 12:02:36 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## Example 1: Auto sales by brand in US 2022

In [2]:
# Search and ingest web content into a DataFrame
auto_df = assistant.create_df("2022 USA national auto sales by brand")
auto_df.show()

[92mINFO: [0mParsing URL: https://www.carpro.com/blog/full-year-2022-national-auto-sales-by-brand

[92mINFO: [0mSQL query for the ingestion:

[92mINFO: [0m[34mCREATE[39;49;00m[37m [39;49;00m[34mOR[39;49;00m[37m [39;49;00m[34mREPLACE[39;49;00m[37m [39;49;00mTEMP[37m [39;49;00m[34mVIEW[39;49;00m[37m [39;49;00mauto_sales_2022[37m [39;49;00m[34mAS[39;49;00m[37m [39;49;00m[34mSELECT[39;49;00m[37m [39;49;00m*[37m [39;49;00m[34mFROM[39;49;00m[37m [39;49;00m[34mVALUES[39;49;00m[37m[39;49;00m
([34m1[39;49;00m,[37m [39;49;00m[33m'Toyota'[39;49;00m,[37m [39;49;00m[34m1849751[39;49;00m,[37m [39;49;00m-[34m9[39;49;00m),[37m[39;49;00m
([34m2[39;49;00m,[37m [39;49;00m[33m'Ford'[39;49;00m,[37m [39;49;00m[34m1767439[39;49;00m,[37m [39;49;00m-[34m2[39;49;00m),[37m[39;49;00m
([34m3[39;49;00m,[37m [39;49;00m[33m'Chevrolet'[39;49;00m,[37m [39;49;00m[34m1502389[39;49;00m,[37m [39;49;00m[34m6[39;49;00m),[37m[39;49;00m

+----+-------------+--------+-------+
|rank|        brand|us_sales|vs_2021|
+----+-------------+--------+-------+
|   1|       Toyota| 1849751|     -9|
|   2|         Ford| 1767439|     -2|
|   3|    Chevrolet| 1502389|      6|
|   4|        Honda|  881201|    -33|
|   5|      Hyundai|  724265|     -2|
|   6|          Kia|  693549|     -1|
|   7|         Jeep|  684612|    -12|
|   8|       Nissan|  682731|    -25|
|   9|       Subaru|  556581|     -5|
|  10|   Ram Trucks|  545194|    -16|
|  11|          GMC|  517649|      7|
|  12|Mercedes-Benz|  350949|      7|
|  13|          BMW|  332388|     -1|
|  14|   Volkswagen|  301069|    -20|
|  15|        Mazda|  294908|    -11|
|  16|        Lexus|  258704|    -15|
|  17|        Dodge|  190793|    -12|
|  18|         Audi|  186875|     -5|
|  19|     Cadillac|  134726|     14|
|  20|     Chrysler|  112713|     -2|
+----+-------------+--------+-------+
only showing top 20 rows



In [5]:
auto_df.llm.plot()

[92mINFO: [0mTo visualize the result of `df` using Plotly, you can follow these steps:

1. First, make sure you have the Plotly package installed. You can install it using `!pip install plotly` if you are using Jupyter Notebook or `pip install plotly` in your terminal.

2. Convert the PySpark DataFrame to a Pandas DataFrame.

3. Create a plot using Plotly's graph objects.

Here's the code:


```
[34mimport[39;49;00m [04m[36mplotly[39;49;00m[04m[36m.[39;49;00m[04m[36mgraph_objects[39;49;00m [34mas[39;49;00m [04m[36mgo[39;49;00m
[34mimport[39;49;00m [04m[36mplotly[39;49;00m[04m[36m.[39;49;00m[04m[36mexpress[39;49;00m [34mas[39;49;00m [04m[36mpx[39;49;00m

[37m# Convert the PySpark DataFrame to a Pandas DataFrame[39;49;00m
pdf = df.toPandas()

[37m# Create a bar plot using Plotly[39;49;00m
fig = px.bar(pdf, x=[33m'[39;49;00m[33mbrand[39;49;00m[33m'[39;49;00m, y=[33m'[39;49;00m[33mus_sales[39;49;00m[33m'[39;49;00m, text=[33m'[39;49;00m[3

In [6]:
auto_df.llm.plot("pie char for market shares, show the top 5 brands and the sum of others")

[92mINFO: [0mTo visualize the result of the `df` using Plotly, you can follow these steps:

1. Convert the PySpark DataFrame to a Pandas DataFrame.
2. Filter the top 5 brands and calculate the sum of others.
3. Create a pie chart using Plotly and show the market shares.

Here's the code to achieve that:


```
[34mimport[39;49;00m [04m[36mpandas[39;49;00m [34mas[39;49;00m [04m[36mpd[39;49;00m
[34mimport[39;49;00m [04m[36mplotly[39;49;00m[04m[36m.[39;49;00m[04m[36mgraph_objs[39;49;00m [34mas[39;49;00m [04m[36mgo[39;49;00m

[37m# Convert the PySpark DataFrame to a Pandas DataFrame[39;49;00m
df_pd = df.toPandas()

[37m# Filter the top 5 brands and calculate the sum of others[39;49;00m
top_5_brands = df_pd.nlargest([34m5[39;49;00m, [33m'[39;49;00m[33mus_sales[39;49;00m[33m'[39;49;00m)
other_brands_total = df_pd[~df_pd.index.isin(top_5_brands.index)][[33m'[39;49;00m[33mus_sales[39;49;00m[33m'[39;49;00m].sum()

[37m# Create a pie chart using Plot

In [9]:
# Apply transforms to a Dataframe
auto_top_growth_df=auto_df.llm.transform("top brand with the highest growth")
auto_top_growth_df.show()

[92mINFO: [0mSQL query for the transform:
[92mINFO: [0m[34mSELECT[39;49;00m[37m [39;49;00mbrand[37m[39;49;00m
[34mFROM[39;49;00m[37m [39;49;00mtemp_view_for_transform[37m[39;49;00m
[34mORDER[39;49;00m[37m [39;49;00m[34mBY[39;49;00m[37m [39;49;00mvs_2021[37m [39;49;00m[34mDESC[39;49;00m[37m[39;49;00m
[34mLIMIT[39;49;00m[37m [39;49;00m[34m1[39;49;00m[37m[39;49;00m



+-------+
|  brand|
+-------+
|Genesis|
+-------+



In [10]:
# Explain what a DataFrame is retrieving.
auto_top_growth_df.llm.explain()

'In summary, this dataframe is retrieving the brand with the highest value for the `vs_2021` metric from a temporary view named `temp_view_for_transform`. The result will only include the top brand based on the `vs_2021` value.'

In [11]:
auto_top_growth_df.llm.verify("expect sales change percentage to be between -100 to 100")

[92mINFO: [0mGenerated code:
[92mINFO: [0m[34mdef[39;49;00m [32mhas_sales_change_in_range[39;49;00m(df) -> [36mbool[39;49;00m:
    [34mfrom[39;49;00m [04m[36mpyspark[39;49;00m[04m[36m.[39;49;00m[04m[36msql[39;49;00m[04m[36m.[39;49;00m[04m[36mfunctions[39;49;00m [34mimport[39;49;00m col

    [37m# Check if the dataframe has the column "sales_change_percentage"[39;49;00m
    [34mif[39;49;00m [33m"[39;49;00m[33msales_change_percentage[39;49;00m[33m"[39;49;00m [35mnot[39;49;00m [35min[39;49;00m df.columns:
        [34mreturn[39;49;00m [34mFalse[39;49;00m

    [37m# Check if the sales change percentage is between -100 and 100[39;49;00m
    df_filtered = df.filter((col([33m"[39;49;00m[33msales_change_percentage[39;49;00m[33m"[39;49;00m) >= -[34m100[39;49;00m) & (col([33m"[39;49;00m[33msales_change_percentage[39;49;00m[33m"[39;49;00m) <= [34m100[39;49;00m))

    [37m# If the number of rows in the filtered dataframe is equal to 

## Example 2: USA Presidents

In [12]:
# You can also specify the expected columns for the ingestion.
df=assistant.create_df("USA presidents", ["president", "vice_president"])
df.show()

[92mINFO: [0mParsing URL: https://en.wikipedia.org/wiki/List_of_presidents_of_the_United_States

[92mINFO: [0mSQL query for the ingestion:

[92mINFO: [0m[34mCREATE[39;49;00m[37m [39;49;00m[34mOR[39;49;00m[37m [39;49;00m[34mREPLACE[39;49;00m[37m [39;49;00mTEMP[37m [39;49;00m[34mVIEW[39;49;00m[37m [39;49;00mpresidents[37m [39;49;00m[34mAS[39;49;00m[37m [39;49;00m[34mSELECT[39;49;00m[37m [39;49;00m*[37m [39;49;00m[34mFROM[39;49;00m[37m [39;49;00m[34mVALUES[39;49;00m[37m [39;49;00m
([33m'George Washington'[39;49;00m,[37m [39;49;00m[33m'John Adams'[39;49;00m),[37m[39;49;00m
([33m'John Adams'[39;49;00m,[37m [39;49;00m[33m'Thomas Jefferson'[39;49;00m),[37m[39;49;00m
([33m'Thomas Jefferson'[39;49;00m,[37m [39;49;00m[33m'Aaron Burr'[39;49;00m),[37m[39;49;00m
([33m'Thomas Jefferson'[39;49;00m,[37m [39;49;00m[33m'George Clinton'[39;49;00m),[37m[39;49;00m
([33m'James Madison'[39;49;00m,[37m [39;49;00m[33m'George Clint

+--------------------+--------------------+
|           president|      vice_president|
+--------------------+--------------------+
|   George Washington|          John Adams|
|          John Adams|    Thomas Jefferson|
|    Thomas Jefferson|          Aaron Burr|
|    Thomas Jefferson|      George Clinton|
|       James Madison|      George Clinton|
|       James Madison|      Elbridge Gerry|
|        James Monroe|  Daniel D. Tompkins|
|   John Quincy Adams|     John C. Calhoun|
|      Andrew Jackson|     John C. Calhoun|
|      Andrew Jackson|    Martin Van Buren|
|    Martin Van Buren|Richard Mentor Jo...|
|William Henry Har...|          John Tyler|
|          John Tyler|                null|
|       James K. Polk|    George M. Dallas|
|      Zachary Taylor|    Millard Fillmore|
|    Millard Fillmore|                null|
|     Franklin Pierce|     William R. King|
|      James Buchanan|John C. Breckinridge|
|     Abraham Lincoln|     Hannibal Hamlin|
|     Abraham Lincoln|      Andr

In [13]:
presidents_who_were_vp = df.llm.transform("presidents who were also vice presidents")
presidents_who_were_vp.show()

[92mINFO: [0mSQL query for the transform:
[92mINFO: [0m[34mSELECT[39;49;00m[37m [39;49;00m[34mDISTINCT[39;49;00m[37m [39;49;00mpresident[37m[39;49;00m
[34mFROM[39;49;00m[37m [39;49;00mtemp_view_for_transform[37m[39;49;00m
[34mWHERE[39;49;00m[37m [39;49;00mpresident[37m [39;49;00m[34mIN[39;49;00m[37m [39;49;00m([34mSELECT[39;49;00m[37m [39;49;00mvice_president[37m [39;49;00m[34mFROM[39;49;00m[37m [39;49;00mtemp_view_for_transform)[37m[39;49;00m



+------------------+
|         president|
+------------------+
|        John Adams|
|  Thomas Jefferson|
|  Martin Van Buren|
|  Millard Fillmore|
|        John Tyler|
|    Andrew Johnson|
| Chester A. Arthur|
|Theodore Roosevelt|
|   Calvin Coolidge|
|   Harry S. Truman|
+------------------+



In [14]:
presidents_who_were_vp.llm.explain()

'In summary, this dataframe is retrieving the distinct list of presidents who have never been vice presidents from the `presidents` table.'

In [15]:
presidents_who_were_vp.llm.verify("expect no NULL values")

[92mINFO: [0mGenerated code:
[92mINFO: [0m[34mdef[39;49;00m [32mhas_no_null_values[39;49;00m(df) -> [36mbool[39;49;00m:
    [34mfrom[39;49;00m [04m[36mpyspark[39;49;00m[04m[36m.[39;49;00m[04m[36msql[39;49;00m[04m[36m.[39;49;00m[04m[36mfunctions[39;49;00m [34mimport[39;49;00m col
    
    [37m# Check if there are any NULL values in the DataFrame[39;49;00m
    null_counts = df.select([col(c).isNull().alias(c) [34mfor[39;49;00m c [35min[39;49;00m df.columns]).filter([33m"[39;49;00m[33m or [39;49;00m[33m"[39;49;00m.join(df.columns)).count()

    [37m# If there are no NULL values, return True, otherwise return False[39;49;00m
    [34mif[39;49;00m null_counts == [34m0[39;49;00m:
        [34mreturn[39;49;00m [34mTrue[39;49;00m
    [34melse[39;49;00m:
        [34mreturn[39;49;00m [34mFalse[39;49;00m

result = has_no_null_values(df)

[92mINFO: [0m
Result: True


# Example 3: Top 10 tech companies

In [16]:
# Search and ingest web content into a DataFrame
company_df=assistant.create_df("Top 10 tech companies by market cap", ['company', 'cap', 'country'])
company_df.show()

[92mINFO: [0mParsing URL: https://www.statista.com/statistics/1350976/leading-tech-companies-worldwide-by-market-cap/

[92mINFO: [0mSQL query for the ingestion:

[92mINFO: [0m[34mCREATE[39;49;00m[37m [39;49;00m[34mOR[39;49;00m[37m [39;49;00m[34mREPLACE[39;49;00m[37m [39;49;00mTEMP[37m [39;49;00m[34mVIEW[39;49;00m[37m [39;49;00mtop_tech_companies[37m [39;49;00m[34mAS[39;49;00m[37m [39;49;00m[34mSELECT[39;49;00m[37m [39;49;00m*[37m [39;49;00m[34mFROM[39;49;00m[37m [39;49;00m[34mVALUES[39;49;00m[37m[39;49;00m
([33m'Apple'[39;49;00m,[37m [39;49;00m[34m2242[39;49;00m,[37m [39;49;00m[33m'United States'[39;49;00m),[37m[39;49;00m
([33m'Microsoft'[39;49;00m,[37m [39;49;00m[34m1821[39;49;00m,[37m [39;49;00m[33m'United States'[39;49;00m),[37m[39;49;00m
([33m'Alphabet (Google)'[39;49;00m,[37m [39;49;00m[34m1229[39;49;00m,[37m [39;49;00m[33m'United States'[39;49;00m),[37m[39;49;00m
([33m'Amazon'[39;49;00m,[37m [39;4

+--------------------+------+-------------+
|             company|   cap|      country|
+--------------------+------+-------------+
|               Apple|2242.0|United States|
|           Microsoft|1821.0|United States|
|   Alphabet (Google)|1229.0|United States|
|              Amazon| 902.4|United States|
|               Tesla| 541.4|United States|
|                TSMC| 410.9|       Taiwan|
|              NVIDIA| 401.7|United States|
|             Tencent| 377.8|        China|
|Meta Platforms (F...| 302.1|United States|
|             Samsung| 301.7|  South Korea|
+--------------------+------+-------------+



In [17]:
us_company_df=company_df.llm.transform("companies in United States")
us_company_df.show()

[92mINFO: [0mSQL query for the transform:
[92mINFO: [0m[34mSELECT[39;49;00m[37m [39;49;00mcompany,[37m [39;49;00mcap,[37m [39;49;00mcountry[37m[39;49;00m
[34mFROM[39;49;00m[37m [39;49;00mtemp_view_for_transform[37m[39;49;00m
[34mWHERE[39;49;00m[37m [39;49;00mcountry[37m [39;49;00m=[37m [39;49;00m[33m'United States'[39;49;00m[37m[39;49;00m



+--------------------+------+-------------+
|             company|   cap|      country|
+--------------------+------+-------------+
|               Apple|2242.0|United States|
|           Microsoft|1821.0|United States|
|   Alphabet (Google)|1229.0|United States|
|              Amazon| 902.4|United States|
|               Tesla| 541.4|United States|
|              NVIDIA| 401.7|United States|
|Meta Platforms (F...| 302.1|United States|
+--------------------+------+-------------+



In [18]:
us_company_df.llm.explain()

'In summary, this dataframe is retrieving the company name, market capitalization, and country for all the top tech companies that are based in the United States.'

In [19]:
us_company_df.llm.verify("expect all company names to be unique")

[92mINFO: [0mGenerated code:
[92mINFO: [0m[34mdef[39;49;00m [32mhas_unique_company_names[39;49;00m(df) -> [36mbool[39;49;00m:
    [34mfrom[39;49;00m [04m[36mpyspark[39;49;00m[04m[36m.[39;49;00m[04m[36msql[39;49;00m[04m[36m.[39;49;00m[04m[36mfunctions[39;49;00m [34mimport[39;49;00m countDistinct

    [37m# Get the number of unique company names[39;49;00m
    unique_companies = df.agg(countDistinct([33m"[39;49;00m[33mcompany[39;49;00m[33m"[39;49;00m)).collect()[[34m0[39;49;00m][[34m0[39;49;00m]

    [37m# Get the total number of rows in the DataFrame[39;49;00m
    total_rows = df.count()

    [37m# Check if the number of unique company names is equal to the total number of rows[39;49;00m
    [34mif[39;49;00m unique_companies == total_rows:
        [34mreturn[39;49;00m [34mTrue[39;49;00m
    [34melse[39;49;00m:
        [34mreturn[39;49;00m [34mFalse[39;49;00m

result = has_unique_company_names(df)

[92mINFO: [0m
Result: True


## Example 4: Ingestion from a URL
Instead of searching for the web page, you can also ask the assistant to ingest from a URL.

In [20]:
best_albums_df = assistant.create_df('https://time.com/6235186/best-albums-2022/', ["album", "artist", "year"])
best_albums_df.show()

[92mINFO: [0mParsing URL: https://time.com/6235186/best-albums-2022/

[92mINFO: [0mSQL query for the ingestion:

[92mINFO: [0m[34mCREATE[39;49;00m[37m [39;49;00m[34mOR[39;49;00m[37m [39;49;00m[34mREPLACE[39;49;00m[37m [39;49;00mTEMP[37m [39;49;00m[34mVIEW[39;49;00m[37m [39;49;00mbest_albums_2022[37m [39;49;00m[34mAS[39;49;00m[37m [39;49;00m[34mSELECT[39;49;00m[37m [39;49;00m*[37m [39;49;00m[34mFROM[39;49;00m[37m [39;49;00m[34mVALUES[39;49;00m[37m[39;49;00m
([33m'Motomami'[39;49;00m,[37m [39;49;00m[33m'Rosalía'[39;49;00m,[37m [39;49;00m[34m2022[39;49;00m),[37m[39;49;00m
([33m'You Can\'[39;49;00mt[37m [39;49;00mKill[37m [39;49;00mMe[33m', '[39;49;00m[34m070[39;49;00m[37m [39;49;00mShake[33m', 2022),[39;49;00m
[33m('[39;49;00mMr.[37m [39;49;00mMorale[37m [39;49;00m&[37m [39;49;00mThe[37m [39;49;00mBig[37m [39;49;00mSteppers[33m', '[39;49;00mKendrick[37m [39;49;00mLamar[33m', 2022),[39;49;00m
[33m('[3

+--------------------+--------------------+----+
|               album|              artist|year|
+--------------------+--------------------+----+
|            Motomami|             Rosalía|2022|
|   You Can't Kill Me|           070 Shake|2022|
|Mr. Morale & The ...|      Kendrick Lamar|2022|
|            Big Time|         Angel Olsen|2022|
|         Electricity|Ibibio Sound Machine|2022|
|     It's Almost Dry|             Pusha T|2022|
|Chloe and the Nex...|   Father John Misty|2022|
|         Renaissance|             Beyoncé|2022|
|          19 Masters|           Saya Gray|2022|
|    Un Verano Sin Ti|           Bad Bunny|2022|
+--------------------+--------------------+----+



In [21]:
best_albums_df.llm.verify("expect each year to be 2022")

[92mINFO: [0mGenerated code:
[92mINFO: [0m[34mdef[39;49;00m [32mhas_2022_years[39;49;00m(df) -> [36mbool[39;49;00m:
    [34mfrom[39;49;00m [04m[36mpyspark[39;49;00m[04m[36m.[39;49;00m[04m[36msql[39;49;00m[04m[36m.[39;49;00m[04m[36mfunctions[39;49;00m [34mimport[39;49;00m col

    [37m# Filter the rows where the year is not 2022[39;49;00m
    not_2022_rows = df.filter(col([33m"[39;49;00m[33myear[39;49;00m[33m"[39;49;00m) != [34m2022[39;49;00m)

    [37m# Check if there are any rows with a year other than 2022[39;49;00m
    [34mif[39;49;00m not_2022_rows.count() == [34m0[39;49;00m:
        [34mreturn[39;49;00m [34mTrue[39;49;00m
    [34melse[39;49;00m:
        [34mreturn[39;49;00m [34mFalse[39;49;00m

result = has_2022_years(df)

[92mINFO: [0m
Result: True


## Example 5: UDF Generation

You can also ask the assistant to generate code for a Spark UDF by providing.

In [22]:
@assistant.udf
def convert_grades(grade_percent: float) -> str:
    """Convert the grade percent to a letter grade using standard cutoffs"""
    ...

[92mINFO: [0mCreating following Python UDF:
[92mINFO: [0m[34mdef[39;49;00m [32mconvert_grades[39;49;00m(grade_percent: [36mfloat[39;49;00m) -> [36mstr[39;49;00m:
    [34mif[39;49;00m grade_percent [35mis[39;49;00m [35mnot[39;49;00m [34mNone[39;49;00m:
        [34mif[39;49;00m grade_percent >= [34m90.0[39;49;00m:
            [34mreturn[39;49;00m [33m"[39;49;00m[33mA[39;49;00m[33m"[39;49;00m
        [34melif[39;49;00m grade_percent >= [34m80.0[39;49;00m:
            [34mreturn[39;49;00m [33m"[39;49;00m[33mB[39;49;00m[33m"[39;49;00m
        [34melif[39;49;00m grade_percent >= [34m70.0[39;49;00m:
            [34mreturn[39;49;00m [33m"[39;49;00m[33mC[39;49;00m[33m"[39;49;00m
        [34melif[39;49;00m grade_percent >= [34m60.0[39;49;00m:
            [34mreturn[39;49;00m [33m"[39;49;00m[33mD[39;49;00m[33m"[39;49;00m
        [34melse[39;49;00m:
            [34mreturn[39;49;00m [33m"[39;49;00m[33mF[39;49;00m[33m"[39;

In [23]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
spark.udf.register("convert_grades", convert_grades)
percentGrades = [(1, 97.8), (2, 72.3), (3, 81.2)]
df = spark.createDataFrame(percentGrades, ["student_id", "grade_percent"])
df.selectExpr("student_id", "convert_grades(grade_percent)").show()

[Stage 18:>                                                         (0 + 1) / 1]

+----------+-----------------------------+
|student_id|convert_grades(grade_percent)|
+----------+-----------------------------+
|         1|                            A|
|         2|                            C|
|         3|                            B|
+----------+-----------------------------+



                                                                                

# Cache
The SparkLLMAssistant supports a simple in-memory and persistent cache system. It keeps an in-memory staging cache, which gets updated for LLM and web search results. The staging cache can be persisted through the commit() method. Cache lookup is always performed on the persistent cache only.

In [7]:
assistant.commit()