<a href="https://colab.research.google.com/github/kangj12/testrepository/blob/main/MIE524_A1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MIE524 - Assignment 1


## Setup

Let's set up Spark on your Colab environment.  Run the cell below!

In [1]:
!pip install pyspark
!pip install -U -q PyDrive
!apt install openjdk-8-jdk-headless -qq
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"

Collecting pyspark
  Downloading pyspark-3.5.3.tar.gz (317.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m317.3/317.3 MB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.5.3-py2.py3-none-any.whl size=317840625 sha256=7e06a11fae49c7f0e2114a284cfc462924d719fa16db1a47f6603d8090b96c8d
  Stored in directory: /root/.cache/pip/wheels/1b/3a/92/28b93e2fbfdbb07509ca4d6f50c5e407f48dce4ddbda69a4ab
Successfully built pyspark
Installing collected packages: pyspark
Successfully installed pyspark-3.5.3
The following additional packages will be installed:
  libxtst6 openjdk-8-jre-headless
Suggested packages:
  openjdk-8-demo openjdk-8-source libnss-mdns fonts-dejavu-extra fonts-nanum fonts-ipafont-gothic
  fonts-ipafont-mincho fonts-wqy-microhei fonts-wqy-zenhei fonts-indic

Now we authenticate a Google Drive client to download the file we will be processing in our Spark job.

**Make sure to follow the interactive instructions.**

In [2]:
import pyspark
from pyspark.sql import *
from pyspark.sql.functions import *
from pyspark import SparkContext, SparkConf
import pandas as pd

# create the Spark Session
spark = SparkSession.builder.getOrCreate()

# create the Spark Context
sc = spark.sparkContext

Put all your imports, and path constants in the next cells.

## Q1 - Message Count in Spark

### Load the dataset

In [3]:
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authenticate and create the PyDrive client
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)



In [4]:
# first make sure to upload the werkmap_messages.csv file from the data folder of your starter repository to colab's files
werkmap_message_data = 'werkmap_messages.csv'

# load csv as Spark Dataframe and show result
df = spark.read.options(delimiter=";", header=True).csv(werkmap_message_data)
df.show()

+----------+-----------+------+--------+--------+--------------------+---------------+-----------------+
|CustomerID|AgeCategory|Gender|Office_U|Office_W|       EventDateTime|      EventType|HandlingChannelID|
+----------+-----------+------+--------+--------+--------------------+---------------+-----------------+
|   2032131|      18-29|     M|     271|     271|2015-11-02 01:36:...|Werkmap message|                1|
|   2032131|      18-29|     M|     271|     271|2015-11-05 21:35:...|Werkmap message|                1|
|   2032131|      18-29|     M|     271|     271|2015-11-06 15:47:...|Werkmap message|                1|
|   2085395|      18-29|     V|     280|     280|2015-10-20 23:44:...|Werkmap message|                2|
|   2085395|      18-29|     V|     280|     280|2016-01-20 23:04:...|Werkmap message|                2|
|   2090314|      50-65|     V|     238|     238|2015-12-24 12:12:...|Werkmap message|                2|
|   2088317|      18-29|     V|     269|     269|2015-1

In [5]:
# Print the data schema of the Dataframe
df.printSchema()

root
 |-- CustomerID: string (nullable = true)
 |-- AgeCategory: string (nullable = true)
 |-- Gender: string (nullable = true)
 |-- Office_U: string (nullable = true)
 |-- Office_W: string (nullable = true)
 |-- EventDateTime: string (nullable = true)
 |-- EventType: string (nullable = true)
 |-- HandlingChannelID: string (nullable = true)



In [6]:
# Convert EventDateTime from String to Timestamp
df = df.withColumn("EventDateTime", to_timestamp(col("EventDateTime"), "yyyy-MM-dd HH:mm:ss.SSSSSSS"))

df = df.withColumn("EventDateTime", date_format(col("EventDateTime"), "yyyy-MM"))
df.show(truncate=False)

