In [1]:
from SamBA.samba import NeighborHoodClassifier, ExpTrainWeighting, ZeroOneTrainWeighting
from SamBA.distances import *
from SamBA.relevances import *
from sklearn.datasets import make_moons, make_blobs, make_gaussian_quantiles
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
import numpy as np
import plotly.express as px

rs = np.random.RandomState(42)



In [8]:
X,y = make_gaussian_quantiles(mean=None, cov=2.0, n_samples=100, n_features=2, n_classes=2, shuffle=False, random_state=42)
clf = NeighborHoodClassifier(n_estimators=10,
                             base_estimator=DecisionTreeClassifier(max_depth=1), 
                            train_weighting=ExpTrainWeighting(),
                            distance=EuclidianDist(),
                            relevance=ExpRelevance(),
                            normalizer=None)
clf.fit(X, y, save_data=True)

NeighborHoodClassifier(distance=<SamBA.distances.EuclidianDist object at 0x7f47e2ad01d0>,
                       n_estimators=10, normalizer=None,
                       relevance=<SamBA.relevances.ExpRelevance object at 0x7f47e2ad0198>,
                       train_weighting=<SamBA.samba.ExpTrainWeighting object at 0x7f47e2ad0160>)

In [11]:
px.scatter(clf.saved_data, x="X", y="Y", animation_frame="Iteration",
           color="Margin", symbol="Class", color_continuous_scale="Bluered",size="Weight", hover_name="Pred")

In [10]:
for estim in clf.estimators_:
    print(estim.tree_.threshold, estim.tree_.feature)

[-1.63458222 -2.         -2.        ] [ 1 -2 -2]
[-1.25396329 -2.         -2.        ] [ 0 -2 -2]
[-1.63458222 -2.         -2.        ] [ 1 -2 -2]
[-1.25396329 -2.         -2.        ] [ 0 -2 -2]
[-1.63458222 -2.         -2.        ] [ 1 -2 -2]
[-1.25396329 -2.         -2.        ] [ 0 -2 -2]
[-1.63458222 -2.         -2.        ] [ 1 -2 -2]
[-1.25396329 -2.         -2.        ] [ 0 -2 -2]
[-1.63458222 -2.         -2.        ] [ 1 -2 -2]
[-1.25396329 -2.         -2.        ] [ 0 -2 -2]


In [5]:
X, y = make_blobs(n_samples=100, n_features=2, centers=4, center_box=(- 1.0, 1.0), random_state=rs)
y[y==2] = 0
y[y==3] = 1
clf = NeighborHoodClassifier(n_estimators=10,
                             base_estimator=DecisionTreeClassifier(max_depth=1), 
                            train_weighting=ZeroOneTrainWeighting(),
                            distance=ExpEuclidianDist(),
                            relevance=ExpRelevance(),
                            normalizer=None)
clf.fit(X, y, save_data=True)

[2 0 2 0 0 0 0 0 0 2 2 2 0 2 2 0 0 0 0 0 2 0 0 0 0 2 0 0 0 0 0 0 0 0 2 2 0
 2 0 2 0 0 0 2 2 2 0 0 0 0 0 0 0 2 0 0 2 0 0 2 0 0 0 0 0 2 2 0 0 0 2 0 0 0
 0 0 0 2 0 0 2 0 2 0 2 0 2 2 2 2 0 2 0 2 2 2 2 0 0 0]
[ 2. -0. -0. -0.  2. -0.  2.  2.  2. -0. -0. -0. -0.  2. -0.  2.  2.  2.
  2.  2. -0.  2.  2.  2.  2. -0.  2.  2.  2. -0. -0.  2. -0.  2. -0. -0.
 -0. -0.  2.  2.  2.  2. -0. -0. -0. -0. -0.  2.  2.  2.  2.  2.  2. -0.
  2.  2. -0. -0.  2. -0. -0.  2.  2. -0.  2.  2. -0.  2.  2. -0.  2. -0.
  2. -0. -0.  2. -0.  2. -0. -0.  2. -0. -0.  2. -0. -0. -0. -0. -0.  2.
  2. -0. -0. -0. -0. -0. -0.  2.  2.  2.]
[-0.  2.  2.  2. -0.  2. -0. -0. -0.  2.  2.  2.  2. -0.  2. -0. -0. -0.
 -0. -0.  2. -0. -0. -0. -0.  2. -0. -0. -0.  2.  2. -0.  2. -0.  2.  2.
  2.  2. -0. -0. -0. -0.  2.  2.  2.  2.  2. -0. -0. -0. -0. -0. -0.  2.
 -0. -0.  2.  2. -0.  2.  2. -0. -0.  2. -0. -0.  2. -0. -0.  2. -0.  2.
 -0.  2.  2. -0.  2. -0.  2.  2. -0.  2.  2. -0.  2.  2.  2.  2.  2. -0.
 -0.  2.  2.  2.  2.  2.

NeighborHoodClassifier(distance=<SamBA.distances.ExpEuclidianDist object at 0x7f47eb21bcf8>,
                       n_estimators=10, normalizer=None,
                       relevance=<SamBA.relevances.ExpRelevance object at 0x7f47eb21bcc0>,
                       train_weighting=<SamBA.samba.ZeroOneTrainWeighting object at 0x7f47ec58def0>)

In [6]:
px.scatter(clf.saved_data, x="X", y="Y", animation_frame="Iteration",
           color="Pred", symbol="Class", color_continuous_scale="Bluered", size="Weight")

In [7]:
for estim in clf.estimators_:
    print(estim.tree_.threshold, estim.tree_.feature)

[-0.331258 -2.       -2.      ] [ 1 -2 -2]
[-0.69064876 -2.         -2.        ] [ 0 -2 -2]
[-2.] [-2]
[-2.] [-2]
[-2.] [-2]
[-2.] [-2]
[-2.] [-2]
[-2.] [-2]
[-2.] [-2]
[-2.] [-2]
