In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

import altair as alt

import seaborn as sns

from sklearn import datasets


In [2]:
## Load Iris dataset
iris_df = datasets.load_iris(as_frame = True)

target_names = list(iris_df.target_names)

X = iris_df.data
y = pd.DataFrame({"label":iris_df.target})

## Concatenate X and y horizontally
iris_df = pd.concat([X,y], axis=1)

## Rename columns
iris_df = iris_df.rename(columns={"sepal length (cm)": "sepal_length",
                       "sepal width (cm)": "sepal_width",
                       "petal length (cm)": "petal_length",
                       "petal width (cm)": "petal_width"})

class_dict = {0: "setosa", 1: "versicolor", 2:"virginica"}

iris_df["class_str"] = iris_df["label"].map(class_dict)
iris_df.head()


Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,label,class_str
0,5.1,3.5,1.4,0.2,0,setosa
1,4.9,3.0,1.4,0.2,0,setosa
2,4.7,3.2,1.3,0.2,0,setosa
3,4.6,3.1,1.5,0.2,0,setosa
4,5.0,3.6,1.4,0.2,0,setosa


In [64]:
def scatter_plot(data, x_var, y_var, x_lim, y_lim, class_label,legend_label, 
                 color_map,point_size, x_label, y_label):
    plot = alt.Chart(data).mark_circle(size=point_size).encode(
        alt.X(x_var,scale=alt.Scale(domain=x_lim), title=x_label),
        alt.Y(y_var,scale=alt.Scale(domain=y_lim), title=y_label),
        color=alt.Color(class_label, legend=alt.Legend(title="Class"),
                        scale=alt.Scale(domain=legend_label, range=color_map))
        ).configure_axis(
            grid=False, 
            domainWidth=1.5, domainColor="black", ## edit axis stroke width and color
            tickWidth=1.5,tickColor="black", ## edit tick stroke width and color
            offset=10, tickCount=6 ## edit distance from the y-axis to x-axis
        ).configure_view(
            strokeOpacity=0 ## edit bounding box visibility
        ).properties(width=250, height=250).interactive()
    
    return plot

color_map = ["#8dd3c7", "orange", "#bebada"]

scatter_plot(data = iris_df, x_var = "sepal_length", y_var = "petal_length", 
             x_lim = (0,10), y_lim = (0,10), class_label = "class_str",
            legend_label = target_names, color_map = color_map, point_size=100,
            x_label = "sepal length (cm)", y_label = "petal length (cm)")

In [62]:
def scatter_plot_multiple(data, row_features,col_features,class_label, legend_label, color_map, point_size):
    plot = alt.Chart(data).mark_circle(size=point_size).encode(
        alt.X(alt.repeat("column"), type='quantitative'),
        alt.Y(alt.repeat("row"), type='quantitative'), 
        alt.Color(class_label, legend=alt.Legend(title="Class"),
                  scale=alt.Scale(domain=legend_label, range=color_map))
        ).properties(
            width=200,
            height=200,
        ).repeat(
            row=row_features,
            column=col_features
        ).configure_axis(
            grid=False, 
            domainWidth=1.5, domainColor="black", ## edit axis stroke width and color
            tickWidth=1.5,tickColor="black", ## edit tick stroke width and color
            offset=10, tickCount=6 ## edit distance from the y-axis to x-axis
        ).configure_view(
            strokeOpacity=0 ## edit bounding box visibility
        ).interactive()
    
    
    return plot

row_features = ["sepal_length", "sepal_width"]
col_features = ["petal_length", "petal_width"]
scatter_plot_multiple(iris_df, row_features=row_features,col_features=col_features, 
                      class_label = "class_str", 
                      legend_label = target_names, color_map = color_map, point_size=80)