# Test Notebook for PR Thresholder
## Import Relevant Packages

In [1]:
import numpy as np
from thresholder import PR_Thresholder
from sklearn.ensemble import RandomForestClassifier

## Generate Test Data

In [2]:
X_train = np.random.randn(1000, 4)
X_test = np.random.randn(1000, 4)
while True:
    train_Y = np.random.randint(3, size=1000)
    true_Y = np.random.randint(3, size=1000)
    if [len(np.unique(train_Y)), len(np.unique(true_Y))] != [3,3]:
        continue
    else:
        break

## Train Model

In [3]:
model = RandomForestClassifier()
model.fit(X_train, train_Y)
predict_proba = model.predict_proba(X_test)

## fit Method

In [4]:
thres = PR_Thresholder()
thres.fit(model, X_test, true_Y, method="prec_rec", save_folder="./PR/")

## Access the Optimal Cutoffs for Probabilities

In [5]:
thres.cutoffs

{0: 0.4, 1: 0.39, 2: 0.38}

## transform Method

In [6]:
thres.transform(predict_proba)

[1,
 0,
 1,
 1,
 0,
 1,
 1,
 2,
 2,
 1,
 0,
 2,
 2,
 1,
 1,
 2,
 2,
 2,
 0,
 1,
 1,
 1,
 2,
 0,
 2,
 2,
 1,
 1,
 0,
 0,
 2,
 2,
 2,
 1,
 0,
 1,
 1,
 0,
 2,
 0,
 0,
 2,
 2,
 0,
 0,
 2,
 2,
 2,
 1,
 0,
 1,
 0,
 2,
 1,
 0,
 0,
 1,
 1,
 1,
 1,
 2,
 2,
 0,
 0,
 2,
 1,
 0,
 0,
 1,
 0,
 2,
 1,
 1,
 0,
 1,
 1,
 2,
 0,
 2,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 2,
 2,
 0,
 0,
 2,
 2,
 1,
 2,
 2,
 0,
 2,
 0,
 2,
 0,
 2,
 1,
 0,
 1,
 1,
 1,
 0,
 0,
 2,
 1,
 2,
 0,
 1,
 2,
 0,
 0,
 2,
 2,
 2,
 0,
 0,
 1,
 0,
 0,
 0,
 2,
 2,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 1,
 1,
 2,
 1,
 1,
 1,
 2,
 0,
 1,
 0,
 0,
 1,
 2,
 0,
 0,
 1,
 1,
 1,
 0,
 2,
 0,
 0,
 2,
 0,
 2,
 2,
 2,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 1,
 0,
 2,
 2,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 2,
 2,
 0,
 0,
 0,
 0,
 2,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 2,
 2,
 2,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 2,
 0,
 2,
 0,
 2,
 0,
 1,
 0,
 0,
 2,
 1,
 2,
 1,
 1,
 2,
 0,
 2,
 2,
 1,
 0,
 0,
 0,
 0,
 1,
 2,
 1,
 1,
 1,
 0,
 2,
 1,
 1,
 2,
 1,
 0,


## fit_transform Method

In [8]:
thres.fit_transform(model, X_test, true_Y, predict_proba, method="prec_rec", save_folder="./PR/")

[1,
 0,
 1,
 1,
 0,
 1,
 1,
 2,
 2,
 1,
 0,
 2,
 2,
 1,
 1,
 2,
 2,
 2,
 0,
 1,
 1,
 1,
 2,
 0,
 2,
 2,
 1,
 1,
 0,
 0,
 2,
 2,
 2,
 1,
 0,
 1,
 1,
 0,
 2,
 0,
 0,
 2,
 2,
 0,
 0,
 2,
 2,
 2,
 1,
 0,
 1,
 0,
 2,
 1,
 0,
 0,
 1,
 1,
 1,
 1,
 2,
 2,
 0,
 0,
 2,
 1,
 0,
 0,
 1,
 0,
 2,
 1,
 1,
 0,
 1,
 1,
 2,
 0,
 2,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 2,
 2,
 0,
 0,
 2,
 2,
 1,
 2,
 2,
 0,
 2,
 0,
 2,
 0,
 2,
 1,
 0,
 1,
 1,
 1,
 0,
 0,
 2,
 1,
 2,
 0,
 1,
 2,
 0,
 0,
 2,
 2,
 2,
 0,
 0,
 1,
 0,
 0,
 0,
 2,
 2,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 1,
 1,
 2,
 1,
 1,
 1,
 2,
 0,
 1,
 0,
 0,
 1,
 2,
 0,
 0,
 1,
 1,
 1,
 0,
 2,
 0,
 0,
 2,
 0,
 2,
 2,
 2,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 1,
 0,
 2,
 2,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 2,
 2,
 0,
 0,
 0,
 0,
 2,
 1,
 0,
 0,
 0,
 0,
 1,
 0,
 2,
 2,
 2,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 2,
 0,
 2,
 0,
 2,
 0,
 1,
 0,
 0,
 2,
 1,
 2,
 1,
 1,
 2,
 0,
 2,
 2,
 1,
 0,
 0,
 0,
 0,
 1,
 2,
 1,
 1,
 1,
 0,
 2,
 1,
 1,
 2,
 1,
 0,
