-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
clustering.py
120 lines (91 loc) · 4.26 KB
/
clustering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import logging
import pickle
from collections import OrderedDict
from pathlib import Path
from typing import Optional, Union
import joblib
from sklearn.base import BaseEstimator, ClusterMixin
from sklearn.metrics import normalized_mutual_info_score
from tqdm import tqdm
from flair.data import Corpus, _iter_dataset
from flair.datasets import DataLoader
from flair.embeddings import DocumentEmbeddings
log = logging.getLogger("flair")
class ClusteringModel:
"""A wrapper class to apply sklearn clustering models on DocumentEmbeddings."""
def __init__(self, model: Union[ClusterMixin, BaseEstimator], embeddings: DocumentEmbeddings) -> None:
"""Instantiate the ClusteringModel.
Args:
model: the clustering algorithm from sklearn this wrapper will use.
embeddings: the flair DocumentEmbedding this wrapper uses to calculate a vector for each sentence.
"""
self.model = model
self.embeddings = embeddings
def fit(self, corpus: Corpus, **kwargs):
"""Trains the model.
Args:
corpus: the flair corpus this wrapper will use for fitting the model.
**kwargs: parameters propagated to the models `.fit()` method.
"""
X = self._convert_dataset(corpus)
log.info("Start clustering " + str(self.model) + " with " + str(len(X)) + " Datapoints.")
self.model.fit(X, **kwargs)
log.info("Finished clustering.")
def predict(self, corpus: Corpus):
"""Predict labels given a list of sentences and returns the respective class indices.
Args:
corpus: the flair corpus this wrapper will use for predicting the labels.
"""
X = self._convert_dataset(corpus)
log.info("Start the prediction " + str(self.model) + " with " + str(len(X)) + " Datapoints.")
predict = self.model.predict(X)
for idx, sentence in enumerate(_iter_dataset(corpus.get_all_sentences())):
sentence.set_label("cluster", str(predict[idx]))
log.info("Finished prediction and labeled all sentences.")
return predict
def save(self, model_file: Union[str, Path]):
"""Saves current model.
Args:
model_file: path where to save the model.
"""
joblib.dump(pickle.dumps(self), str(model_file))
log.info("Saved the model to: " + str(model_file))
@staticmethod
def load(model_file: Union[str, Path]):
"""Loads a model from a given path.
Args:
model_file: path to the file where the model is saved.
"""
log.info("Loading model from: " + str(model_file))
return pickle.loads(joblib.load(str(model_file)))
def _convert_dataset(
self, corpus, label_type: Optional[str] = None, batch_size: int = 32, return_label_dict: bool = False
):
"""Makes a flair-corpus sklearn compatible.
Turns the corpora into X, y datasets as required for most sklearn clustering models.
Ref.: https://scikit-learn.org/stable/modules/classes.html#module-sklearn.cluster
"""
log.info("Embed sentences...")
sentences = []
for batch in tqdm(DataLoader(corpus.get_all_sentences(), batch_size=batch_size)):
self.embeddings.embed(batch)
sentences.extend(batch)
X = [sentence.embedding.cpu().detach().numpy() for sentence in sentences]
if label_type is None:
return X
labels = [sentence.get_labels(label_type)[0].value for sentence in sentences]
label_dict = {v: k for k, v in enumerate(OrderedDict.fromkeys(labels))}
y = [label_dict.get(label) for label in labels]
if return_label_dict:
return X, y, label_dict
return X, y
def evaluate(self, corpus: Corpus, label_type: str):
"""This method calculates some evaluation metrics for the clustering.
Also, the result of the evaluation is logged.
Args:
corpus: the flair corpus this wrapper will use for evaluation.
label_type: the label from the sentence will be used for the evaluation.
"""
X, Y = self._convert_dataset(corpus, label_type=label_type)
predict = self.model.predict(X)
log.info("NMI - Score: " + str(normalized_mutual_info_score(predict, Y)))