# Visualisation with Plotly
By the end of this lecture on you will be able to:
- work with Plotly via Pandas or directly from Polars
- create bar, row, grouped bar and  scatter charts with Plotly
- create time series charts with Plotly

In [None]:
import polars as pl
import plotly.express as px

In [None]:
csvFile = '../data/titanic.csv'

In [None]:
df = pl.read_csv(csvFile)
df.head(3)

## Using Plotly via Pandas
With a **Pandas** `DataFrame` in Plotly we can pass the `DataFrame` as the first argument and then the column names as subsequent arguments

In [None]:
px.scatter(
    df.to_pandas(),
    x="Age",
    y="Fare"
)

This approach is handy as it saves us from writing the name of the `DataFrame` variable for each column. 

Unfortunately, at present we cannot pass a **Polars** `DataFrame` in the same way. This is because if Plotly does not recognise the type of the first entry it calls `pd.DataFrame` on it. Calling `pd.DataFrame` on a Polars `DataFrame` does not work as it: 
- transposes the data and
- drops the column names

We see this below for the Titanic data

In [None]:
import pandas as pd
pd.DataFrame(df)

This beaviour will be corrected at some point within the internals of Pandas and Polars.

If you want to use the approach of converting to Pandas be aware that:
- converting to Pandas requires copying your data and 
- Plotly also copies your data internally

Therefore it is best to limit the columns to those needed for the plot with `select`

In [None]:
px.scatter(
    (
        df
        .select(["Age","Fare"])
        .to_pandas()
    ),
    x="Age",
    y="Fare"
)

## Working with Plotly directly
In the rest of this notebook we see how to work with Plotly directly from a Polars `DataFrame`.

The key point is that we pass columns directly to the `x`,`y`,`color` arguments without passing the `DataFrame` first

## Bar chart
We can make a bar chart with Plotly by specifying the dataframe columns with the `x` and `y` arguments

In [None]:
classCounts = (
    df['Pclass']
    .value_counts()
    .sort("Pclass")
    .with_column(
        pl.col("Pclass").cast(pl.Utf8)
    )
)
px.bar(
    x=classCounts["Pclass"],
    y=classCounts["counts"]
)

Note that we have to cast the `Pclass` column from `Int64` to `Utf8` (string) for Plotly to process it correctly.

We can format the chart in the normal way for Plotly. In this example we add a title and axis labels

In [None]:
px.bar(
    x=classCounts["Pclass"],
    y=classCounts["counts"],
    title="Number of passengers per class",
    labels = {
        "x":"Passenger class",
        "y":"Number of passengers",
    },
)

## Row chart
We plot the same data as a horizontal row chart by:
- switching the `x` and `y` arguments and
- setting `orientation='h'`

In [None]:
px.bar(
    y=classCounts["Pclass"],
    x=classCounts["counts"],
    orientation='h'
)

## Grouped bar chart
We use the example of the survival rate broken down by passenger class from the lecture on group operations using `over`

In [None]:
survivedPercentageDf = (
    pl.read_csv(csvFile)
    .groupby(["Pclass","Survived"])
    .agg(
        pl.col("Name").count().alias("counts")
    )
    .with_column(
        (100*(pl.col("counts")/pl.col("counts").sum().over("Pclass"))).round(3).alias("% Survived")
    )
    .sort(["Pclass","Survived"])
)

We create a grouped bar chart by passing the column to group the bars by to the `color` argument

In [None]:
fig = px.histogram(
          x = survivedPercentageDf["Pclass"], 
          y = survivedPercentageDf["% Survived"],
          color = survivedPercentageDf['Survived'], 
          barmode = 'group',
          title="% survival by class",
          labels = {
              "x":"Passenger class",
              "y":"% Survived",
              "color":"Survived"
          },
          height=400
)
fig.show()

## Scatter plot and plot size
We make a scatter plot of `log10(Age)` and `log10(Fare`).

We colour the points by survival. If we keep `Survived` as an integer-column Plotly sets a continuous color scale. Here we cast `Survived` to string so that Plotly uses a discrete color scale

In [None]:
px.scatter(
    x=df["Age"].log(10),
    y=df["Fare"].log(10),
    color=df["Survived"].cast(pl.Utf8),
    labels = {
        "x":"Age",
        "y":"Fare",
        "color":"Survived"
    },
    width=800,
    height=600
)

To control the shape of the chart in Jupyter we must set the width **and** height fields

## Line chart
For the line chart we bring in some new data: a time series of ocean wave heights from the buoys near the stormy coast of Ireland

In [None]:
waveCsvFile = "../data/wave_data.csv"

We are primarily interested in looking at time series of the `significant_wave_height` column. This measures the wave height in metres at each station (i.e. at each buoy).

We show the first two rows of this `DataFrame` here

In [None]:
(
    pl.read_csv(waveCsvFile,parse_dates=True)
    .select(["time","stationID","significant_wave_height"])
    .head(2)
)

We use `groupby_dynamic` to group the wave data by station and in 1 hour blocks. We take the hourly mean by station.

We apply a filter at the end of this query to restrict the plot to the 6 stations with the largest waves

