## Programming for Big Data Project
### by Amy Reidy

Setup an Apache Spark instance in Google Colab



In [None]:
# install java
!apt-get install openjdk-8-jdk-headless -qq > /dev/null

# install spark (change the version number if needed)
!wget -q https://archive.apache.org/dist/spark/spark-3.0.2/spark-3.0.2-bin-hadoop2.7.tgz

# unzip the spark file to the current folder
!tar xf spark-3.0.2-bin-hadoop2.7.tgz

# set your spark folder to your system path environment. 
import os
os.environ["SPARK_HOME"] = "/content/spark-3.0.2-bin-hadoop2.7"

# install findspark using pip
!pip install -q findspark
import findspark
findspark.init()

Create a Spark session

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *


spark = SparkSession.builder\
          .master("local")\
          .appName("Food_Prices")\
          .config('spark.ui.port', '4050')\
          .getOrCreate()

spark

In [None]:
# loading the food prices dataset
raw_df = spark.read.csv("wfpvam_foodprices.csv", inferSchema=True, header=True)
raw_df.show()

+-------+-----------+-------+----------+------+--------+-----+--------------+------+--------+-----+-------+-----+-------+--------+-------+--------+------------------+
|adm0_id|  adm0_name|adm1_id| adm1_name|mkt_id|mkt_name|cm_id|       cm_name|cur_id|cur_name|pt_id|pt_name|um_id|um_name|mp_month|mp_year|mp_price|mp_commoditysource|
+-------+-----------+-------+----------+------+--------+-----+--------------+------+--------+-----+-------+-----+-------+--------+-------+--------+------------------+
|    1.0|Afghanistan|    272|Badakhshan|   266|Fayzabad|   55|Bread - Retail|   0.0|     AFN|   15| Retail|    5|     KG|       1|   2014|    50.0|              null|
|    1.0|Afghanistan|    272|Badakhshan|   266|Fayzabad|   55|Bread - Retail|   0.0|     AFN|   15| Retail|    5|     KG|       2|   2014|    50.0|              null|
|    1.0|Afghanistan|    272|Badakhshan|   266|Fayzabad|   55|Bread - Retail|   0.0|     AFN|   15| Retail|    5|     KG|       3|   2014|    50.0|              null

In [None]:
  # extracting number of rows from the Dataframe
  row = raw_df.count()
   
  # extracting number of columns from the Dataframe
  col = len(raw_df.columns)
 
  print(f'Dimension of the Dataframe is: {(row,col)}')
  print(f'Number of Rows are: {row}')
  print(f'Number of Columns are: {col}')

Dimension of the Dataframe is: (2050638, 18)
Number of Rows are: 2050638
Number of Columns are: 18


In [None]:
# display the schema of dataframe
raw_df.printSchema()

root
 |-- adm0_id: double (nullable = true)
 |-- adm0_name: string (nullable = true)
 |-- adm1_id: integer (nullable = true)
 |-- adm1_name: string (nullable = true)
 |-- mkt_id: integer (nullable = true)
 |-- mkt_name: string (nullable = true)
 |-- cm_id: integer (nullable = true)
 |-- cm_name: string (nullable = true)
 |-- cur_id: double (nullable = true)
 |-- cur_name: string (nullable = true)
 |-- pt_id: integer (nullable = true)
 |-- pt_name: string (nullable = true)
 |-- um_id: integer (nullable = true)
 |-- um_name: string (nullable = true)
 |-- mp_month: integer (nullable = true)
 |-- mp_year: integer (nullable = true)
 |-- mp_price: double (nullable = true)
 |-- mp_commoditysource: string (nullable = true)



In [None]:
# getting the names of the columns
print(raw_df.columns)

['adm0_id', 'adm0_name', 'adm1_id', 'adm1_name', 'mkt_id', 'mkt_name', 'cm_id', 'cm_name', 'cur_id', 'cur_name', 'pt_id', 'pt_name', 'um_id', 'um_name', 'mp_month', 'mp_year', 'mp_price', 'mp_commoditysource']


In [None]:
# computing basic statistics of dataframe
raw_df.describe().show()

