# Visualize Gaussian Naive Bayes
Please refer to [1.9.1. Gaussian Naive Bayes - scikit-learn](https://scikit-learn.org/stable/modules/naive_bayes.html#gaussian-naive-bayes)


In [1]:
import numpy as np
import pandas as pd
import plotly.express as px

In [2]:
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
data = load_iris()
X = data.data
y = data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)
gnb = GaussianNB()
gnb_fit = gnb.fit(X_train, y_train)

In [4]:
grid_num = 5
grid_list = [np.linspace(min, max, grid_num) for min, max in zip(X.min(axis=0), X.max(axis=0))]
mg = np.meshgrid(*grid_list)
mg_ravel = [x.ravel() for x in mg]
X_mesh = np.column_stack(mg_ravel)
y_mesh = gnb_fit.predict_proba(X_mesh)
df = pd.DataFrame(np.c_[X_mesh, y_mesh], columns=data.feature_names+ list(data.target_names))\
     .melt(data.feature_names)\
     .astype({data.feature_names[2]: 'category', data.feature_names[3]: 'category'})
px.scatter(df,
           x=data.feature_names[0],
           y=data.feature_names[1],
           facet_col=data.feature_names[2],
           facet_row=data.feature_names[3],
           color='variable', size='value')