# Welcome to Databricks & PySpark: Module 1

---

## Course Introduction

Welcome to the first module of our Databricks course! In this notebook, you'll get hands-on experience with Databricks and PySpark, learning how to process and analyze data at scale.

---

## Objectives

By the end of this module, you will be able to:

1. **Set up your Databricks environment**  
   Mount data sources and load CSV files.
2. **Understand PySpark basics**  
   Create Spark DataFrames and perform basic data operations.
3. **Explore your data**  
   Generate summary statistics, aggregate data, and perform exploratory analysis.
4. **Write data efficiently**  
   Save your results in Parquet and Delta formats, and understand the benefits of Delta Lake.
5. **Apply your knowledge**  
   Complete practical exercises, including filtering transactions, grouping and counting products, and saving filtered data as Delta tables.

---

Let's get started!

## Objective 1: Setting Up Your Databricks Environment

Follow these steps to set up your Databricks environment, create a catalog and schema, and load your first CSV file as a table.

---

### 1. Create a Catalog and Schema using the Databricks SQL Editor

1. Open the **Databricks SQL Editor** from the workspace sidebar, or use a notebook cell with `%sql` at the top to run SQL code directly in your notebook.
2. Run the following SQL commands to create a new catalog and a landing schema:

   ``` sql
   CREATE CATALOG IF NOT EXISTS dbx_course_catalog;
   USE CATALOG dbx_course_catalog;

   CREATE SCHEMA IF NOT EXISTS landing;
   USE SCHEMA landing;`
   ```
   

   Replace `dbx_course_catalog` and `landing` with your preferred names if desired.

---

In [0]:
%sql
CREATE CATALOG IF NOT EXISTS dbx_course_catalog;
USE CATALOG dbx_course_catalog;

CREATE SCHEMA IF NOT EXISTS landing;
USE SCHEMA landing;

### 2. Upload and Load a CSV File as a Table

1. In the sidebar, click **Catalog** and navigate to your catalog and schema.
2. Click **+ New > Add data**.
3. Upload your CSV file to a Unity Catalog volume or select an existing file.
4. In the **Add data** UI:
   - Select your catalog (`dbx_course_catalog`) and schema (`landing`).
   - (Optional) Edit the table name.
   - Click **Create table** to load the CSV as a managed table.


Alternatively, you can use SQL to create a table from a CSV file:

```sql
CREATE TABLE landing.customer_transactions
USING CSV
OPTIONS (
  path = 'yourpath/customer_transactions.csv',
  header = 'true',
  inferSchema = 'true'
);
```
Note: This approach can cause issues in the Databricks Community (Free) Edition due to limitations with DBFS (Databricks File System), which is a distributed file system used by Databricks to interact with cloud storage. In the free edition, access to DBFS and Unity Catalog volumes is restricted, so this method is not used in this class.

---

You have now set up your environment, created a catalog and schema, and loaded your first table!

## Objective 2: PySpark Basics

In this section, you'll learn how to use PySpark to work with data in Databricks.

---

### 1. Create a Spark DataFrame from a Table

Let's start by creating a Spark DataFrame from the `customer_transactions` table you loaded earlier using `spark.table()`, and then display the first few rows with `.show()`:

In [0]:
df = spark.table("dbx_course_catalog.landing.customer_transactions")
display(df)

---

### 2. PySpark Exercises

Let's explore the data we've just loaded and get a feel for how PySpark works in Databricks. In this section, we'll walk through some basic operations you can perform on your DataFrame. These exercises will help you understand how to select specific columns, filter rows, sort data, group and aggregate, remove duplicates, rename columns, and save your results as new tables.

**Note:** The columns in your dataset are: `transaction_id`, `customer_id`, `date`, `product`, `quantity`, `price_per_unit`, `location`, and `data_issue`.  
- Each row represents a single product within a transaction. Multiple rows can share the same `transaction_id` if they are part of the same transaction.

In the following exercises, you'll practice essential PySpark DataFrame operations using the `customer_transactions` data. Each exercise specifies the columns or conditions you'll use:

1. **Select specific columns:**  
   Extract only the `customer_id` and `price_per_unit` columns from the DataFrame.

2. **Filter rows based on a condition:**  
   Retrieve all transactions where `price_per_unit` is greater than 100 euros.

3. **Order and sort data:**  
   Sort the DataFrame by `price_per_unit` in descending order to see the largest transactions first.

4. **Group by and aggregate:**  
   Group the data by `customer_id` and calculate the total amount spent (sum of `quantity * price_per_unit` for all products per customer).

5. **Drop duplicates:**  
   Remove duplicate records.

6. **Rename columns:**  
   Rename the `price_per_unit` column to `unit_price` for clarity.

7. **Save filtered data as a new table:**  
   Save the filtered DataFrame (transactions with `price_per_unit` > 100 euros) as a new table called `high_value_transactions` in your schema.

8. **Append data to an existing table:**  
   Add this high-value transactions to the `high_value_transactions` table:

| transaction_id                           | customer_id | date       | product | quantity | price_per_unit | location  | data_issue |
|------------------------------------------|-------------|------------|---------|----------|---------------|-----------|------------|
| 23b1ca18-f88a-46ce-b22c-e7fa4673050f     | 4828        | 2025-03-06 | Monitor | 3        | 300           | Amsterdam | null       |
> **Tip for Exercise 8:**  
> To practice appending data, first create a new DataFrame containing just this single row.  
> - Use `import datetime` and `from pyspark.sql import Row` to specify the date and define the row.
> - Manually define the schema using `StructType` and `StructField`.  
> - Save this DataFrame as a new table, then use `.write.mode("append")` to add it to your existing `high_value_transactions` table.

These exercises will help you become comfortable with selecting, filtering, sorting, aggregating, deduplicating, renaming, and saving data using PySpark in Databricks.

In [0]:
#### a. Select specific columns

df_selected_cols = df.select("customer_id", "price_per_unit")
display(df_selected_cols)

In [0]:
#### b. Filter rows

df_filtered_price = df.filter(df.price_per_unit > 100)
display(df_filtered_price)

In [0]:
#### c. Order and sort data

df_sorted_price = df.orderBy(df.price_per_unit.desc())
display(df_sorted_price)

In [0]:
#### d. Group by and aggregate

from pyspark.sql import functions as F

df_grouped_total = df.withColumn(
    "total_amount", F.col("quantity") * F.col("price_per_unit")
).groupBy("customer_id").agg(F.sum("total_amount").alias("total_spent"))
display(df_grouped_total)

In [0]:
#### e. Drop duplicates

df_deduped_all = df.dropDuplicates()
print(f"Rows removed: {df.count() - df_deduped_all.count()}")

In [0]:
#### f. Rename columns

df_renamed_unit_price = df.withColumnRenamed("price_per_unit", "unit_price")
display(df_renamed_unit_price)

In [0]:
#### g. Create a new table

df_filtered_price.write.mode("overwrite").saveAsTable("dbx_course_catalog.landing.high_value_transactions")

In [0]:
#### h. Insert into an existing table

from pyspark.sql import Row
import datetime

new_row = Row(
    transaction_id="23b1ca18-f88a-46ce-b22c-e7fa4673050f",
    customer_id=4828,
    date=datetime.date(2025, 3, 6),
    product="Monitor",
    quantity=3,
    price_per_unit=300,
    location="Amsterdam",
    data_issue=None
)

from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DateType, DoubleType

schema = StructType([
    StructField("transaction_id", StringType(), True),
    StructField("customer_id", IntegerType(), True),
    StructField("date", DateType(), True),
    StructField("product", StringType(), True),
    StructField("quantity", IntegerType(), True),
    StructField("price_per_unit", DoubleType(), True),
    StructField("location", StringType(), True),
    StructField("data_issue", StringType(), True)
])

df_new = spark.createDataFrame([new_row.asDict()], schema=schema)
df_new.write.mode("append").saveAsTable("dbx_course_catalog.landing.high_value_transactions")

## Objective 3: Data Exploration and Analysis

In this section, you'll learn how to explore and analyze your data using PySpark. You'll generate summary statistics, perform aggregations, and conduct exploratory analysis to better understand the patterns and trends in your dataset. These skills are essential for uncovering insights and preparing your data for further processing or machine learning tasks.

---

### Objective 3 Exercises: Data Exploration and Analysis

Practice the following exercises to deepen your understanding of data exploration and analysis with PySpark.
To import PySpark functions for analysis: `from pyspark.sql import functions as F`

1. **Summary Statistics:**  
   Use the `.describe()` method to generate summary statistics (count, mean, stddev, min, max) for the `quantity` and `price_per_unit` columns.

2. **Value Counts:**  
   Count the number of transactions for each unique `product` using `.groupBy().count()`.

3. **Missing Data Analysis:**  
   Find out how many rows have missing (`null`) values in the `data_issue` column. ( Use the `F.col()` function to select a column and use `.isNull()` to check for null values)

4. **Top Customers:**  
   Identify the top 5 customers who spent the most in total (sum of `quantity * price_per_unit`), and display their `customer_id` and total spent. You can use `.alias()` to rename aggregate columns for clarity.

5. **Monthly Trends:**  
   Group transactions by month (extract month from the `date` column using `F.month`) and calculate the total quantity sold each month.

6. **Location Analysis:**  
   For each `location`, compute the average `price_per_unit` and the total number of transactions.

7. **Data Issue Investigation:**  
   List all unique values found in the `data_issue` column and count how many times each occurs.

> 8. **Optional (difficult) exercise: Outlier Detection:**  
   Find all transactions where `quantity` is greater than 10 or `price_per_unit` is more than 3 standard deviations above the mean.

Try to solve each exercise using PySpark DataFrame operations. These tasks will help you build skills in summarizing, grouping, filtering, and analyzing your data.

In [0]:
from pyspark.sql import functions as F

# 1. Summary Statistics for 'quantity' and 'price_per_unit'
display(df.select("quantity", "price_per_unit").describe())

In [0]:
# 2. Value Counts for each unique 'product'
df_product_counts = df.groupBy("product").count()
display(df_product_counts)

In [0]:
# 3. Missing Data Analysis for 'data_issue'
missing_data_issue_count = df.filter(F.col("data_issue").isNotNull()).count()
print(f"Rows with data_issue: {missing_data_issue_count}")

In [0]:
# 4. Top 5 Customers by Total Spent
df_top_customers = (
    df.withColumn("total_amount", F.col("quantity") * F.col("price_per_unit"))
      .groupBy("customer_id")
      .agg(F.sum("total_amount").alias("total_spent"))
      .orderBy(F.col("total_spent").desc())
      .limit(5)
)
display(df_top_customers)

In [0]:
# 5. Monthly Trends: Total quantity sold per month
df_monthly_trends = (
    df.withColumn("month", F.month("date"))
      .groupBy("month")
      .agg(F.sum("quantity").alias("total_quantity"))
      .orderBy("month")
)
display(df_monthly_trends)

In [0]:
# 6. Location Analysis: Avg price_per_unit and total transactions per location
df_location_analysis = (
    df.groupBy("location")
      .agg(
          F.avg("price_per_unit").alias("avg_price_per_unit"),
          F.count("*").alias("transaction_count")
      )
)
display(df_location_analysis)

In [0]:
# 7. Data Issue Investigation: Unique values and counts in 'data_issue'
df_data_issue_counts = (
    df.groupBy("data_issue")
      .count()
      .orderBy(F.col("count").desc())
)
display(df_data_issue_counts)

In [0]:
# 8. Outlier Detection: quantity > 10 or price_per_unit > mean + 3*stddev
stats = df.select(
    F.mean("price_per_unit").alias("mean"),
    F.stddev("price_per_unit").alias("stddev")
).collect()[0]
mean_price = stats["mean"]
stddev_price = stats["stddev"]
threshold = mean_price + 3 * stddev_price

df_outliers = df.filter(
    (F.col("quantity") > 10) | (F.col("price_per_unit") > threshold)
)
display(df_outliers)