We chose to filter out rows with missing or invalid values for key fields.

This decision is based on guidance provided in the forum and the homework instructions, which state that if a row lacks the necessary information to determine viewing behavior or wealth characteristics, it can be safely disregarded.

# Section 1.1

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.functions import *
from pyspark.sql.functions import sum as spark_sum
from pyspark.sql.types import *
from functools import reduce

In [0]:
def load_csv_file(filename, schema):
  # Reads the relevant file from distributed file system using the given schema

  allowed_files = {'Daily program data': ('Daily program data', "|"),
                   'demographic': ('demographic', "|")}

  if filename not in allowed_files.keys():
    print(f'You were trying to access unknown file \"{filename}\". Only valid options are {allowed_files.keys()}')
    return None

  filepath = allowed_files[filename][0]
  dataPath = f"dbfs:/mnt/coursedata2024/fwm-stb-data/{filepath}"
  delimiter = allowed_files[filename][1]

  df = spark.read.format("csv")\
    .option("header","false")\
    .option("delimiter",delimiter)\
    .schema(schema)\
    .load(dataPath)
  return df
schemas_dict = {'Daily program data':
                  StructType([
                    StructField('prog_code', StringType()),
                    StructField('title', StringType()),
                    StructField('genre', StringType()),
                    StructField('air_date', StringType()),
                    StructField('air_time', StringType()),
                    StructField('Duration', FloatType())
                  ]),
                'viewing':
                  StructType([
                    StructField('device_id', StringType()),
                    StructField('event_date', StringType()),
                    StructField('event_time', IntegerType()),
                    StructField('mso_code', StringType()),
                    StructField('prog_code', StringType()),
                    StructField('station_num', StringType())
                  ]),
                'viewing_full':
                  StructType([
                    StructField('mso_code', StringType()),
                    StructField('device_id', StringType()),
                    StructField('event_date', IntegerType()),
                    StructField('event_time', IntegerType()),
                    StructField('station_num', StringType()),
                    StructField('prog_code', StringType())
                  ]),
                'demographic':
                  StructType([StructField('household_id',StringType()),
                    StructField('household_size',IntegerType()),
                    StructField('num_adults',IntegerType()),
                    StructField('num_generations',IntegerType()),
                    StructField('adult_range',StringType()),
                    StructField('marital_status',StringType()),
                    StructField('race_code',StringType()),
                    StructField('presence_children',StringType()),
                    StructField('num_children',IntegerType()),
                    StructField('age_children',StringType()), #format like range - 'bitwise'
                    StructField('age_range_children',StringType()),
                    StructField('dwelling_type',StringType()),
                    StructField('home_owner_status',StringType()),
                    StructField('length_residence',IntegerType()),
                    StructField('home_market_value',StringType()),
                    StructField('num_vehicles',IntegerType()),
                    StructField('vehicle_make',StringType()),
                    StructField('vehicle_model',StringType()),
                    StructField('vehicle_year',IntegerType()),
                    StructField('net_worth',IntegerType()),
                    StructField('income',StringType()),
                    StructField('gender_individual',StringType()),
                    StructField('age_individual',IntegerType()),
                    StructField('education_highest',StringType()),
                    StructField('occupation_highest',StringType()),
                    StructField('education_1',StringType()),
                    StructField('occupation_1',StringType()),
                    StructField('age_2',IntegerType()),
                    StructField('education_2',StringType()),
                    StructField('occupation_2',StringType()),
                    StructField('age_3',IntegerType()),
                    StructField('education_3',StringType()),
                    StructField('occupation_3',StringType()),
                    StructField('age_4',IntegerType()),
                    StructField('education_4',StringType()),
                    StructField('occupation_4',StringType()),
                    StructField('age_5',IntegerType()),
                    StructField('education_5',StringType()),
                    StructField('occupation_5',StringType()),
                    StructField('polit_party_regist',StringType()),
                    StructField('polit_party_input',StringType()),
                    StructField('household_clusters',StringType()),
                    StructField('insurance_groups',StringType()),
                    StructField('financial_groups',StringType()),
                    StructField('green_living',StringType())
                  ])
}

