<a href="https://colab.research.google.com/github/dymiyata/intro-to-ml-and-ai-2025-2026/blob/main/intro_to_data_visualization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Data Visualization

Now that we know how to use pandas to work with dataframes the next important thing to talk about is being able to visualize the data.

Moreover, when we do machine learning, it's important for us to be able to visualize our results, especially if we're trying to show our results to others.

**Important Reminder**: Make sure you save a copy of this notebook to your own google drive.  Otherwise any changes you make won't be saved. You can do so by clicking

File -> Save a copy in Drive

## Importing Libaries

We begin by importing pandas in order to actually work with the data.

In [None]:
import pandas as pd

Next we import two new libraries.  
- Matplotlib: A comprehensive library for creating visualizations in Python
- Seaborn: A library built on top of matplotlib which abstracts away some of the complexity of working with Matplotlib. (It simplifies it for us).


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

### The Dataset

Today we will work with a dataset of many flowers of a certain type.  This flower is called an *iris*.
- Each row is a specific iris with a bunch of measurements taken
- There are three species of iris in this dataset as we will see.

This dataset is included in the Seaborn library so we don't have to download anything to use it, we can simply load it with Seaborn.

(normally we will have to obtain a .csv file or something. Then manually load it in with pandas like we did with the pokemon dataset last week).



In [None]:
df = sns.load_dataset("iris")

If we want to do a rough overview of the dataset, what commands can we run?
Here are some questions we can try to answer:

- How many rows are there?
- What are the columns of the dataframe?
- What are the datatypes for each of the entries in the dataset?

- What is the average `sepal_length`?

- What is the average `sepal_length` for just the `setosa` `species`

- What are the different species and how many of each are there?

## Histograms

Who knows what a histogram is?

We can make a histogram to view the distribution of a single column in our dataframe.  Let's do so with the `petal_length` column.
- What does changing the number of bins do?
- What does `kde=True` do?
- What do you think about the shape of the distribution?
- Let's add the line `sns.set_theme()`

In [None]:
sns.histplot(df["petal_length"], bins=20, kde=True)
plt.show()


Let's add a title and some more appropriate labels to the axes.

In [None]:
sns.histplot(df["petal_length"], bins=20)
plt.title("Distribution of Petal Lengths of Irises")
plt.xlabel("Petal Length (cm)")
plt.ylabel("Count")
plt.show()

## Exploring Relationships Between Variables

Who knows what a scatter plot is?

Let's create a scatter plot exploring `petal_length` vs `petal_width`.

Any guesses for an overall trend?

In [None]:
sns.scatterplot(x="petal_length", y="petal_width", data=df)

Notice some of he values seems to cluster together. Why might this be the case?

## Correlation

Trends between two quantitative variables can be measured using a value called the *correlation coefficiont* or just *correlation* usually denoted by the letter $r$.  

- It is always the case that
$$-1 \leq r \leq 1$$
- The closer $r$ is to $1$ or $-1$ the *stronger* the correlation.
- A positive $r$ value, means the data seems to have a positive slope
- A negative $r$ value, means the data seems to have a negative slope
- An $r$ value near 0 means there isn't a correlation between the two variables.

We can use `DataFrame.corr()` to explore the correlations between the variables. First we need to only keep the quantitative variables.


In [None]:
numerical_df = df.select_dtypes("number")
numerical_df.corr()


We can use a heatmap to make this more visually appealing

In [None]:
sns.heatmap(numerical_df.corr(), annot=True, cmap='crest')
plt.show()

Let's make a scatter plot with variables that seem strongly correlated.  Then another with variables that are less strongly correlated.

In [None]:
# Strong Correlation Here

In [None]:
# Weak Correlation Here


Maybe `sepal_width` and `sepal_length` are not that correlated.  However, there still seems to be some *clutering*.  Maybe within a certain species they are correlated?

Let's explore this with the same plot, but we add colors to separate the different species using the argument `hue="species"`.

## Pair plots

Maybe we want scatterplots for all the pairs of numerical variables at once. We can do so with a *pair plot*.

## Many Other Types of Plots to Explore

Here are just some random examples of other types of plots you can make with Seaborn.  There's too much to go over everything in depth, so feel free to play around with things to explore what's possible.

In [None]:
sns.countplot(x="species", data=df)

In [None]:
sns.barplot(x="species", y="sepal_length", data=df, hue="species")
plt.show()

In [None]:
sns.boxplot(x="species", y="sepal_length", data=df)
plt.show()

In [None]:
sns.swarmplot(x="species", y="petal_length", data=df, hue="species", size=4.5)
plt.show()

In [None]:
grid = sns.FacetGrid(df, col="species")
grid.map_dataframe(sns.scatterplot, x="sepal_length", y="petal_length")
plt.show()

## Trends and Patterns

Notice by plotting the data we can begin to see many trends in the data.
It seems that if we know some information about an iris, we can predict other information about it.

- For example, if we know the `petal_width` maybe we can *predict* the `petal_length`.  

- Or even better, if we know `petal_width`, `petal_length`, `sepal_width` and `sepal_length` perhaps we can *predict* what `species` of iris it is.  

In the future weeks we will finally start to do some *actual* machine learning to train a computer to learn these patterns in order to be able to make these predictions.

We will begin with (multi) linear regression.  Here's a preview:

In [None]:
sns.regplot(x="petal_length", y="petal_width", data=df)
plt.show()