In [None]:
stationAggs = (
    pl.read_csv(waveCsvFile,parse_dates=True)
    .groupby_dynamic('time',every="1h",by="stationID")
    .agg(
            pl.col(pl.Float64).mean().suffix("_mean"),
    )
    # Apply a filter to output the stations with the largest waves
    .filter(
        pl.col('significant_wave_height_mean').mean().over("stationID").rank(method='dense',reverse=True) < 7
    )
)
stationAggs.head(3)

We can now produce a time series plot of this hourly data by station. We pass:
- the `time` column as the `x`-axis
- the `significant_wave_height_mean` as the `y`-axis
- the `stationID` as the `color`-axis

In [None]:
px.line(
    x=stationAggs["time"],
    y=stationAggs["significant_wave_height_mean"],
    color=stationAggs["stationID"],
    title="Mean wave height in hourly intervals",
    labels = {
        "x":"Date",
        "y":"Wave height (m)",
    },
    width=800,
    height=400
)

The chart shows the arrival of some large waves on 25th September 2022

## Exercises
In the exercises you will develop your understanding of:
- creating charts via Pandas or directly from Polars
- creating charts with control over axis labels and sizing

### Exercise 1
Make a bar chart of the `SibSp` column showing how many passengers there are with 0,1,2 etc siblings. Ensure the chart is correctly ordered in the number of siblings (that means the x-axis must be in order from 0,1,2 etc)

Do this first by converting to Pandas

Expand the following cell if you want some hints

In [None]:
#Hint 1: Do a value counts on the `SibSp` column
#Hint 2: Cast the `SibSp` column to string

In [None]:
df = pl.read_csv(csvFile)

Do this again working directly with Polars

### Exercise 2
Inspect the columns in the stationsAggs `DataFrame`. Note the `mean_wave_period_mean` column that has the mean wave period (the time between wave crests) in seconds.

In [None]:
stationAggs = (
    pl.read_csv(waveCsvFile,parse_dates=True)
    .groupby_dynamic('time',every="1h",by="stationID")
    .agg(
            pl.col(pl.Float64).mean().suffix("_mean"),
    )
    # Apply a filter to output the stations with the largest waves
    .filter(
        pl.col('significant_wave_height_mean').mean().over("stationID").rank(method='dense',reverse=True) < 7
    )
)
stationAggs.head(3)

Make a time series plot of the mean wave period showing it has a similar pattern to the wave height plot above

Now make a scatter plot with:
- the significant wave height on the x-axis
- the mean wave period on the y-axis
- coloured by station

Add labels and a title:
- on the x-axis "Wave height (m)"
- on the y-axis "Wave period (s)"
- on the color axis "Station ID"
- for the title "Wave height versus period"

Make the plot area have equal width and height

## Solutions

### Solution to exercise 1
Make a bar chart of the `SibSp` column showing how many passengers there are with 0,1,2 etc siblings. Ensure the chart is correctly ordered in the number of siblings (that means the x-axis must be in order from 0,1,2 etc)

Do this first by converting to Pandas

Expand the following cell if you want some hints

In [None]:
#Hint 1: Do a value counts on the `SibSp` column
#Hint 2: Cast the `SibSp` column to string

In [None]:
(
    px.bar(
        df["SibSp"]
        .value_counts()
        .with_columns(
            pl.col("SibSp").cast(pl.Utf8)
        )
        .sort("SibSp")
        .to_pandas(),
        x="SibSp",
        y="counts"
    )
)        

Do this again working directly with Polars

In [None]:
siblingCount = (
    df["SibSp"]
    .value_counts()
    .with_columns(
        pl.col("SibSp").cast(pl.Utf8)
    )
    .sort("SibSp")
)
(
    px.bar(
        x=siblingCount["SibSp"],
        y=siblingCount["counts"]
    )
)        

### Solution to Exercise 2

Inspect the columns in the stationsAggs `DataFrame`. Note the `mean_wave_period_mean` column that has the mean wave period (the time between wave crests) in seconds.

Make a time series plot of the mean wave period (we will add in labels etc below)

In [None]:
px.line(
    x=stationAggs["time"],
    y=stationAggs["mean_wave_period_mean"],
    color=stationAggs["stationID"]
)

Make a scatter plot with:
- the significant wave height on the x-axis
- the mean wave period on the y-axis
- coloured by station

In [None]:
(
    px.scatter(
        x=stationAggs["significant_wave_height_mean"],
        y=stationAggs["mean_wave_period_mean"],        
        color=stationAggs["stationID"]
    )
)

Add labels and a title:
- on the x-axis "Wave height (m)"
- on the y-axis "Wave period (s)"
- on the color axis "Station ID"
- for the title "Wave height versus period"

In [None]:
(
    px.scatter(
        x=stationAggs["significant_wave_height_mean"],
        y=stationAggs["mean_wave_period_mean"],        
        color=stationAggs["stationID"],
        labels={
            "x":"Wave height (m)",
            "t":"Wave period (s)",
        },
        title="Wave height versus period",
    )
)

Make the plot area have equal width and height

In [None]:
(
    px.scatter(
        x=stationAggs["significant_wave_height_mean"],
        y=stationAggs["mean_wave_period_mean"],        
        color=stationAggs["stationID"],
        labels={
            "x":"Wave height (m)",
            "t":"Wave period (s)",
        },
        title="Wave height versus period",
        width=600,
        height=500
    )
)