-
Notifications
You must be signed in to change notification settings - Fork 0
/
strength_thesaurus.py
105 lines (82 loc) · 3.03 KB
/
strength_thesaurus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import json
from unittest import result
import numpy as np
try:
from .text_transform import get_progressbar
except (ModuleNotFoundError, ImportError):
from text_transform import get_progressbar
class StrengthThesaurus:
def __init__(self, alpha=0.5, length=10) -> None:
self.alpha = alpha
self.length = length
self._strength_ = {}
def __getitem__(self, index: tuple) -> float:
j, i = index
return self._strength_[j][i]
def __len__(self):
return len(self.vocabulary)
def expansion_query(self, query):
terms = set()
for term in query:
words = self.expansion_term(term)
terms = terms.union(words)
return list(terms)
def expansion_term(self, term):
if not term in self.vocabulary:
return []
bar = get_progressbar(len(self._list_), f' expansion {term} ')
bar.start()
index = self._list_.index(term)
result = []
for j, _ in enumerate(self._list_):
value = self.strength(index, j)
result.append((j, value))
bar.update(j+1)
bar.finish()
result.sort(key=lambda x: x[1], reverse=True)
result = [self._list_[i] for i, value in result if value >= self.alpha]
return result[0:self.length]
def strength(self, tj, ti):
try:
return self._strength_[(tj, ti)]
except KeyError:
Dtj: set = set(self.vocabulary[self._list_[tj]]['X'])
Dti: set = set(self.vocabulary[self._list_[ti]]['Y'])
if Dti.isdisjoint(Dtj):
return 0
nij = Dti.intersection(Dtj)
self._strength_[(tj, ti)] = (len(nij) + 1) / (len(Dti) + 2)
return self._strength_[(tj, ti)]
def fit(self, X, Y):
self.vocabulary = {}
for i, text in enumerate(X):
for word in text:
try:
self.vocabulary[word]['X'].append(i)
except KeyError:
self.vocabulary[word] = {'X': [], 'Y': []}
self.vocabulary[word]['X'].append(i)
for i, text in enumerate(Y):
for word in text:
try:
self.vocabulary[word]['Y'].append(i)
except KeyError:
self.vocabulary[word] = {'X': [], 'Y': []}
self.vocabulary[word]['Y'].append(i)
self._list_ = list(self.vocabulary)
def dumps_path(self, path, key=""):
return f'{path}/{key}/data_strength_t.json'
def save_model(self, path, key=""):
with open(self.dumps_path(path, key), 'w+') as f:
f.write(json.dumps({
'v': self.vocabulary,
'l': self._list_,
}))
def load_model(self, path, key=''):
with open(self.dumps_path(path, key), 'r') as f:
data = json.load(f)
self.vocabulary = data['v']
self._list_ = data['l']
def print_expansion(query):
for term in query:
print(term, end=" ")