+----------+-----------+------+--------+--------+-------------+---------------+-----------------+
|CustomerID|AgeCategory|Gender|Office_U|Office_W|EventDateTime|EventType      |HandlingChannelID|
+----------+-----------+------+--------+--------+-------------+---------------+-----------------+
|2032131   |18-29      |M     |271     |271     |2015-11      |Werkmap message|1                |
|2032131   |18-29      |M     |271     |271     |2015-11      |Werkmap message|1                |
|2032131   |18-29      |M     |271     |271     |2015-11      |Werkmap message|1                |
|2085395   |18-29      |V     |280     |280     |2015-10      |Werkmap message|2                |
|2085395   |18-29      |V     |280     |280     |2016-01      |Werkmap message|2                |
|2090314   |50-65      |V     |238     |238     |2015-12      |Werkmap message|2                |
|2088317   |18-29      |V     |269     |269     |2015-11      |Werkmap message|1                |
|2088317   |18-29   

In [7]:
# SQL Session
df.createOrReplaceTempView("werkmap_messages")

query = """
SELECT EventDateTime, COUNT(*) as total_messages
from werkmap_messages
group by EventDateTime
order by EventDateTime ASC
"""

message_counts = spark.sql(query)
message_counts.show()

+-------------+--------------+
|EventDateTime|total_messages|
+-------------+--------------+
|      2015-07|          2933|
|      2015-08|          3954|
|      2015-09|          7079|
|      2015-10|          8483|
|      2015-11|          9695|
|      2015-12|          9572|
|      2016-01|         12582|
|      2016-02|         11760|
+-------------+--------------+



In [8]:
message_counts = message_counts.toPandas()
message_counts.to_csv("Q1.txt", index=False)

If you executed the cells above, you should be able to see the file *plotsummaries.txt* under the "Files" tab on the left panel.

### Write your function in the next cells

In [9]:
def get_month_year(row):
    """
    INPUT:
        row : a row of the input data
    OUTPUT:
        month_year : string
    """

    # YOUR CODE HERE

# You may have additional functions or modify the provided functions as necessary

## Run your function in the next cells to output required content.

## PART 2 - Oxford Covid-19 Government Response Tracker

### Load the dataset

In [10]:
#id='1J_2ido9_-LiasNi8xlzk5-DeHu-2_4Zs'
id='1J_2ido9_-LiasNi8xlzk5-DeHu-2_4Zs'
downloaded = drive.CreateFile({'id': id})
downloaded.GetContentFile('OxCGRT_USA_latest.csv')

### Q2 - Computing Index Score with Spark

In [11]:
indicators = ["C1M_School closing",
"C2M_Workplace closing",
"C3M_Cancel public events",
"C4M_Restrictions on gatherings",
"C5M_Close public transport",
"C6M_Stay at home requirements",
"C7M_Restrictions on internal movement",
"C8EV_International travel controls",
"E1_Income support",
"E2_Debt/contract relief",
"H1_Public information campaigns",
"H2_Testing policy",
"H3_Contact tracing",
"H6M_Facial Coverings",
"H7_Vaccination policy",
"H8M_Protection of elderly people"]

def clean_data(df):
    """
    INPUT:
        df: spark dataframe
    OUTPUT:
        cleaned data: spark dataframe

    NOTE: output the given word with characters stripped.
    """
    # YOUR CODE HERE


def impute_data(df):
    """
    INPUT:
        df: spark dataframe
    OUTPUT:
        imputed data: spark dataframe

    NOTE: output the dataframe with nan values replaced with the minimal value of the given indicator.
    """
    for indicator in indicators:
      df.groupBy(indicator).count().show()



def group_and_aggregate_data(df):
    """
    INPUT:
        df: spark dataframe
    OUTPUT:
        groupe and aggregated data: spark dataframe

    NOTE: output the dataframe with grouped (by month) and aggregated (based on the algorithm) data.
    """
    # YOUR CODE HERE