In [0]:
# Load datasets using pre-defined schema and helper
program_df = load_csv_file('Daily program data', schemas_dict['Daily program data'])
demographic_df = load_csv_file('demographic', schemas_dict['demographic'])

# Load reference data (parquet)
ref_data_schema = StructType([
    StructField('device_id', StringType()),
    StructField('dma', StringType()),
    StructField('dma_code', StringType()),
    StructField('household_id', IntegerType()),
    StructField('zipcode', IntegerType())
])
reference_df = spark.read.format('parquet') \
    .option("inferSchema", "true") \
    .load("dbfs:/FileStore/ddm/ref_data")

# Load viewing data (CSV with schema)
viewing_df = spark.read.format("csv") \
    .option("header", "true") \
    .option("delimiter", ",") \
    .schema(schemas_dict['viewing_full']) \
    .load("dbfs:/FileStore/ddm/10m_viewing")



In [0]:
program_df = program_df.dropDuplicates()
demographic_df = demographic_df.dropDuplicates()
reference_df = reference_df.dropDuplicates()
viewing_df = viewing_df.dropDuplicates()


program_df = program_df.cache()
demographic_df = demographic_df.cache()
reference_df = reference_df.cache()
viewing_df = viewing_df.cache()

In [0]:
# Cast key columns to correct types and compute new ones:
# - program_df: duration (int), air_date (date), air_time (int)
# - viewing_df: event_date (date), event_time (int)
# - reference_df: household_id as 8-digit string
# - demographic_df:
#     - income_num: maps income brackets (A–D) or integers to numerical scale
#     - num_adults, age_individual, age_2 cast to int

program_df = program_df.withColumn("duration", col("duration").cast("int")) \
                       .withColumn("air_date", to_date(col("air_date"), "yyyyMMdd")) \
                       .withColumn("air_time", col("air_time").cast("int"))

viewing_df = viewing_df.withColumn("event_date", to_date(col("event_date"), "yyyyMMdd")) \
                       .withColumn("event_time", col("event_time").cast("int"))


reference_df = reference_df.withColumn( "household_id",
    lpad(col("household_id").cast("string"), 8, "0")  # Pad with zeros to 8 characters
)


demographic_df = demographic_df.withColumn("num_adults", col("num_adults").cast("int")) \
    .withColumn("income", 
                when(col("income").rlike("^[A-D]$"), ascii(substring("income", 1, 1)) - ascii(lit("A")) + 10)
                .otherwise(col("income").cast("int"))
    ) \
    .withColumn("age_individual", col("age_individual").cast("int")) \
    .withColumn("age_2", col("age_2").cast("int"))

In [0]:
# 1. Trim program_df to required columns
program_df = program_df.select("prog_code", "title", "genre", "air_date", "air_time", "duration")

# 2. Trim demographic_df to columns needed for conditions 2, 3, and 5
demographic_df = demographic_df.select(
    "household_id", "num_adults", "age_individual", "age_2", "num_vehicles", "vehicle_make", "income"
)

# 3. Trim reference_df to device-household mapping
reference_df = reference_df.select("device_id", "household_id")

# 4. Trim viewing_df to viewing info only
viewing_df = viewing_df.select("device_id", "prog_code", "event_date", "event_time")

In [0]:
# ---- Program-level conditions ----
avg_duration = program_df.select(avg("duration")).first()[0]

program_df = program_df.withColumn("cond_1", col("duration") > avg_duration)

program_df = program_df.withColumn(
    "cond_4",
    (dayofmonth(col("air_date")) == 13) & (dayofweek(col("air_date")) == 6)
)

