Skip to content

cdj0311/bert-sklearn

 
 

Repository files navigation

scikit-learn wrapper for BERT

A scikit-learn model for text classification/regression based on the huggingface pytorch port of Google's BERT(Bidirectional Encoder Representations from Transformers) model.

  • Added an MSE loss for regression tasks
  • Added configurable MLP as final classifier/regressor

basic operation

model.fit(X,y) where

  • X: list, pandas dataframe, or numpy array of text or text pairs

  • y : list, pandas dataframe, or numpy array of labels/targets

from bert_sklearn import BertClassifier
from bert_sklearn import BertRegressor
from bert_sklearn import load_model

# define model
model = BertClassifier()   # for classification 
# model = BertRegressor()  # for regression 
 
# fit model
model.fit(X_train, y_train)

# make predictions
y_pred = model.predict(X_test)

# make probabilty predictions
y_pred = model.predict_proba(X_test)

# score model on test data
model.score(X_test, y_test)

# save model to disk
savefile='/data/mymodel.bin'
model.save(savefile)

# load model from disk
new_model = load_model(savefile)

# do stuff with new model
new_model.score(X_test, y_test)

See demo notebook.

model options

# try different options...
model.bert_model = 'bert-large-uncased'
model.num_mlp_layers = 3
model.max_seq_length = 196
model.epochs = 4
model.learning_rate = 4e-5
model.gradient_accumulation_steps = 4

# fit model
model.fit(X_train, y_train)

# do stuff...
model.score(X_test, y_test)

See options

hyperparameter tuning

from sklearn.model_selection import GridSearchCV

params = {'epochs':[3, 4], 'learning_rate':[2e-5, 3e-5, 5e-5]}

# wrap classifier in GridSearchCV
clf = GridSearchCV(BertClassifier(validation_fraction=0), 
                    params,
                    scoring='accuracy',
                    verbose=True)

# fit gridsearch 
clf.fit(X_train ,y_train)

See demo_tuning_hyperparameters notebook.

GLUE datasets

The train and dev data sets from the GLUE(Generalized Language Understanding Evaluation) benchmarks were used with bert-base-uncased model and compared againt the reported results in the Google paper and GLUE leaderboard.

MNLI(m/mm) QQP QNLI SST-2 CoLA STS-B MRPC RTE
BERT base(leaderboard) 84.6/83.4 89.2 90.1 93.5 52.1 87.1 84.8 66.4
bert-sklearn 83.7/83.9 90.2 88.6 92.32 58.1 89.7 86.8 64.6

Individual runs can be found can be found here.

installation

requires python >= 3.5 and pytorch >= 0.4.1

# install pytorch-pretrained-bert from PyPI
pip install pytorch-pretrained-bert

# setup bert-sklearn locally
git clone -b master https://github.com/charles9n/bert-sklearn
cd bert-sklearn
pip install .

references

About

a sklearn wrapper for BERT

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 78.3%
  • Python 21.7%