/
_sbert.py
83 lines (66 loc) · 2.6 KB
/
_sbert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer as SBERT
from embetter.base import EmbetterBase
class SentenceEncoder(EmbetterBase):
"""
Encoder that can numerically encode sentences.
Arguments:
name: name of model, see available options
device: manually override cpu/gpu device, tries to grab gpu automatically when available
The following model names should be supported:
- `all-mpnet-base-v2`
- `multi-qa-mpnet-base-dot-v1`
- `all-distilroberta-v1`
- `all-MiniLM-L12-v2`
- `multi-qa-distilbert-cos-v1`
- `all-MiniLM-L6-v2`
- `multi-qa-MiniLM-L6-cos-v1`
- `paraphrase-multilingual-mpnet-base-v2`
- `paraphrase-albert-small-v2`
- `paraphrase-multilingual-MiniLM-L12-v2`
- `paraphrase-MiniLM-L3-v2`
- `distiluse-base-multilingual-cased-v1`
- `distiluse-base-multilingual-cased-v2`
You can find the more options, and information, on the [sentence-transformers docs page](https://www.sbert.net/docs/pretrained_models.html#model-overview).
**Usage**:
```python
import pandas as pd
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from embetter.grab import ColumnGrabber
from embetter.text import SentenceEncoder
# Let's suppose this is the input dataframe
dataf = pd.DataFrame({
"text": ["positive sentiment", "super negative"],
"label_col": ["pos", "neg"]
})
# This pipeline grabs the `text` column from a dataframe
# which then get fed into Sentence-Transformers' all-MiniLM-L6-v2.
text_emb_pipeline = make_pipeline(
ColumnGrabber("text"),
SentenceEncoder('all-MiniLM-L6-v2')
)
X = text_emb_pipeline.fit_transform(dataf, dataf['label_col'])
# This pipeline can also be trained to make predictions, using
# the embedded features.
text_clf_pipeline = make_pipeline(
text_emb_pipeline,
LogisticRegression()
)
# Prediction example
text_clf_pipeline.fit(dataf, dataf['label_col']).predict(dataf)
```
"""
def __init__(self, name="all-MiniLM-L6-v2", device=None):
if not device:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.name = name
self.device = device
self.tfm = SBERT(name, device=self.device)
def transform(self, X, y=None):
"""Transforms the text into a numeric representation."""
# Convert pd.Series objects to encode compatable
if isinstance(X, pd.Series):
X = X.to_numpy()
return self.tfm.encode(X)