## Setup Agent

In [1]:
from langchain_community.agent_toolkits import SparkSQLToolkit, create_spark_sql_agent
from langchain_community.utilities.spark_sql import SparkSQL
from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline

In [63]:
from pyspark.sql import SparkSession
import glob

spark = SparkSession.builder.getOrCreate()
schema = "gdelt"
try:
    spark.sql(f"DROP SCHEMA {schema} CASCADE")
except:
    pass
try:
    spark.sql(f"DROP TABLE events")
except:
    pass
try:
    spark.sql(f"DROP TABLE cameo")
except:
    pass

spark.sql(f"CREATE DATABASE IF NOT EXISTS {schema}")
spark.sql(f"USE {schema}")
directory = "/home/tiago/factored-datathon-2024-seed42/files"

gkg_files = glob.glob(directory + '/**/*export.CSV', recursive=True)
csv_files = [file for file in gkg_files if file.lower().endswith('.csv')]
events_df = spark.read.csv(csv_files, sep='\t', header=True, inferSchema=True)
events_df.show(5)

cameo_df = spark.read.csv("/home/tiago/factored-datathon-2024-seed42/files/cameo.csv", sep='\t', header=True, inferSchema=True)
cameo_df.show(5)

cameo_df = cameo_df.withColumnRenamed("EventCode", "EventCode_join")
merged_df = events_df.join(cameo_df, events_df['EventCode'] == cameo_df['EventCode_join'], "left")
merged_df = merged_df.drop("EventCode_join")

merged_df.write.saveAsTable("events")
spark.sql("SELECT * FROM events").show()


+-------------+--------+---------+----+------------+----------+----------+-----------------+--------------------+----------------+-------------------+-------------------+---------------+---------------+---------------+----------+-------------+-----------------+--------------------+----------------+-------------------+-------------------+---------------+---------------+---------------+-----------+---------+-------------+-------------+---------+--------------+-----------+----------+-----------+-----------------+--------------+--------------------+---------------------+------------------+-------------+--------------+-------------------+--------------+--------------------+---------------------+------------------+-------------+--------------+-------------------+--------------+--------------------+---------------------+------------------+-------------+--------------+-------------------+---------+--------------------+
|GLOBALEVENTID| SQLDATE|MonthYear|Year|FractionDate|Actor1Code|Actor1Name|A

24/08/17 18:18:11 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
24/08/17 18:18:11 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 84.44% for 9 writers
24/08/17 18:18:11 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 76.00% for 10 writers
24/08/17 18:18:11 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 69.09% for 11 writers
24/08/17 18:18:11 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 63.33% for 12 writers
24/08/17 18:18:11 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 58.46% for 13 writers
24/08/17 18:18:11 WARN MemoryManager: Total allocation exceeds 95.

+-------------+--------+---------+----+------------+----------+--------------+-----------------+--------------------+----------------+-------------------+-------------------+---------------+---------------+---------------+----------+----------+-----------------+--------------------+----------------+-------------------+-------------------+---------------+---------------+---------------+-----------+---------+-------------+-------------+---------+--------------+-----------+----------+-----------+-----------------+--------------+--------------------+---------------------+------------------+-------------+--------------+-------------------+--------------+--------------------+---------------------+------------------+-------------+--------------+-------------------+--------------+--------------------+---------------------+------------------+-------------+--------------+-------------------+---------+--------------------+--------------------+
|GLOBALEVENTID| SQLDATE|MonthYear|Year|FractionDate|A

In [7]:
import getpass
import os

os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter your API token: ")

In [64]:
from langchain.chat_models import ChatOpenAI

spark_sql = SparkSQL(schema=schema)
llm = ChatOpenAI(temperature=0, model="gpt-4o-mini")
toolkit = SparkSQLToolkit(db=spark_sql, llm=llm)
agent_executor = create_spark_sql_agent(llm=llm, toolkit=toolkit, max_iterations=10, verbose=True)


In [54]:
print(agent_executor.invoke("Which country is most affect by events whose event description cites sanction?")['output'])



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3mAction: list_tables_sql_db  
Action Input: ""  [0m
Observation: [38;5;200m[1;3mcameo, events[0m
Thought:[32;1m[1;3mI need to check the schema of the `cameo` and `events` tables to understand their structure and identify the relevant columns for my query.  
Action: schema_sql_db  
Action Input: "cameo, events"  [0m
Observation: [33;1m[1;3mCREATE TABLE spark_catalog.gdelt.cameo (
  `EventCode ` INT,
  EventName STRING)
;

/*
3 rows from cameo table:
EventCode 	EventName
1	MAKE PUBLIC STATEMENT
10	Make statement, not specified below
11	Decline comment
*/

CREATE TABLE spark_catalog.gdelt.events (
  GLOBALEVENTID INT,
  SQLDATE INT,
  MonthYear INT,
  Year INT,
  FractionDate DOUBLE,
  Actor1Code STRING,
  Actor1Name STRING,
  Actor1CountryCode STRING,
  Actor1KnownGroupCode STRING,
  Actor1EthnicCode STRING,
  Actor1Religion1Code STRING,
  Actor1Religion2Code STRING,
  Actor1Type1Code STRING,
  Actor1Type2Code STRING,
 

KeyboardInterrupt: 