In [None]:
#  %pip install --upgrade xarray numpy matplotlib pandas seaborn

In [None]:
# Ignore future warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# import external packages
import xarray as xr
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

# Add path with self-created packages and import them
import sys
sys.path.append('./src')
import sciebo

# Wheel Speed Analysis with Pandas and Seaborn 

In the experiment reported by [Steinmetz et al, 2019 in Nature](https://www.nature.com/articles/s41586-019-1787-x), the mice perform a discrimination task where that had to find out which stimuli (left vs right) has the higher contrast. And they report their decision by steering a wheel such that the stimuli with the higher contrast moves to the center screen.

##### Analysis Goals
In this notebook we will be analyzing the speed at which the mouse turned the wheel to report their decision.

##### Learning Goals
First, we will be exploring the [**Xarray**](https://docs.xarray.dev/) Python package to load the data files in `.nc` (i.e. netCDF) format. We then will continue our exploration of the [**Pandas**](https://pandas.pydata.org/) and [**Seaborn**](https://seaborn.pydata.org/) Python packages to visualise and analyse this data.

---

#### Download Dataset

This data has been pre-processed from Steinmetz and al 2019, and is hosted on Sciebo: https://uni-bonn.sciebo.de/s/Y8C1TJcuk1GFg3V.  The code below should download one of the files (i.e. "steinmetz_2016-12-14_Cori.nc") to the folder `data`.

In [None]:
sciebo.download_from_sciebo('https://uni-bonn.sciebo.de/s/qmuIZfZC2bRe2iR', 'data/steinmetz_2016-12-14_Cori.nc')

---

## XArray Datasets: From netCDF files to pandas DataFrames

We will first explore a new file format: the **NetCDF** format, which is great for handling big, complicated data easily. Unlike a CSV file, which is like a basic spreadsheet, a netCDF file not only holds your data but also can contain the metadata explaining what the data is about. This is quite useful when you're dealing with lots of numbers and measurements.

**xarray**

| Code               | Description                                                                                                  |
|--------------------|--------------------------------------------------------------------------------------------------------------|
| `dset = xr.load_dataset()`| Loads a dataset from a specified file path using the xarray library, ideal for multi-dimensional data arrays. |
| `dset["variable_name"]`   | Extracts a single variable from the dataset using indexing.                                                   |
| `dset["variable_name"].to_dataframe()`     | Creates a Pandas DataFrame from the selected variable, for familiar data manipulation and analysis.   |

Let's start by load the file we download using the `xarray` library. 

In [None]:
dataset = xr.load_dataset('data/steinmetz_2016-12-14_Cori.nc')
dataset

The movement of the wheel is only recorded during "active trials". Hence we use the `where(condition)` method of the Xarray DataSet to *drop* all datapoints where the value of `active_trial` is not 1. 

In [None]:
dataset = dataset.where(dataset.active_trials==1 , drop=True)
dataset

**Note** that now the number of trials is lower. Let's now continue by getting a better understanding of the Xarrat dataset.

While you can interactively explore the variables that are available in this dataset, you can also view the variables using the `data_vars` attribute:

In [None]:
dataset.data_vars

**Exercies:**

Let's explore the dataset we loaded using the Xarray Python package and see how we can turn them into Pandas DataFrames.

**Example:** Similar to Pandas DataFrame, we can also use the indexing (using `[]`) with variable names to extract a single variable (or multiple variables) from the dataset:

In [None]:
dataset["contrast_left"]

Select only the `response_time` variable from the dataset.

Using the `.to_dataframe()` method create a Pandas DataFrame from `dataset["response_time"]`.

Use the `.reset_index()` on the resulting dataframe to expand the all index columns (e.g. `mouse` and `session_date`) across all rows.

Using the methods we used above, create a Pandas Dataframe for the wheel speed data (variable name is `wheel`). Let's call it `wheel_df`.

The resulting DataFrame will look similar to the following:
| mouse | session_date | trial | time | wheel |
|-------|--------------|-------|------|-------|
| Cori  | 2016-12-14   | 1     | 0.01 | -1.0  |
| Cori  | 2016-12-14   | 1     | 0.02 | 0.0   |
| Cori  | 2016-12-14   | 1     | 0.03 | 0.0   |
| Cori  | 2016-12-14   | 1     | 0.04 | 0.0   |
| Cori  | 2016-12-14   | 1     | 0.05 | 0.0   |
| ...  | ...  | ...     | ... | ...  |

---

## Visualizing Wheel Speed with Seaborn

In this section we will use the [**Seaborn**](https://seaborn.pydata.org/) plotting library to visualize the speed at which the mouse turned the wheel in every trial. Not only Seaborn is designed to work well with Pandas DataFrames, but it also produces informative and good-looking plots with a minimal code.   

The experimental data contains the wheel speed over time, across all trials. A positive turning speed means that the wheel is being moved to the right, while a negative speed tells us the wheel is being moved to the left. When the wheel is not being moved at all, the turning speed is zero.

Let's visualise this data using line plots from the Seaborn library.


**pandas**

| Code                      | Description                                      |
|---------------------------|--------------------------------------------------|
| `df['column_name']`       | accessing individual columns.                     |
| `df.reset_index()`        | Resets the index of the DataFrame, making it sequential. |
| `df.set_index(new_index)` | Sets the index of the DataFrame to `new_index`.  |
| `df.unstack()`            | Pivots a level of the index labels to the columns, reshaping the DataFrame. |
| `df.stack()`              | Pivots a level of the column labels to the index, reshaping the DataFrame. |

**seaborn** 

| Code                                          | Description                                   |
| --------------------------------------------- | --------------------------------------------- |
| `sns.lineplot(data)`                          | Plot data as a line.      |
| `sns.lineplot(data, label='my_label')`        | Show a legend label associated with data.     |
| `sns.lineplot(data, color='blue')`            | Plot a line in blue.                          |
| `sns.lineplot(data, linestyle='dotted')`      | Plot a dotted line.                           |

Before we start visualizing the wheel speed, let's restructure our `wheel_df` such that it becomes a bit easier to analyse. Let's change the dataframe such that as indices we have `time` and as columns we have `trial`:

Set `time` and `trial` both as indices of the `wheel_df`, using `.set_index()` method.

In [None]:
wheel_df_indexed = wheel_df.set_index(['time', 'trial'])
wheel_df_indexed

Use `[]` to select the `"wheel"` column, and apply the `.unstack()` method.

In [None]:
wheel_df = wheel_df_indexed["wheel"].unstack()
wheel_df

**Exercises**

**Example:** Use a line plot to plot the wheel speed for trial 7.

In [None]:
sns.lineplot(wheel_df[7])

Use a line plot to plot the wheel speed for trial 17.

Create a line plot for the wheel speed in two different trials (e.g. trial 9 and 21) by calling the plotting function twice.

Let's now explore some of the other arguments that we can pass into the `sns.lineplot()` function to change the style of the plot or add more information to it (e.g. a legend).

Recreate the plot above where one trial is shown as a dotted line.

Recreate the plot above, this time adding a legend to the plot. **Hint:** specify a `label` for each line plot.

Plot three trials, where each trial has the same line color but a different linestyle

Plot three trials, where each trial has a different line color but the same linestyle

Instead of using a single value to select a specific trial, we can specify a list of trial numbers using `[]`, and plot all of them at the same time. 

**Example:** Plot wheel speed for trials 9, 34, and 21.

In [None]:
sns.lineplot(wheel_df[[9, 34, 21]])

And now let's plot the wheel speeds for all trials. **Hint:** pass the argument `legend=None` to `lineplot` to hide the legend.

Remake the above plot changing the opacity of the lines. **Hint:** `lineplot` has an argument `alpha` that takes values between 0 and 1. 

Experiment with the `alpha` value to find which value best displays the lines.

**BONUS** This looks overplotted! Instead we can make a heatmap, which shows how many datapoints are in a region of the plot.

Here is the idea: we want to change the line plot essentially to a 2d histogram: one dimension is time, and the other wheel speed, the count is basically how many times every combination of these two variables happened in the dataset.

Here are the steps:

1. we are going to bin the the columns that correspond to the line plot, namely time and wheel speed
2. then we will chunck our dataframe for unique combination of these two bins
3. we will count the number of entries for each chunck

Pandas and seaborn give us all the tools we need for this:

In [None]:
# first we need to make both time and wheel to be a single column
wheel_df_reset_index = wheel_df.stack().reset_index(name="wheel")
wheel_df_reset_index

In [None]:
time_bin_size = .1
wheel_bin_size = 2

time_bin_index = wheel_df_reset_index['time'] // time_bin_size
wheel_bin_index = wheel_df_reset_index['wheel'] // wheel_bin_size
sns.heatmap(wheel_df_reset_index.groupby([wheel_bin_index, time_bin_index]).size().unstack(fill_value=0), norm=LogNorm(), cmap=plt.cm.Greys,);

What happens when you change the bin size for time and/or for the wheel speed?


---

## Describing Data with Metrics: Determining Turning Direction

Data analysis is all about making sense of data. Statistics like averages is one example of this. Such measures can help us minimise the amount of data that we need to think about and allow to answer more generic questions. In the context of the wheel speed data that we have been working on so far, a natural question is: **which way did the subject turn the wheel?**

In this section, we will be using aggregation methods such as `.mean()` combined with filtering to further analyze the wheel speed data for instance to only look at trials in which the wheel was turned to the left.

**Pandas**

| Code                             | Description                                                      |
| -------------------------------- | ---------------------------------------------------------------- |
| `df.mean()`                      | Calculate the mean of every column of a dataframe.               |
| `df[df < 0]`                     | Filtering based on a condition, suitable for a Pandas Series.    |
| `df[["column1", "column2"]]`     | Select multiple specific columns from a dataframe.               |
| `len(df)`                        | Count the number of rows in a dataframe.                         | 

Let's start by re-creating the `wheel_df` using the methods we used earlier in this notebook:

In [None]:
wheel_df = dataset['wheel'].to_dataframe().reset_index().set_index(['time', 'trial'])['wheel'].unstack()
wheel_df

**Exercises**

Calculate the mean wheel speed for each trial from `wheel_df`. Name the result `mean_wheel_speeds`.

In the next few steps, we will use `mean_wheel_speeds` and `wheel_df` to visualize the trials that the average wheel speed was greater that 0.

First, identify the trial numbers where the mean wheel speed is greater than 0. **Hint:** the index of `mean_wheel_speeds` is the trial number.

How many trials have a mean wheel speed greater than 0?

Now that we have the trial numbers where the wheel speed was positive on average:
1. use it to index the corresponding columns of our `wheel_df` dataframe
2. plot the wheel speed for the these trials. **Sanity check:** the lines should mostly be above 0

Find the number of trials where the mean wheel speed is less than 0.

Find the number of trials where the mean wheel speed is 0.

Let's now try the same analysis, but this time Instead we will find trials whose maximum wheel speed is greater than 10.
Determine the maximum wheel speed for each trial. Call this `max_wheel_speeds`.


Find the trials numbers where the maximum wheel speed is greater than 10.

Plot the wheel speeds of trials where the maximum wheel speed is greater than 10.

How many trials have a maximum wheel speed greater than 10?

---

## Result Validation using Python's Print Function

Python's `print` function is integral for outputting human-readable results. In this section we will use print statements to compare different metrics to determine the wheel turning direction, and produce cleanly formatted sentences

We have two metrics to determine the wheel turning direction:
* when the mean wheel speed is greater than 0 -> the wheel was turned to the right
* when the maximum turning speed is greater than 10 -> the wheel was turned to the right

Luckily, the authors have also provided information in the dataset that specifies when they considered the turn to be left, right, or no turn. This information is in the `response_type` variable. <br>
A `response_type` of:
- $ \ 1 \ \ : \ $ corresponds to a right turn
- $-1: \ $ corresponds to a left turn
- $ \ 0 \ \ : \ $  corresponds to no turning at all

Let's compare our metric for determining the turning directoin with the information provided by the authors.


**Python**

| Code                                                          | Description                                                      |
|---------------------------------------------------------------|------------------------------------------------------------------|
| `print(f"This is a formatted string with {variable}")`         | Print a string with variable value embedded.                     |
| `print("The mean speed is {:.2f}".format(mean_speed))`         | Print a string with a formatted floating-point number.          |

**Pandas**

| Code                                                          | Description                                                      |
|---------------------------------------------------------------|------------------------------------------------------------------|
| `df[df['column1'] == 8]`                                      | Filter dataframe rows based on a single condition.               |
| `df['column2'].isin(values)`                                  | Find rows of dataframe where column values are in a provided list of values.            |
| `df[(df['column2'] == 4) & (df['column1'] > 2)]`              | Filter dataframe based on multiple conditions.                   |

First, let's get the response types from the data.

In [None]:
response_df = dataset["response_type"].to_dataframe()
response_df = response_df.reset_index()
response_df

**Exercises**

**Example:** Print the value of variable `var_a` using format string.

In [None]:
var_a = 83
print(f'the value of my variable is {var_a}.')

Print the number of trials where the mean speed is less than zero using format string.

In [None]:
num_trials_with_mean_less_than_zero = len(mean_wheel_speeds[mean_wheel_speeds<0])

Print the number of trials where the mean speed is equal to zero.

Print the percentage of trials where the mean speed is equal to zero.

Are there too many decimal points in the number? We can apply rounding to 4 decimal places by writing the print statement like so:

```python
print(f"{my_variable:.4f}"`
```

Display the percentage of trials with zero mean speed rounded to 3 decimal places.

Now let's move on to comparing the turning direction of our predictions with those of the dataset authors, Steinmetz et. al. We will be focusing on the **right turns**.

Find trials in `response_df` where `response_type` is 1 (that means wheel was turned to the right). Call this `steinmetz_trials`.

Get the rows of `mean_wheel_speeds` where `mean_wheel_speeds.index` is in `steinmetz_trials`. These are the mean wheel speeds of trials where the dataset authors recorded a right turn of the wheel.

If our approach of using the mean wheel speed is what the authors used, all the mean wheel speeds should now be greater than 0. Is that the case? 


Using a print statement describe the result, mentioning the percentage of the values that are greater than zero.

Let's repeat the above analysis, but instead taking the other criteria where the maximum wheel speed is greater than 10. Is the closer to what authors might have used to identify right turns?