# MIE524 - Lab 1 - Spark Warm-up


## Setup

In [None]:
!pip install pyspark
!pip install -U -q PyDrive
!apt install openjdk-8-jdk-headless -qq
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"

In [None]:
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authenticate and create the PyDrive client
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

# Word Count with Spark
We will work with the *pg100.txt* file which contains a copy of the complete works of Shakespeare.

In [None]:
id='1FV9oO0opIaww85HGR0Oe_mv6FSOmzVZ6'
downloaded = drive.CreateFile({'id': id})
downloaded.GetContentFile('pg100.txt')

In [None]:
# Let's import the libraries we will need
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import pyspark
from pyspark.sql import *
from pyspark.sql.functions import *
from pyspark import SparkContext, SparkConf

In [None]:
# create the session
conf = SparkConf().set("spark.ui.port", "4050")

# create the context
sc = pyspark.SparkContext(conf=conf)
spark = SparkSession.builder.getOrCreate()

In [None]:
spark

Load the dataset:

In [None]:
# txt = spark.read.text("pg100.txt")
rdd = spark.sparkContext.textFile("pg100.txt")

In [None]:
rdd.take(10)

### Remove punctuation and transform all words to lower case using **map ()**

In [None]:
def lower_str(x):
  lowercase_str = x.lower()
  return lowercase_str

rdd = rdd.map(lower_str)

def strip_punc(x):
  punc = '!"#$%&\'()*+,.:;<=>?@[\\]^_`{|}-~'
  for c in punc:
    x_clean = x.replace(c, '')
  return x_clean

rdd = rdd.map(strip_punc)

In [None]:
rdd.take(10)

### Split sentences in words using **flatMap()**

In [None]:
rdd = rdd.flatMap(lambda line: line.split(" "))
rdd.take(10)

### Exclude whitespaces using **filter()**

In [None]:
rdd = rdd.filter(lambda x:x!='')
rdd.take(10)

### Count how many times each word occurs using **reduceByKey()**

In [None]:
# initialize (key,val) pair RDD
rdd_count = rdd.map(lambda word:(word,1))
rdd_count.take(10)

In [None]:
rdd_count_rbk = rdd_count.reduceByKey(lambda x,y:(x+y)).sortByKey()
rdd_count_rbk.take(10)

### Rank by frequency of occurence

In [None]:
# switch (key,val) pairs as (val,key)
rdd_count_rbk = rdd_count_rbk.map(lambda x:(x[1],x[0]))
rdd_count_rbk.take(10)

In [None]:
rdd_count_rbk.sortByKey(False).take(10)

In [None]:
rdd_count_rbk.saveAsTextFile('counts')

# Oxford Covid-19 Government Response Tracker

