In [1]:
from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("spark-nb") \
    .master("spark://spark-master:7077") \
    .enableHiveSupport() \
    .getOrCreate()

In [2]:
df = spark.read.format("parquet").load("s3a://raw-data/")

In [3]:
df.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- lpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- lpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- ehail_fee: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- trip_type: long (nullable = true)
 |-- congestion_surcharge: double (nullable = true)



In [4]:
df.show(n=5)

+--------+--------------------+---------------------+------------------+----------+------------+------------+---------------+-------------+-----------+-----+-------+----------+------------+---------+---------------------+------------+------------+---------+--------------------+
|VendorID|lpep_pickup_datetime|lpep_dropoff_datetime|store_and_fwd_flag|RatecodeID|PULocationID|DOLocationID|passenger_count|trip_distance|fare_amount|extra|mta_tax|tip_amount|tolls_amount|ehail_fee|improvement_surcharge|total_amount|payment_type|trip_type|congestion_surcharge|
+--------+--------------------+---------------------+------------------+----------+------------+------------+---------------+-------------+-----------+-----+-------+----------+------------+---------+---------------------+------------+------------+---------+--------------------+
|       2| 2024-03-01 00:10:52|  2024-03-01 00:26:12|                 N|         1|         129|         226|              1|         1.72|       12.8|  1.0|    0.

In [5]:
pandas_df = df.toPandas()
print("type: ", type(pandas_df))
pandas_df.head(5)

type:  <class 'pandas.core.frame.DataFrame'>


Unnamed: 0,VendorID,lpep_pickup_datetime,lpep_dropoff_datetime,store_and_fwd_flag,RatecodeID,PULocationID,DOLocationID,passenger_count,trip_distance,fare_amount,extra,mta_tax,tip_amount,tolls_amount,ehail_fee,improvement_surcharge,total_amount,payment_type,trip_type,congestion_surcharge
0,2,2024-03-01 00:10:52,2024-03-01 00:26:12,N,1.0,129,226,1.0,1.72,12.8,1.0,0.5,3.06,0.0,,1.0,18.36,1.0,1.0,0.0
1,2,2024-03-01 00:22:21,2024-03-01 00:35:15,N,1.0,130,218,1.0,3.25,17.7,1.0,0.5,0.0,0.0,,1.0,20.2,2.0,1.0,0.0
2,2,2024-03-01 00:45:27,2024-03-01 01:04:32,N,1.0,255,107,2.0,4.58,23.3,1.0,0.5,3.5,0.0,,1.0,32.05,1.0,1.0,2.75
3,1,2024-03-01 00:02:00,2024-03-01 00:23:45,N,1.0,181,71,1.0,0.0,22.5,0.0,1.5,0.0,0.0,,1.0,24.0,1.0,1.0,0.0
4,2,2024-03-01 00:16:45,2024-03-01 00:23:25,N,1.0,95,135,1.0,1.15,8.6,1.0,0.5,1.0,0.0,,1.0,12.1,1.0,1.0,0.0


In [6]:
df.count()

167585

## Example of DataFrame API and SparkSQL

In [7]:
df = df.filter("VendorID = 2")
df = df.sort(df.lpep_dropoff_datetime.desc())
df.collect()[0].lpep_pickup_datetime

datetime.datetime(2024, 3, 31, 16, 40)

In [8]:
df.createOrReplaceTempView("taxi_view")
sql = """
SELECT lpep_pickup_datetime
FROM taxi_view
WHERE VendorID = 2
ORDER BY lpep_dropoff_datetime DESC
"""
spark.sql(sql).collect()[0].lpep_pickup_datetime

datetime.datetime(2024, 3, 31, 16, 40)

## Write Some Result to S3

In [9]:
result_path = "s3a://spark-warehouse/raw_data/save_from_notebook"
df.limit(10).write.format("csv").option("header", "true").mode("overwrite").save(result_path)

## Write Data as Hive Table

In [10]:
from pyspark.sql.functions import col

# hive table does not support timestamp_ntz
# we need to cast to timestamp type
df = df.withColumn("lpep_pickup_datetime", col("lpep_pickup_datetime").cast("timestamp"))
df = df.withColumn("lpep_dropoff_datetime", col("lpep_dropoff_datetime").cast("timestamp"))

In [11]:
df.write.format("parquet").mode("append").saveAsTable("local_db.sample_hive_table")

In [12]:
sql = "SHOW CREATE TABLE local_db.sample_hive_table"
createtab_stmt = spark.sql(sql).collect()[0].createtab_stmt
print(createtab_stmt)

CREATE TABLE spark_catalog.local_db.sample_hive_table (
  VendorID INT,
  lpep_pickup_datetime TIMESTAMP,
  lpep_dropoff_datetime TIMESTAMP,
  store_and_fwd_flag STRING,
  RatecodeID BIGINT,
  PULocationID INT,
  DOLocationID INT,
  passenger_count BIGINT,
  trip_distance DOUBLE,
  fare_amount DOUBLE,
  extra DOUBLE,
  mta_tax DOUBLE,
  tip_amount DOUBLE,
  tolls_amount DOUBLE,
  ehail_fee DOUBLE,
  improvement_surcharge DOUBLE,
  total_amount DOUBLE,
  payment_type BIGINT,
  trip_type BIGINT,
  congestion_surcharge DOUBLE)
USING parquet



In [13]:
spark.sql("SELECT COUNT(*) FROM local_db.sample_hive_table").show()

+--------+
|count(1)|
+--------+
|  146630|
+--------+



In [14]:
spark.sql("SELECT * FROM local_db.sample_hive_table LIMIT 5").toPandas()

  if not is_datetime64tz_dtype(pser.dtype):
  if is_datetime64tz_dtype(s.dtype):
  if not is_datetime64tz_dtype(pser.dtype):
  if is_datetime64tz_dtype(s.dtype):


Unnamed: 0,VendorID,lpep_pickup_datetime,lpep_dropoff_datetime,store_and_fwd_flag,RatecodeID,PULocationID,DOLocationID,passenger_count,trip_distance,fare_amount,extra,mta_tax,tip_amount,tolls_amount,ehail_fee,improvement_surcharge,total_amount,payment_type,trip_type,congestion_surcharge
0,2,2024-02-15 18:06:41,2024-02-15 18:08:56,N,1,75,75,1,0.71,5.1,2.5,0.5,1.82,0.0,,1.0,10.92,1,1,0.0
1,2,2024-02-15 17:53:07,2024-02-15 18:08:49,N,1,82,138,1,3.73,19.1,7.5,0.5,5.62,0.0,,1.0,33.72,1,1,0.0
2,2,2024-02-15 18:04:33,2024-02-15 18:08:17,N,1,247,247,1,0.56,-5.8,-2.5,-0.5,0.0,0.0,,-1.0,-9.8,4,1,0.0
3,2,2024-02-15 18:04:33,2024-02-15 18:08:17,N,1,247,247,1,0.56,5.8,2.5,0.5,0.0,0.0,,1.0,9.8,4,1,0.0
4,2,2024-02-15 17:54:37,2024-02-15 18:08:08,N,1,75,42,1,2.6,14.9,2.5,0.5,0.0,0.0,,1.0,18.9,2,1,0.0


In [15]:
spark.stop()