# PySpark Transformations Lab

## üéØ Lab Scenario

You're a data engineer at a taxi analytics company. Your manager has asked you to analyze NYC taxi trip data and create a summary table that will be used by the business intelligence team for reporting.

**Your Task:** Transform raw taxi trip data into a business-ready summary table.

---

## üìã Learning Objectives

By completing this lab, you will:

* Read data from Unity Catalog tables
* Apply PySpark transformations: `select`, `filter`, `withColumn`, `groupBy`, `orderBy`
* Create calculated columns for business metrics
* Write transformed data to a new table in Unity Catalog

---

## üìä Business Requirements

Create a summary table that shows:
* Trip statistics by hour of day
* Only include trips with valid data (positive fares and distances)
* Calculate average fare, total revenue, and trip counts
* Add a revenue category column (Low/Medium/High)
* Sort results for easy analysis

---

## ‚ö†Ô∏è Lab Format

This is a **challenge lab** - you'll need to figure out the solutions yourself!

* Each section has **requirements** but not step-by-step instructions
* **Hints** are available if you get stuck
* **Solutions** are provided at the end for verification

**Ready? Let's begin!** üöÄ

## Task 1: Load the Source Data üìö

**Your Challenge:**

Load the NYC taxi trips data from the Unity Catalog samples.

**Requirements:**
* Use the table: `samples.nyctaxi.trips`
* Store it in a variable called `taxi_df`
* Display the first few rows to verify

**Questions to answer:**
1. How many columns does the dataset have?
2. What are the data types?
3. What columns will you need for the analysis?

---

**Write your code in the cell below:**

In [0]:
# TODO: Load the taxi trips data from samples.nyctaxi.trips
# Store it in a variable called taxi_df



### üí° Hints for Task 1

<details>
<summary><b>Hint 1:</b> How to read from Unity Catalog (click to expand)</summary>

Use `spark.table()` to read from Unity Catalog tables:
```python
df = spark.table("catalog.schema.table_name")
```
</details>

<details>
<summary><b>Hint 2:</b> Exploring the data (click to expand)</summary>

Useful methods:
* `df.printSchema()` - See column names and types
* `display(df)` - View the data interactively
* `df.count()` - Count total rows
* `df.columns` - List all column names
</details>

## Task 2: Explore the Data Schema üîç

**Your Challenge:**

Examine the structure of the data to understand what you're working with.

**Requirements:**
* Print the schema to see all columns and data types
* Identify which columns you'll need:
  * Pickup datetime (for extracting hour)
  * Fare amount (for revenue calculations)
  * Trip distance (for filtering valid trips)
  * Any other relevant fields

---

**Write your code in the cell below:**

In [0]:
# TODO: Print the schema of taxi_df
# TODO: Examine the columns and identify the ones you need



## Task 3: Select Relevant Columns üéØ

**Your Challenge:**

Select only the columns you need for the analysis to improve performance.

**Requirements:**
* Use the `select()` transformation
* Include these columns:
  * `tpep_pickup_datetime` - for extracting hour
  * `fare_amount` - for revenue calculations
  * `trip_distance` - for filtering
  * `passenger_count` - for additional analysis (optional)
* Store the result in a new variable called `selected_df`
* Verify you have the correct columns

---

**Write your code in the cell below:**

In [0]:
# TODO: Select only the columns you need
# Store in selected_df



### üí° Hints for Task 3

<details>
<summary><b>Hint 1:</b> Using select() (click to expand)</summary>

The `select()` method takes column names as arguments:
```python
df.select("column1", "column2", "column3")
```
</details>

<details>
<summary><b>Hint 2:</b> Verifying your selection (click to expand)</summary>

Check your work:
```python
print(selected_df.columns)  # See column names
display(selected_df)        # View the data
```
</details>

## Task 4: Filter for Valid Trips üìä

**Your Challenge:**

Filter out invalid or suspicious trip records.

**Requirements:**
* Use the `filter()` transformation
* Keep only trips where:
  * `fare_amount` > 0 (positive fares)
  * `trip_distance` > 0 (positive distances)
  * `fare_amount` < 500 (remove outliers)
* Store the result in a variable called `filtered_df`
* Check how many rows remain after filtering

**Bonus Challenge:** Can you write the filter in a single statement?

---

**Write your code in the cell below:**

In [0]:
# TODO: Filter for valid trips
# Store in filtered_df