+-------+------------------+-----------+------------------+---------+------------------+--------+------------------+--------------------+-------+--------+------------------+---------+-----------------+-------+------------------+-----------------+------------------+------------------+
|summary|           adm0_id|  adm0_name|           adm1_id|adm1_name|            mkt_id|mkt_name|             cm_id|             cm_name| cur_id|cur_name|             pt_id|  pt_name|            um_id|um_name|          mp_month|          mp_year|          mp_price|mp_commoditysource|
+-------+------------------+-----------+------------------+---------+------------------+--------+------------------+--------------------+-------+--------+------------------+---------+-----------------+-------+------------------+-----------------+------------------+------------------+
|  count|           2050638|    2050638|           2050638|  1439622|           2050638| 2050638|           2050638|             2050638|2050638|

In [None]:
from pyspark.sql.functions import col,isnan,when,count

# creating a dataframe to show how many null values are in each column 
nulls_df = raw_df.select([count(when(col(c).contains('None') | \
                            col(c).contains('NULL') | \
                            (col(c) == '' ) | \
                            col(c).isNull() | \
                            isnan(c), c 
                           )).alias(c) 
                    for c in raw_df.columns])
nulls_df.show()

+-------+---------+-------+---------+------+--------+-----+-------+------+--------+-----+-------+-----+-------+--------+-------+--------+------------------+
|adm0_id|adm0_name|adm1_id|adm1_name|mkt_id|mkt_name|cm_id|cm_name|cur_id|cur_name|pt_id|pt_name|um_id|um_name|mp_month|mp_year|mp_price|mp_commoditysource|
+-------+---------+-------+---------+------+--------+-----+-------+------+--------+-----+-------+-----+-------+--------+-------+--------+------------------+
|      0|        0|      0|   611016|     0|       0|    0|      0|     0|       0|    0|      0|    0|      0|       0|      0|       0|           2050638|
+-------+---------+-------+---------+------+--------+-----+-------+------+--------+-----+-------+-----+-------+--------+-------+--------+------------------+



### Function 1: Preprocessing

In [None]:
from pyspark.sql import functions as F

def preprocess_dataframe(df):
  
  #making list of all the unnecessary columns
  cols = ('adm0_id','adm1_id', 'adm1_name', 'mkt_id', 'cm_id', 'cur_id', 'pt_id', 'pt_name', 'um_id', 'mp_commoditysource')
  
  #dropping unnecessary columns
  dropped_column_df = df.drop(*cols)

  #renaming columns
  renamed_df = dropped_column_df.withColumnRenamed("adm0_name","country") \
    .withColumnRenamed("mkt_name","market") \
    .withColumnRenamed("cm_name", "commodity") \
    .withColumnRenamed("cur_name", "currency") \
    .withColumnRenamed("um_name", "unit") \
    .withColumnRenamed("mp_month", "month") \
    .withColumnRenamed("mp_year", "year") \
    .withColumnRenamed("mp_price", "price")

  #creating new column with date formed from the month and year columns
  date_df = renamed_df.select(renamed_df["*"], F.concat_ws('-',F.lit(1), renamed_df.month, renamed_df.year).alias('new_date'))

  #turning the date column into the date type
  date_df2 = date_df.select(date_df["*"], col("new_date"), to_date(col("new_date"),"d-M-yyyy").alias("date"))

  #dropping th old date column
  clean_df = date_df2.drop("new_date") 

  #returning the processed dataframe
  return clean_df

In [None]:
#using the preprocessing function to clean the food prices dataframe
clean_df = preprocess_dataframe(raw_df)

In [None]:
#showing schema of new dataframe
clean_df.printSchema()

root
 |-- country: string (nullable = true)
 |-- market: string (nullable = true)
 |-- commodity: string (nullable = true)
 |-- currency: string (nullable = true)
 |-- unit: string (nullable = true)
 |-- month: integer (nullable = true)
 |-- year: integer (nullable = true)
 |-- price: double (nullable = true)
 |-- date: date (nullable = true)



In [None]:
#showing sample of new dataframe
clean_df.show()

+-----------+--------+--------------+--------+----+-----+----+-----+----------+
|    country|  market|     commodity|currency|unit|month|year|price|      date|
+-----------+--------+--------------+--------+----+-----+----+-----+----------+
|Afghanistan|Fayzabad|Bread - Retail|     AFN|  KG|    1|2014| 50.0|2014-01-01|
|Afghanistan|Fayzabad|Bread - Retail|     AFN|  KG|    2|2014| 50.0|2014-02-01|
|Afghanistan|Fayzabad|Bread - Retail|     AFN|  KG|    3|2014| 50.0|2014-03-01|
|Afghanistan|Fayzabad|Bread - Retail|     AFN|  KG|    4|2014| 50.0|2014-04-01|
|Afghanistan|Fayzabad|Bread - Retail|     AFN|  KG|    5|2014| 50.0|2014-05-01|
|Afghanistan|Fayzabad|Bread - Retail|     AFN|  KG|    6|2014| 50.0|2014-06-01|
|Afghanistan|Fayzabad|Bread - Retail|     AFN|  KG|    7|2014| 50.0|2014-07-01|
|Afghanistan|Fayzabad|Bread - Retail|     AFN|  KG|    8|2014| 50.0|2014-08-01|
|Afghanistan|Fayzabad|Bread - Retail|     AFN|  KG|    9|2014| 50.0|2014-09-01|
|Afghanistan|Fayzabad|Bread - Retail|   

