forked from flairNLP/flair
/
manifold.py
131 lines (86 loc) · 3.4 KB
/
manifold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from sklearn.manifold import TSNE
import tqdm
import numpy
class _Transform:
def __init__(self):
pass
def fit(self, X):
return self.transform.fit_transform(X)
class tSNE(_Transform):
def __init__(self):
super().__init__()
self.transform = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
class Visualizer(object):
def visualize_word_emeddings(self, embeddings, sentences, output_file):
X = self.prepare_word_embeddings(embeddings, sentences)
contexts = self.word_contexts(sentences)
trans_ = tSNE()
reduced = trans_.fit(X)
self.visualize(reduced, contexts, output_file)
def visualize_char_emeddings(self, embeddings, sentences, output_file):
X = self.prepare_char_embeddings(embeddings, sentences)
contexts = self.char_contexts(sentences)
trans_ = tSNE()
reduced = trans_.fit(X)
self.visualize(reduced, contexts, output_file)
@staticmethod
def prepare_word_embeddings(embeddings, sentences):
X = []
for sentence in tqdm.tqdm(sentences):
embeddings.embed(sentence)
for i, token in enumerate(sentence):
X.append(token.embedding.detach().numpy()[None, :])
X = numpy.concatenate(X, 0)
return X
@staticmethod
def word_contexts(sentences):
contexts = []
for sentence in sentences:
strs = [x.text for x in sentence.tokens]
for i, token in enumerate(strs):
prop = '<b><font color="red"> {token} </font></b>'.format(token=token)
prop = " ".join(strs[max(i - 4, 0) : i]) + prop
prop = prop + " ".join(strs[i + 1 : min(len(strs), i + 5)])
contexts.append("<p>" + prop + "</p>")
return contexts
@staticmethod
def prepare_char_embeddings(embeddings, sentences):
X = []
for sentence in tqdm.tqdm(sentences):
sentence = " ".join([x.text for x in sentence])
hidden = embeddings.lm.get_representation([sentence])
X.append(hidden.squeeze().detach().numpy())
X = numpy.concatenate(X, 0)
return X
@staticmethod
def char_contexts(sentences):
contexts = []
for sentence in sentences:
sentence = " ".join([token.text for token in sentence])
for i, char in enumerate(sentence):
context = '<span style="background-color: yellow"><b>{}</b></span>'.format(
char
)
context = "".join(sentence[max(i - 30, 0) : i]) + context
context = context + "".join(
sentence[i + 1 : min(len(sentence), i + 30)]
)
contexts.append(context)
return contexts
@staticmethod
def visualize(X, contexts, file):
import matplotlib.pyplot
import mpld3
fig, ax = matplotlib.pyplot.subplots()
ax.grid(True, alpha=0.3)
points = ax.plot(
X[:, 0], X[:, 1], "o", color="b", mec="k", ms=5, mew=1, alpha=0.6
)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_title("Hover mouse to reveal context", size=20)
tooltip = mpld3.plugins.PointHTMLTooltip(
points[0], contexts, voffset=10, hoffset=10
)
mpld3.plugins.connect(fig, tooltip)
mpld3.save_html(fig, file)