### üí° Hints for Task 4

<details>
<summary><b>Hint 1:</b> Using filter() (click to expand)</summary>

You can use column references and comparison operators:
```python
df.filter(df.column_name > value)
```
</details>

<details>
<summary><b>Hint 2:</b> Multiple conditions (click to expand)</summary>

Combine conditions with `&` (and) or `|` (or):
```python
df.filter((df.col1 > 0) & (df.col2 < 100))
```
Note: Wrap each condition in parentheses!
</details>

<details>
<summary><b>Hint 3:</b> Alternative syntax (click to expand)</summary>

You can also use SQL-style strings:
```python
df.filter("column1 > 0 AND column2 < 100")
```
</details>

## Task 5: Add Calculated Columns üßÆ

**Your Challenge:**

Create new columns with calculated values.

**Requirements:**

Use `withColumn()` to add these new columns:

1. **`pickup_hour`** - Extract the hour from `tpep_pickup_datetime`
   * Hint: Use the `hour()` function from `pyspark.sql.functions`

2. **`revenue_category`** - Categorize fares as:
   * "Low" if fare_amount < 10
   * "Medium" if fare_amount between 10 and 30
   * "High" if fare_amount > 30
   * Hint: Use the `when()` function for conditional logic

* Store the result in a variable called `enriched_df`
* Verify the new columns were added correctly

---

**Write your code in the cell below:**

In [0]:
# TODO: Import necessary functions
# from pyspark.sql.functions import ...

# TODO: Add pickup_hour column
# TODO: Add revenue_category column
# Store in enriched_df



### üí° Hints for Task 5

<details>
<summary><b>Hint 1:</b> Importing functions (click to expand)</summary>

You'll need these functions:
```python
from pyspark.sql.functions import hour, when, col
```
</details>

<details>
<summary><b>Hint 2:</b> Extracting hour (click to expand)</summary>

Use the `hour()` function:
```python
df.withColumn("hour_column", hour("datetime_column"))
```
</details>

<details>
<summary><b>Hint 3:</b> Conditional logic with when() (click to expand)</summary>

Chain `when()` statements for multiple conditions:
```python
from pyspark.sql.functions import when

df.withColumn("category",
    when(col("amount") < 10, "Low")
    .when(col("amount") <= 30, "Medium")
    .otherwise("High")
)
```
</details>

<details>
<summary><b>Hint 4:</b> Chaining withColumn() (click to expand)</summary>

You can chain multiple `withColumn()` calls:
```python
df.withColumn("col1", ...).withColumn("col2", ...)
```
</details>

## Task 6: Aggregate Data by Hour üìà

**Your Challenge:**

Create summary statistics grouped by hour of day.

**Requirements:**

Use `groupBy()` and aggregation functions to calculate:

* Group by: `pickup_hour`
* Calculate:
  * `trip_count` - Total number of trips (use `count()`)
  * `total_revenue` - Sum of all fares (use `sum()`)
  * `avg_fare` - Average fare amount (use `avg()`)
  * `avg_distance` - Average trip distance (use `avg()`)

* Store the result in a variable called `hourly_summary`
* Round the averages to 2 decimal places

---

**Write your code in the cell below:**

In [0]:
# TODO: Import aggregation functions
# from pyspark.sql.functions import ...

# TODO: Group by pickup_hour and calculate aggregations
# Store in hourly_summary



### üí° Hints for Task 6

<details>
<summary><b>Hint 1:</b> Importing aggregation functions (click to expand)</summary>

You'll need:
```python
from pyspark.sql.functions import count, sum, avg, round
```
</details>

<details>
<summary><b>Hint 2:</b> Using groupBy() with agg() (click to expand)</summary>

Basic syntax:
```python
df.groupBy("column").agg(
    count("*").alias("count_name"),
    sum("column2").alias("sum_name"),
    avg("column3").alias("avg_name")
)
```
</details>

<details>
<summary><b>Hint 3:</b> Rounding values (click to expand)</summary>

Wrap aggregations with `round()`:
```python
round(avg("column_name"), 2).alias("avg_column")
```
</details>

## Task 7: Sort the Results üîΩ

**Your Challenge:**

Sort the summary data for easy analysis.

**Requirements:**

* Use `orderBy()` to sort by `pickup_hour` in ascending order
* This will show the data chronologically from hour 0 (midnight) to hour 23 (11 PM)
* Store the result in a variable called `final_df`
* Display the results to verify

