# Visualisation with Seaborn
By the end of this section you will learn how to:
- work with Seaborn directly from Polars and via Pandas
- create a range of charts with Seaborn

Seaborn does not have explicit support for Polars. However, some Seaborn charts accept Polars `DataFrames` directly while others require conversion to Pandas.

In [None]:
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
csvFile = '../data/titanic.csv'

In [None]:
df = pl.read_csv(csvFile)
df.head(3)

## Bar chart
We can make a bar chart with Seaborn by specifying the dataframe columns with the `x` and `y` arguments

In [None]:
classCounts = (
    df['Pclass']
    .value_counts()
    .sort("Pclass")
    .with_columns(
        pl.col("Pclass").cast(pl.Utf8)
    )
)
sns.barplot(
    classCounts.to_pandas(),
    x="Pclass",
    y="counts"
)

Note that we have to cast the `Pclass` column from `Int64` to `Utf8` (string) so that Seaborn knows `Pclass` is a categorical-type variable rather than a numeric variable. This point recurs below - it may be worth casting integer columns that are really categoricals to strings when doing visualisation.

### Adding title and axis labels
We add a title and axis labels in the standard way for Seaborn using the `set` method

In [None]:
(
    sns.barplot(
        classCounts.to_pandas(),
        x="Pclass",
        y="counts",
    )
    .set(
        title="Number of Passengers by class",
        xlabel="Passneger class",
        ylabel="Number of passnegers"
    )
)

## Row chart
We plot the same data as a horizontal row chart by switching the `x` and `y` arguments to `sns.barplot`

In [None]:
sns.barplot(
    classCounts.to_pandas(),
    y="Pclass",
    x="counts"
)

## Grouped bar chart
We use the example of the survival rate broken down by passenger class from the lecture on group operations using the `over` window expression

In [None]:
survivedPercentageDf = (
    pl.read_csv(csvFile)
    .groupby(["Pclass","Survived"])
    .agg(
        pl.col("Name").count().alias("counts")
    )
    .with_columns(
        (100*(pl.col("counts")/pl.col("counts").sum().over("Pclass"))).round(3).alias("% Survived")
    )
    .sort(["Pclass","Survived"])
)

We create a grouped bar chart with the `sns.catplot` function

In [None]:
# Draw a grouped barplot by passenger class and survival
g = sns.catplot(
    data=survivedPercentageDf.to_pandas(), 
    kind="bar",
    x="Pclass", 
    y="% Survived", 
    hue="Survived",
    palette="dark", 
    alpha=.6, # Set transparency
    height=6
)
g.set_axis_labels("Passenger class", "% Survived")
g.legend.set_title("Survival")
g.set(title="Survival by passenger class")

## Scatter plot and plot size
We make a scatter plot of `log10(Age)` and `log10(Fare`) with `sns.scatterplot`. We colour the points by passenger class by setting `hue="Pclass"`

We do not need to cast to Pandas in this case

In [None]:
sns.scatterplot(
    (
        df
        .with_columns(
            [
                pl.col("Pclass").cast(pl.Utf8),
                pl.col("Age").log(),
                pl.col("Fare").log(),
            ]
        )
    ),
    x="Age",
    y="Fare",
    hue="Pclass"
)

## Other Seaborn charts
Seaborn comes with a variety of charts for more advanced visualisations.

In this example we use `sns.jointplot` to look at the scatter plot of `log(Age)` versus `log(Fare)` along with the distribution of each axis by passenger class.

We can pass a Polars `DataFrame` directly to this function

In [None]:
sns.jointplot(
    data=(
        df
        .with_columns(
            [
                pl.col(pl.Float64).log(),
                pl.col("Pclass").cast(pl.Utf8)
            ]
        )
    ),
    x="Age", 
    y="Fare", 
    hue="Pclass",
)

We can also use the Seaborn `sns.pairplot` to create a facet chart that shows relationships between many columns in one overwhelming chart

In [None]:
sns.pairplot(
    data=(
        df
        .with_columns(
            [
                pl.col(pl.Float64).log(),
                pl.col("Pclass").cast(pl.Utf8)
            ]
        )
        .to_pandas()
    ),
    hue="Pclass"
)

## Line chart
For the line chart we bring in some new data: a time series of ocean wave heights from the buoys near the stormy coast of Ireland

In [None]:
waveCsvFile = "../data/wave_data.csv"

We are primarily interested in looking at time series of the `significant_wave_height` column. This measures the wave height in metres at each station (i.e. at each buoy).

We show the first two rows of this `DataFrame` here

In [None]:
(
    pl.read_csv(waveCsvFile)
    .with_columns(pl.col("time").str.strptime(pl.Datetime, "%Y-%m-%dT%TZ"))
    .select(["time","stationID","significant_wave_height"])
    .head(2)
)

We use `groupby_dynamic` to group the wave data by station and in 1 hour blocks. We take the hourly mean by station.

