In [1]:
import joblib
from pathlib import Path
import kachery_p2p as kp
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from utils import parse_sf_results

In [2]:
models_path = Path('models')
models_path.mkdir(exist_ok=True)

In [3]:
# Load in the data from the sf url
khash = 'sha1://b3444629251cafda919af535f0e9837279151c6e/spikeforest-full-gt-qm.json?manifest=cf73c99d06c11e328e635e14dc24b8db7372db3d'
sf_data = kp.load_json(khash)

In [4]:
# Get 'paired' study names to exclude
paired_study_names = list(set([entry['studyName'] for entry in sf_data if 'paired' in entry['studyName']]))

In [5]:
# Get metrics-fp dataset, excluding all paired studies
dataset = parse_sf_results(sf_data=sf_data, exclude_study_names=paired_study_names, train_test_split=True)

In [6]:
model = make_pipeline(RandomForestClassifier())
model.fit(dataset['X_train'], dataset['y_train'])

Pipeline(steps=[('randomforestclassifier', RandomForestClassifier())])

In [7]:
y_test_preds = model.predict(dataset['X_test'])
f1 = f1_score(dataset['y_test'], y_test_preds)
print(f'F1-Score is {f1}')

F1-Score is 0.9052069425901201


In [8]:
joblib.dump(model, models_path / 'random_forest_general_clf.joblib')


['models/random_forest_general_clf.joblib']