In [None]:
import warnings
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

warnings.filterwarnings("ignore")
#plt.style.use('dark_background')

In [None]:
from utils import datasets
from utils.active_forests import ALIF, ALEIF, BALEIF, BALIF, RandomForest
from utils.utils import *

In [None]:
import imageio
import snakeviz
from sklearn.manifold import TSNE
%load_ext snakeviz

In [None]:
def run_sq_example(resolution=50, performance_logs=None, runid=0, size_query=0, dt=5):
    np.random.seed(42)
    performance_logs = performance_logs or {}
    X_train, X_test, y_train, y_test = train_test_split(
        data, labels, test_size=0.5, stratify=labels
    )
    queried = np.zeros_like(y_train)
    points = []
    model.fit(X_train, seed=42)
    prediction = model.predict(X_test) 
    performance_logs[f"run n:{runid}"] = [(roc_auc_score(y_test,prediction),average_precision_score(y_test,prediction))]

    plt.figure(figsize=(4,4),dpi=80)
    plt_scatter_predictions(model, adaptive_range=False, resolution=resolution)
    plt.savefig(f"images/example_st/{path}/{path}_iteration{0}", edgecolor="auto", bbox_inches='tight')

    for i in tqdm(range(30)):
        if not size_query or len(X_train)<size_query: eligible = np.arange(len(X_train))
        else: eligible = np.random.choice(np.arange(size_query),size_query)
            
        interest = model.interest_on_info_for(X_train[eligible])
        selected = eligible[np.argmax(interest)]

        model.update(X_train[selected], y_train[selected])
        points.append((*X_train[selected], y_train[selected]))
        X_train = X_train[np.arange(len(X_train)) != selected]
        y_train = y_train[np.arange(len(y_train)) != selected]
        
        prediction = model.predict(X_test)
        performance_logs[f"run n:{runid}"].append((roc_auc_score(y_test,prediction),average_precision_score(y_test,prediction)))        
        
        plt.figure(figsize=(4,4),dpi=80)
        plt_scatter_predictions(model, np.array(points), adaptive_range=False, resolution=resolution)
        plt.savefig(f"images/example_st/{path}/{path}_iteration{i+1}", edgecolor="auto", bbox_inches='tight')
        if i%dt==0: 
            plt.show()
            plt.figure(figsize=(4,4),dpi=80)
    
    return performance_logs

In [None]:
def plt_scatter_predictions(model, points=None, adaptive_range=False, resolution=50):
    grid = np.linspace(-0.9,0.9,resolution).astype(np.float64)
    heatmap = np.array([[model.predict(np.array([(x,y) for x in grid])) for y in grid]])[0]
    #x,y = np.meshgrid(grid,grid)
    #plt.contour(x,y,heatmap,levels=[0.30,0.40,0.50,0.60,0.70])
    
    if adaptive_range:
        plt.imshow(heatmap,extent=(-0.9,0.9,-0.9,0.9), origin = "lower", cmap="coolwarm") 
        plt.colorbar()
    else:
        plt.imshow(heatmap,extent=(-0.9,0.9,-0.9,0.9), vmin=0.25, vmax=0.75, origin = "lower", cmap="coolwarm") 
    
    if points is not None:
        x,y,l = points[:,0], points[:,1], points[:,2]
        plt.scatter(x[l==1],y[l==1], s=20, facecolors='firebrick', edgecolors='w')
        plt.scatter(x[l==0],y[l==0], s=20, facecolors='midnightblue', edgecolors='w')
    plt.xticks([])
    plt.yticks([])

In [None]:
def plt_scatter_predictions_tsne(model, k=10, adaptive_range=True):
    plt.figure(figsize=(8,8),dpi=80)
    #data, labels = datasets.load_dataset(k=k)    
    embedded = data
    
    heatmap = labels if model == "target" else model.predict(data)
    if adaptive_range: vmin, vmax = np.min(heatmap),np.max(heatmap)
    else: vmin, vmax = 0,1
    
    plt.scatter(embedded[:,0][labels==0], embedded[:,1][labels==0], marker="o", edgecolors='w', s=2*100, c=heatmap[labels==0], vmin=vmin, vmax=vmax, cmap="coolwarm")
    plt.scatter(embedded[:,0][labels==1], embedded[:,1][labels==1], marker="X",edgecolors='w', s=2*150, c=heatmap[labels==1], vmin=vmin, vmax=vmax, cmap="coolwarm")
    plt.xlim(-0.9,0.9)
    plt.ylim(-0.9,0.9)
    plt.xticks([])
    plt.yticks([]) 

In [None]:
data, labels = datasets.load_dataset(k=1)
plt_scatter_predictions_tsne("target",k=1)
performance_logs = []
resolution = 100

In [None]:
model = BALIF(n_estimators=100, max_samples=512, query_strategy="margin", ensamble_prediction="naive")
path="sparse"
performance_logs.append((path,run_sq_example(resolution=resolution, dt=1)))

In [None]:
model = BALEIF(n_estimators=100, max_samples=512, query_strategy="margin", ensamble_prediction="naive")
path="sparse_baleif"
performance_logs.append((path,run_sq_example(resolution=resolution, dt=1)))

In [None]:
data, labels = datasets.load_dataset(k=10)
plt_scatter_predictions_tsne("target",k=5)
performance_logs = []

In [None]:
model = BALIF(n_estimators=100, max_samples=512, query_strategy="margin", ensamble_prediction="naive")
path="anomalous"
performance_logs.append((path,run_sq_example(resolution=resolution)))

In [None]:
model = BALEIF(n_estimators=100, max_samples=512, query_strategy="margin", ensamble_prediction="naive")
path="margin"
performance_logs.append((path,run_sq_example(resolution=resolution)))