## Pivot Tables in Pandas

A pivot table is a way of summarizing data grouped by one or more categories in a table. The core parameters are:

In [None]:
"""
pd.pivot_table(
    data=df,
    values=...,   # what to aggregate
    index=...,    # rows (grouping)
    columns=...,  # columns (optional second grouping)
    aggfunc=...   # aggregation function
)
"""

Note that the aggfunc argument works exactly like the .agg method when it comes to aggregating multiple columns with multiple functions.

## One-dimensional Grouping

In [1]:
import pandas as pd

In [2]:
cities = pd.DataFrame({
    "city": ["Lagos", "Abuja", "Lagos", "Ibadan", "Abuja", "Lagos"],
    "month": ["Jan", "Jan", "Feb", "Jan", "Feb", "Feb"],
    "sales": [100, 80, 120, 90, 70, 110]
})
cities

Unnamed: 0,city,month,sales
0,Lagos,Jan,100
1,Abuja,Jan,80
2,Lagos,Feb,120
3,Ibadan,Jan,90
4,Abuja,Feb,70
5,Lagos,Feb,110


In [3]:
# total sales per city
sales_per_city = pd.pivot_table(
    data=cities,
    values="sales",
    index="city",
    aggfunc="sum"
)
sales_per_city

Unnamed: 0_level_0,sales
city,Unnamed: 1_level_1
Abuja,150
Ibadan,90
Lagos,330


This is equivalent to:

In [5]:
cities.groupby("city")["sales"].sum()

Unnamed: 0_level_0,sales
city,Unnamed: 1_level_1
Abuja,150
Ibadan,90
Lagos,330


A pivot table is a DataFrame. The column passed to the index argument is the index of the new DataFrame.

In [4]:
type(sales_per_city)

## Two-dimensional Grouping

In [6]:
## sales per city per month
pd.pivot_table(
    data=cities,
    values="sales",
    index="city",
    columns="month",
    aggfunc="sum"
)

month,Feb,Jan
city,Unnamed: 1_level_1,Unnamed: 2_level_1
Abuja,70.0,80.0
Ibadan,,90.0
Lagos,230.0,100.0


## Handling Missing Combinations

By default, missing combinations become NaN.

In [7]:
# sales per city per month, missing combinations replaced with 0
pd.pivot_table(
    cities,
    values="sales",
    index="city",
    columns="month",
    aggfunc="sum",
    fill_value=0
)

month,Feb,Jan
city,Unnamed: 1_level_1,Unnamed: 2_level_1
Abuja,70,80
Ibadan,0,90
Lagos,230,100


## Multiple Aggregation Functions

When using multiple aggregations, pass a list of aggregation functions into the aggfunc argument.

In [8]:
# sum and mean sales per city per month
pd.pivot_table(
    cities,
    values="sales",
    index="city",
    columns="month",
    aggfunc=["sum", "mean"],
    fill_value=0
)

Unnamed: 0_level_0,sum,sum,mean,mean
month,Feb,Jan,Feb,Jan
city,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2
Abuja,70,80,70.0,80.0
Ibadan,0,90,0.0,90.0
Lagos,230,100,115.0,100.0


## Multiple Value Columns

When aggregating multiple columns at once, pass a list of the required columns to the values argument.

In [11]:
# sales and profit per city
import numpy as np
profit = np.random.randint(100, 200, 6)
cities["profit"] = profit

pd.pivot_table(
    cities,
    values=["sales", "profit"],
    index="city",
    aggfunc="sum",
    fill_value=0
)

Unnamed: 0_level_0,profit,sales
city,Unnamed: 1_level_1,Unnamed: 2_level_1
Abuja,260,150
Ibadan,100,90
Lagos,415,330


## Mini-Exercise 1

Create a pivot table showing average sales per city per month.

In [12]:
# average sales per city per month
pd.pivot_table(
    cities,
    values="sales",
    index="city",
    columns="month",
    aggfunc="mean",
    fill_value=0
)

month,Feb,Jan
city,Unnamed: 1_level_1,Unnamed: 2_level_1
Abuja,70.0,80.0
Ibadan,0.0,90.0
Lagos,115.0,100.0


## Mini-Exercise 2

Create a pivot table showing total sales per month (across all cities).

In [13]:
# total sales per month across all cities
pd.pivot_table(
    cities,
    values="sales",
    index="month",
    aggfunc="sum",
    fill_value=0
)

Unnamed: 0_level_0,sales
month,Unnamed: 1_level_1
Feb,300
Jan,270


## Flattening Columns

After complex pivots, columns can get messy with multi-index levels. For example:

In [14]:
# table of summary statistics: sum, mean, max, min of sales and profit per city per month
table = pd.pivot_table(
    cities,
    values=["sales", "profit"],
    index="city",
    columns="month",
    aggfunc=["sum", "mean", "max", "min"],
    fill_value=0
)
table

Unnamed: 0_level_0,sum,sum,sum,sum,mean,mean,mean,mean,max,max,max,max,min,min,min,min
Unnamed: 0_level_1,profit,profit,sales,sales,profit,profit,sales,sales,profit,profit,sales,sales,profit,profit,sales,sales
month,Feb,Jan,Feb,Jan,Feb,Jan,Feb,Jan,Feb,Jan,Feb,Jan,Feb,Jan,Feb,Jan
city,Unnamed: 1_level_3,Unnamed: 2_level_3,Unnamed: 3_level_3,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3,Unnamed: 16_level_3
Abuja,101,159,70,80,101.0,159.0,70.0,80.0,101,159,70,80,101,159,70,80
Ibadan,0,100,0,90,0.0,100.0,0.0,90.0,0,100,0,90,0,100,0,90
Lagos,222,193,230,100,111.0,193.0,115.0,100.0,116,193,120,100,106,193,110,100


This cleanup is often used to tidy up messy data:

In [15]:
table.columns = [
    "_".join(col).strip() for col in table.columns.to_flat_index()
]
table

Unnamed: 0_level_0,sum_profit_Feb,sum_profit_Jan,sum_sales_Feb,sum_sales_Jan,mean_profit_Feb,mean_profit_Jan,mean_sales_Feb,mean_sales_Jan,max_profit_Feb,max_profit_Jan,max_sales_Feb,max_sales_Jan,min_profit_Feb,min_profit_Jan,min_sales_Feb,min_sales_Jan
city,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
Abuja,101,159,70,80,101.0,159.0,70.0,80.0,101,159,70,80,101,159,70,80
Ibadan,0,100,0,90,0.0,100.0,0.0,90.0,0,100,0,90,0,100,0,90
Lagos,222,193,230,100,111.0,193.0,115.0,100.0,116,193,120,100,106,193,110,100
