In [None]:
import warnings

import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

warnings.filterwarnings("ignore")
#plt.style.use('dark_background')

In [None]:
import utils.datasets
from utils.active_forests import ALIF, ALEIF, BALEIF, BALIF, RandomForest
from utils import *

In [None]:
from sklearn.manifold import TSNE

In [None]:
def plot_belief_distr(model,data,labels):
    plt.figure(figsize=(6,4),dpi=120)
    predictions = np.array(model.predict(data))
    distributions = np.array(model.predict(data,getdistr=True))

    x = np.linspace(0,1,1000)
    plt.title("prediction distributions")

    for distr, pred, label in zip(distributions,predictions, labels):
        c, alpha = ("midnightblue", 0.2) if label == 0 else ("firebrick", 0.5)

        plt.plot(x,distr(x), c=c, alpha=alpha)
        plt.plot(pred,distr(pred), "o", c=c)
        
    plt.plot([],[],c="midnightblue", alpha=0.2, label="inliers")
    plt.plot([],[],c="firebrick", alpha=0.5, label="anomalies")
    plt.legend()
    plt.yticks([])  
    plt.xlim(0,1)
    plt.ylim(bottom=0.1)

In [None]:
def plt_scatter_predictions_tsne(model, dataset_name, adaptive_range=True):
    plt.figure(figsize=(8,8),dpi=80)
    data, labels = datasets.load_dataset("wine")    
    embedded = TSNE(perplexity=5, early_exaggeration=5, random_state=0).fit_transform(data)
    
    heatmap = labels if model == "target" else model.predict(data)
    if adaptive_range: vmin, vmax = np.min(heatmap),np.max(heatmap)
    else: vmin, vmax = 0,1
    
    plt.scatter(embedded[:,0][labels==0], embedded[:,1][labels==0], marker="o", edgecolors='w', s=2*100, c=heatmap[labels==0], vmin=vmin, vmax=vmax, cmap="coolwarm")
    plt.scatter(embedded[:,0][labels==1], embedded[:,1][labels==1], marker="X",edgecolors='w', s=2*150, c=heatmap[labels==1], vmin=vmin, vmax=vmax, cmap="coolwarm")
    plt.xticks([])
    plt.yticks([])    

In [None]:
data, labels = datasets.load_dataset("wine")
model = BALEIF(n_estimators=100, query_strategy="margin", ensamble_prediction="naive")
model.fit(data)
#path="wine"
plt_scatter_predictions_tsne("target","wine")
#plt.savefig(f"images/example_st/{path}_target", edgecolor="auto")

In [None]:
np.random.seed(42)
X_train, X_test, y_train, y_test = utils.train_test_split(
    data, labels, test_size=0.5, stratify=labels
)
queried = np.zeros_like(y_train)
points = []
model.fit(X_train)

#plt_scatter_predictions_tsne(model,"wine")
#plt.savefig(f"images/example_st/{path}_iteration{0}", edgecolor="auto")
plot_belief_distr(model,data,labels)
#plt.savefig(f"images/example_st/{path}_distr_iteration{0}", edgecolor="auto")

In [None]:
for i in tqdm(range(10)):
    eligible = np.arange(len(X_train))
    interest = model.interest_on_info_for(X_train[eligible])
    selected = eligible[np.argmax(interest)]

    model.update(X_train[selected], y_train[selected])
    points.append((*X_train[selected], y_train[selected]))
    X_train = X_train[np.arange(len(X_train)) != selected]
    y_train = y_train[np.arange(len(y_train)) != selected]

    #plt_scatter_predictions_tsne(model,"wine")
    plot_belief_distr(model,data,labels)