# Joining and Grouping

we will perform on a data frame: linking or joining data frames, as well as grouping data (and performing operations on the GroupedData object)

In [18]:
# Lets bring back the state of the code where we left off in previous notebook

from pyspark.sql import SparkSession
import pyspark.sql.functions as F
import os

spark = SparkSession.builder.getOrCreate()

DIRECTORY = "data/broadcast_logs"
logs = (
    spark.read.csv(
        os.path.join(DIRECTORY, "BroadcastLogs_2018_Q3_M8_sample.CSV"),
        sep="|",
        header=True,
        inferSchema=True,
        timestampFormat="yyyy-MM-dd",
    )
    .drop("BroadcastLogID", "SequenceNO")
    .withColumn(
        "duration_seconds",
        (
            F.col("Duration").substr(1, 2).cast("int") * 60 * 60
            + F.col("Duration").substr(4, 2).cast("int") * 60
            + F.col("Duration").substr(7, 2).cast("int")
        ),
    )
)

In [19]:
log_identifier = spark.read.csv(
    os.path.join(DIRECTORY,"ReferenceTables/LogIdentifier.csv"),
    sep="|",
    header=True,
    inferSchema=True
)

log_identifier.printSchema()

root
 |-- LogIdentifierID: string (nullable = true)
 |-- LogServiceID: integer (nullable = true)
 |-- PrimaryFG: integer (nullable = true)



In [20]:
# Once the table is ingested, we filter the data frame 
# to keep only the primary channels, as per the data documentation. 
log_identifier = log_identifier.where(F.col("PrimaryFG") == 1)
print(log_identifier.count())

758


In [21]:
log_identifier.show(5)

+---------------+------------+---------+
|LogIdentifierID|LogServiceID|PrimaryFG|
+---------------+------------+---------+
|           13ST|        3157|        1|
|         2000SM|        3466|        1|
|           70SM|        3883|        1|
|           80SM|        3590|        1|
|           90SM|        3470|        1|
+---------------+------------+---------+
only showing top 5 rows



We have two data frames, logs and log_identifier, each containing a set of columns. We are ready to start joining! The join operation has three major ingredients:
- Two tables, called a *left* and a *right* table, respectively
- One or more *predicates*, which are the series of conditions that determine how
records between the two tables are joined
- A *method* to indicate how we perform the join when the predicate succeeds and
when it fails

```sql
[LEFT].join(
    [RIGHT],
    on=[PREDICATES],
    how=[METHOD]
)
```
There are two important points to highlight:
- If one record in the left table resolves the predicate with more than one record in the right table (or vice versa), this record will be duplicated in the joined table.
- If one record in the left or right table does not resolve the predicate with any record in the other table, it will not be present in the resulting table unless the join method specifies a protocol for failed predicates.

>**Note**: if you are performing an “equi-join,” where you are testing for equality between identically named columns, you can simply specify the name of the columns as a string or a list of strings as a predicate. 

```sql
logs.join(
log_identifier,
on="LogServiceID",
how=[METHOD]
)
```

