<a href="https://colab.research.google.com/github/google/applied-machine-learning-intensive/blob/master/v2/02_data/03_visualizations/colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Copyright 2020 Google LLC.

In [0]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Visualizations

Visualizations play an important role in the field of machine learning and data science. Often we need to distill key information found in large quantities of data down into meaningful and digestible forms. Good visualizations can tell a story about your data in a way that prose cannot.

In this lab we will explore some common visualization techniques. We will utilize toolkits such as [Matplotlib's Pyplot](https://matplotlib.org/api/pyplot_api.html) and [Seaborn](https://seaborn.pydata.org/) to create informative images that provide information and insights about our data.

## Pie Charts

A pie chart is used to show how much of each type of data in a dataset contributes to the whole. It is a circular chart where each class of data represents a portion of the whole.

Let's create a pie chart using a sample dataset. 

The `labels` variable contains a tuple of ice cream flavors that we have available.

The `votes` variable contains a tuple of vote counts. They represent the number of votes each flavor got when we asked a group of people what their favorite flavor of ice cream is.

We can create the chart using [Matplotlib's Pyplot](https://matplotlib.org/api/pyplot_api.html) library. We pass [`plt.pie()`](https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.pie.html) the number of votes for each flavor and then the labels that we want to use.

In [0]:
import matplotlib.pyplot as plt

flavors = ('Chocolate', 'Vanilla', 'Pistachio', 'Mango', 'Strawberry')
votes = (12, 11, 4, 8, 7)

plt.pie(
    votes,
    labels=flavors,
)
plt.show()

Given this chart we can easily see that chocolate is the most popular flavor, with vanilla not falling far behind. Admittedly we could also tell this by looking at the raw data. However, the data in the pie chart format also allows us to easily see other information, such as the fact that chocolate and vanilla combined represent more than half of the votes.

What we don't see are actual percentages.

If we want to see what percentage each flavor contributes, we can use the `autopct` argument. For the argument value we provide a format string that can be used to set the precision of the number that is shown.

Try changing the value to `%1.0%%` and `%1.2f%%`. What happens?

In [0]:
import matplotlib.pyplot as plt

flavors = ('Chocolate', 'Vanilla', 'Pistachio', 'Mango', 'Strawberry')
votes = (12, 11, 4, 8, 7)

plt.pie(
    votes,
    labels=flavors,
    autopct='%1.1f%%',
)
plt.show()

Now we can see the percent that each ice cream flavor contributes to the whole.

One thing that's still a little confusing about this chart is the choice of color. We have an idea of what color each of these ice cream flavors has in real life, but what is shown on the chart doesn't match up with those real-world colors.

We can fix this!

Matplotlib's pie chart allows you to change the colors shown on the chart by passing in an iterable of color values. You can use one of a small number of pre-programmed values such as 'b' for blue and 'g' for green.

In our case we pass in the HTML names for the colors. These are six-character values where the first two characters represent the amount of red in the color, the next two the amount of green, and the final two the amount of blue. You can find many tables and pickers for these by searching for ['html color codes'](https://www.google.com/search?q=html+color+codes).

Below we picked custom colors for each of the flavors.

In [0]:
import matplotlib.pyplot as plt

flavors = ('Chocolate', 'Vanilla', 'Pistachio', 'Mango', 'Strawberry')
votes = (12, 11, 4, 8, 7)
colors = ('#8B4513', '#FFF8DC', '#93C572', '#E67F0D', '#D53032')

plt.pie(
    votes,
    labels=flavors,
    autopct='%1.1f%%',
    colors=colors,
)
plt.show()

Excellent! Now the colors have a closer relationship to the data that they represent. We won't always have this tight of a relationship, but you might find yourself in situations where you need to use a fixed color palette for some other reason, such as corporate branding. `colors` is great for that.

Now let's imagine we're preparing this chart for a presentation, and we want to call out one of the flavors in particular. Maybe mango is new to market, and we want to call out how much popularity it has already captured.

To do this we can use the `explode` argument. This allows us to set an offset for each slice of the pie from the center. In the example below we pushed mango out by `0.1` while keeping all of the rest of the pieces tied to the center.

In [0]:
import matplotlib.pyplot as plt

flavors = ('Chocolate', 'Vanilla', 'Pistachio', 'Mango', 'Strawberry')
votes = (12, 11, 4, 8, 7)
colors = ('#8B4513', '#FFF8DC', '#93C572', '#E67F0D', '#D53032')
explode = (0, 0, 0, 0.1, 0)

plt.pie(
    votes,
    labels=flavors,
    autopct='%1.1f%%',
    colors=colors,
    explode=explode,
)
plt.show()

We now have mango pulled out a bit from the pie, so that we can highlight its impact.

Notice that we could set offsets for every piece of the chart, and those offsets can be arbitrary numbers. Play around a bit with different and multiple offsets. Do negative numbers work?

Our pie chart looks pretty nice now, but it is very flat. We can give it a bit of a three-dimensional look by adding a shadow with the `shadow` argument.

In [0]:
import matplotlib.pyplot as plt

flavors = ('Chocolate', 'Vanilla', 'Pistachio', 'Mango', 'Strawberry')
votes = (12, 11, 4, 8, 7)
colors = ('#8B4513', '#FFF8DC', '#93C572', '#E67F0D', '#D53032')
explode = (0, 0, 0, 0.1, 0)

plt.pie(
    votes,
    labels=flavors,
    autopct='%1.1f%%',
    colors=colors,
    explode=explode,
    shadow=True
    )
plt.show()

To wrap it up, we can add a title using `plt.title()`. Notice that this is not an argument to `plt.pie()`, but is instead a separate method call on `plt`.

In [0]:
import matplotlib.pyplot as plt

flavors = ('Chocolate', 'Vanilla', 'Pistachio', 'Mango', 'Strawberry')
votes = (12, 11, 4, 8, 7)
colors = ('#8B4513', '#FFF8DC', '#93C572', '#E67F0D', '#D53032')
explode = (0, 0, 0, 0.1, 0)

plt.title('Favorite Ice Cream Flavors')
plt.pie(
    votes,
    labels=flavors,
    autopct='%1.1f%%',
    colors=colors,
    explode=explode,
    shadow=True
    )
plt.show()

We now have a nice pie chart that shows all of the favorite ice cream flavors in our poll!

Remember pie charts are good for showing how distinct classes of data (in this case, ice cream flavors) contribute to the whole.

They also work best when there are only a few classes represented. Imagine if we had 100 flavors of ice cream. The less popular flavors would all be impossible to view meaningfully.

## Bar Charts

Bar charts are another powerful tool for comparing categorical data. Similar to pie charts, they can be used to compare categories of data against each other.

However, pie charts are also good for seeing how one category of data compares against the whole. Bar charts aren't very good for this.

Also, bar charts can meaningfully display more categories of data than pie charts.

Let's start by taking a look at a bar chart showing the populations of each country in South America.

To do this we will use Matplotlib again. This time we will use the `bar()` method.

`bar()` has two required arguments. The first argument contains the x-coordinates of the data. Since we want to plot country names on the x-axis, there aren't any natural x-coordinates.

In cases like this we can use `NumPy`'s [`arange()`](https://docs.scipy.org/doc/numpy/reference/generated/numpy.arange.html) function to create a list of evenly spaced numbers. We ask for numbers between 0 and the length of the data, which should give us a list of whole numbers starting at 0 and ending at `len(data)-1`, which is 13 in the example case.

The next argument is the numeric data to plot. In this example we plot the population data.

In [0]:
import matplotlib.pyplot as plt
import numpy as np

countries = ('Argentina', 'Bolivia', 'Brazil', 'Chile', 'Colombia', 'Ecuador',
             'Falkland Islands', 'French Guiana', 'Guyana', 'Paraguay', 'Peru',
             'Suriname', 'Uruguay', 'Venezuela')

populations = (45076704, 11626410, 212162757, 19109629, 50819826, 17579085,
               3481, 287750, 785409, 7107305, 32880332, 585169, 3470475,
               28258770)

x_coords = np.arange(len(countries))
plt.bar(x_coords, populations)
plt.show()

You can see in the chart above that the x-labels aren't meaningful. We can fix this by passing a `tick_label` argument to `bar()`. Since we have relatively wide labels, it is also useful to rotate the labels by 90-degrees so that they are more readable. We do this using the `plt.xticks(rotation=90)` method call.

In [0]:
import matplotlib.pyplot as plt
import numpy as np

countries = ('Argentina', 'Bolivia', 'Brazil', 'Chile', 'Colombia', 'Ecuador',
             'Falkland Islands', 'French Guiana', 'Guyana', 'Paraguay', 'Peru',
             'Suriname', 'Uruguay', 'Venezuela')

populations = (45076704, 11626410, 212162757, 19109629, 50819826, 17579085,
               3481, 287750, 785409, 7107305, 32880332, 585169, 3470475,
               28258770)

x_coords = np.arange(len(countries))
plt.bar(x_coords, populations, tick_label=countries)
plt.xticks(rotation=90) #rotates text for x-axis labels
plt.show()

We can add labels to bar charts to help make the charts more readable. In the example below we add a y-label using the `ylabel()` method and a chart title using the `title()` method.

In [0]:
import matplotlib.pyplot as plt
import numpy as np

countries = ('Argentina', 'Bolivia', 'Brazil', 'Chile', 'Colombia', 'Ecuador',
             'Falkland Islands', 'French Guiana', 'Guyana', 'Paraguay', 'Peru',
             'Suriname', 'Uruguay', 'Venezuela')

populations = (45076704, 11626410, 212162757, 19109629, 50819826, 17579085,
              3481, 287750, 785409, 7107305, 32880332, 585169, 3470475,
              28258770)

x_coords = np.arange(len(countries))
plt.bar(x_coords, populations, tick_label=countries)
plt.xticks(rotation=90)
plt.ylabel('Population (Millions)')
plt.title('South American Populations')
plt.show()

The chart is looking pretty good.

But what if you were asked the question: *What is the second most populous country in South America?*

You would likely have to stare a bit at Argentina and Columbia.

This is because the data is sorted alphabetically, which isn't the most helpful sorting for answering questions about the data. Unfortunately Matplotlib doesn't have built in sorting. Instead, you can import Pandas and use it to sort the data.

In [0]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

countries = ('Argentina', 'Bolivia', 'Brazil', 'Chile', 'Colombia', 'Ecuador',
             'Falkland Islands', 'French Guiana', 'Guyana', 'Paraguay', 'Peru',
             'Suriname', 'Uruguay', 'Venezuela')

populations = (45076704, 11626410, 212162757, 19109629, 50819826, 17579085,
               3481, 287750, 785409, 7107305, 32880332, 585169, 3470475,
               28258770)

df = pd.DataFrame({
    'Country': countries,
    'Population': populations,
})
df.sort_values(by='Population', inplace=True)

x_coords = np.arange(len(df))
plt.bar(x_coords, df['Population'], tick_label=df['Country'])
plt.xticks(rotation=90)
plt.ylabel('Population (Millions)')
plt.title('South American Populations')
plt.show()

Now we can easily see that Columbia is the second largest country.

If we wanted to call that out, we could pass a list of bar colors to the `bar()` method.

In [0]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

countries = ('Argentina', 'Bolivia', 'Brazil', 'Chile', 'Colombia', 'Ecuador',
             'Falkland Islands', 'French Guiana', 'Guyana', 'Paraguay', 'Peru',
             'Suriname', 'Uruguay', 'Venezuela')

populations = (45076704, 11626410, 212162757, 19109629, 50819826, 17579085,
               3481, 287750, 785409, 7107305, 32880332, 585169, 3470475,
               28258770)

df = pd.DataFrame({
    'Country': countries,
    'Population': populations,
})
df.sort_values(by='Population', inplace=True)

x_coords = np.arange(len(df))
colors = ['#0000FF' for _ in range(len(df))]
colors[-2] = '#FF0000'
plt.bar(x_coords, df['Population'], tick_label=df['Country'], color=colors)
plt.xticks(rotation=90)
plt.ylabel('Population (Millions)')
plt.title('South American Populations')
plt.show()

We can also make the chart larger using the `figure()` method. We pass the `figsize=` argument which represents the width and height of the figure in inches. 

In [0]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

countries = ('Argentina', 'Bolivia', 'Brazil', 'Chile', 'Colombia', 'Ecuador',
             'Falkland Islands', 'French Guiana', 'Guyana', 'Paraguay', 'Peru',
             'Suriname', 'Uruguay', 'Venezuela')

populations = (45076704, 11626410, 212162757, 19109629, 50819826, 17579085,
               3481, 287750, 785409, 7107305, 32880332, 585169, 3470475,
               28258770)

df = pd.DataFrame({
    'Country': countries,
    'Population': populations,
})
df.sort_values(by='Population', inplace=True)

x_coords = np.arange(len(df))
colors = ['#0000FF' for _ in range(len(df))]
colors[-2] = '#FF0000'
plt.figure(figsize=(20,10))
plt.bar(x_coords, df['Population'], tick_label=df['Country'], color=colors)
plt.xticks(rotation=90)
plt.ylabel('Population (Millions)')
plt.title('South American Populations')
plt.show()

## Line Graphs

Line graphs are another useful visualization. While pie charts and bar charts are useful in showing how classes of data relate to each other, line graphs are more useful for showing how data progresses over some period. For example, line graphs can be useful in charting temperature over time, stock prices over time, weight by day, or any other continuous metric.

We'll create a very simple line graph below. The data we have is the temperature in celsius and the hour of the day for a single day and location.

You can see that to create the line graph we use the [`plt.plot()`](https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.plot.html) method.

In [0]:
import matplotlib.pyplot as plt

temperature_c = [2, 1, 0, 0, 1, 5, 8, 9, 8, 5, 3, 2, 2]
hour = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24]

plt.plot(
  hour,
  temperature_c
)
plt.show()

We can see that the temperature starts at around 2 degrees celsius at midnight, has a little drop to freezing around 05:00, gets up to around 9 degrees celsius at 15:00, and then drops back down to about 2 degrees at midnight.

We can, of course, add the standard chart elements of `title()`, `ylabel()`, and `xlabel()`.

In [0]:
import matplotlib.pyplot as plt

temperature_c = [2, 1, 0, 0, 1, 5, 8, 9, 8, 5, 3, 2, 2]
hour = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24]

plt.plot(
  hour,
  temperature_c,
)
plt.title('Temperatures in Kirkland, WA, USA on 2 Feb 2020')
plt.ylabel('Temperature Celsius')
plt.xlabel('Hour')
plt.show()

We can also add markers at each of the data points. In the example below we add a dot marker at each data point using the `marker='o'` argument.

In [0]:
import matplotlib.pyplot as plt

temperature_c = [2, 1, 0, 0, 1, 5, 8, 9, 8, 5, 3, 2, 2]
hour = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24]