**Bonus Challenge:** Can you also sort by total_revenue descending to see which hours generate the most revenue?

---

**Write your code in the cell below:**

In [0]:
# TODO: Sort hourly_summary by pickup_hour
# Store in final_df



### üí° Hints for Task 7

<details>
<summary><b>Hint 1:</b> Using orderBy() (click to expand)</summary>

Basic syntax:
```python
df.orderBy("column_name")  # Ascending by default
```
</details>

<details>
<summary><b>Hint 2:</b> Descending order (click to expand)</summary>

For descending order:
```python
df.orderBy(col("column_name").desc())
# or
df.orderBy("column_name", ascending=False)
```
</details>

<details>
<summary><b>Hint 3:</b> Multiple sort columns (click to expand)</summary>

Sort by multiple columns:
```python
df.orderBy("column1", col("column2").desc())
```
</details>

## Task 8: Save Results to Unity Catalog üíæ

**Your Challenge:**

Write your transformed data to a new table in Unity Catalog.

**Requirements:**

* Create a table in the `main` catalog
* Use your default schema (usually your username)
* Table name: `taxi_hourly_summary`
* Full table name format: `main.<your_schema>.taxi_hourly_summary`
* Use the `saveAsTable()` method
* Mode: `overwrite` (so you can re-run the lab)

**Important Notes:**
* The table will be created in your personal schema
* You can verify it was created by querying it
* This table can now be used by other notebooks and queries!

---

**Write your code in the cell below:**

In [0]:
# TODO: Write final_df to Unity Catalog
# Table: main.default.taxi_hourly_summary (or use your schema name)



### üí° Hints for Task 8

<details>
<summary><b>Hint 1:</b> Using saveAsTable() (click to expand)</summary>

Basic syntax:
```python
df.write.mode("overwrite").saveAsTable("catalog.schema.table_name")
```
</details>

<details>
<summary><b>Hint 2:</b> Finding your schema name (click to expand)</summary>

You can use:
```python
# Get current database/schema
spark.sql("SELECT current_database()").show()

# Or just use 'default' schema
"main.default.taxi_hourly_summary"
```
</details>

<details>
<summary><b>Hint 3:</b> Write modes (click to expand)</summary>

Common modes:
* `"overwrite"` - Replace table if it exists
* `"append"` - Add to existing table
* `"error"` - Fail if table exists (default)
* `"ignore"` - Do nothing if table exists
</details>

## Task 9: Verify Your Table ‚úÖ

**Your Challenge:**

Confirm that your table was created successfully.

**Requirements:**

* Query your new table using SQL or PySpark
* Display the results
* Verify:
  * All 24 hours are present (0-23)
  * The columns are correct
  * The data looks reasonable

**Bonus:** Try querying it with SQL using `%sql` magic command!

---

**Write your code in the cell below:**

In [0]:
# TODO: Read your table back and display it
# Verify it was created correctly



### üí° Hints for Task 9

<details>
<summary><b>Hint 1:</b> Reading a table (click to expand)</summary>

Read it back:
```python
verify_df = spark.table("main.default.taxi_hourly_summary")
display(verify_df)
```
</details>

<details>
<summary><b>Hint 2:</b> Using SQL (click to expand)</summary>

You can also use SQL:
```sql
%sql
SELECT * FROM main.default.taxi_hourly_summary
ORDER BY pickup_hour
```
</details>

<details>
<summary><b>Hint 3:</b> Checking row count (click to expand)</summary>

Verify you have 24 rows (one per hour):
```python
print(f"Row count: {verify_df.count()}")
```
</details>

---

# üìù Complete Solutions

**Only look at these if you're stuck or want to verify your work!**

Try to solve the challenges yourself first. Learning happens through struggle and problem-solving!

---

## ‚úÖ Solution: Tasks 1-2 (Load and Explore Data)

<details>
<summary><b>Click to reveal solution</b></summary>

```python
# Task 1: Load the data
taxi_df = spark.table("samples.nyctaxi.trips")

# Verify it loaded
print(f"Total rows: {taxi_df.count():,}")
print(f"Total columns: {len(taxi_df.columns)}")

# Task 2: Explore the schema
taxi_df.printSchema()

# Display sample data
display(taxi_df.limit(10))
```

