RF-PHATE is a supervised dimensionality reduction tool which can be used in exploratory data analysis. To install, please use: `pip install git+https://github.com/jakerhodes/RF-PHATE`.

In [None]:
import rfphate
import seaborn as sns
import pandas as pd

Here we provide a simple example of using RF-PHATE for dimensionality reduction. We use the `titanic` dataset with survival as the response. We use the `dataprep` function to encode categorical variables as numeric and normalize all numeric variables. This function provides us the data, `x` and labels `y`. 

In [None]:
data = rfphate.load_data('titanic')
x, y = rfphate.dataprep(data)

We instantiate an RF-PHATE object and generate the 2-dimensional embedding using the `fit_transform` method. The resulting embedding will be stored as a NumPy array. This should take less than a couple of seconds to run.

In [None]:
rfphate_op = rfphate.RFPHATE(random_state = 42)
emb = rfphate_op.fit_transform(x, y)

We can explore variable relationships by coloring the scatterplot of the embedding by different featuers. We start by coloring with the class labels, `Survived'.`

In [None]:
sns.scatterplot(x = emb[:, 0], y = emb[:, 1], hue = y, markers = {'survived': '.', 'died': 'X'}, style = data['Survived'], alpha = .8)

To select features for coloration, we can see which features are most important for the prediction problem using the random forests `feature_importances_`

In [None]:
data.iloc[:, 1:].columns

In [None]:
importances = rfphate_op.feature_importances_
sns.barplot(x = data.iloc[:, 1:].columns, y = importances)

In [None]:
sns.scatterplot(x = emb[:, 0], y = emb[:, 1], hue = data['Fare'], markers = {'survived': '.', 'died': 'X'}, style = data['Survived'], alpha = .8)

In [None]:
sns.scatterplot(x = emb[:, 0], y = emb[:, 1], hue = data['Age'], markers = {'survived': '.', 'died': 'X'}, style = data['Survived'], alpha = .8)

In [None]:
sns.scatterplot(x = emb[:, 0], y = emb[:, 1], hue = data['Sex'], markers = {'survived': '.', 'died': 'X'}, style = data['Survived'], alpha = .8)

In [None]:
sns.scatterplot(x = emb[:, 0], y = emb[:, 1], hue = pd.Categorical(data['Pclass']), markers = {'survived': '.', 'died': 'X'}, style = data['Survived'], alpha = .8)

We can take a closer look using `plotly` for an interactive scatterplot. (You can install `plotly` with `pip install plotly`).

In [None]:
import plotly.express as px

px.scatter(data, x = emb[:, 0], y = emb[:, 1], color = data['Pclass'].astype(str), symbol = data['Survived'], 
symbol_map = {'died': 'x', 'survived': 'circle'}, hover_data = ['Survived', 'Pclass', 'Sex', 'Age', 'Fare'])

In [None]:
px.scatter()