def compute_index_score(df):
    """
    INPUT:
        df: spark dataframe
    OUTPUT:
        list of index scores per region and period: list

    NOTE: output a list of computed scores per region and period based on the algorithm.
    """
    # YOUR CODE HERE

# You may have additional functions

In [12]:
# Start new spark session
spark.stop()

# create the session
conf = SparkConf().set("spark.ui.port", "4050")

# create the context
sc = pyspark.SparkContext(conf=conf)
spark = SparkSession.builder.getOrCreate()

In [13]:
# Load Dataset
OxCGRT = spark.read.option("header", True).csv("OxCGRT_USA_latest.csv")
OxCGRT.show()

# C1M = Indicator, C1M_Flag = Flag for C1M Indicator, etc....

+-------------+-----------+----------+----------+------------+--------+------------------+--------+---------------------+--------+------------------------+--------+------------------------------+--------+--------------------------+--------+-----------------------------+--------+-------------------------------------+--------+----------------------------------+-----------------+-------+-----------------------+------------------+------------------------+-------------------------------+-------+-----------------+------------------+-------------------------------------+-------------------------+--------------------+--------+---------------------+-------+--------------------------------+--------+-----------+-----------------------------------+----------------------------------+-------------------------------------------------------------------------------+--------------------------------------------------------------------+--------------------------------------------------+-------------+-----

In [14]:
OxCGRT.printSchema()
# C1M_School closing = v_jt
# Max of all C1M_school = Nj
# f_jt = recorded binary flag for indicator j in period t
# v_jt? t = period t, j = indicator j --> v_jt = recorded policy value on ordinal scale