In [None]:
#showing a sample of rows for Rwanda
clean_df.filter("country=='Rwanda'").show()

+-------+-------+--------------+--------+----+-----+----+--------+----------+
|country| market|     commodity|currency|unit|month|year|   price|      date|
+-------+-------+--------------+--------+----+-----+----+--------+----------+
| Rwanda|Gahanga|Maize - Retail|     RWF|  KG|    7|2010|   107.5|2010-07-01|
| Rwanda|Gahanga|Maize - Retail|     RWF|  KG|    8|2010|121.6667|2010-08-01|
| Rwanda|Gahanga|Maize - Retail|     RWF|  KG|    9|2010|  108.75|2010-09-01|
| Rwanda|Gahanga|Maize - Retail|     RWF|  KG|   10|2010|   177.5|2010-10-01|
| Rwanda|Gahanga|Maize - Retail|     RWF|  KG|   11|2010|  181.25|2010-11-01|
| Rwanda|Gahanga|Maize - Retail|     RWF|  KG|   12|2010|   150.0|2010-12-01|
| Rwanda|Gahanga|Maize - Retail|     RWF|  KG|    1|2011|   188.6|2011-01-01|
| Rwanda|Gahanga|Maize - Retail|     RWF|  KG|    2|2011|   180.0|2011-02-01|
| Rwanda|Gahanga|Maize - Retail|     RWF|  KG|    3|2011|  140.75|2011-03-01|
| Rwanda|Gahanga|Maize - Retail|     RWF|  KG|    4|2011|   197.

### Function 2 - Get Names of Countries

In [None]:
def get_country_names(df):
  
  #getting the distinct values in the column
  distinct_column_vals = df.select("country").distinct().collect()
  distinct_column_vals = [v["country"] for v in distinct_column_vals]
  
  #getting the number of countries and names
  print("There are {} countries in the dataset:".format(len(distinct_column_vals)))
  return distinct_column_vals

In [None]:
#using the function to get all the names of the countries
get_country_names(clean_df)

There are 98 countries in the dataset:


