# Lab 5: Prediction, Pivot, Group, and Joins

Welcome to lab 5! This week, we'll learn about prediction, and `pivot`, `group`, and `join` methods from [Section 8](https://www.inferentialthinking.com/chapters/08/Functions_and_Tables.html).  
First, set up the tests and imports by running the cell below.

In [None]:
# Don't change this cell; just run it. 

import numpy as np
from datascience import *

# These lines do some fancy plotting magic.\n",
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')

# These lines load the tests.

from gofer.ok import check

## 1. Prediction of Income given Expenditure

The [Consumer Expenditure Surveys](https://www.bls.gov/cex/) is a national survey in the U.S. It contains data on expenditures, income, and tax statistics about consumer units (CU) across the country. It provides information on the buying habits of U.S. consumers.

Let's consider a recent sample from 2017 1st quarter.

In [None]:
CEdata = Table().read_table("CEdata2017Q1.csv")
CEdata

**Question 1.** Due to the skewedness of Income and Expenditure, data scientists ususally perform data transformation first. Take the log transformation of `Income` and `Expenditure`, and create two new columns: `LogIncome` and `LogExpenditure`.

*Hint:* Use the np method `log` to take the logarithm of a numerical variable.

In [None]:
CEdata = ...
CEdata

In [None]:
check('tests/q1_1.py')

**Question 2.** Create a scatter plot of `LogIncome` and `LogExpenditure`. Plot `LogIncome` on the y-axis and `LogExpenditure` on the x-axis. Describe the relationship between `LogIncome` and `LogExpenditure` given your scatter plot.

In [None]:
## create a scatter plot


*Write your answer here, replacing this text.*

Now let's try predict `LogIncome` of a consumer unit (CU) given its `LogExpenditure`, and calculate the prediction based on mean values for a CU with `LogExpenditure` 8. Revisit the T8-prediction demo in class if you need a quick refresher.

In [None]:
nearby = CEdata.where('LogExpenditure', are.between(7.5, 8.5))
nearby_mean = nearby.column('LogIncome').mean()
nearby_mean

**Question 3.** Write a function to predict the `LogIncome` given `LogExpenditure`. Return the predicted `LogIncome` of a CU whose `LogExpenditure` is 10.

In [None]:
def predict(LogExpenditure):
    


In [None]:
predict(10)

In [None]:
check('tests/q1_3.py')

**Question 4.** Predict `LogIncome` of every CU in the dataset, and create a new column called `predictedLogIncome` of the `CEdata` table.

In [None]:
CEdata = ...
CEdata

In [None]:
check('tests/q1_4.py')

Once predictions are made, we should check the prediction errors and see how our predictions are. 

In [None]:
def difference(x, y):
    return x - y

**Question 5.** Define prediction error as (`LogIncome` - `predictedLogIncome`). Calculate the prediction error of every CU in the dataset, and add a new column called `errors` of the `CEdata` table.

In [None]:
errors = ...
CEdata = ...
CEdata

In [None]:
check('tests/q1_5.py')

**Question 6.** Make a histogram of the prediction errors using the given bins. How many CUs have a prediction error greater than 1?

In [None]:
my_bins = np.arange(-4, 5, 1)
CEdata.hist(...)

In [None]:
# How many CUs have a prediction error greater than 1?
CU_prediction_error_num = ...

In [None]:
check('tests/q1_6.py')

## 2. Causes of Death by Year


This exercise is designed to give you practice using the Table method `pivot`. [Here](https://docs.google.com/presentation/d/13AjdMawjHQg8y9rA2lNancd1Nw42jdtNimVO9qU8AKI/edit#slide=id.g7ddf730208_0_95) is a link to the lecture slides in case you need a quick refresher.

Run the cell below to view a demo on how you can use pivot on a table.


In [None]:
from IPython.display import YouTubeVideo
YouTubeVideo("4WzXo8eKLAg")

We'll be looking at a [dataset](http://www.healthdata.gov/dataset/leading-causes-death-zip-code-1999-2013) from the California Department of Public Health that records the cause of death, as recorded on a death certificate, for everyone who died in California from 1999 to 2013.  The data is in the file `causes_of_death.csv`. Each row records the number of deaths by a specific cause in one year in one ZIP code.

In [None]:
causes = Table.read_table('causes_of_death.csv')
causes

The causes of death in the data are abbreviated.  We've provided a table called `abbreviations.csv` to decode the abbreviations.

In [None]:
abbreviations = Table.read_table('abbreviations.csv')
abbreviations.show()

The dataset is missing data on certain causes of death, such as homicide and hypertensive renal disease, for certain years.  It looks like those causes of death are relatively rare, so for some purposes it makes sense to remove them from consideration.  Of course, we'll have to keep in mind that we're no longer looking at a comprehensive report on all deaths in California.

**Question 1.** Let's clean up our data. First, remove rows with HOM, HYP, and NEP as the cause of death from the table for the reasons described above. Next, join together the abbreviations table and our causes of death table so that we have a more detailed description of each disease in each row. Lastly, drop the column which contains the abbreviation of the disease, and rename the column that contains the full description to 'Cause of Death'. Assign the resulting table to the name `cleaned_causes`.

*Hint:* You should expect this to take more than one line. Use many lines and many intermediate tables to complete this question. 

<!--
BEGIN QUESTION
name: q1_1
-->

In [None]:
cleaned_causes = ...
cleaned_causes

In [None]:
check('tests/q2_1.py')

We're going to examine the changes in causes of death over time.  To make a plot of those numbers, we need to have a table with one row per year, and the information about all the causes of death for each year.

**Question 2.** Create a table with one row for each year and a column for each kind of death, where each cell contains the number of deaths by that cause in that year. Call the table `cleaned_causes_by_year`.

<!--
BEGIN QUESTION
name: q1_2
-->

In [None]:
cleaned_causes_by_year = ...
cleaned_causes_by_year.show(15)

In [None]:
check('tests/q2_2.py')

**Question 3.** Make a plot of all the causes of death by year, using the `cleaned_causes_by_year` table.  There should be a single plot with one line per cause of death.

*Hint:* Use the Table method `plot`.  If you pass only a single argument, a line will be made for each of the other columns.

<!--
BEGIN QUESTION
name: q1_3
manual: true
-->
<!-- EXPORT TO PDF -->

In [None]:
...

After seeing the plot above, we would now like to examine the distributions of diseases over the years using percentages. Below, we have assigned `distributions` to a table with all of the same columns, but the raw counts in the cells are replaced by the percentage of the the total number of deaths by a particular disease that happened in that specific year. 

Try to understand the code below. 

In [None]:
def percents(array_x):
    return np.round( (array_x/sum(array_x))*100, 2)

# We are making the labels an array; you are not expected to know the asarray function.
labels = np.asarray(cleaned_causes_by_year.labels)
distributions = Table().with_columns(labels.item(0), cleaned_causes_by_year.column(0),
                                     labels.item(1), percents(cleaned_causes_by_year.column(1)),
                                     labels.item(2), percents(cleaned_causes_by_year.column(2)),
                                     labels.item(3), percents(cleaned_causes_by_year.column(3)),
                                     labels.item(4), percents(cleaned_causes_by_year.column(4)),
                                     labels.item(5), percents(cleaned_causes_by_year.column(5)),
                                     labels.item(6), percents(cleaned_causes_by_year.column(6)),
                                     labels.item(7), percents(cleaned_causes_by_year.column(7)),
                                     labels.item(8), percents(cleaned_causes_by_year.column(8)),
                                     labels.item(9), percents(cleaned_causes_by_year.column(9)),
                                     labels.item(10), percents(cleaned_causes_by_year.column(10)),
                                     labels.item(11), percents(cleaned_causes_by_year.column(11)))
distributions.show()

**Question 4.** What is the sum (roughly) of each of the columns (except the Year column) in the table above? Why does this make sense? 

<!--
BEGIN QUESTION
name: q1_4
manual: true
-->
<!-- EXPORT TO PDF -->

*Write your answer here, replacing this text.*

**Question 5:** Over the years 1999-2013, we suspect that most stroke-related deaths happened in earlier years, while most Chronic Liver Disease-related deaths occured in more recent years. Draw a horizontal bar chart to display the percent of total deaths related to "Cerebrovascular Disease (Stroke)" and "Chronic Liver Disease and Cirrhosis" over the time period. 

*Hint*: Use the Table method `barh`. If you pass through a single column label, it creates bar charts of the other columns.


<!--
BEGIN QUESTION
name: q1_5
manual: true
-->
<!-- EXPORT TO PDF -->

In [None]:
...

# Don't change the code below this comment.
plt.title("% of total deaths / disease per year")
plt.xlabel("% of total deaths")
plt.show();

## 3. NBA player salaries

Recall in class we have worked with the nba player salaries dataset, 2015-2016 season.

*Note: There are no checks for this section since many are about making graphs. Make sure to show your results when getting checked off.*

This exercise is designed to give you practice using the Table method `group`. [Here](http://data8.org/fa19/python-reference.html) is a link to the Python reference page in case you need a quick refresher.

Run the cell below to view a demo on how you can use group on a table.

In [None]:
from IPython.display import YouTubeVideo
YouTubeVideo("HLoYTCUP0fc")

In [None]:
# This table can be found online: 
# https://www.statcrunch.com/app/index.php?dataid=1843341

# NBA players, 2015-2016 season
nba = Table.read_table('nba_salaries.csv').relabeled(3, 'SALARY').sort('PLAYER')
nba

**Question 1.** How many players are there in each position? What is the maximum salary of a center (C)?

*Hint: you can first create a table containing only the POSITION column and the SALARY column.*

In [None]:
position_and_salary = ...


*Write your answer here, replacing this text.*

**Question 2.** How much do Atlanta Hawks point guards (PG) make on average? What about point guards at Chicago Bulls?

*Hint: you can first create a table containing only the POSITION column, the TEAM column, and the SALARY column.*

In [None]:
position_team_salary = ...


*Write your answer here, replacing this text.*

**Question 3.** Make a histogram of salary of all point guards (PG) in this dataset.

*Hint: reuse the position_and_salary table you have created for Question 1.*

In [None]:
## create a histogram


**Question 4.** Make a bar chart of the top 10 point guards (PG) salary in a decreasing order. Include the PG's name in the bar chart (i.e. your categorical variable in the `barh` method should be PLAYER).

*Hint: you can first create a table containing only the POSITION column, the PLAYER column, and the SALARY column.*

In [None]:
## create a bar chart


## 4. Submission


Congratulations, you're done with Homework 4!  Be sure to 
- **Run all the tests and verify that they all pass** (the next cell has a shortcut for that), 
- **Save and Checkpoint** from the `File` menu,

In [None]:
# For your convenience, you can run this cell to run all the tests at once!
import glob
from gofer.ok import grade_notebook
if not globals().get('__GOFER_GRADER__', False):
    display(grade_notebook('lab05.ipynb', sorted(glob.glob('tests/q*.py'))))