plt.plot(
  hour,
  temperature_c,
  marker='o',
)
plt.title('Temperatures in Kirkland, WA, USA on 2 Feb 2020')
plt.ylabel('Temperature Celsius')
plt.xlabel('Hour')
plt.show()

We can even have multiple lines on the same chart. Say, for instance, that we wanted to illustrate actual and predicted temperature values. We can just call `plot()` twice, once with each set of values.

Notice that in the second call, we use another argument to `plot()`, `linestyle='--'`. This causes the predicted line to look like a dashed-line while the actual values stay solid.

You can find all of the many line formatting options at the [Matplotlib `pyplot.plot()` documentation](https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.plot.html).

In [0]:
import matplotlib.pyplot as plt

temperature_c_actual = [2, 1, 0, 0, 1, 5, 8, 9, 8, 5, 3, 2, 2]
temperature_c_predicted = [2, 2, 1, 0, 1, 3, 7, 8, 8, 6, 4, 3, 3]
hour = [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24]

plt.plot(hour, temperature_c_actual)
plt.plot(hour, temperature_c_predicted, linestyle='--')
plt.title('Temperatures in Kirkland, WA, USA on 2 Feb 2020')
plt.ylabel('Temperature Celsius')
plt.xlabel('Hour')
plt.show()

## Scatter Plots

