# EmbeddingPlotter Demo
The `EmbeddingPlotter` is used to plot 2D or 3D scatterplots of data.
It automatically plots data according to its dimensionality, data type, target (regression/classification), etc.
A default look-and-feel is implemented, but it can be modified easily to fullfill specific requirements.

In [1]:
import numpy as np
from pathlib import Path
from collections import Counter

from sklearn.datasets import load_iris

import matplotlib.pyplot as plt

import plotly.graph_objs as go
import plotly.express as px
import ipywidgets as widgets

from IPython.display import display


from hyperpyper.plotting import EmbeddingPlotter

random_state = 23

## Load some data
For demonstration purposes, we simply load a toy data set that ships with sklearn.

In [2]:
data = load_iris(return_X_y=False)

X = data.data
y = data.target

X.shape

(150, 4)

## Interactive 2D plots
We start by passing just the first two dimensions to the `EmbeddingPlotter`.
As can be seen, the target is interpreted as a regression target (colorbar), which is not what we need in our case, so let's switch to a classification target.

In [3]:
plotter = EmbeddingPlotter(data=X[:,0:2],
                           color=y)
display(plotter.plot())

### Classification Targets
Turning the target variables into strings will change the plot according to a classification task.
Let's turn the class indices into the corresponding target names.
The plot then automatically uses a legend instead of a colorbar.

In [4]:
data.target_names

array(['setosa', 'versicolor', 'virginica'], dtype='<U10')

In [5]:
y_str = [data.target_names[y] for y in y]

In [6]:
plotter = EmbeddingPlotter(data=X[:,0:2],
                           color=y_str)
display(plotter.plot())

### Dynamic Axis Scaling
Per default, the axis ranges are fixed according to the parameters `fix_axis=True` and `axis_margin`.
This prevents the axis to be rescaled after a specific class is selected or deselected in the legend.
In case this is not the desired behaviour and dynamic rescaling is needed, simply set `fix_axis=False`.
Try clicking a class in the legend, and observe the different behaviour.

In [7]:
plotter = EmbeddingPlotter(data=X[:,0:2],
                           color=y_str,
                           fix_axis=False)
display(plotter.plot())

### Updating Item Order in Legend

In [8]:
plotter.update_legend_order([data.target_names[2], data.target_names[1], data.target_names[0]])

display(plotter.plot())

### Modifying Markers
Marker modifcations can be applied to specific markers which can be addressed by name with the `name_pattern` parameter.

In [9]:
custom_marker_style = {'marker': {'symbol': 'square', 'size': 12, 'opacity': 0.5, 'color': 'dodgerblue'}}
plotter.update_traces(name_pattern='virginica', update_params=custom_marker_style)

display(plotter.plot())


Markers can also be addressed with patterns.
Here is an example that changes all markers corresponding to classes beginning with the letter "v".

In [10]:
custom_marker_style = {'marker': {'symbol': 'square', 'size': 10,}}
plotter.update_traces(name_pattern='v', update_params=custom_marker_style)

display(plotter.plot())


### Modifying plot appearance with figure handle
In order to modify even more characteristics of a plot, you can use the figure handle, which gives you full control over all available modifications.

More information about updating figures can be found here: https://plotly.com/python/creating-and-updating-figures/

In [11]:
fig = plotter._get_fig()

# Modify the properties of the zero line (x and y axes)
fig.update_xaxes(
    zeroline=True,  # Show the zero line on the x-axis
    zerolinecolor='gray',  # Set the color of the zero line
    zerolinewidth=1.,  # Set the width of the zero line
)
fig.update_yaxes(
    zeroline=True,  # Show the zero line on the y-axis
    zerolinecolor='gray',  # Set the color of the zero line
    zerolinewidth=1.,  # Set the width of the zero line
)

# Modify the properties of the grid lines
fig.update_xaxes(
    showgrid=True,  # Show x-axis grid lines
    gridcolor='gray',  # Set the color of the grid lines
    gridwidth=1.,  # Set the width of the grid lines
    griddash='dot',  # Set the dash style of the grid lines ('solid', 'dot', 'dash', etc.)
)
fig.update_yaxes(
    showgrid=True,  # Show y-axis grid lines
    gridcolor='gray',  # Set the color of the grid lines
    gridwidth=1.,  # Set the width of the grid lines
    griddash='dot',  # Set the dash style of the grid lines ('solid', 'dot', 'dash', etc.)
)

# Modify font style and size for axis labels and tick labels
fig.update_xaxes(
    title='x-axis',  # Set the x-axis label text
    titlefont=dict(size=24, color='black', family='Times New Roman'),
    tickfont=dict(size=24, color='black', family='Times New Roman'),
    showticklabels=True,
)
fig.update_yaxes(
    title='',  # Set the y-axis label text
    titlefont=dict(size=30, color='red', family='Times New Roman'),
    tickfont=dict(size=30, color='red', family='Times New Roman'),
    showticklabels=True,
)

# Modify the layout and position of the legend
fig.update_traces(showlegend=True)
fig.update_layout(
    title_text='This is a Title Text', title_x=0.5, title_y=.97, title_font_size=36,
    legend_title='Iris Dataset',
    legend=dict(font=dict(color='black', size=16, family='Times New Roman'),
                x=0.5,  # Set the x position of the legend to 0 (left-aligned)
                y=-.2,  # Set the y position of the legend to position it above the plot
                orientation='h'),  # Set the orientation to horizontal (single row)
)

# Modify the properties of the background
fig.update_layout(plot_bgcolor='white',
)

### Exporting to .html
Exporting the plot as .html preserves the interactivity, such as zooming and switching markers on/off, etc.

In [12]:
RESULT_PATH = Path.home() / "tmp" / "results"

# Save the plot as HTML file
plotter.to_html(Path(RESULT_PATH, 'iris_plot.html'))

## Interactive 3D plots
Switching to 3d is done automatically as soon as we pass 3D data to the `EmbeddingPlotter`.
Be aware that the available modifications are different compared to a 2D plot.

More information about updating figures can be found here: https://plotly.com/python/creating-and-updating-figures/

In [13]:
plotter = EmbeddingPlotter(data=X[:,0:3],
                           color=y_str,
                           )
display(plotter.plot())

In [14]:
custom_marker_style = {'marker': {'symbol': 'square', 'size': 10,}}
plotter.update_traces(name_pattern='v', update_params=custom_marker_style)

display(plotter.plot())