**Key columns we need:**
* `tpep_pickup_datetime` - Timestamp for pickup
* `fare_amount` - Fare charged
* `trip_distance` - Distance traveled
* `passenger_count` - Number of passengers

</details>

## ‚úÖ Solution: Task 3 (Select Columns)

<details>
<summary><b>Click to reveal solution</b></summary>

```python
# Select only the columns we need
selected_df = taxi_df.select(
    "tpep_pickup_datetime",
    "fare_amount",
    "trip_distance",
    "passenger_count"
)

# Verify
print("Selected columns:", selected_df.columns)
display(selected_df)
```

**Why select?**
* Improves performance by reducing data size
* Makes transformations clearer
* Only keeps relevant columns

</details>

## ‚úÖ Solution: Task 4 (Filter Valid Trips)

<details>
<summary><b>Click to reveal solution</b></summary>

```python
# Filter for valid trips
filtered_df = selected_df.filter(
    (selected_df.fare_amount > 0) & 
    (selected_df.trip_distance > 0) & 
    (selected_df.fare_amount < 500)
)

# Check how many rows remain
print(f"Rows after filtering: {filtered_df.count():,}")
display(filtered_df)
```

**Alternative SQL-style syntax:**
```python
filtered_df = selected_df.filter(
    "fare_amount > 0 AND trip_distance > 0 AND fare_amount < 500"
)
```

**Why filter?**
* Removes invalid/suspicious data
* Improves data quality
* Reduces processing time

</details>

## ‚úÖ Solution: Task 5 (Add Calculated Columns)

<details>
<summary><b>Click to reveal solution</b></summary>

```python
# Import necessary functions
from pyspark.sql.functions import hour, when, col

# Add calculated columns
enriched_df = filtered_df \
    .withColumn("pickup_hour", hour("tpep_pickup_datetime")) \
    .withColumn("revenue_category",
        when(col("fare_amount") < 10, "Low")
        .when(col("fare_amount") <= 30, "Medium")
        .otherwise("High")
    )

# Verify new columns
print("Columns after enrichment:", enriched_df.columns)
display(enriched_df)
```

**Key concepts:**
* `hour()` extracts hour from timestamp
* `when()` provides conditional logic (like IF-THEN-ELSE)
* `otherwise()` is the final ELSE clause
* Chain multiple `withColumn()` calls with `\`

</details>

## ‚úÖ Solution: Task 6 (Group and Aggregate)

<details>
<summary><b>Click to reveal solution</b></summary>

```python
# Import aggregation functions
from pyspark.sql.functions import count, sum as spark_sum, avg, round as spark_round

# Group by hour and calculate aggregations
hourly_summary = enriched_df.groupBy("pickup_hour").agg(
    count("*").alias("trip_count"),
    spark_round(spark_sum("fare_amount"), 2).alias("total_revenue"),
    spark_round(avg("fare_amount"), 2).alias("avg_fare"),
    spark_round(avg("trip_distance"), 2).alias("avg_distance")
)

# Display results
display(hourly_summary)
```

**Key concepts:**
* `groupBy()` groups rows by column values
* `agg()` applies aggregation functions
* `alias()` renames the result columns
* Use `spark_sum` and `spark_round` to avoid conflicts with Python built-ins

</details>

## ‚úÖ Solution: Task 7 (Sort Results)

<details>
<summary><b>Click to reveal solution</b></summary>

```python
# Sort by pickup_hour
final_df = hourly_summary.orderBy("pickup_hour")

# Display sorted results
display(final_df)

# Bonus: Sort by revenue (descending)
revenue_sorted = hourly_summary.orderBy(col("total_revenue").desc())
print("\nTop revenue hours:")
display(revenue_sorted)
```

**Key concepts:**
* `orderBy()` sorts the DataFrame
* Default is ascending order
* Use `.desc()` for descending
* Can sort by multiple columns

</details>

## ‚úÖ Solution: Tasks 8-9 (Write and Verify Table)

<details>
<summary><b>Click to reveal solution</b></summary>

```python
# Task 8: Write to Unity Catalog
final_df.write.mode("overwrite").saveAsTable("main.default.taxi_hourly_summary")

print("‚úÖ Table created successfully!")

# Task 9: Verify the table
verify_df = spark.table("main.default.taxi_hourly_summary")