We apply a filter at the end of this query to restrict the plot to the 6 stations with the largest waves

In [None]:
stationAggs = (
    pl.read_csv(waveCsvFile)
    .with_columns(pl.col("time").str.strptime(pl.Datetime, "%Y-%m-%dT%TZ"))
    .groupby_dynamic('time',every="1h",by="stationID")
    .agg(
            pl.col(pl.Float64).mean().suffix("_mean"),
    )
    # Apply a filter to output the stations with the largest waves
    .filter(
        pl.col('significant_wave_height_mean').mean().over("stationID").rank(method='dense',descending=True) < 7
    )
)
stationAggs.head(3)

We can now produce a time series plot of this hourly data by station. 

To do we can again pass a Polars `DataFrame` directly to `sns.lineplot`

In [None]:
sns.lineplot(
    stationAggs,
    x="time",
    y="significant_wave_height_mean",
    hue="stationID",
)

We need to make the plot wider to have clear x

In [None]:
fig, ax = plt.subplots(figsize=(12, 5))

sns.lineplot(
    stationAggs,
    x="time",
    y="significant_wave_height_mean",
    hue="stationID",
    ax=ax
)

The chart shows the arrival of some large waves on 25th September 2022

## Exercises
In the exercises you will develop your understanding of:
- creating a range of charts with Seaborn
- creating charts with control over axis labels and sizing

### Exercise 1
Make a bar chart of the `SibSp` column showing how many passengers there are with 0,1,2 etc siblings. Ensure the chart is correctly ordered in the number of siblings (that means the x-axis must be in order from 0,1,2 etc)

Expand the following cell if you want some hints

In [None]:
#Hint 1: Do a value counts on the `SibSp` column
#Hint 2: Cast the `SibSp` column to string

In [None]:
df = pl.read_csv(csvFile)

### Exercise 2
Inspect the columns in the stationsAggs`DataFrame` below. Note the `mean_wave_period_mean` column that has the mean wave period (the time between wave crests) in seconds.

In [None]:
stationAggs = (
    pl.read_csv(waveCsvFile)
    .with_columns(pl.col("time").str.strptime(pl.Datetime, "%Y-%m-%dT%TZ"))
    .groupby_dynamic('time',every="1h",by="stationID")
    .agg(
            pl.col(pl.Float64).mean().suffix("_mean"),
    )
    # Apply a filter to output the stations with the largest waves
    .filter(
        pl.col('significant_wave_height_mean').mean().over("stationID").rank(method='dense',descending=True) < 7
    )
)
stationAggs.head(3)

Make a time series plot of the mean wave period showing it has a similar pattern to the wave height plot above

Now make a scatter plot with:
- the significant wave height on the x-axis
- the mean wave period on the y-axis
- coloured by station

Add labels and a title:
- on the x-axis "Wave height (m)"
- on the y-axis "Wave period (s)"
- on the color axis "Station ID"
- for the title "Wave height versus period"

## Solutions
### Solution to exercise 1
Make a bar chart of the `SibSp` column showing how many passengers there are with 0,1,2 etc siblings. Ensure the chart is correctly ordered in the number of siblings (that means the x-axis must be in order from 0,1,2 etc)

Expand the following cell if you want some hints

In [None]:
#Hint 1: Do a value counts on the `SibSp` column
#Hint 2: Cast the `SibSp` column to string

In [None]:
siblingCount = (
    df["SibSp"]
    .value_counts()
    .with_columns(
        pl.col("SibSp").cast(pl.Utf8)
    )
    .sort("SibSp")
)
(
    sns.barplot(
        data=siblingCount.to_pandas(),
        x="SibSp",
        y="counts"
    )
)        

### Solution to Exercise 2

Inspect the columns in the stationsAggs `DataFrame`. Note the `mean_wave_period_mean` column that has the mean wave period (the time between wave crests) in seconds.

Make a time series plot of the mean wave period showing it has a similar pattern to the wave height plot above

In [None]:
sns.lineplot(
    stationAggs,
    x="time",
    y="mean_wave_period_mean",
    hue="stationID"
)

Make a scatter plot with:
- the significant wave height on the x-axis
- the mean wave period on the y-axis
- coloured by station with a legend

In [None]:
sns.scatterplot(
    stationAggs,
    x="significant_wave_height_mean",
    y="mean_wave_period_mean",
    hue="stationID"
)

Add labels and a title:
- on the x-axis "Wave height (m)"
- on the y-axis "Wave period (s)"
- on the color axis "Station ID"
- for the title "Wave height versus period"

In [None]:
sns.scatterplot(
    stationAggs,
    x="significant_wave_height_mean",
    y="mean_wave_period_mean",
    hue="stationID",
).set(xlabel='Wave height (m)', ylabel='Wave period (s)',title="Wave height versus period")
plt.legend(title="Station ID")