Scatter plots work great for data with two numeric components. They provide a great way to get a quick look at your data to see if you notice any patterns or outliers.

In the example below, we have data related to gross domestic product (GDP) and population for countries with a population of more than one hundred million. GDP is the the total value of goods and services created/provided by a country over the course of a year.

We then use [`plt.scatter()`](https://matplotlib.org/3.1.3/api/_as_gen/matplotlib.pyplot.scatter.html) to create a scatter plot of population and GDP.


In [0]:
import matplotlib.pyplot as plt

country = ['Bangladesh', 'Brazil', 'China', 'India', 'Indonesia', 'Japan', 
           'Mexico', 'Nigeria', 'Pakistan', 'Russia', 'United States']
gdp = [2421, 13418, 9475, 4353, 7378, 35477, 14276, 5087, 4133, 20255, 49267]
population = [148692131, 194946470, 1341335152, 1224614327, 239870937,
              126535920, 113423047, 158423182, 173593383, 142958164, 310383948]

plt.scatter(population, gdp)
plt.show()

The scatter plot is interesting because we can gather some insights about our data. We can see that there are two population outliers and one (arguably two) GDP outliers.

This information can help us decide if we need to correct for or exclude the outliers in our analysis.

We can also add more than one set of data to a scatter plot. In the example below, we plot the diameters and weights of a batch of lemons and limes to see if we can determine a pattern.

In [0]:
import matplotlib.pyplot as plt

lemon_diameter = [6.44, 6.87, 7.7, 8.85, 8.15, 9.96, 7.21, 10.04, 10.2, 11.06]
lemon_weight = [112.05, 114.58, 116.71, 117.4, 128.93, 
                132.93, 138.92, 145.98, 148.44, 152.81]

lime_diameter = [6.15, 7.0, 7.0, 7.69, 7.95, 7.51, 10.46, 8.72, 9.53, 10.09]
lime_weight = [112.76, 125.16, 131.36, 132.41, 138.08,
               142.55, 156.86, 158.67, 163.28, 166.74]

plt.scatter(lemon_diameter, lemon_weight)
plt.scatter(lime_diameter, lime_weight)
plt.show()

Looking at our sample, there isn't a very clear pattern. However, one of the citruses does seem to be a little heavier per centimeter of diameter.

But which one?

It is really difficult to tell. Let's clean this chart up a bit.

First we'll add a title using `plt.title()`, an x-label using `plt.xlabel()`, and a y-label using `plt.ylabel()`.

In [0]:
import matplotlib.pyplot as plt

lemon_diameter = [6.44, 6.87, 7.7, 8.85, 8.15, 9.96, 7.21, 10.04, 10.2, 11.06]
lemon_weight = [112.05, 114.58, 116.71, 117.4, 128.93, 
                132.93, 138.92, 145.98, 148.44, 152.81]

lime_diameter = [6.15, 7.0, 7.0, 7.69, 7.95, 7.51, 10.46, 8.72, 9.53, 10.09]
lime_weight = [112.76, 125.16, 131.36, 132.41, 138.08,
               142.55, 156.86, 158.67, 163.28, 166.74]

plt.title('Lemons vs. Limes')
plt.xlabel('Diameter (cm)')
plt.ylabel('Weight (g)')
plt.scatter(lemon_diameter, lemon_weight)
plt.scatter(lime_diameter, lime_weight)
plt.show()

Now we can add some color and a legend to make our scatter plot a little more intuitive.

We add color by passing the `color=` argument to `plt.scatter()`. In this case we just set the lemon points to be yellow using `color='y'` and the lime points to be green using `color='g'`.

To add the legend we call `plt.legend()` and pass it a list containing a label for each scatter of data.

In [0]:
import matplotlib.pyplot as plt

lemon_diameter = [6.44, 6.87, 7.7, 8.85, 8.15, 9.96, 7.21, 10.04, 10.2, 11.06]
lemon_weight = [112.05, 114.58, 116.71, 117.4, 128.93, 
                132.93, 138.92, 145.98, 148.44, 152.81]

lime_diameter = [6.15, 7.0, 7.0, 7.69, 7.95, 7.51, 10.46, 8.72, 9.53, 10.09]
lime_weight = [112.76, 125.16, 131.36, 132.41, 138.08,
               142.55, 156.86, 158.67, 163.28, 166.74]

plt.title('Lemons vs. Limes')
plt.xlabel('Diameter (cm)')
plt.ylabel('Weight (g)')
plt.scatter(lemon_diameter, lemon_weight, color='y')
plt.scatter(lime_diameter, lime_weight, color='g')
plt.legend(['lemons', 'limes'])
plt.show()

Now we can see more clearly that our limes tend to be a little heavier per centimeter of diameter than our lemons.

## Heatmaps

Heatmaps are a type of visualization that uses color coding to represent the relative value/density of data across a surface. Often this is a tabular chart, but it doesn't have to be limited to that.

For tabular data, there are labels on the x and y axes. The values at the intersection of those labels maps to a color.

These colors can then be used to visually inspect the data to find clusters of similar values and detect trends in the data.

Let's start with a sample dataset that will literally map heat. We will be working with data about the average high temperatures each month for the 12 largest cities in the world.

To create this heatmap we will use a new library, [Seaborn](https://seaborn.pydata.org/). Seaborn is a visualization library that is built on top of [Matplotlib](https://matplotlib.org/). It provides a higher-level interface and can create more attractive charts with less effort. Any of the visualizations that we have seen in this lab so far could have also been created in Seaborn.

You'll see both Matplotlib and seaborn in use in real data analytics projects, so we want to introduce you to both of them in this lab.

Anyway, let's build a heatmap!

In the code below, we first import seaborn. We then create lists containing the names of the 12 largest cities in the world and the 12 months in the year.

Next we assign a list-of-lists to the `temperatures` variable. Each row in the list represents a city. Each column is a month. The values are the average high temperatures for the city for the month.

Finally we call `sns.heatmap()` to create the heatmap. We pass in the temperature data, the city names as y-labels, and the month abbreviations as x-labels.

In [0]:
import seaborn as sns

cities = ['Tokyo', 'Delhi', 'Shanghai', 'Sao Paulo', 'Mumbai', 'Mexico City',
          'Beijing', 'Osaka', 'Cairo', 'New York', 'Dhaka', 'Karachi']

months = ['J', 'F', 'M', 'A', 'M', 'J', 'J', 'A', 'S', 'O', 'N', 'D']

temperatures = [
  [10, 10, 14, 19, 23, 26, 30, 31, 27, 22, 17, 12], # Tokyo
  [20, 24, 30, 37, 40, 39, 35, 34, 34, 33, 28, 22], # Delhi
  [ 8, 10, 14, 20, 24, 28, 32, 32, 27, 23, 17, 11], # Shanghai
  [29, 29, 28, 27, 23, 23, 23, 25, 25, 26, 27, 28], # Sao Paulo
  [31, 32, 33, 33, 34, 32, 30, 30, 31, 34, 34, 32], # Mumbai
  [22, 24, 26, 27, 27, 26, 24, 25, 24, 24, 23, 23], # Mexico City
  [ 2,  5, 12, 21, 27, 30, 31, 30, 26, 19, 10,  4], # Bejing
  [ 9, 10, 14, 20, 25, 28, 32, 33, 29, 23, 18, 12], # Osaka
  [19, 21, 24, 29, 33, 35, 35, 35, 34, 30, 25, 21], # Cairo
  [ 4,  6, 11, 18, 22, 27, 29, 29, 25, 18, 13,  7], # New York
  [25, 29, 32, 33, 33, 32, 32, 32, 32, 31, 29, 26], # Dhaka
  [26, 28, 32, 35, 36, 35, 33, 32, 33, 35, 32, 28], # Karachi
]

sns.heatmap(temperatures, yticklabels=cities, xticklabels=months)

We can see the data in the resultant chart. But how do we interpret it?

It is actually fairly difficult to make any sense of the data. The left and right of the graph might contain somewhat darker colors, which maps to cooler temperatures, but even that is difficult to determine.

If you think about it, this makes sense. The cities are sorted by size, largest to smallest. Let's change the sorting to be latitude.

In [0]:
import seaborn as sns

cities = ['New York', 'Beijing', 'Tokyo', 'Osaka', 'Shanghai', 'Cairo', 'Delhi',
          'Karachi', 'Dhaka', 'Mexico City', 'Mumbai', 'Sao Paulo']

temperatures = [
  [ 4,  6, 11, 18, 22, 27, 29, 29, 25, 18, 13,  7], # New York
  [ 2,  5, 12, 21, 27, 30, 31, 30, 26, 19, 10,  4], # Beijing
  [10, 10, 14, 19, 23, 26, 30, 31, 27, 22, 17, 12], # Tokyo
  [ 9, 10, 14, 20, 25, 28, 32, 33, 29, 23, 18, 12], # Osaka
  [ 8, 10, 14, 20, 24, 28, 32, 32, 27, 23, 17, 11], # Shanghai
  [19, 21, 24, 29, 33, 35, 35, 35, 34, 30, 25, 21], # Cairo
  [20, 24, 30, 37, 40, 39, 35, 34, 34, 33, 28, 22], # Delhi
  [26, 28, 32, 35, 36, 35, 33, 32, 33, 35, 32, 28], # Karachi
  [25, 29, 32, 33, 33, 32, 32, 32, 32, 31, 29, 26], # Dhaka
  [22, 24, 26, 27, 27, 26, 24, 25, 24, 24, 23, 23], # Mexico City
  [31, 32, 33, 33, 34, 32, 30, 30, 31, 34, 34, 32], # Mumbai
  [29, 29, 28, 27, 23, 23, 23, 25, 25, 26, 27, 28], # Sao Paulo
]

sns.heatmap(temperatures, yticklabels=cities, xticklabels=months)

This makes much more sense. We can see that the cities at higher latitudes are colder from September through March and that the temperature tends to rise as the latitude gets smaller.

Also notice that Sao Paulo still seems to see warmer months mid-year even though it is in the southern hemisphere.

Admittedly, the color scheme is difficult to read. It is possible to change the color scheme using the `cmap=` argument. `cmap=` accepts lists of colors and preset color schemes. You can find the schemes in the [Matplotlib colormap documentation](https://matplotlib.org/3.1.0/tutorials/colors/colormaps.html).

In [0]:
import seaborn as sns

cities = ['New York', 'Beijing', 'Tokyo', 'Osaka', 'Shanghai', 'Cairo', 'Delhi',
          'Karachi', 'Dhaka', 'Mexico City', 'Mumbai', 'Sao Paulo']

temperatures = [
  [ 4,  6, 11, 18, 22, 27, 29, 29, 25, 18, 13,  7], # New York
  [ 2,  5, 12, 21, 27, 30, 31, 30, 26, 19, 10,  4], # Beijing
  [10, 10, 14, 19, 23, 26, 30, 31, 27, 22, 17, 12], # Tokyo
  [ 9, 10, 14, 20, 25, 28, 32, 33, 29, 23, 18, 12], # Osaka
  [ 8, 10, 14, 20, 24, 28, 32, 32, 27, 23, 17, 11], # Shanghai
  [19, 21, 24, 29, 33, 35, 35, 35, 34, 30, 25, 21], # Cairo
  [20, 24, 30, 37, 40, 39, 35, 34, 34, 33, 28, 22], # Delhi
  [26, 28, 32, 35, 36, 35, 33, 32, 33, 35, 32, 28], # Karachi
  [25, 29, 32, 33, 33, 32, 32, 32, 32, 31, 29, 26], # Dhaka
  [22, 24, 26, 27, 27, 26, 24, 25, 24, 24, 23, 23], # Mexico City
  [31, 32, 33, 33, 34, 32, 30, 30, 31, 34, 34, 32], # Mumbai
  [29, 29, 28, 27, 23, 23, 23, 25, 25, 26, 27, 28], # Sao Paulo
]

sns.heatmap(
    temperatures,
    yticklabels=cities,
    xticklabels=months,
    cmap='coolwarm',
)

There are many more options available. Check out the [heatmap documentation](https://seaborn.pydata.org/generated/seaborn.heatmap.html) for more.

# Exercises: Which Visualization?

There are five exercises in this Colab. Each exercise contains a dataset and a query about that dataset. Using Matplotlib or Seaborn, create a visualization that allows any viewer to easily answer the question.

For each exercise, choose one of the following visualizations:
* Pie chart
* Bar chart
* Line graph
* Scatter plot
* Heatmap

You may only use each visualization one time. When you are done, you should have one of each of the five types of visualizations.

In some cases, there may be more than one "right answer", as there are often multiple good ways to visualize data. Use your judgement to choose which one you think is best for each question.

Add titles, labels, color coding, and other visual aides when you can to help the user interpret the charts.

## Exercise 1: Desserts



We have a list of Bitcoin prices recorded at the end of each week (Sunday) in 2018 and 2019. Create a visualization that allows you to answer the question: *Which year, 2018 or 2019, tended to provide better returns for Bitcoin holders?*

**Student Solution**

In [0]:
prices = [14292.2, 12858.9, 11467.5, 9241.1, 8559.6, 11073.5, 9704.3, 11402.3,
          8762.0, 7874.9, 8547.4, 6938.2, 6905.7, 8004.4, 8923.1, 9352.4,
          9853.5, 8459.5, 8245.1, 7361.3, 7646.6, 7515.8, 6505.8, 6167.3, 
          6398.9, 6765.5, 6254.8, 7408.7, 8234.1, 7014.3, 6231.6, 6379.1,
          6734.8, 7189.6, 6184.3, 6519.0, 6729.6, 6603.9, 6596.3, 6321.7,
          6572.2, 6494.2, 6386.2, 6427.1, 5621.8, 3920.4, 4196.2, 3430.4,
          3228.7, 3964.4, 3706.8, 3785.4, 3597.2, 3677.8, 3570.9, 3502.5,
          3661.4, 3616.8, 4120.4, 3823.1, 3944.3, 4006.4, 4002.5, 4111.8,
          5046.2, 5051.8, 5290.2, 5265.9, 5830.9, 7190.3, 7262.6, 8027.4,
          8545.7, 7901.4, 8812.5, 10721.7, 11906.5, 11268.0, 11364.9, 10826.7,
          9492.1, 10815.7, 11314.5, 10218.1, 10131.0, 9594.4, 10461.1, 10337.3,
          9993.0, 8208.5, 8127.3, 8304.4, 7957.3, 9230.6, 9300.6, 8804.5,
          8497.3, 7324.1, 7546.6, 7510.9, 7080.8, 7156.2, 7321.5, 7376.8]
          
# Your Solution Goes Here

**Explanation**

Which chart did you choose and why?

> *Your solution goes here.*

Which year seemed to be better for Bitcoin holders?

> *Your solution goes here.*

---

### Answer Key

In [0]:
import matplotlib.pyplot as plt

prices = [14292.2, 12858.9, 11467.5, 9241.1, 8559.6, 11073.5, 9704.3, 11402.3,
          8762.0, 7874.9, 8547.4, 6938.2, 6905.7, 8004.4, 8923.1, 9352.4,
          9853.5, 8459.5, 8245.1, 7361.3, 7646.6, 7515.8, 6505.8, 6167.3, 
          6398.9, 6765.5, 6254.8, 7408.7, 8234.1, 7014.3, 6231.6, 6379.1,
          6734.8, 7189.6, 6184.3, 6519.0, 6729.6, 6603.9, 6596.3, 6321.7,
          6572.2, 6494.2, 6386.2, 6427.1, 5621.8, 3920.4, 4196.2, 3430.4,
          3228.7, 3964.4, 3706.8, 3785.4, 3597.2, 3677.8, 3570.9, 3502.5,
          3661.4, 3616.8, 4120.4, 3823.1, 3944.3, 4006.4, 4002.5, 4111.8,
          5046.2, 5051.8, 5290.2, 5265.9, 5830.9, 7190.3, 7262.6, 8027.4,
          8545.7, 7901.4, 8812.5, 10721.7, 11906.5, 11268.0, 11364.9, 10826.7,
          9492.1, 10815.7, 11314.5, 10218.1, 10131.0, 9594.4, 10461.1, 10337.3,
          9993.0, 8208.5, 8127.3, 8304.4, 7957.3, 9230.6, 9300.6, 8804.5,
          8497.3, 7324.1, 7546.6, 7510.9, 7080.8, 7156.2, 7321.5, 7376.8]
weeks = list(range(1, 105))

plt.title('Bitcoin Prices 2018-2019')
plt.ylabel('Price (USD)')
plt.xlabel('Week')
plt.plot(weeks, prices)
plt.show()

**Explanation**

Which chart did you choose and why?

> We chose a line chart, though a bar chart or even a heat map might be arguable in this scenario. The line chart works well with the relatively large amount of data that we have.

> Which year seemed to be better for Bitcoin holders?

2019 tended to have more of an upward trajectory than down, so it was better if you had bought earlier and were simply holding the coin.



---

## Exercise 2: Candy

We have a bag of candy. It contains five different kinds of candy, each named below. Create a chart that shows the percent chance that we would pull a Snickers candy out of the bag if we did a blind selection. Call out the chance of choosing a Snickers candy.

**Student Solution**

In [0]:
candy_names = ['Kit Kat', 'Snickers', 'Milky Way', 'Toblerone', 'Twix']
candy_counts = [52, 39, 78, 13, 78]

# Your Solution Goes Here

**Explanation**

Which chart did you choose and why?

> *Your solution goes here.*

What are the percentage odds that you'll choose a Snickers bar when randomly pulling a candy out of the bag?

> *Your solution goes here.*

---

### Answer Key

In [0]:
import matplotlib.pyplot as plt

candy_names = ['Kit Kat', 'Snickers', 'Milky Way', 'Toblerone', 'Twix']
candy_counts = [52, 39, 78, 13, 78]
explode = [0, 0.1, 0, 0, 0]

plt.title('Candy Distribution')
plt.pie(candy_counts,
        labels=candy_names,
        explode=explode,
        autopct='%1.1f%%')
plt.show()

**Explanation**

Which chart did you choose and why?

> We chose the pie chart since we had a small number of classes of data and because it shows percentages.

What are the percentage odds that you'll choose a Snickers bar when randomly pulling a candy out of the bag?

> 15%

---

## Exercise 3: Dessert Popularity

A restaurant we're consulting for has a dessert menu that's too big. They want to cut a few items from the menu. In order to keep most of their customers happy, they want to remove only the three least popular desserts from the menu.

We have a list of the desserts that the restaurant serves, as well as a count of the number of times that dessert has been sold in the last week.

Create a visualization that shows the relative popularities of the desserts. Call out the three desserts that should be removed.

**Student Solution**

In [0]:
dessert_sales = {
  'Lava Cake': 14,
  'Mousse': 5,
  'Chocolate Cake': 12,
  'Ice Cream': 19,
  'Truffles': 6,
  'Brownie': 8,
  'Chocolate Chip Cookie': 12,
  'Chocolate Pudding': 9,
  'Souffle': 10,
  'Chocolate Cheesecake': 17,
  'Chocolate Chips': 2,
  'Fudge': 9,
  'Mochi': 13,
}

# Your Solution Goes Here

**Explanation**

Which chart did you choose and why?

> *Your solution goes here.*

Which three desserts should be removed from the menu?

> *Your solution goes here.*

---

### Answer Key

In [0]:
import matplotlib.pyplot as plt
import numpy as np

dessert_sales = {
  'Lava Cake': 14,
  'Mousse': 5,
  'Chocolate Cake': 12,
  'Ice Cream': 19,
  'Truffles': 6,
  'Brownie': 8,
  'Chocolate Chip Cookie': 12,
  'Chocolate Pudding': 9,
  'Souffle': 10,
  'Chocolate Cheesecake': 17,
  'Chocolate Chips': 2,
  'Fudge': 9,
  'Mochi': 13,
}

names_sorted = sorted(dessert_sales, key=dessert_sales.get, reverse=True)
sales_sorted = [dessert_sales[n] for n in names_sorted]

dessert_count = len(dessert_sales)
ticks = np.arange(dessert_count)

plt.title('Dessert Sales')

plt.ylabel('Number Sold')
bar_color = ['g' if i < dessert_count-3 else 'r' for i in range(dessert_count)]
plt.bar(ticks, sales_sorted, color=bar_color)

plt.xlabel('Dessert')
plt.xticks(ticks, names_sorted, rotation=90)

plt.show()

**Explanation**

Which chart did you choose and why?

> We chose a bar chart since it nicely shows relative values. There were a few too many desserts for a pie chart to work. Also, it is nice to be able to sort the bars in the chart to make a point.

Which three desserts should be removed from the menu?

> Truffles, Mousse, Chocolate Chips

---

## Exercise 4: CPU Usage

We have the hourly average CPU usage for a worker's computer over the course of a week. Each row of data represents a day of the week starting with Monday. Each column of data is an hour in the day starting with 0 being midnight.

Create a chart that shows the CPU usage over the week. You should be able to answer the following questions using the chart:

* When does the worker typically take lunch?
* Did the worker do work on the weekend?
* On which weekday did the worker start working on their computer at the latest hour?

**Student Solution**

In [0]:
cpu_usage = [
  [2, 2, 4, 2, 4, 1, 1, 4, 4, 12, 22, 23, 
   45, 9, 33, 56, 23, 40, 21, 6, 6, 2, 2, 3], # Monday
  [1, 2, 3, 2, 3, 2, 3, 2, 7, 22, 45, 44,
   33, 9, 23, 19, 33, 56, 12, 2, 3, 1, 2, 2], # Tuesday
  [2, 3, 1, 2, 4, 4, 2, 2, 1, 2,  5, 31,
   54, 7, 6, 34, 68, 34, 49, 6, 6, 2, 2, 3], # Wednesday
  [1, 2, 3, 2, 4, 1, 2, 4, 1, 17, 24, 18,
   41, 3, 44, 42, 12, 36, 41, 2, 2, 4, 2, 4], # Thursday
  [4, 1, 2, 2, 3, 2, 5, 1, 2, 12, 33, 27,
   43, 8, 38, 53, 29, 45, 39, 3, 1, 1, 3, 4], # Friday
  [2, 3, 1, 2, 2, 5, 2, 8, 4, 2, 3,
   1, 5, 1, 2, 3, 2, 6, 1, 2, 2, 1, 4, 3], # Saturday
  [1, 2, 3, 1, 1, 3, 4, 2, 3, 1, 2,
   2, 5, 3, 2, 1, 4, 2, 45, 26, 33, 2, 2, 1], # Sunday
]

# Your Solution Goes Here

**Explanation**

Which chart did you choose and why?

> *Your solution goes here.*

When does the worker typically take lunch?

> *Your solution goes here.*

Did the worker do work on the weekend?

> *Your solution goes here.*

On which weekday did the worker start working on their computer at the latest hour?

> *Your solution goes here.*

---

### Answer Key

In [0]:
import seaborn as sns

cpu_usage = [
  [2, 2, 4, 2, 4, 1, 1, 4, 4, 12, 22, 23, 
   45, 9, 33, 56, 23, 40, 21, 6, 6, 2, 2, 3], # Monday
  [1, 2, 3, 2, 3, 2, 3, 2, 7, 22, 45, 44,
   33, 9, 23, 19, 33, 56, 12, 2, 3, 1, 2, 2], # Tuesday
  [2, 3, 1, 2, 4, 4, 2, 2, 1, 2,  5, 31,
   54, 7, 6, 34, 68, 34, 49, 6, 6, 2, 2, 3], # Wednesday
  [1, 2, 3, 2, 4, 1, 2, 4, 1, 17, 24, 18,
   41, 3, 44, 42, 12, 36, 41, 2, 2, 4, 2, 4], # Thursday
  [4, 1, 2, 2, 3, 2, 5, 1, 2, 12, 33, 27,
   43, 8, 38, 53, 29, 45, 39, 3, 1, 1, 3, 4], # Friday
  [2, 3, 1, 2, 2, 5, 2, 8, 4, 2, 3,
   1, 5, 1, 2, 3, 2, 6, 1, 2, 2, 1, 4, 3], # Saturday
  [1, 2, 3, 1, 1, 3, 4, 2, 3, 1, 2,
   2, 5, 3, 2, 1, 4, 2, 45, 26, 33, 2, 2, 1], # Sunday
]
days = ['M', 'T', 'W', 'T', 'F', 'S', 'S']
hours = list(range(24))

sns.heatmap(
    cpu_usage,
    yticklabels=days,
    xticklabels=hours,
    cmap='coolwarm',
)

**Explanation**

Which chart did you choose and why?

> We chose the heatmap since it could give us a tabular visualization of the CPU usage over time. If we understand CPU usage to mean that the worker is working, then we can use the colors on the map to answer the questions.

When does the worker typically take lunch?

> 13:00; you see the CPU usage drop on that hour every weekday

Did the worker do work on the weekend?

> Yes, there seems to be some work done on the computer on Sunday evening. You can tell by the "hotter" colors from 18:00 until 20:00 that day.

On which weekday did the worker start working on their computer at the latest hour?

> Wednesday seems to be the latest start. The computer doesn't really get going until 10:00 or 11:00.

---

## Exercise 5: Mushrooms

A researcher is studying mushrooms. They have found a ring of mushrooms and have labelled their coordinates on a plane. Typically mushrooms radiate out from a central starting mushroom. Given the coordinates below, the researcher wants to answer the question: *Approximately where on the plane was the original mushroom?*

Create a chart that will allow the researcher to approximate the center of the growth.

**Student Solution**

In [0]:
x = [4.61, 5.08, 5.18, 7.82, 10.46, 7.66, 7.6, 9.32, 14.04, 9.95, 4.95, 7.23, 
     5.21, 8.64, 10.08, 8.32, 12.83, 7.51, 7.82, 6.29, 0.04, 6.62, 13.16, 6.34, 
     0.09, 10.04, 13.06, 9.54, 11.32, 7.12, -0.67, 10.5, 8.37, 7.24, 9.18, 
     10.12, 12.29, 8.53, 11.11, 9.65, 9.42, 8.61, -0.67, 5.94, 6.49, 7.57, 3.11,
     8.7, 5.28, 8.28, 9.55, 8.33, 13.7, 6.65, 2.4, 3.54, 9.19, 7.51, -0.68,
     8.47, 14.82, 5.31, 14.01, 8.75, -0.57, 5.35, 10.51, 3.11, -0.26, 5.74,
     8.33, 6.5, 13.85, 9.78, 4.91, 4.19, 14.8, 10.04, 13.47, 3.28]

y = [-2.36, -3.41, 13.01, -2.91, -2.28, 12.83, 13.13, 11.94, 0.93, -2.76, 13.31,
     -3.57, -2.33, 12.43, -1.83, 12.32, -0.42, -3.08, -2.98, 12.46, 8.34, -3.19,
     -0.47, 12.78, 2.12, -2.72, 10.64, 11.98, 12.21, 12.52, 5.53, 11.72, 12.91,
     12.56, -2.49, 12.08, -1.09, -2.89, -1.78, -2.47, 12.77, 12.41, 5.33, -3.23,
     13.45, -3.41, 12.46, 12.1, -2.56, 12.51, -2.37, 12.76, 9.69, 12.59, -1.12,
     -2.8, 12.94, -3.55, 7.33, 12.59, 2.92, 12.7, 0.5, 12.57, 6.39, 12.84,
     -1.95, 11.76, 6.82, 12.44, 13.28, -3.46, 0.7, -2.55, -2.37, 12.48, 7.26,
     -2.45, 0.31, -2.51]

# Your Solution Goes Here

**Explanation**

Which chart did you choose and why?

> *Your solution goes here.*

What are the approximate x,y coordinates of the initial mushroom?

> *Your solution goes here.*

---

### Answer Key

In [0]:
import matplotlib.pyplot as plt

x = [4.61, 5.08, 5.18, 7.82, 10.46, 7.66, 7.6, 9.32, 14.04, 9.95, 4.95, 7.23, 
     5.21, 8.64, 10.08, 8.32, 12.83, 7.51, 7.82, 6.29, 0.04, 6.62, 13.16, 6.34, 
     0.09, 10.04, 13.06, 9.54, 11.32, 7.12, -0.67, 10.5, 8.37, 7.24, 9.18, 
     10.12, 12.29, 8.53, 11.11, 9.65, 9.42, 8.61, -0.67, 5.94, 6.49, 7.57, 3.11,
     8.7, 5.28, 8.28, 9.55, 8.33, 13.7, 6.65, 2.4, 3.54, 9.19, 7.51, -0.68,
     8.47, 14.82, 5.31, 14.01, 8.75, -0.57, 5.35, 10.51, 3.11, -0.26, 5.74,
     8.33, 6.5, 13.85, 9.78, 4.91, 4.19, 14.8, 10.04, 13.47, 3.28]

y = [-2.36, -3.41, 13.01, -2.91, -2.28, 12.83, 13.13, 11.94, 0.93, -2.76, 13.31,
     -3.57, -2.33, 12.43, -1.83, 12.32, -0.42, -3.08, -2.98, 12.46, 8.34, -3.19,
     -0.47, 12.78, 2.12, -2.72, 10.64, 11.98, 12.21, 12.52, 5.53, 11.72, 12.91,
     12.56, -2.49, 12.08, -1.09, -2.89, -1.78, -2.47, 12.77, 12.41, 5.33, -3.23,
     13.45, -3.41, 12.46, 12.1, -2.56, 12.51, -2.37, 12.76, 9.69, 12.59, -1.12,
     -2.8, 12.94, -3.55, 7.33, 12.59, 2.92, 12.7, 0.5, 12.57, 6.39, 12.84,
     -1.95, 11.76, 6.82, 12.44, 13.28, -3.46, 0.7, -2.55, -2.37, 12.48, 7.26,
     -2.45, 0.31, -2.51]

plt.title('Mushroom Ring')
plt.scatter(x, y)
plt.show()

**Explanation**

Which chart did you choose and why?

> We chose to use a scatter plot since it shows data on an x,y plane.

What are the approximate x,y coordinates of the initial mushroom?

> (7.5, 5.0)