['Chad',
 'Paraguay',
 'Yemen',
 'State of Palestine',
 'Senegal',
 'Eritrea',
 'Philippines',
 'Djibouti',
 'Turkey',
 'Malawi',
 'Iraq',
 'Afghanistan',
 'Cambodia',
 'Jordan',
 'Rwanda',
 'Sudan',
 'Iran  (Islamic Republic of)',
 'Sri Lanka',
 'Algeria',
 'Togo',
 'Argentina',
 'Angola',
 'Ecuador',
 'Lesotho',
 'Madagascar',
 'Ghana',
 'Myanmar',
 'Nicaragua',
 'Benin',
 'Peru',
 'Sierra Leone',
 'China',
 'Belarus',
 'Timor-Leste',
 'Somalia',
 'Tajikistan',
 'United Republic of Tanzania',
 'Burundi',
 'Bolivia',
 'Nigeria',
 'Gabon',
 'Moldova Republic of',
 'Mauritania',
 'Central African Republic',
 'Niger',
 'Bangladesh',
 'Russian Federation',
 'Congo',
 'Swaziland',
 'Thailand',
 'Bhutan',
 'Democratic Republic of the Congo',
 'Cape Verde',
 'Panama',
 'Ukraine',
 'Venezuela',
 "Cote d'Ivoire",
 'Mexico',
 'Bassas da India',
 'Georgia',
 'Zimbabwe',
 'Indonesia',
 'Guatemala',
 'Mongolia',
 'Azerbaijan',
 'Libya',
 'Armenia',
 'Liberia',
 'Honduras',
 'Uganda',
 'Namibia',
 

### Function 3 - Get Names of Commodities in a Specific Country

In [None]:
def get_country_commodities(df, country):
  
  # filtering the dataframe by country
  country_df = df.filter(df.country == country)

  # getting the distinct values in the column for commodities
  distinct_column_vals = country_df.select("commodity").distinct().collect()
  distinct_column_vals = [v["commodity"] for v in distinct_column_vals]
  
  # showing the number of commodities and names of commodities
  print("There are {} commodities listed in the dataset for {}:".format(len(distinct_column_vals), country))
  return distinct_column_vals

In [None]:
# getting the names of the commodities listed for Kenya
get_country_commodities(clean_df, 'Kenya')

There are 34 commodities listed in the dataset for Kenya:


['Beans - Wholesale',
 'Meat (beef) - Retail',
 'Maize flour - Retail',
 'Cooking fat - Retail',
 'Maize - Retail',
 'Maize (white) - Retail',
 'Wheat flour - Retail',
 'Fuel (kerosene) - Retail',
 'Sugar - Retail',
 'Cabbage - Retail',
 'Maize (white) - Wholesale',
 'Bananas - Retail',
 'Beans (dry) - Retail',
 'Oil (vegetable) - Retail',
 'Salt - Retail',
 'Milk (cow, pasteurized) - Retail',
 'Kale - Retail',
 'Milk (cow, fresh) - Retail',
 'Meat (camel) - Retail',
 'Onions (red) - Retail',
 'Tomatoes - Retail',
 'Sorghum - Wholesale',
 'Sorghum - Retail',
 'Potatoes (Irish) - Retail',
 'Potatoes (Irish) - Wholesale',
 'Rice - Retail',
 'Fuel (diesel) - Retail',
 'Milk (UHT) - Retail',
 'Beans (dry) - Wholesale',
 'Meat (goat) - Retail',
 'Bread - Retail',
 'Fuel (petrol-gasoline) - Retail',
 'Milk (camel, fresh) - Retail',
 'Maize - Wholesale']

### Function 4 - Get Years Available for a Certain Commodity in a Certain Country

In [None]:
def get_years(df, country, commodity):
  
  # filtering the dataframe by country and commodity
  country_commodity_df = df.filter((df.country == country) & (df.commodity == commodity))

  # getting the distinct values in the column for years
  distinct_column_vals = country_commodity_df.select("year").distinct().collect()
  distinct_column_vals = [v["year"] for v in distinct_column_vals]
  
  # showing the number of years and the years
  print("There are {} years listed in the dataset for {} in {}:".format(len(distinct_column_vals), commodity, country))
  return distinct_column_vals

In [None]:
get_years(clean_df, 'Kenya', 'Beans - Wholesale')

There are 16 years listed in the dataset for Beans - Wholesale in Kenya:


[2007,
 2018,
 2015,
 2006,
 2013,
 2014,
 2019,
 2020,
 2012,
 2009,
 2016,
 2010,
 2011,
 2008,
 2017,
 2021]

### Function 5 - Get Summary Statistics Table for Country and Commodity.

In [None]:
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number, countDistinct, avg, stddev, format_number

def get_summary_table(df, country, commodity, year=None, save=False):

  # if there is a year given, then the dataframe is filtered by this year, as well as by country and commodity
  if year != None:
      country_commodity_df = df.filter((df.country == country) & (df.commodity == commodity) & (df.year == year))
  # if no year is given, then the dataframe is just filtered by country and commodity
  else:
      country_commodity_df = df.filter((df.country == country) & (df.commodity == commodity))
  
  # grouping the data by date and getting the average
  avg_df = country_commodity_df.groupBy("date").agg({'price':'avg'})

  # grouping by date, getting the mininum and rounding to 2 decimals
  min_df = country_commodity_df.groupBy("date").agg({'price':'min'})
  min_df = min_df.select('date', format_number("min(price)",2).alias("minimum_price"))

  # grouping by date, getting the maxinum and rounding to 2 decimals              
  max_df = country_commodity_df.groupBy("date").agg({'price':'max'})
  max_df = max_df.select('date', format_number("max(price)",2).alias("maximum_price"))

  # grouping tby date, getting the standard deviation and rounding to 2 decimals
  std_df = country_commodity_df.groupBy("date").agg({'price':'std'})
  std_df = std_df.select('date', format_number("stddev(price)",2).alias("standard_deviation_price"))

  # getting the previous value for monthly average and substracting from the next average
  my_window = Window.partitionBy().orderBy("date")
  avg_df = avg_df.withColumn("prev_value", F.lag(avg_df['avg(price)']).over(my_window))
  avg_df = avg_df.withColumn("diff_avg", F.when(F.isnull(avg_df['avg(price)'] - avg_df.prev_value), 0)
                              .otherwise(avg_df['avg(price)'] - avg_df.prev_value))

  #creating column with the difference in average from preious month
  avg_df = avg_df.select('date', format_number("avg(price)",2).alias("average_price"), format_number("diff_avg",2).alias("diff_from_prev_month_avg"))

  
  # creating a column with the name of the market which had the lowest price for month
  windowDept = Window.partitionBy("date").orderBy(col("price").asc())
  min_market_df = country_commodity_df.withColumn("row",row_number().over(windowDept)) \
      .filter(col("row") == 1).drop("row") 
  min_market_df = min_market_df.sort(min_market_df.date.asc())
  min_market_df = min_market_df.select(col("market").alias("min_price_market"), col("date"))

  # creating a column with the name of the market which had the highest price for month
  windowDept = Window.partitionBy("date").orderBy(col("price").desc())
  max_market_df = country_commodity_df.withColumn("row",row_number().over(windowDept)) \
      .filter(col("row") == 1).drop("row") 
  max_market_df = max_market_df.sort(max_market_df.date.asc())
  max_market_df = max_market_df.select(col("market").alias("max_price_market"), col("date"))

  #joining all the columns
  joined_df = avg_df.join(min_df, "date") \
                    .join(min_market_df, "date") \
                    .join(max_df, "date") \
                    .join(max_market_df, "date") \
                    .join(std_df, "date")

  # sorting the columns so the most recent data comes first
  sorted_df = joined_df.sort(joined_df.date.desc())

  #option to save table as a csv file
  if save == True:
    sorted_df.toPandas().to_csv('{}_{}.csv'.format(country, commodity))

  # showing summary table
  return sorted_df 


                

In [None]:
# showing the summary table for beans in Kenya, and saving the file
kenya_beans_df = get_summary_table(clean_df, 'Kenya', 'Beans - Wholesale', save=True)
kenya_beans_df.show(20)

+----------+-------------+------------------------+-------------+----------------+-------------+----------------+------------------------+
|      date|average_price|diff_from_prev_month_avg|minimum_price|min_price_market|maximum_price|max_price_market|standard_deviation_price|
+----------+-------------+------------------------+-------------+----------------+-------------+----------------+------------------------+
|2021-02-01|        60.16|                    0.48|        60.16|          Nakuru|        60.16|          Nakuru|                     NaN|
|2021-01-01|        59.68|                  -18.73|        59.68|          Nakuru|        59.68|          Nakuru|                     NaN|
|2020-04-01|        78.41|                   20.77|        78.41|         Nairobi|        78.41|         Nairobi|                     NaN|
|2019-09-01|        57.64|                  -17.18|        57.64|          Nakuru|        57.64|          Nakuru|                     NaN|
|2019-08-01|        74.82| 

In [None]:
#import other necessary libraries for data visualizations
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

### Function 6: Get Chart with History of Market Prices

In [None]:
def get_market_history_chart(df, country, commodity):

  # filtering the dataframe by country and commodity
  country_commodity_df = df.filter((df.country == country) & (df.commodity == commodity))

  currency = country_commodity_df.first()['currency']
  unit = country_commodity_df.first()['unit']

  # turning dataframe into a pandas dataframe to visualize
  pd_df = country_commodity_df.toPandas()

  # turning into the data column into date type
  pd_df["Date"] = pd.to_datetime(pd_df["date"]) 

  # plotting market prices on a line chart
  fig = px.line(pd_df, x="Date", y="price", color="market", 
                title='Historical Market Prices in {} for {}.'.format(country,commodity),
                labels={'price': 'Price ({}/{})'.format(currency, unit)})

  # getting a range slider that allows user to choose timeframe
  fig.update_xaxes(rangeslider_visible=True)
  
  # displaying the chart
  fig.show()

In [None]:
# visualizing all recorded market prices for bean in Kenya
get_market_history_chart(clean_df, 'Kenya', 'Beans - Wholesale')

### Function 7: Get Chart with Market Prices from Certain Year

In [None]:
def get_market_year_chart(df, country, commodity, year=2019):

  # filtering the dataframe by country and commodity
  country_commodity_year_df = df.filter((df.country == country) & (df.commodity == commodity) & (df.year == year))

  currency = country_commodity_year_df.first()['currency']
  unit = country_commodity_year_df.first()['unit']

  # turning dataframe into a pandas dataframe to visualize
  pd_df = country_commodity_year_df.toPandas()

  # turning into the data column into date type
  pd_df["Date"] = pd.to_datetime(pd_df["date"]) 

  # plotting market prices for the chosen year on a line chart
  fig = px.line(pd_df, x="Date", y="price", color="market", markers=True,
                title='{} Market Prices in {} for {}.'.format(year, country,commodity),
                labels={'price': 'Price ({}/{})'.format(currency, unit)})

  # displaying the line chart
  fig.show()

In [None]:
# visualizing the market prices for beans in Kenya in 2018
get_market_year_chart(clean_df, 'Kenya', 'Beans - Wholesale', 2018)

### Function 8 - Get Charts with Average Prices for a Certain Year and All Years

In [None]:
def get_average_charts(df, country, commodity, year=2019):
  
  # filtering dataframe by country and commodity
  df = df.filter((df.country == country) & (df.commodity == commodity))

  # grouping by date and getting the averages rounded to 2 decimal places
  avg_df = df.groupBy("date").agg({'price':'avg'})
  avg_df = avg_df.select('date', format_number("avg(price)",2).alias("average_price"))

  # grouping by date and getting the standard deviation rounded to 2 decimal places
  std_df = df.groupBy("date").agg({'price':'std'})
  std_df = std_df.select('date', format_number("stddev(price)",2).alias("standard_deviation_price"))

  # joined the average and standard deviation into one dataframe
  joined_df = avg_df.join(std_df, "date")

  # adding a column for average + 1 standard deviation and a column for average - 1 standard deviation
  joined_df = joined_df.withColumn("avg+std", col("average_price")+col("standard_deviation_price")) \
              .withColumn("avg-std", col("average_price")-col("standard_deviation_price"))
  
  # sorting the dataframe in ascending order
  sorted_df = joined_df.sort(joined_df.date.asc())

  # turning the dataframe into a pandas dataframe to visualize
  pd_df1 = sorted_df.toPandas()

  # casting the average price as a float
  pd_df1['average_price'] = pd_df1['average_price'].astype(float)
  #pd_df1['standard_deviation_price'] = pd_df1['standard_deviation_price'].astype(float)
  
  # casting the date column as date type
  pd_df1["date"] = pd.to_datetime(pd_df1["date"])

  # filtering the dataframe by year
  pd_df2 = pd_df1[pd_df1['date'].dt.strftime('%Y') == str(year)] 

  # making the subplots
  fig = make_subplots(rows=2, cols=1, row_heights=[0.7, 0.3], 
                      subplot_titles=("Average Monthly Prices in {} (with +1/-1 Standard Deviation)".format( year), "History of Average Prices"),
                      shared_yaxes=True)

  # trace 1: average prices for year
  fig.append_trace(go.Scatter(
          name='Average Price',
          x=pd_df2['date'],
          y=pd_df2['average_price'],
          mode='lines',
          line=dict(color='rgb(31, 119, 180)')),
          row=1, col=1)

  # trace 2: average + 1 standard deviation    
  fig.append_trace(go.Scatter(
          name='Upper Bound',
          x=pd_df2['date'],
          y=pd_df2['avg+std'],
          mode='lines',
          marker=dict(color="#444"),
          line=dict(width=0),
          showlegend=False),
          row=1, col=1)

  # trace 3: average - 1 standard deviation
  fig.append_trace(go.Scatter(
          name='Lower Bound',
          x=pd_df2['date'],
          y=pd_df2['avg-std'],
          marker=dict(color="#444"),
          line=dict(width=0),
          mode='lines',
          fillcolor='rgba(68, 68, 68, 0.3)',
          fill='tonexty',
          showlegend=False),
          row=1, col=1)

  # trace 4: average prices for all years
  fig.append_trace(go.Scatter(
          name='Average Price',
          x=pd_df1['date'],
          y=pd_df1['average_price'],
          mode='lines',
          line=dict(color='rgb(31, 119, 180)')),
          row=2, col=1)
  
  # customizing charts
  fig.update_layout(
      showlegend=False,
      height=800,
      yaxis_title='Average Price',
      title='Average Prices of {} in {}'.format(commodity, country),
      hovermode="x"
  )

  # displaying chart
  fig.show()

In [None]:
# displaying the chart for average prices for beans in Kenya in 2018 and historically
get_average_charts(clean_df, 'Kenya', 'Beans - Wholesale', 2016)