**LEFT AND RIGHT OUTER JOIN**
Left (how="left" or how="left_outer") and right (how="right" or how="right_
outer") are like an inner join in that they generate a record for a successful predicate. The difference is what happens when the predicate is false:
- A left (also called a left outer) join will add the unmatched records from the left table in the joined table, filling the columns coming from the right table with null.
- A right (also called a right outer) join will add the unmatched records from the right in the joined table, filling the columns coming from the left table with null.

In practice, this means that your joined table is guaranteed to contain all the records of the table that feed the join (left or right).

```sql
logs_and_channels =logs.join(
    log_identifier,
    on="LogServiceID",
    how="inner"
)
```

In [22]:
from pyspark.sql.utils import AnalysisException

logs_and_channels_verbose = logs.join(
    log_identifier, logs["LogServiceID"] == log_identifier["LogServiceID"]
)
logs_and_channels_verbose.printSchema()

try:
    logs_and_channels_verbose.select("LogServiceID")
except AnalysisException as err:
    print(err)

root
 |-- LogServiceID: integer (nullable = true)
 |-- LogDate: timestamp (nullable = true)
 |-- AudienceTargetAgeID: integer (nullable = true)
 |-- AudienceTargetEthnicID: integer (nullable = true)
 |-- CategoryID: integer (nullable = true)
 |-- ClosedCaptionID: integer (nullable = true)
 |-- CountryOfOriginID: integer (nullable = true)
 |-- DubDramaCreditID: integer (nullable = true)
 |-- EthnicProgramID: integer (nullable = true)
 |-- ProductionSourceID: integer (nullable = true)
 |-- ProgramClassID: integer (nullable = true)
 |-- FilmClassificationID: integer (nullable = true)
 |-- ExhibitionID: integer (nullable = true)
 |-- Duration: string (nullable = true)
 |-- EndTime: string (nullable = true)
 |-- LogEntryDate: timestamp (nullable = true)
 |-- ProductionNO: string (nullable = true)
 |-- ProgramTitle: string (nullable = true)
 |-- StartTime: string (nullable = true)
 |-- Subtitle: string (nullable = true)
 |-- NetworkAffiliationID: integer (nullable = true)
 |-- SpecialAttenti

We prefer using the simplified syntax, since it takes care of removing the second instance of the predicate column. This only works when using an equality comparison, since the data is identical in both columns from the predicate, which prevents information loss. 

In [23]:
logs_and_channels = logs.join(log_identifier, "LogServiceID")

logs_and_channels.printSchema()

root
 |-- LogServiceID: integer (nullable = true)
 |-- LogDate: timestamp (nullable = true)
 |-- AudienceTargetAgeID: integer (nullable = true)
 |-- AudienceTargetEthnicID: integer (nullable = true)
 |-- CategoryID: integer (nullable = true)
 |-- ClosedCaptionID: integer (nullable = true)
 |-- CountryOfOriginID: integer (nullable = true)
 |-- DubDramaCreditID: integer (nullable = true)
 |-- EthnicProgramID: integer (nullable = true)
 |-- ProductionSourceID: integer (nullable = true)
 |-- ProgramClassID: integer (nullable = true)
 |-- FilmClassificationID: integer (nullable = true)
 |-- ExhibitionID: integer (nullable = true)
 |-- Duration: string (nullable = true)
 |-- EndTime: string (nullable = true)
 |-- LogEntryDate: timestamp (nullable = true)
 |-- ProductionNO: string (nullable = true)
 |-- ProgramTitle: string (nullable = true)
 |-- StartTime: string (nullable = true)
 |-- Subtitle: string (nullable = true)
 |-- NetworkAffiliationID: integer (nullable = true)
 |-- SpecialAttenti

The second approach relies on the fact that PySpark-joined data frames remember
the origin of the columns. Because of this, we can refer to the `LogServiceID` columns using the same nomenclature as before 

In [24]:
logs_and_channels_verbose = logs.join(
    log_identifier, logs["LogServiceID"] == log_identifier["LogServiceID"]
)

logs_and_channels.drop(log_identifier["LogServiceID"]).select("LogServiceID")

DataFrame[LogServiceID: int]

The last approach is convenient if you use the Column object directly. PySpark will not resolve the origin name when you rely on F.col() to work with columns. To solve this in the most general way, we need to alias() our tables when performing the join.

In [25]:
logs_and_channels_verbose = logs.alias("left").join(
                        log_identifier.alias("right"),
                        logs["LogServiceID"] == log_identifier["LogServiceID"],
                        )

logs_and_channels_verbose.drop(F.col("right.LogServiceID")).select(
    "LogServiceID"
)

DataFrame[LogServiceID: int]

Now that the first join is done, we will link two additional tables to continue our data discovery and processing. The `CategoryID` table contains information about the types of programs, and the `ProgramClassID` table contains the data that allows us to pinpoint the commercials.

This time, we are performing left joins since we are not entirely certain about the existence of the keys in the link table.

In [26]:
cd_category = spark.read.csv(
    os.path.join(DIRECTORY, "ReferenceTables/CD_Category.csv"),
    sep="|",
    header=True,
    inferSchema=True,
).select(
    "CategoryID",
    "CategoryCD",
    F.col("EnglishDescription").alias("Category_Description"),
)

cd_program_class = spark.read.csv(
    os.path.join(DIRECTORY, "ReferenceTables/CD_ProgramClass.csv"),
    sep="|",
    header=True,
    inferSchema=True,
).select(
    "ProgramClassID",
    "ProgramClassCD",
    F.col("EnglishDescription").alias("ProgramClass_Description"),
)

full_log = logs_and_channels.join(cd_category, "CategoryID", how="left").join(
    cd_program_class, "ProgramClassID", how="left"
)


## Summarizing the table
### - by groupby and GroupedData

This section covers how to summarize a data frame into more granular dimensions
(versus the entire data frame) via the groupby() method.

Going back to the question: *what are the channels with the greatest and least proportion of commercials?* To answer this, we have to take each channel and sum the duration_seconds in two ways:
- One to get the number of seconds when the program is a commercial
- One to get the number of seconds of total programming

Since we are already acquainted with the basic syntax of `groupby()`, this section
starts by presenting a full code block that computes the total duration (in seconds) of the program class. In the next listing we perform the grouping, compute the aggregate function, and present the results in decreasing order.

In [27]:
(full_log
.groupby("ProgramClassCD", "ProgramClass_Description")
.agg(F.sum("duration_seconds").alias("duration_total"))
.orderBy("duration_total", ascending=False)
.show(20, False)
)

+--------------+--------------------------------------+--------------+
|ProgramClassCD|ProgramClass_Description              |duration_total|
+--------------+--------------------------------------+--------------+
|PGR           |PROGRAM                               |20992510      |
|COM           |COMMERCIAL MESSAGE                    |3519163       |
|PFS           |PROGRAM FIRST SEGMENT                 |1344762       |
|SEG           |SEGMENT OF A PROGRAM                  |1205998       |
|PRC           |PROMOTION OF UPCOMING CANADIAN PROGRAM|880600        |
|PGI           |PROGRAM INFOMERCIAL                   |679182        |
|PRO           |PROMOTION OF NON-CANADIAN PROGRAM     |335701        |
|OFF           |SCHEDULED OFF AIR TIME PERIOD         |142279        |
|ID            |NETWORK IDENTIFICATION MESSAGE        |74926         |
|NRN           |No recognized nationality             |59686         |
|MAG           |MAGAZINE PROGRAM                      |57622         |
|PSA  

#### Using `agg()` with custom column definitions

When grouping and aggregating columns in PySpark, we have all the power of the Column object at our fingertips. This means that we can group by and aggregate on
custom columns! For this section, we will start by building a definition of `duration_commercial`, which takes the duration of a program only if it is a commercial, and use this in our `agg()` statement to seamlessly compute both the total duration and the commercial duration.

In [28]:
F.when(
    F.trim(F.col("ProgramClassCD")).isin(
        ["COM", "PRC", "PGI", "PRO", "PSA", "MAG", "LOC", "SPO", "MER", "SOL"]
    ),
    F.col("duration_seconds"),
).otherwise(0)


Column<'CASE WHEN (trim(ProgramClassCD) IN (COM, PRC, PGI, PRO, PSA, MAG, LOC, SPO, MER, SOL)) THEN duration_seconds ELSE 0 END'>

I think that the best way to describe the code this time is to literally translate it into plain English.

***When*** *the field of the column ***ProgramClass***, trimmed of spaces at the beginning and end of the field, ***is in*** our list of commercial codes, then take the value of the field in the column `duration_seconds`. ***Otherwise***, use zero as a value.*


In [29]:
answer = (
    full_log.groupby("LogIdentifierID")
    .agg(
        F.sum(
            F.when(
                F.trim(F.col("ProgramClassCD")).isin(
                    ["COM", "PRC", "PGI", "PRO", "LOC", "SPO", "MER", "SOL"]
                ),
                F.col("duration_seconds"),
            ).otherwise(0)
        ).alias("duration_commercial"),
        F.sum("duration_seconds").alias("duration_total"),
    )
    .withColumn(
        "commercial_ratio", F.col(
            "duration_commercial") / F.col("duration_total")
    )
)

answer.orderBy("commercial_ratio", ascending=False).show(50, False)

+---------------+-------------------+--------------+-------------------+
|LogIdentifierID|duration_commercial|duration_total|commercial_ratio   |
+---------------+-------------------+--------------+-------------------+
|CIMT           |775                |775           |1.0                |
|TLNSP          |15480              |15480         |1.0                |
|TELENO         |17790              |17790         |1.0                |
|MSET           |2700               |2700          |1.0                |
|HPITV          |13                 |13            |1.0                |
|TANG           |8125               |8125          |1.0                |
|MMAX           |23333              |23582         |0.9894410991434145 |
|MPLU           |20587              |20912         |0.9844586840091814 |
|INVST          |20094              |20470         |0.9816316560820714 |
|ZT�L�          |21542              |21965         |0.9807420896881403 |
|RAPT           |17916              |18279         

### Drop and fill

Our first option is to plainly ignore the records that have null values. we'll see different ways to use the `dropna()` method to drop records based on the
presence of null values. `dropna()` is pretty easy to use. This data frame method takes three parameters:
- `how`, which can take the value `any` or `all`. If `any` is selected, PySpark will drop records where at least one of the fields is `null`. In the case of `all`, only the records where all fields are `null` will be removed. By default, PySpark will take the `any` mode.
- `thresh` takes an integer value. If set (its default is `None`), PySpark will ignore the `how` parameter and only drop the records with less than thresh `non-null` values.
- `subset` will take an optional list of columns that `dropna()` will use to make its decision.

In [30]:
answer_no_null = answer.dropna(subset=["commercial_ratio"])

answer_no_null.orderBy("commercial_ratio", ascending=False).show(50, False)

+---------------+-------------------+--------------+-------------------+
|LogIdentifierID|duration_commercial|duration_total|commercial_ratio   |
+---------------+-------------------+--------------+-------------------+
|CIMT           |775                |775           |1.0                |
|TLNSP          |15480              |15480         |1.0                |
|TELENO         |17790              |17790         |1.0                |
|MSET           |2700               |2700          |1.0                |
|HPITV          |13                 |13            |1.0                |
|TANG           |8125               |8125          |1.0                |
|MMAX           |23333              |23582         |0.9894410991434145 |
|MPLU           |20587              |20912         |0.9844586840091814 |
|INVST          |20094              |20470         |0.9816316560820714 |
|ZT�L�          |21542              |21965         |0.9807420896881403 |
|RAPT           |17916              |18279         

In [31]:
print(answer_no_null.count())

322


`fillna()` is even simpler than `dropna()`. This data frame method takes two
parameters:
- The value, which is a Python int, float, string, or bool. PySpark will only fill the compatible columns; for instance, if we were to fillna("zero"), our commercial_ratio, being a double, would not be filled.
- The same subset parameter we encountered in dropna(). We can limit the scope of our filling to only the columns we want.

In [32]:
answer_no_null = answer.fillna(0)
answer_no_null.orderBy(
"commercial_ratio", ascending=False).show(50, False)

+---------------+-------------------+--------------+-------------------+
|LogIdentifierID|duration_commercial|duration_total|commercial_ratio   |
+---------------+-------------------+--------------+-------------------+
|CIMT           |775                |775           |1.0                |
|TLNSP          |15480              |15480         |1.0                |
|TELENO         |17790              |17790         |1.0                |
|MSET           |2700               |2700          |1.0                |
|HPITV          |13                 |13            |1.0                |
|TANG           |8125               |8125          |1.0                |
|MMAX           |23333              |23582         |0.9894410991434145 |
|MPLU           |20587              |20912         |0.9844586840091814 |
|INVST          |20094              |20470         |0.9816316560820714 |
|ZT�L�          |21542              |21965         |0.9807420896881403 |
|RAPT           |17916              |18279         

In [33]:
print(answer_no_null.count()) 

324
