
<div style="text-align: center; line-height: 0; padding-top: 9px;">
  <img src="https://databricks.com/wp-content/uploads/2018/03/db-academy-rgb-1200px.png" alt="Databricks Learning" style="width: 600px">
</div>

# Revenue by Traffic Lab
Get the 3 traffic sources generating the highest total revenue.
1. Aggregate revenue by traffic source
2. Get top 3 traffic sources by total revenue
3. Clean revenue columns to have two decimal places

##### Methods
- <a href="https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/dataframe.html" target="_blank">DataFrame</a>: **`groupBy`**, **`sort`**, **`limit`**
- <a href="https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/column.html" target="_blank">Column</a>: **`alias`**, **`desc`**, **`cast`**, **`operators`**
- <a href="https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html" target="_blank">Built-in Functions</a>: **`avg`**, **`sum`**

In [0]:
%run ./Includes/Classroom-Setup-00.06L

### Setup
Run the cell below to create the starting DataFrame **`df`**.

In [0]:
from pyspark.sql.functions import col

# Purchase events logged on the BedBricks website
df = (spark.table("events")
      .withColumn("revenue", col("ecommerce.purchase_revenue_in_usd"))
      .filter(col("revenue").isNotNull())
      .drop("event_name")
     )

display(df)


### 1. Aggregate revenue by traffic source
- Group by **`traffic_source`**
- Get sum of **`revenue`** as **`total_rev`**. Round this to the tens decimal place (e.g. `nnnnn.n`). 
- Get average of **`revenue`** as **`avg_rev`**

Remember to import any necessary built-in functions.

In [0]:
# TODO

traffic_df = (df.FILL_IN
)

display(traffic_df)


**1.1: CHECK YOUR WORK**

In [0]:
from pyspark.sql.functions import round

expected1 = [(620096.0, 1049.2318), (4026578.5, 986.1814), (1200591.0, 1067.192), (2322856.0, 1093.1087), (826921.0, 1086.6242), (404911.0, 1091.4043)]
test_df = traffic_df.sort("traffic_source").select(round("total_rev", 4).alias("total_rev"), round("avg_rev", 4).alias("avg_rev"))
result1 = [(row.total_rev, row.avg_rev) for row in test_df.collect()]

assert(expected1 == result1)
print("All test pass")


### 2. Get top three traffic sources by total revenue
- Sort by **`total_rev`** in descending order
- Limit to first three rows

In [0]:
# TODO
top_traffic_df = (traffic_df.FILL_IN
)
display(top_traffic_df)


**2.1: CHECK YOUR WORK**

In [0]:
expected2 = [(4026578.5, 986.1814), (2322856.0, 1093.1087), (1200591.0, 1067.192)]
test_df = top_traffic_df.select(round("total_rev", 4).alias("total_rev"), round("avg_rev", 4).alias("avg_rev"))
result2 = [(row.total_rev, row.avg_rev) for row in test_df.collect()]

assert(expected2 == result2)
print("All test pass")


### 3. Limit revenue columns to two decimal places
- Modify columns **`avg_rev`** and **`total_rev`** to contain numbers with two decimal places
  - Use **`withColumn()`** with the same names to replace these columns
  - To limit to two decimal places, multiply each column by 100, cast to long, and then divide by 100

In [0]:
# TODO
final_df = (top_traffic_df.FILL_IN
)

display(final_df)


**3.1: CHECK YOUR WORK**

In [0]:
expected3 = [(4026578.5, 986.18), (2322856.0, 1093.1), (1200591.0, 1067.19)]
result3 = [(row.total_rev, row.avg_rev) for row in final_df.collect()]

assert(expected3 == result3)
print("All test pass")

### 4. Bonus: Rewrite using a built-in math function
Find a built-in math function that rounds to a specified number of decimal places

In [0]:
# TODO
bonus_df = (top_traffic_df.FILL_IN
)

display(bonus_df)


**4.1: CHECK YOUR WORK**

In [0]:
expected4 = [(4026578.5, 986.18), (2322856.0, 1093.11), (1200591.0, 1067.19)]
result4 = [(row.total_rev, row.avg_rev) for row in bonus_df.collect()]

assert(expected4 == result4)
print("All test pass")


### 5. Chain all the steps above

In [0]:
# TODO
chain_df = (df.FILL_IN
)

display(chain_df)


**5.1: CHECK YOUR WORK**

In [0]:
expected5 = [(4026578.5, 986.18), (2322856.0, 1093.11), (1200591.0, 1067.19)]
result5 = [(row.total_rev, row.avg_rev) for row in chain_df.collect()]

assert(expected5 == result5)
print("All test pass")


Run the following cell to delete the tables and files associated with this lesson.

In [0]:
DA.cleanup()

&copy; 2023 Databricks, Inc. All rights reserved.<br/>
Apache, Apache Spark, Spark and the Spark logo are trademarks of the <a href="https://www.apache.org/">Apache Software Foundation</a>.<br/>
<br/>
<a href="https://databricks.com/privacy-policy">Privacy Policy</a> | <a href="https://databricks.com/terms-of-use">Terms of Use</a> | <a href="https://help.databricks.com/">Support</a>