program_df = program_df.withColumn("cond_6",
    array_contains(split(col("genre"), ",\\s*"), "Collectibles") |
    array_contains(split(col("genre"), ",\\s*"), "Art") |
    array_contains(split(col("genre"), ",\\s*"), "Snowmobile") |
    array_contains(split(col("genre"), ",\\s*"), "Public affairs") |
    array_contains(split(col("genre"), ",\\s*"), "Animated") |
    array_contains(split(col("genre"), ",\\s*"), "Music")
)

# Title match on >=2 of: 'better', 'girls', 'the', 'call'
title_words = ["better", "girls", "the", "call"]
program_df = program_df.withColumn("cond_7", 
    (size(
        expr(f"""filter(array({','.join([f"lower(title) like '%{w}%'" for w in title_words])}), x -> x)""")
    ) >= 2)
)


In [0]:
# ---- Demographic-based conditions via viewings ----

device_counts_df = reference_df.groupBy("household_id") \
    .agg(countDistinct("device_id").alias("device_count"))

demographic_df = demographic_df.join(device_counts_df, on="household_id", how="left")

# 1. Compute conditions in demographic_df
mean_income = demographic_df.select(avg("income")).first()[0]

demographic_df = demographic_df.withColumn("cond_2", col("vehicle_make") == "91") \
    .withColumn("cond_3", (col("num_adults") == 2) & (abs(col("age_individual") - col("age_2")) <= 6)) \
    .withColumn("cond_5", (col("device_count") > 3) & (col("income") < mean_income))

# 2. Join device conditions
device_conditions_df = reference_df.join(demographic_df, on="household_id", how="inner") \
                                   .select("device_id", "cond_2", "cond_3", "cond_5")

# 3. Link viewings to program codes and devices
viewing_conditions_df = viewing_df.select("prog_code", "device_id").distinct() \
                                  .join(device_conditions_df, on="device_id", how="inner")

# 4. Aggregate to get whether any viewing of a program satisfied each condition
from pyspark.sql.functions import max as spark_max

prog_demo_conditions_df = viewing_conditions_df.groupBy("prog_code").agg(
    spark_max(col("cond_2").cast("int")).alias("cond_2"),
    spark_max(col("cond_3").cast("int")).alias("cond_3"),
    spark_max(col("cond_5").cast("int")).alias("cond_5")
)

# 5. Join onto program_df
program_df = program_df.join(prog_demo_conditions_df, on="prog_code", how="left").fillna(0)

# Section 1.2

In [0]:
all_cond_cols = [f"cond_{i}" for i in range(1, 8)]

program_df = program_df.withColumn(
    "cond_count",
    reduce(lambda a, b: a + b, (col(c).cast("int") for c in all_cond_cols))
).drop(*all_cond_cols)

In [0]:
# Step 1: Flag program airings that meet at least 4 of the 7 conditions
program_df = program_df.withColumn("is_malicious", col("cond_count") >= 4)

# Step 2: Count total and malicious records per title
malicious_stats_df = program_df.groupBy("title").agg(
    count("*").alias("total_records"),
    sum(when(col("is_malicious"), 1).otherwise(0)).alias("malicious_records")
)

# Step 3: Compute malicious percentage
malicious_stats_df = malicious_stats_df.withColumn(
    "malicious_ratio", col("malicious_records") / col("total_records")
)

# Step 4: Filter to titles with >40% malicious and get top 20 by malicious percentage
top_malicious_titles = (
    malicious_stats_df
    .filter(col("malicious_ratio") > 0.4)
    .orderBy(col("malicious_ratio").desc())
).select("title", "malicious_ratio")
top_malicious_titles.cache()

DataFrame[title: string, malicious_ratio: double]

In [0]:
display(top_malicious_titles.limit(20))

title,malicious_ratio
Poseidon,1.0
Notes From the Heart Healer,1.0
Local 24 News Good Day 6am,1.0
Concorde: Flying Supersonic,1.0
12 O'Clock Boys,1.0
Good Day Atlanta 7:00am,1.0
Failure to Launch,1.0
Songs That Don't Suck,1.0
MC Alternative,1.0
Jamie Marks Is Dead,1.0
