# Class 10: Practice with .group() and .pivot()

In [1]:
from datascience import *
import numpy as np
# import for plotting
%matplotlib inline
import matplotlib.pyplot as plt
plt.style.use('fivethirtyeight')
# Fix for datascience plots
import collections as collections
import collections.abc as abc
collections.Iterable = abc.Iterable

To make group and pivot operations as easy to follow as possible, we will start with super simple table of ice cream cone flavors

In [2]:
cones = Table().with_columns(
    'Flavor', make_array('strawberry', 'chocolate', 'chocolate', 'strawberry', 'chocolate', 'bubblegum'),
    'Color', make_array('pink', 'light brown', 'dark brown', 'pink', 'dark brown', 'pink'),
    'Price', make_array(3.55, 4.75, 5.25, 5.25, 5.25, 4.75)
)
cones

Flavor,Color,Price
strawberry,pink,3.55
chocolate,light brown,4.75
chocolate,dark brown,5.25
strawberry,pink,5.25
chocolate,dark brown,5.25
bubblegum,pink,4.75


# Group

Let's group the table by Flavor

In [3]:
cones.group("Flavor")

Flavor,count
bubblegum,1
chocolate,3
strawberry,2


By default the rows with the same flavor are simply counted. 

But you can pass a function to use instead of count to apply to members of each group. Suppose we wanted to know the maximum price in each group.

In [4]:
cones.group("Flavor", np.max)

Flavor,Color amax,Price amax
bubblegum,,4.75
chocolate,,5.25
strawberry,,5.25


Notice that ther were three chocolate flavored cones in the original table. We grouped by flavor and used np.max to get the most expensive in each flavor category.

## Challege 1
Find the average price for the cones of each flavor.

## Challenge 2
Instead of grouping by Flavor, group by Color and find the lowest price in each color category.

# Pivot
Start with the same table.

In [5]:
cones

Flavor,Color,Price
strawberry,pink,3.55
chocolate,light brown,4.75
chocolate,dark brown,5.25
strawberry,pink,5.25
chocolate,dark brown,5.25
bubblegum,pink,4.75


Now instead of grouping we will pivot.

In [6]:
cones.pivot("Flavor", "Color")

Color,bubblegum,chocolate,strawberry
dark brown,0,2,0
light brown,0,1,0
pink,1,0,2


Look carefully at what happened. The first is the column to pivot on. The rows in that column become the columns in the new table. We pivoted a rows to columns.

The second argument is specifies which column's values to use for the rows of the new table.

Just as with group(), by default, the values of the tables are counts. The original table had two rows with flavor "chocolate" and color "dark brown," so that is the value that when into the table at the intersection of the chocolate column with the dark brown row.

Again, just as with group, we can use a different collection function. In the example below, we pass two more arguments to pivot. The column we want to collect and the collection function. In other words, we are saying for each combination of Flavor and color, take the average of all the prices.

Thus, for the chocolate cones that are dark brown (recall there are two), the average prices is $5.25.

In [7]:
cones.pivot("Flavor", "Color", "Price", np.mean)

Color,bubblegum,chocolate,strawberry
dark brown,0.0,5.25,0.0
light brown,0.0,4.75,0.0
pink,4.75,0.0,4.4


## Challenge 3
Instead of the mean, find the lowest price for each flavor-color combination.

Now that you have seen group and pivot in action on a small table, you might be wondering what all the fuss is about. Why are group and pivot so use for data exploration? Their true power only becomes apparent when we are confronting larger data sets.

# Another look at the heart disease data

In [9]:
path = '../data/'
data = path + 'heart.csv'
heart = Table.read_table(data)
heart

age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target
52,1,0,125,212,0,1,168,0,1.0,2,2,3,0
53,1,0,140,203,1,0,155,1,3.1,0,0,3,0
70,1,0,145,174,0,1,125,1,2.6,0,0,3,0
61,1,0,148,203,0,1,161,0,0.0,2,1,3,0
62,0,0,138,294,1,1,106,0,1.9,1,3,2,0
58,0,0,100,248,0,0,122,0,1.0,1,0,2,1
58,1,0,114,318,0,2,140,0,4.4,0,3,1,0
55,1,0,160,289,0,0,145,1,0.8,1,1,3,0
46,1,0,120,249,0,0,144,0,0.8,2,0,3,0
54,1,0,122,286,0,0,116,1,3.2,1,2,2,0


This data set has over 1000 rows, which is still small by today's standards, but you can no longer see the answers by simple inspection.

Let's imagine we want to know if the resting blood pressure (`trestbps`) is higher on average for patients with heart disease (`target` = 1) than those without heart disease (`target` = 1).

If you didn't know about grouping, you might approach it this way.

In [15]:
heart_disease = heart.where("target", are.equal_to(1))
no_heart_disease = heart.where("target", are.equal_to(0))

avg_bp_heart_disease = np.mean(heart_disease.column("trestbps"))
avg_bp_no_heart_disease = np.mean(no_heart_disease.column("trestbps"))

print("Average BP for patients with no heart disease:", avg_bp_no_heart_disease)
print("Average BP for patients with heart disease:", avg_bp_heart_disease)

Average BP for patients with no heart disease: 134.106212425
Average BP for patients with heart disease: 129.245247148


Here is the same analysis using group().

In [13]:
heart.group("target", np.mean)

target,age mean,sex mean,cp mean,trestbps mean,chol mean,fbs mean,restecg mean,thalach mean,exang mean,oldpeak mean,slope mean,ca mean,thal mean
0,56.5691,0.827655,0.482966,134.106,251.293,0.164329,0.456914,139.13,0.549098,1.6002,1.16633,1.15832,2.53908
1,52.4087,0.570342,1.37833,129.245,240.979,0.134981,0.598859,158.586,0.134981,0.569962,1.59316,0.370722,2.11977


This compares the average for all columns in one shot! Think how long that would take to do for every column using the first approach.

To get just the BP we first select the columns of interest, then group:

In [14]:
heart.select("target", "trestbps").group("target", np.mean)

target,trestbps mean
0,134.106
1,129.245


Grouping is a powerful way to summarize information about a lot of columns quickly.

**Now let's look at a pivot example.**
But first, it is confusing to use 0,1 for both gender and disease. Let's change the values of `sex` to (male, female).

In [34]:
def change_values(label):
    if label == 1:
        return "male"
    elif label == 0:
        return "female"
    else:
        return "not declared"

Note that the function checks for values other than (0, 1). First, limiting gender to male and female shows this is an *old* data set. Second, it is a good practice to check for unexpected values in your data set.

In [35]:
# Test the function
print(change_values(1))
print(change_values(0))
print(change_values(5))

male
female
not declared


In [36]:
# Apply the function
heart = heart.with_column("gender", heart.apply(change_values, "sex"))
heart.show(3)

age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target,gender
52,1,0,125,212,0,1,168,0,1.0,2,2,3,0,male
53,1,0,140,203,1,0,155,1,3.1,0,0,3,0,male
70,1,0,145,174,0,1,125,1,2.6,0,0,3,0,male


Now lets pivot so that:\
- the `target` becomes the columns
- the `gender` becomes the rows
- the average of the `trestbps` becomes the values


In [37]:
heart.pivot("target", "gender", "trestbps", np.mean)

gender,0,1
female,146.488,128.836
male,131.528,129.553


## Challenge 4
Make the same pivot table but instead of blood pressure, show average cholesteral.