print(f"\nTable has {verify_df.count()} rows (should be 24)")
print("\nTable contents:")
display(verify_df)
```

**Using SQL to verify:**
```sql
%sql
SELECT * 
FROM main.default.taxi_hourly_summary
ORDER BY pickup_hour
```

**Key concepts:**
* `saveAsTable()` writes to Unity Catalog
* `mode("overwrite")` replaces existing table
* Table is now available to all users with permissions
* Can query with SQL or PySpark

</details>

## üìú Complete Solution - All Steps Combined

<details>
<summary><b>Click to reveal complete solution</b></summary>

```python
# Import all necessary functions
from pyspark.sql.functions import (
    hour, when, col, count, sum as spark_sum, 
    avg, round as spark_round
)

# Step 1: Load data
taxi_df = spark.table("samples.nyctaxi.trips")

# Step 2: Select columns
selected_df = taxi_df.select(
    "tpep_pickup_datetime",
    "fare_amount",
    "trip_distance",
    "passenger_count"
)

# Step 3: Filter valid trips
filtered_df = selected_df.filter(
    (col("fare_amount") > 0) & 
    (col("trip_distance") > 0) & 
    (col("fare_amount") < 500)
)

# Step 4: Add calculated columns
enriched_df = filtered_df \
    .withColumn("pickup_hour", hour("tpep_pickup_datetime")) \
    .withColumn("revenue_category",
        when(col("fare_amount") < 10, "Low")
        .when(col("fare_amount") <= 30, "Medium")
        .otherwise("High")
    )

# Step 5: Group and aggregate
hourly_summary = enriched_df.groupBy("pickup_hour").agg(
    count("*").alias("trip_count"),
    spark_round(spark_sum("fare_amount"), 2).alias("total_revenue"),
    spark_round(avg("fare_amount"), 2).alias("avg_fare"),
    spark_round(avg("trip_distance"), 2).alias("avg_distance")
)

# Step 6: Sort results
final_df = hourly_summary.orderBy("pickup_hour")

# Step 7: Write to Unity Catalog
final_df.write.mode("overwrite").saveAsTable("main.default.taxi_hourly_summary")

# Step 8: Verify
print("‚úÖ Lab completed successfully!")
print(f"\nCreated table with {final_df.count()} rows")
display(final_df)
```

**Optimized version (chained transformations):**
```python
from pyspark.sql.functions import *

# All transformations in one chain
final_df = (
    spark.table("samples.nyctaxi.trips")
    .select("tpep_pickup_datetime", "fare_amount", "trip_distance", "passenger_count")
    .filter((col("fare_amount") > 0) & (col("trip_distance") > 0) & (col("fare_amount") < 500))
    .withColumn("pickup_hour", hour("tpep_pickup_datetime"))
    .withColumn("revenue_category",
        when(col("fare_amount") < 10, "Low")
        .when(col("fare_amount") <= 30, "Medium")
        .otherwise("High")
    )
    .groupBy("pickup_hour")
    .agg(
        count("*").alias("trip_count"),
        round(sum("fare_amount"), 2).alias("total_revenue"),
        round(avg("fare_amount"), 2).alias("avg_fare"),
        round(avg("trip_distance"), 2).alias("avg_distance")
    )
    .orderBy("pickup_hour")
)

final_df.write.mode("overwrite").saveAsTable("main.default.taxi_hourly_summary")
```

</details>

---

## üéâ Congratulations!

You've completed the PySpark Transformations Lab!

### üéØ What You Accomplished:

‚úÖ Read data from Unity Catalog  
‚úÖ Applied `select()` to choose relevant columns  
‚úÖ Applied `filter()` to clean invalid data  
‚úÖ Applied `withColumn()` to create calculated fields  
‚úÖ Applied `groupBy()` to aggregate data  
‚úÖ Applied `orderBy()` to sort results  
‚úÖ Wrote results to a new Unity Catalog table  

### üìä Key Takeaways:

* **Transformations are lazy** - They don't execute until an action is called
* **Chain transformations** - Build complex logic step by step
* **Unity Catalog** - Modern way to manage data in Databricks
* **Business value** - Raw data ‚Üí Actionable insights

### üöÄ Next Steps:

* Experiment with different aggregations
* Try joining multiple tables
* Explore window functions
* Learn about partitioning and optimization
* Build more complex data pipelines

---

**Great work!** You're now ready to tackle real-world data engineering challenges! üí™