root
 |-- CountryName: string (nullable = true)
 |-- CountryCode: string (nullable = true)
 |-- RegionName: string (nullable = true)
 |-- RegionCode: string (nullable = true)
 |-- Jurisdiction: string (nullable = true)
 |-- Date: string (nullable = true)
 |-- C1M_School closing: string (nullable = true)
 |-- C1M_Flag: string (nullable = true)
 |-- C2M_Workplace closing: string (nullable = true)
 |-- C2M_Flag: string (nullable = true)
 |-- C3M_Cancel public events: string (nullable = true)
 |-- C3M_Flag: string (nullable = true)
 |-- C4M_Restrictions on gatherings: string (nullable = true)
 |-- C4M_Flag: string (nullable = true)
 |-- C5M_Close public transport: string (nullable = true)
 |-- C5M_Flag: string (nullable = true)
 |-- C6M_Stay at home requirements: string (nullable = true)
 |-- C6M_Flag: string (nullable = true)
 |-- C7M_Restrictions on internal movement: string (nullable = true)
 |-- C7M_Flag: string (nullable = true)
 |-- C8EV_International travel controls: string (nullabl

In [15]:
# Only need Month + Year for Date column + filter out dates not in 2020 or 2021
OxCGRT = OxCGRT.withColumn("Date", substring(col("Date"), 1, 6))
OxCGRT = OxCGRT.filter((col("Date").startswith('2020')) | col("Date").startswith('2021'))

In [16]:
# Check for nan values in indicator columns
OxCGRT.select([count(when(isnan(c), c)).alias(c) for c in indicators]).show()

+------------------+---------------------+------------------------+------------------------------+--------------------------+-----------------------------+-------------------------------------+----------------------------------+-----------------+-----------------------+-------------------------------+-----------------+------------------+--------------------+---------------------+--------------------------------+
|C1M_School closing|C2M_Workplace closing|C3M_Cancel public events|C4M_Restrictions on gatherings|C5M_Close public transport|C6M_Stay at home requirements|C7M_Restrictions on internal movement|C8EV_International travel controls|E1_Income support|E2_Debt/contract relief|H1_Public information campaigns|H2_Testing policy|H3_Contact tracing|H6M_Facial Coverings|H7_Vaccination policy|H8M_Protection of elderly people|
+------------------+---------------------+------------------------+------------------------------+--------------------------+-----------------------------+-------------

In [17]:
# Filter out Null values for RegionName
OxCGRT = OxCGRT.where(col("RegionName").isNotNull())
OxCGRT.show()

+-------------+-----------+----------+----------+------------+------+------------------+--------+---------------------+--------+------------------------+--------+------------------------------+--------+--------------------------+--------+-----------------------------+--------+-------------------------------------+--------+----------------------------------+-----------------+-------+-----------------------+------------------+------------------------+-------------------------------+-------+-----------------+------------------+-------------------------------------+-------------------------+--------------------+--------+---------------------+-------+--------------------------------+--------+-----------+-----------------------------------+----------------------------------+-------------------------------------------------------------------------------+--------------------------------------------------------------------+--------------------------------------------------+-------------+-------

In [18]:
# Step 1: Find mode for each indicator j for each time period t (v j_t)
all_column_expr = {x: "mode" for x in indicators}
indicator_time_mode_df = OxCGRT.groupBy("RegionName", "Date").agg(all_column_expr)
indicator_time_mode_df.show()


+----------+------+------------------------------+-----------------------------+------------------------+-------------------------------------------+--------------------------------------+------------------------+--------------------------+----------------------------------------+-----------------------+--------------------------------+---------------------------+---------------------------+-----------------------------------+-----------------------+-------------------------------------+------------------------------------+
|RegionName|  Date|mode(C3M_Cancel public events)|mode(E2_Debt/contract relief)|mode(H3_Contact tracing)|mode(C7M_Restrictions on internal movement)|mode(H8M_Protection of elderly people)|mode(C1M_School closing)|mode(H6M_Facial Coverings)|mode(C8EV_International travel controls)|mode(H2_Testing policy)|mode(C5M_Close public transport)|mode(H7_Vaccination policy)|mode(C2M_Workplace closing)|mode(C6M_Stay at home requirements)|mode(E1_Income support)|mode(H1_Public i

In [34]:
# Step 1: Find mode for each indicator j for each time period t (v j_t), need to determine tiebreaker by bigger value
# ----ENTER CODE HERE -------
def get_mode_for_indicator(df, indicator):
  df_grouped = df.groupBy("RegionName", "Date", indicator).agg(count(f"{indicator}").alias(f"{indicator}_count"))
  window = Window.partitionBy("RegionName", "Date").orderBy(desc(f"{indicator}_count"), desc(f"{indicator}"))
  df_grouped = df_grouped.withColumn("rank", row_number().over(window))
  df_grouped = df_grouped.where("rank == 1")
  df_grouped = df_grouped.select("RegionName", "Date", indicator)
  return df_grouped

full_df = OxCGRT
# Create an empty schema
columns = StructType([])

# Create an empty dataframe with empty schema
df_indicator = spark.createDataFrame(data = [],
                           schema = columns)
for i, colm in enumerate(indicators):
  temp_df = get_mode_for_indicator(full_df, colm)
  if i == 0:
    df_indicator = temp_df
  else:
    df_indicator = df_indicator.join(temp_df, ["RegionName", "Date"])
df_indicator.show()


+----------+------+------------------+---------------------+------------------------+------------------------------+--------------------------+-----------------------------+-------------------------------------+----------------------------------+-----------------+-----------------------+-------------------------------+-----------------+------------------+--------------------+---------------------+--------------------------------+
|RegionName|  Date|C1M_School closing|C2M_Workplace closing|C3M_Cancel public events|C4M_Restrictions on gatherings|C5M_Close public transport|C6M_Stay at home requirements|C7M_Restrictions on internal movement|C8EV_International travel controls|E1_Income support|E2_Debt/contract relief|H1_Public information campaigns|H2_Testing policy|H3_Contact tracing|H6M_Facial Coverings|H7_Vaccination policy|H8M_Protection of elderly people|
+----------+------+------------------+---------------------+------------------------+------------------------------+----------------

In [20]:
df_test.select("Date").distinct().show(50)

+------+
|  Date|
+------+
|202110|
|202006|
|202009|
|202101|
|202011|
|202007|
|202111|
|202108|
|202109|
|202005|
|202012|
|202004|
|202102|
|202105|
|202103|
|202107|
|202003|
|202112|
|202010|
|202001|
|202008|
|202002|
|202104|
|202106|
+------+



In [43]:
# Step 2: Find mode for indicator flags (f_jt)
def get_mode(df, indicator):
  df_grouped = df.groupBy("RegionName", "Date", indicator).agg(count(f"{indicator}").alias(f"{indicator}_count"))
  window = Window.partitionBy("RegionName", "Date").orderBy(desc(f"{indicator}_count"), desc(f"{indicator}"))
  df_grouped = df_grouped.withColumn("rank", row_number().over(window))
  df_grouped = df_grouped.where("rank == 1")
  df_grouped = df_grouped.select("RegionName", "Date", indicator)
  return df_grouped

full_df = OxCGRT
# Create an empty schema
columns = StructType([])

# Create an empty dataframe with empty schema
df_flags = spark.createDataFrame(data = [],
                           schema = columns)
for i, colm in enumerate(indicator_flags):
  test_df = get_mode(full_df, colm)
  if i == 0:
    df_flags = test_df
  else:
    df_flags = df_flags.join(test_df, ["RegionName", "Date"])

# add in null flag colummns to flag spark dataframe
null_flags = ['C8EV_Flag', 'E2_Flag', 'H2_Flag', 'H3_Flag']
for f in null_flags:
  df_flags = df_flags.withColumn(f, lit(None))
df_flags.show()

+----------+------+--------+--------+--------+--------+--------+--------+--------+-------+-------+--------+-------+--------+---------+-------+-------+-------+
|RegionName|  Date|C1M_Flag|C2M_Flag|C3M_Flag|C4M_Flag|C5M_Flag|C6M_Flag|C7M_Flag|E1_Flag|H1_Flag|H6M_Flag|H7_Flag|H8M_Flag|C8EV_Flag|E2_Flag|H2_Flag|H3_Flag|
+----------+------+--------+--------+--------+--------+--------+--------+--------+-------+-------+--------+-------+--------+---------+-------+-------+-------+
|   Alabama|202001|    NULL|    NULL|    NULL|    NULL|    NULL|    NULL|    NULL|   NULL|   NULL|    NULL|   NULL|    NULL|     NULL|   NULL|   NULL|   NULL|
|   Alabama|202002|    NULL|    NULL|    NULL|    NULL|    NULL|    NULL|    NULL|   NULL|      1|    NULL|   NULL|    NULL|     NULL|   NULL|   NULL|   NULL|
|   Alabama|202003|       1|       0|       0|       0|       0|       0|       1|      0|      0|    NULL|   NULL|       1|     NULL|   NULL|   NULL|   NULL|
|   Alabama|202004|       1|       1|       1|

In [52]:
# Join two dataframes together
merged_df = df_indicator.join(df_flags, ["RegionName", "Date"])
merged_df.show()

+----------+------+------------------+---------------------+------------------------+------------------------------+--------------------------+-----------------------------+-------------------------------------+----------------------------------+-----------------+-----------------------+-------------------------------+-----------------+------------------+--------------------+---------------------+--------------------------------+--------+--------+--------+--------+--------+--------+--------+-------+-------+--------+-------+--------+---------+-------+-------+-------+
|RegionName|  Date|C1M_School closing|C2M_Workplace closing|C3M_Cancel public events|C4M_Restrictions on gatherings|C5M_Close public transport|C6M_Stay at home requirements|C7M_Restrictions on internal movement|C8EV_International travel controls|E1_Income support|E2_Debt/contract relief|H1_Public information campaigns|H2_Testing policy|H3_Contact tracing|H6M_Facial Coverings|H7_Vaccination policy|H8M_Protection of elderly p

In [115]:
# calculate index scores each regionname-date combination
for i in range(len(indicators)):
  indicator, flag = indicators[i], full_indicator_flags[i]
  F = flag_dict[flag]
  N = ind_dict[indicator]

  if i == 0:
    score_df = merged_df.withColumn(
    f'score{indicator}',
    when(merged_df[indicator] == 0, 0)
    .otherwise(100 * (merged_df[indicator] - 0.5 * (F - merged_df[flag])) / N)
    )
  else:
    score_df = score_df.withColumn(
    f'score_{indicator}',
    when(score_df[indicator] == 0, 0)
    .otherwise(100 * (score_df[indicator] - 0.5 * (F - coalesce(score_df[flag], lit(0)))) / N)
    )

# Show the updated dataframe with the score columns
score_df.show()

+----------+------+------------------+---------------------+------------------------+------------------------------+--------------------------+-----------------------------+-------------------------------------+----------------------------------+-----------------+-----------------------+-------------------------------+-----------------+------------------+--------------------+---------------------+--------------------------------+--------+--------+--------+--------+--------+--------+--------+-------+-------+--------+-------+--------+---------+-------+-------+-------+-----------------------+---------------------------+------------------------------+------------------------------------+--------------------------------+-----------------------------------+-------------------------------------------+----------------------------------------+-----------------------+-----------------------------+-------------------------------------+-----------------------+------------------------+-------------

In [143]:
# Calculate final index score thru SQL
score_df.createOrReplaceTempView("final_score")
#query = """
#select RegionName, Date, (cast("scoreC1M_School closing" as float) + cast("score_C2M_Workplace closing" as float)+ cast("score_C3M_Cancel public events" as float) +
#cast("score_C4M_Restrictions on gatherings" as float) + cast("score_C5M_Close public transport" as float)+ cast("score_C6M_Stay at home requirements" as float) +
#cast("score_C7M_Restrictions on internal movement" as float) + cast("score_C8EV_International travel controls" as float) + cast("score_E1_Income support" as float) +
#cast("score_E2_Debt/contract relief" as float) + cast("score_H1_Public information campaigns" as float) + cast("score_H2_Testing policy" as float) + cast("score_H3_Contact tracing" as float) +
#cast("score_H6M_Facial Coverings" as float) + cast("score_H7_Vaccination policy" as float) + cast("score_H8M_Protection of elderly people" as float)) / 16.0 as GovernmentResponseIndex
#from final_score
#"""

query = """
select RegionName, Date, round((`scoreC1M_school closing` + `score_C2M_Workplace closing` + `score_C3M_Cancel public events` + `score_C4M_Restrictions on gatherings` +
`score_C5M_Close public transport` + `score_C6M_Stay at home requirements` + `score_C7M_Restrictions on internal movement` + `score_C8EV_International travel controls` +
`score_E1_Income support` + `score_E2_Debt/contract relief` + `score_H1_Public information campaigns` + `score_H2_Testing policy` +
`score_H3_Contact tracing` + `score_H6M_Facial Coverings` + `score_H7_Vaccination policy` + `score_H8M_Protection of elderly people`) / 16, 2) as GovernmentResponseIndex
from final_score
"""
full_df = spark.sql(query)
full_df.show(20)
#full_score_df = score_df.withColumn('Test', ((score_df['scoreC1M_School closing'] + score_df['score_C2M_workplace closing']) / 2))

#full_score_df = score_df.withColumn('GovernmentResponseIndex', )

#df.select(colsToSum.map(col).reduce((c1, c2) => c1 + c2) as "sum")

#full_score_df.show()

+----------+------+-----------------------+
|RegionName|  Date|GovernmentResponseIndex|
+----------+------+-----------------------+
|   Alabama|202001|                    0.0|
|   Alabama|202002|                    9.9|
|   Alabama|202003|                   40.1|
|   Alabama|202004|                  67.71|
|   Alabama|202005|                  60.68|
|   Alabama|202006|                  59.64|
|   Alabama|202007|                  56.25|
|   Alabama|202008|                  46.35|
|   Alabama|202009|                   52.6|
|   Alabama|202010|                   47.4|
|   Alabama|202011|                   47.4|
|   Alabama|202012|                   47.4|
|   Alabama|202101|                   49.9|
|   Alabama|202102|                  47.81|
|   Alabama|202103|                  47.81|
|   Alabama|202104|                  48.18|
|   Alabama|202105|                  46.09|
|   Alabama|202106|                  42.71|
|   Alabama|202107|                  40.63|
|   Alabama|202108|             

In [144]:
# Save Q2 to text file
q2_df = full_df.toPandas()
q2_df.to_csv("Q2.txt", index=False)

In [49]:
# Load in Fj and Nj
flag_values = [1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1]
ind_values = [3, 3, 2, 4, 2, 3, 2, 4, 2, 2, 2, 3, 2, 4, 5, 3]
flag_dict = dict(zip(full_indicator_flags, flag_values))
ind_dict = dict(zip(indicators, ind_values))
print(flag_dict, ind_dict)

{'C1M_Flag': 1, 'C2M_Flag': 1, 'C3M_Flag': 1, 'C4M_Flag': 1, 'C5M_Flag': 1, 'C6M_Flag': 1, 'C7M_Flag': 1, 'C8EV_Flag': 0, 'E1_Flag': 1, 'E2_Flag': 0, 'H1_Flag': 1, 'H2_Flag': 0, 'H3_Flag': 0, 'H6M_Flag': 1, 'H7_Flag': 1, 'H8M_Flag': 1} {'C1M_School closing': 3, 'C2M_Workplace closing': 3, 'C3M_Cancel public events': 2, 'C4M_Restrictions on gatherings': 4, 'C5M_Close public transport': 2, 'C6M_Stay at home requirements': 3, 'C7M_Restrictions on internal movement': 2, 'C8EV_International travel controls': 4, 'E1_Income support': 2, 'E2_Debt/contract relief': 2, 'H1_Public information campaigns': 2, 'H2_Testing policy': 3, 'H3_Contact tracing': 2, 'H6M_Facial Coverings': 4, 'H7_Vaccination policy': 5, 'H8M_Protection of elderly people': 3}


In [48]:
# You may use below lists to get indicators and flags header
indicators = ["C1M_School closing",
"C2M_Workplace closing",
"C3M_Cancel public events",
"C4M_Restrictions on gatherings",
"C5M_Close public transport",
"C6M_Stay at home requirements",
"C7M_Restrictions on internal movement",
"C8EV_International travel controls",
"E1_Income support",
"E2_Debt/contract relief",
"H1_Public information campaigns",
"H2_Testing policy",
"H3_Contact tracing",
"H6M_Facial Coverings",
"H7_Vaccination policy",
"H8M_Protection of elderly people"]

indicator_flags = ["C1M_Flag",
"C2M_Flag",
"C3M_Flag",
"C4M_Flag",
"C5M_Flag",
"C6M_Flag",
"C7M_Flag",
"E1_Flag",
"H1_Flag",
"H6M_Flag",
"H7_Flag",
"H8M_Flag"]

full_indicator_flags = ["C1M_Flag",
"C2M_Flag",
"C3M_Flag",
"C4M_Flag",
"C5M_Flag",
"C6M_Flag",
"C7M_Flag",
"C8EV_Flag",
"E1_Flag",
"E2_Flag",
"H1_Flag",
"H2_Flag",
"H3_Flag",
"H6M_Flag",
"H7_Flag",
"H8M_Flag"]

In [None]:
# Group and Aggregate Data based on Date # Using SQL
OxCGRT.show()

Run your function in the next cells to output required content.

### Q3 - Association Rules

In [None]:
def transform_to_items(df):
    """
      INPUT:
          df: spark dataframe
      OUTPUT:
          list itemsets: list

      NOTE: output a list itemsets from given dataframe.
      """
      # YOUR CODE HERE

In [None]:
def apriori(items, min_sup, itemset_size):
    """
    INPUT:
        items: list
        min_sup: the min support
    OUTPUT:
        list of frequent itemsets: list

    NOTE: output a list of frequent itemsets.
    """
    # YOUR CODE HERE

# You may have additional functions

Run your function in the next cells to output required content.