forked from flairNLP/flair
/
test_embeddings.py
123 lines (74 loc) · 3.42 KB
/
test_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import pytest
from flair.embeddings import WordEmbeddings, TokenEmbeddings, StackedEmbeddings, \
DocumentPoolEmbeddings, FlairEmbeddings, DocumentRNNEmbeddings
from flair.data import Sentence
def test_loading_not_existing_embedding():
with pytest.raises(ValueError):
WordEmbeddings('other')
with pytest.raises(ValueError):
WordEmbeddings('not/existing/path/to/embeddings')
def test_loading_not_existing_char_lm_embedding():
with pytest.raises(ValueError):
FlairEmbeddings('other')
@pytest.mark.integration
def test_stacked_embeddings():
sentence, glove, charlm = init_document_embeddings()
embeddings: StackedEmbeddings = StackedEmbeddings([glove, charlm])
embeddings.embed(sentence)
for token in sentence.tokens:
assert(len(token.get_embedding()) == 1074)
token.clear_embeddings()
assert(len(token.get_embedding()) == 0)
@pytest.mark.integration
def test_document_lstm_embeddings():
sentence, glove, charlm = init_document_embeddings()
embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings([glove, charlm], hidden_size=128,
bidirectional=False)
embeddings.embed(sentence)
assert (len(sentence.get_embedding()) == 128)
assert (len(sentence.get_embedding()) == embeddings.embedding_length)
sentence.clear_embeddings()
assert (len(sentence.get_embedding()) == 0)
@pytest.mark.integration
def test_document_bidirectional_lstm_embeddings():
sentence, glove, charlm = init_document_embeddings()
embeddings: DocumentRNNEmbeddings = DocumentRNNEmbeddings([glove, charlm], hidden_size=128,
bidirectional=True)
embeddings.embed(sentence)
assert (len(sentence.get_embedding()) == 512)
assert (len(sentence.get_embedding()) == embeddings.embedding_length)
sentence.clear_embeddings()
assert (len(sentence.get_embedding()) == 0)
@pytest.mark.integration
def test_document_pool_embeddings():
sentence, glove, charlm = init_document_embeddings()
for mode in ['mean', 'max', 'min']:
embeddings: DocumentPoolEmbeddings = DocumentPoolEmbeddings([glove, charlm], mode=mode)
embeddings.embed(sentence)
assert (len(sentence.get_embedding()) == 1074)
sentence.clear_embeddings()
assert (len(sentence.get_embedding()) == 0)
def init_document_embeddings():
text = 'I love Berlin. Berlin is a great place to live.'
sentence: Sentence = Sentence(text)
glove: TokenEmbeddings = WordEmbeddings('turian')
charlm: TokenEmbeddings = FlairEmbeddings('news-forward-fast')
return sentence, glove, charlm
def load_and_apply_word_embeddings(emb_type: str):
text = 'I love Berlin.'
sentence: Sentence = Sentence(text)
embeddings: TokenEmbeddings = WordEmbeddings(emb_type)
embeddings.embed(sentence)
for token in sentence.tokens:
assert(len(token.get_embedding()) != 0)
token.clear_embeddings()
assert(len(token.get_embedding()) == 0)
def load_and_apply_char_lm_embeddings(emb_type: str):
text = 'I love Berlin.'
sentence: Sentence = Sentence(text)
embeddings: TokenEmbeddings = FlairEmbeddings(emb_type)
embeddings.embed(sentence)
for token in sentence.tokens:
assert(len(token.get_embedding()) != 0)
token.clear_embeddings()
assert(len(token.get_embedding()) == 0)