We will analyze the Oxford Covid-19 Government Response Tracker data available [here](https://github.com/OxCGRT/covid-policy-tracker/tree/master).


The Oxford Covid-19 Government Response Tracker (OxCGRT) collects systematic information on policy measures that governments have taken to tackle COVID-19.

The different policy responses are tracked since 1 January 2020, cover more than 180 countries and are coded into 23 indicators, such as school closures, travel restrictions, vaccination policy. These policies are recorded on a scale to reflect the extent of government action, and scores are aggregated into a suite of policy indices. The data can help decision-makers and citizens understand governmental responses in a consistent way, aiding efforts to fight the pandemic.

https://www.bsg.ox.ac.uk/research/covid-19-government-response-tracker


**OxCGRT** [Get the dataset here](https://drive.google.com/file/d/1ECXsyH6HtWjTa8VpHweFQs1ceq29niFo/view?usp=share_link)


In [None]:
id='1ECXsyH6HtWjTa8VpHweFQs1ceq29niFo'
downloaded = drive.CreateFile({'id': id})
downloaded.GetContentFile('OxCGRT_nat_latest.csv')

In [None]:
# create the session
conf = SparkConf().set("spark.ui.port", "4050")

# create the context
sc = pyspark.SparkContext(conf=conf)
spark = SparkSession.builder.getOrCreate()

Load the dataset:

In [None]:
OxCGRT_latest = spark.read.option("header", True).csv("OxCGRT_nat_latest.csv")


Check the schema:

In [None]:
OxCGRT_latest.printSchema()

Get a sample with take():

In [None]:
OxCGRT_latest.take(3)

Get a formatted sample with `show()`:

In [None]:
OxCGRT_latest.show()

In [None]:
print("In total there are {0} records".format(OxCGRT_latest.count()))

You can check the levels on each policy [here](https://github.com/OxCGRT/covid-policy-tracker/blob/master/documentation/codebook.md).

## Q1: Which are the top 20 countries that had their schools closed for the longest period of time?
C1M_School closing = 3

In [None]:
school_closing_counts = OxCGRT_latest.where("`C1M_School closing` = 3")\
                                    .groupBy("CountryName")\
                                    .agg(count("*").alias("C1M_School_closing_count"))\
                                    .sort(desc("C1M_School_closing_count"))
school_closing_counts.show()

In [None]:
OxCGRT_latest.createOrReplaceTempView("OxCGRT_latest")

query = """
SELECT CountryName, count(*) as C1M_School_closing_count
FROM OxCGRT_latest
WHERE `C1M_School closing` = 3
GROUP BY CountryName
ORDER BY C1M_School_closing_count DESC
"""

school_closing_counts = spark.sql(query)
school_closing_counts.show()

Move to Pandas

In [None]:
school_closing_counts_pd = school_closing_counts.toPandas()
school_closing_counts_pd.head()

In [None]:
pl = school_closing_counts_pd.head(20).plot(kind="bar",
                            x="CountryName", y="C1M_School_closing_count",
                            figsize=(10, 7), alpha=0.5, color="olive")
pl.set_xlabel("Country")
pl.set_ylabel("Number of Days Schools Closed")


## Q2: What are the total number of confirmed cases?
Find these values for CountryName IN [Canada, United States, India, United Kingdom, China, Iran, Brazil, Australia, and South Africa].




First, get a lits of country names.

In [None]:
confirmed_cases_countries = OxCGRT_latest.selectExpr(["CountryName", "to_date(Date,'yyyyMMdd') as Date", "ConfirmedCases", "GovernmentResponseIndex_Average"])\
                                        .where("CountryName IN ('Canada', 'United States', 'India', 'United Kingdom', 'China', 'Iran', 'Brazil', 'Australia', 'South Africa')")
confirmed_cases_countries.show()

In [None]:
confirmed_cases_daily = confirmed_cases_countries.where("CountryName IN ('Canada', 'United States', 'India', 'United Kingdom', 'China', 'Iran', 'Brazil', 'Australia', 'South Africa')")\
                                    .toPandas()
confirmed_cases_daily.head()

In [None]:
fig = plt.figure(figsize=(10, 6))

# iterate the different groups to create a different series
for country, confimed_case in confirmed_cases_daily.groupby("CountryName"):
    plt.plot(confimed_case["Date"], confimed_case["ConfirmedCases"].fillna(0), label=country)


plt.legend(loc='best')

## Q3: What are the daily confirmed cases?
The *ConfirmedCases* columns is a cumulative sum, we need to convert them to daily values first.

Find these values for CountryName IN [Canada, United States, India, United Kingdom, China, Iran, Brazil, Australia, and South Africa].

In [None]:
from pyspark.sql.window import Window
import pyspark.sql.functions as f

window = Window.partitionBy("CountryName").orderBy("Date")

daily_confirmed_cases = confirmed_cases_countries.withColumn("ConfirmedCases", f.col("ConfirmedCases") - f.lag(f.col("ConfirmedCases"), 1, 0).over(window))

daily_confirmed_cases.show()

In [None]:
daily_confirmed_cases_df = daily_confirmed_cases.toPandas()

fig = plt.figure(figsize=(10, 6))

# iterate the different groups to create a different series
for country, confimed_case in daily_confirmed_cases_df.groupby("CountryName"):
    plt.plot(confimed_case["Date"], confimed_case["ConfirmedCases"].fillna(0), label=country)


plt.legend(loc='best')

## Q4: Plot the Government Response Index vs the daily number of confirmed cases.

Create a plot for each of the following countries [Canada, United States, India, United Kingdom, China, Iran, Brazil, Australia, and South Africa].




In [None]:
for country, confimed_case in daily_confirmed_cases_df.groupby("CountryName"):
    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    ax1.plot(confimed_case["Date"], confimed_case["ConfirmedCases"].fillna(0), 'g-')
    ax2.plot(confimed_case["Date"], confimed_case["GovernmentResponseIndex_Average"].fillna(0), 'b-')
    plt.title(country)
    plt.show()