-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
/
token.py
1800 lines (1443 loc) · 74.3 KB
/
token.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import hashlib
from abc import abstractmethod
from pathlib import Path
from typing import List, Union
from collections import Counter
from functools import lru_cache
import torch
from bpemb import BPEmb
from transformers import XLNetTokenizer, T5Tokenizer, GPT2Tokenizer, AutoTokenizer, AutoConfig, AutoModel
import flair
import gensim
import os
import re
import logging
import numpy as np
from flair.data import Sentence, Token, Corpus, Dictionary
from flair.embeddings.base import Embeddings, ScalarMix
from flair.file_utils import cached_path, open_inside_zip
log = logging.getLogger("flair")
class TokenEmbeddings(Embeddings):
"""Abstract base class for all token-level embeddings. Ever new type of word embedding must implement these methods."""
@property
@abstractmethod
def embedding_length(self) -> int:
"""Returns the length of the embedding vector."""
pass
@property
def embedding_type(self) -> str:
return "word-level"
class StackedEmbeddings(TokenEmbeddings):
"""A stack of embeddings, used if you need to combine several different embedding types."""
def __init__(self, embeddings: List[TokenEmbeddings]):
"""The constructor takes a list of embeddings to be combined."""
super().__init__()
self.embeddings = embeddings
# IMPORTANT: add embeddings as torch modules
for i, embedding in enumerate(embeddings):
embedding.name = f"{str(i)}-{embedding.name}"
self.add_module(f"list_embedding_{str(i)}", embedding)
self.name: str = "Stack"
self.static_embeddings: bool = True
self.__embedding_type: str = embeddings[0].embedding_type
self.__embedding_length: int = 0
for embedding in embeddings:
self.__embedding_length += embedding.embedding_length
def embed(
self, sentences: Union[Sentence, List[Sentence]], static_embeddings: bool = True
):
# if only one sentence is passed, convert to list of sentence
if type(sentences) is Sentence:
sentences = [sentences]
for embedding in self.embeddings:
embedding.embed(sentences)
@property
def embedding_type(self) -> str:
return self.__embedding_type
@property
def embedding_length(self) -> int:
return self.__embedding_length
def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
for embedding in self.embeddings:
embedding._add_embeddings_internal(sentences)
return sentences
def __str__(self):
return f'StackedEmbeddings [{",".join([str(e) for e in self.embeddings])}]'
def get_names(self) -> List[str]:
"""Returns a list of embedding names. In most cases, it is just a list with one item, namely the name of
this embedding. But in some cases, the embedding is made up by different embeddings (StackedEmbedding).
Then, the list contains the names of all embeddings in the stack."""
names = []
for embedding in self.embeddings:
names.extend(embedding.get_names())
return names
class WordEmbeddings(TokenEmbeddings):
"""Standard static word embeddings, such as GloVe or FastText."""
def __init__(self, embeddings: str, field: str = None):
"""
Initializes classic word embeddings. Constructor downloads required files if not there.
:param embeddings: one of: 'glove', 'extvec', 'crawl' or two-letter language code or custom
If you want to use a custom embedding file, just pass the path to the embeddings as embeddings variable.
"""
self.embeddings = embeddings
old_base_path = (
"https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings/"
)
base_path = (
"https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.3/"
)
embeddings_path_v4 = (
"https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4/"
)
embeddings_path_v4_1 = "https://s3.eu-central-1.amazonaws.com/alan-nlp/resources/embeddings-v0.4.1/"
cache_dir = Path("embeddings")
# GLOVE embeddings
if embeddings.lower() == "glove" or embeddings.lower() == "en-glove":
cached_path(f"{old_base_path}glove.gensim.vectors.npy", cache_dir=cache_dir)
embeddings = cached_path(
f"{old_base_path}glove.gensim", cache_dir=cache_dir
)
# TURIAN embeddings
elif embeddings.lower() == "turian" or embeddings.lower() == "en-turian":
cached_path(
f"{embeddings_path_v4_1}turian.vectors.npy", cache_dir=cache_dir
)
embeddings = cached_path(
f"{embeddings_path_v4_1}turian", cache_dir=cache_dir
)
# KOMNINOS embeddings
elif embeddings.lower() == "extvec" or embeddings.lower() == "en-extvec":
cached_path(
f"{old_base_path}extvec.gensim.vectors.npy", cache_dir=cache_dir
)
embeddings = cached_path(
f"{old_base_path}extvec.gensim", cache_dir=cache_dir
)
# FT-CRAWL embeddings
elif embeddings.lower() == "crawl" or embeddings.lower() == "en-crawl":
cached_path(
f"{base_path}en-fasttext-crawl-300d-1M.vectors.npy", cache_dir=cache_dir
)
embeddings = cached_path(
f"{base_path}en-fasttext-crawl-300d-1M", cache_dir=cache_dir
)
# FT-CRAWL embeddings
elif (
embeddings.lower() == "news"
or embeddings.lower() == "en-news"
or embeddings.lower() == "en"
):
cached_path(
f"{base_path}en-fasttext-news-300d-1M.vectors.npy", cache_dir=cache_dir
)
embeddings = cached_path(
f"{base_path}en-fasttext-news-300d-1M", cache_dir=cache_dir
)
# twitter embeddings
elif embeddings.lower() == "twitter" or embeddings.lower() == "en-twitter":
cached_path(
f"{old_base_path}twitter.gensim.vectors.npy", cache_dir=cache_dir
)
embeddings = cached_path(
f"{old_base_path}twitter.gensim", cache_dir=cache_dir
)
# two-letter language code wiki embeddings
elif len(embeddings.lower()) == 2:
cached_path(
f"{embeddings_path_v4}{embeddings}-wiki-fasttext-300d-1M.vectors.npy",
cache_dir=cache_dir,
)
embeddings = cached_path(
f"{embeddings_path_v4}{embeddings}-wiki-fasttext-300d-1M",
cache_dir=cache_dir,
)
# two-letter language code wiki embeddings
elif len(embeddings.lower()) == 7 and embeddings.endswith("-wiki"):
cached_path(
f"{embeddings_path_v4}{embeddings[:2]}-wiki-fasttext-300d-1M.vectors.npy",
cache_dir=cache_dir,
)
embeddings = cached_path(
f"{embeddings_path_v4}{embeddings[:2]}-wiki-fasttext-300d-1M",
cache_dir=cache_dir,
)
# two-letter language code crawl embeddings
elif len(embeddings.lower()) == 8 and embeddings.endswith("-crawl"):
cached_path(
f"{embeddings_path_v4}{embeddings[:2]}-crawl-fasttext-300d-1M.vectors.npy",
cache_dir=cache_dir,
)
embeddings = cached_path(
f"{embeddings_path_v4}{embeddings[:2]}-crawl-fasttext-300d-1M",
cache_dir=cache_dir,
)
elif not Path(embeddings).exists():
raise ValueError(
f'The given embeddings "{embeddings}" is not available or is not a valid path.'
)
self.name: str = str(embeddings)
self.static_embeddings = True
if str(embeddings).endswith(".bin"):
self.precomputed_word_embeddings = gensim.models.KeyedVectors.load_word2vec_format(
str(embeddings), binary=True
)
else:
self.precomputed_word_embeddings = gensim.models.KeyedVectors.load(
str(embeddings)
)
self.field = field
self.__embedding_length: int = self.precomputed_word_embeddings.vector_size
super().__init__()
@property
def embedding_length(self) -> int:
return self.__embedding_length
@lru_cache(maxsize=10000, typed=False)
def get_cached_vec(self, word: str) -> torch.Tensor:
if word in self.precomputed_word_embeddings:
word_embedding = self.precomputed_word_embeddings[word]
elif word.lower() in self.precomputed_word_embeddings:
word_embedding = self.precomputed_word_embeddings[word.lower()]
elif re.sub(r"\d", "#", word.lower()) in self.precomputed_word_embeddings:
word_embedding = self.precomputed_word_embeddings[
re.sub(r"\d", "#", word.lower())
]
elif re.sub(r"\d", "0", word.lower()) in self.precomputed_word_embeddings:
word_embedding = self.precomputed_word_embeddings[
re.sub(r"\d", "0", word.lower())
]
else:
word_embedding = np.zeros(self.embedding_length, dtype="float")
word_embedding = torch.tensor(
word_embedding.tolist(), device=flair.device, dtype=torch.float
)
return word_embedding
def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
for i, sentence in enumerate(sentences):
for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
if "field" not in self.__dict__ or self.field is None:
word = token.text
else:
word = token.get_tag(self.field).value
word_embedding = self.get_cached_vec(word=word)
token.set_embedding(self.name, word_embedding)
return sentences
def __str__(self):
return self.name
def extra_repr(self):
# fix serialized models
if "embeddings" not in self.__dict__:
self.embeddings = self.name
return f"'{self.embeddings}'"
class CharacterEmbeddings(TokenEmbeddings):
"""Character embeddings of words, as proposed in Lample et al., 2016."""
def __init__(
self,
path_to_char_dict: str = None,
char_embedding_dim: int = 25,
hidden_size_char: int = 25,
):
"""Uses the default character dictionary if none provided."""
super().__init__()
self.name = "Char"
self.static_embeddings = False
# use list of common characters if none provided
if path_to_char_dict is None:
self.char_dictionary: Dictionary = Dictionary.load("common-chars")
else:
self.char_dictionary: Dictionary = Dictionary.load_from_file(
path_to_char_dict
)
self.char_embedding_dim: int = char_embedding_dim
self.hidden_size_char: int = hidden_size_char
self.char_embedding = torch.nn.Embedding(
len(self.char_dictionary.item2idx), self.char_embedding_dim
)
self.char_rnn = torch.nn.LSTM(
self.char_embedding_dim,
self.hidden_size_char,
num_layers=1,
bidirectional=True,
)
self.__embedding_length = self.hidden_size_char * 2
self.to(flair.device)
@property
def embedding_length(self) -> int:
return self.__embedding_length
def _add_embeddings_internal(self, sentences: List[Sentence]):
for sentence in sentences:
tokens_char_indices = []
# translate words in sentence into ints using dictionary
for token in sentence.tokens:
char_indices = [
self.char_dictionary.get_idx_for_item(char) for char in token.text
]
tokens_char_indices.append(char_indices)
# sort words by length, for batching and masking
tokens_sorted_by_length = sorted(
tokens_char_indices, key=lambda p: len(p), reverse=True
)
d = {}
for i, ci in enumerate(tokens_char_indices):
for j, cj in enumerate(tokens_sorted_by_length):
if ci == cj:
d[j] = i
continue
chars2_length = [len(c) for c in tokens_sorted_by_length]
longest_token_in_sentence = max(chars2_length)
tokens_mask = torch.zeros(
(len(tokens_sorted_by_length), longest_token_in_sentence),
dtype=torch.long,
device=flair.device,
)
for i, c in enumerate(tokens_sorted_by_length):
tokens_mask[i, : chars2_length[i]] = torch.tensor(
c, dtype=torch.long, device=flair.device
)
# chars for rnn processing
chars = tokens_mask
character_embeddings = self.char_embedding(chars).transpose(0, 1)
packed = torch.nn.utils.rnn.pack_padded_sequence(
character_embeddings, chars2_length
)
lstm_out, self.hidden = self.char_rnn(packed)
outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(lstm_out)
outputs = outputs.transpose(0, 1)
chars_embeds_temp = torch.zeros(
(outputs.size(0), outputs.size(2)),
dtype=torch.float,
device=flair.device,
)
for i, index in enumerate(output_lengths):
chars_embeds_temp[i] = outputs[i, index - 1]
character_embeddings = chars_embeds_temp.clone()
for i in range(character_embeddings.size(0)):
character_embeddings[d[i]] = chars_embeds_temp[i]
for token_number, token in enumerate(sentence.tokens):
token.set_embedding(self.name, character_embeddings[token_number])
def __str__(self):
return self.name
class FlairEmbeddings(TokenEmbeddings):
"""Contextual string embeddings of words, as proposed in Akbik et al., 2018."""
def __init__(self,
model,
fine_tune: bool = False,
chars_per_chunk: int = 512,
with_whitespace: bool = True,
tokenized_lm: bool = True,
):
"""
initializes contextual string embeddings using a character-level language model.
:param model: model string, one of 'news-forward', 'news-backward', 'news-forward-fast', 'news-backward-fast',
'mix-forward', 'mix-backward', 'german-forward', 'german-backward', 'polish-backward', 'polish-forward',
etc (see https://github.com/flairNLP/flair/blob/master/resources/docs/embeddings/FLAIR_EMBEDDINGS.md)
depending on which character language model is desired.
:param fine_tune: if set to True, the gradient will propagate into the language model. This dramatically slows
down training and often leads to overfitting, so use with caution.
:param chars_per_chunk: max number of chars per rnn pass to control speed/memory tradeoff. Higher means faster
but requires more memory. Lower means slower but less memory.
:param with_whitespace: If True, use hidden state after whitespace after word. If False, use hidden
state at last character of word.
:param tokenized_lm: Whether this lm is tokenized. Default is True, but for LMs trained over unprocessed text
False might be better.
"""
super().__init__()
cache_dir = Path("embeddings")
aws_path: str = "https://s3.eu-central-1.amazonaws.com/alan-nlp/resources"
hu_path: str = "https://flair.informatik.hu-berlin.de/resources"
clef_hipe_path: str = "https://files.ifi.uzh.ch/cl/siclemat/impresso/clef-hipe-2020/flair"
self.PRETRAINED_MODEL_ARCHIVE_MAP = {
# multilingual models
"multi-forward": f"{aws_path}/embeddings-v0.4.3/lm-jw300-forward-v0.1.pt",
"multi-backward": f"{aws_path}/embeddings-v0.4.3/lm-jw300-backward-v0.1.pt",
"multi-v0-forward": f"{aws_path}/embeddings-v0.4/lm-multi-forward-v0.1.pt",
"multi-v0-backward": f"{aws_path}/embeddings-v0.4/lm-multi-backward-v0.1.pt",
"multi-v0-forward-fast": f"{aws_path}/embeddings-v0.4/lm-multi-forward-fast-v0.1.pt",
"multi-v0-backward-fast": f"{aws_path}/embeddings-v0.4/lm-multi-backward-fast-v0.1.pt",
# English models
"en-forward": f"{aws_path}/embeddings-v0.4.1/big-news-forward--h2048-l1-d0.05-lr30-0.25-20/news-forward-0.4.1.pt",
"en-backward": f"{aws_path}/embeddings-v0.4.1/big-news-backward--h2048-l1-d0.05-lr30-0.25-20/news-backward-0.4.1.pt",
"en-forward-fast": f"{aws_path}/embeddings/lm-news-english-forward-1024-v0.2rc.pt",
"en-backward-fast": f"{aws_path}/embeddings/lm-news-english-backward-1024-v0.2rc.pt",
"news-forward": f"{aws_path}/embeddings-v0.4.1/big-news-forward--h2048-l1-d0.05-lr30-0.25-20/news-forward-0.4.1.pt",
"news-backward": f"{aws_path}/embeddings-v0.4.1/big-news-backward--h2048-l1-d0.05-lr30-0.25-20/news-backward-0.4.1.pt",
"news-forward-fast": f"{aws_path}/embeddings/lm-news-english-forward-1024-v0.2rc.pt",
"news-backward-fast": f"{aws_path}/embeddings/lm-news-english-backward-1024-v0.2rc.pt",
"mix-forward": f"{aws_path}/embeddings/lm-mix-english-forward-v0.2rc.pt",
"mix-backward": f"{aws_path}/embeddings/lm-mix-english-backward-v0.2rc.pt",
# Arabic
"ar-forward": f"{aws_path}/embeddings-stefan-it/lm-ar-opus-large-forward-v0.1.pt",
"ar-backward": f"{aws_path}/embeddings-stefan-it/lm-ar-opus-large-backward-v0.1.pt",
# Bulgarian
"bg-forward-fast": f"{aws_path}/embeddings-v0.3/lm-bg-small-forward-v0.1.pt",
"bg-backward-fast": f"{aws_path}/embeddings-v0.3/lm-bg-small-backward-v0.1.pt",
"bg-forward": f"{aws_path}/embeddings-stefan-it/lm-bg-opus-large-forward-v0.1.pt",
"bg-backward": f"{aws_path}/embeddings-stefan-it/lm-bg-opus-large-backward-v0.1.pt",
# Czech
"cs-forward": f"{aws_path}/embeddings-stefan-it/lm-cs-opus-large-forward-v0.1.pt",
"cs-backward": f"{aws_path}/embeddings-stefan-it/lm-cs-opus-large-backward-v0.1.pt",
"cs-v0-forward": f"{aws_path}/embeddings-v0.4/lm-cs-large-forward-v0.1.pt",
"cs-v0-backward": f"{aws_path}/embeddings-v0.4/lm-cs-large-backward-v0.1.pt",
# Danish
"da-forward": f"{aws_path}/embeddings-stefan-it/lm-da-opus-large-forward-v0.1.pt",
"da-backward": f"{aws_path}/embeddings-stefan-it/lm-da-opus-large-backward-v0.1.pt",
# German
"de-forward": f"{aws_path}/embeddings/lm-mix-german-forward-v0.2rc.pt",
"de-backward": f"{aws_path}/embeddings/lm-mix-german-backward-v0.2rc.pt",
"de-historic-ha-forward": f"{aws_path}/embeddings-stefan-it/lm-historic-hamburger-anzeiger-forward-v0.1.pt",
"de-historic-ha-backward": f"{aws_path}/embeddings-stefan-it/lm-historic-hamburger-anzeiger-backward-v0.1.pt",
"de-historic-wz-forward": f"{aws_path}/embeddings-stefan-it/lm-historic-wiener-zeitung-forward-v0.1.pt",
"de-historic-wz-backward": f"{aws_path}/embeddings-stefan-it/lm-historic-wiener-zeitung-backward-v0.1.pt",
"de-historic-rw-forward": f"{hu_path}/embeddings/redewiedergabe_lm_forward.pt",
"de-historic-rw-backward": f"{hu_path}/embeddings/redewiedergabe_lm_backward.pt",
# Spanish
"es-forward": f"{aws_path}/embeddings-v0.4/language_model_es_forward_long/lm-es-forward.pt",
"es-backward": f"{aws_path}/embeddings-v0.4/language_model_es_backward_long/lm-es-backward.pt",
"es-forward-fast": f"{aws_path}/embeddings-v0.4/language_model_es_forward/lm-es-forward-fast.pt",
"es-backward-fast": f"{aws_path}/embeddings-v0.4/language_model_es_backward/lm-es-backward-fast.pt",
# Basque
"eu-forward": f"{aws_path}/embeddings-stefan-it/lm-eu-opus-large-forward-v0.2.pt",
"eu-backward": f"{aws_path}/embeddings-stefan-it/lm-eu-opus-large-backward-v0.2.pt",
"eu-v1-forward": f"{aws_path}/embeddings-stefan-it/lm-eu-opus-large-forward-v0.1.pt",
"eu-v1-backward": f"{aws_path}/embeddings-stefan-it/lm-eu-opus-large-backward-v0.1.pt",
"eu-v0-forward": f"{aws_path}/embeddings-v0.4/lm-eu-large-forward-v0.1.pt",
"eu-v0-backward": f"{aws_path}/embeddings-v0.4/lm-eu-large-backward-v0.1.pt",
# Persian
"fa-forward": f"{aws_path}/embeddings-stefan-it/lm-fa-opus-large-forward-v0.1.pt",
"fa-backward": f"{aws_path}/embeddings-stefan-it/lm-fa-opus-large-backward-v0.1.pt",
# Finnish
"fi-forward": f"{aws_path}/embeddings-stefan-it/lm-fi-opus-large-forward-v0.1.pt",
"fi-backward": f"{aws_path}/embeddings-stefan-it/lm-fi-opus-large-backward-v0.1.pt",
# French
"fr-forward": f"{aws_path}/embeddings/lm-fr-charlm-forward.pt",
"fr-backward": f"{aws_path}/embeddings/lm-fr-charlm-backward.pt",
# Hebrew
"he-forward": f"{aws_path}/embeddings-stefan-it/lm-he-opus-large-forward-v0.1.pt",
"he-backward": f"{aws_path}/embeddings-stefan-it/lm-he-opus-large-backward-v0.1.pt",
# Hindi
"hi-forward": f"{aws_path}/embeddings-stefan-it/lm-hi-opus-large-forward-v0.1.pt",
"hi-backward": f"{aws_path}/embeddings-stefan-it/lm-hi-opus-large-backward-v0.1.pt",
# Croatian
"hr-forward": f"{aws_path}/embeddings-stefan-it/lm-hr-opus-large-forward-v0.1.pt",
"hr-backward": f"{aws_path}/embeddings-stefan-it/lm-hr-opus-large-backward-v0.1.pt",
# Indonesian
"id-forward": f"{aws_path}/embeddings-stefan-it/lm-id-opus-large-forward-v0.1.pt",
"id-backward": f"{aws_path}/embeddings-stefan-it/lm-id-opus-large-backward-v0.1.pt",
# Italian
"it-forward": f"{aws_path}/embeddings-stefan-it/lm-it-opus-large-forward-v0.1.pt",
"it-backward": f"{aws_path}/embeddings-stefan-it/lm-it-opus-large-backward-v0.1.pt",
# Japanese
"ja-forward": f"{aws_path}/embeddings-v0.4.1/lm__char-forward__ja-wikipedia-3GB/japanese-forward.pt",
"ja-backward": f"{aws_path}/embeddings-v0.4.1/lm__char-backward__ja-wikipedia-3GB/japanese-backward.pt",
# Malayalam
"ml-forward": f"https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/ml-forward.pt",
"ml-backward": f"https://raw.githubusercontent.com/qburst/models-repository/master/FlairMalayalamModels/ml-backward.pt",
# Dutch
"nl-forward": f"{aws_path}/embeddings-stefan-it/lm-nl-opus-large-forward-v0.1.pt",
"nl-backward": f"{aws_path}/embeddings-stefan-it/lm-nl-opus-large-backward-v0.1.pt",
"nl-v0-forward": f"{aws_path}/embeddings-v0.4/lm-nl-large-forward-v0.1.pt",
"nl-v0-backward": f"{aws_path}/embeddings-v0.4/lm-nl-large-backward-v0.1.pt",
# Norwegian
"no-forward": f"{aws_path}/embeddings-stefan-it/lm-no-opus-large-forward-v0.1.pt",
"no-backward": f"{aws_path}/embeddings-stefan-it/lm-no-opus-large-backward-v0.1.pt",
# Polish
"pl-forward": f"{aws_path}/embeddings/lm-polish-forward-v0.2.pt",
"pl-backward": f"{aws_path}/embeddings/lm-polish-backward-v0.2.pt",
"pl-opus-forward": f"{aws_path}/embeddings-stefan-it/lm-pl-opus-large-forward-v0.1.pt",
"pl-opus-backward": f"{aws_path}/embeddings-stefan-it/lm-pl-opus-large-backward-v0.1.pt",
# Portuguese
"pt-forward": f"{aws_path}/embeddings-v0.4/lm-pt-forward.pt",
"pt-backward": f"{aws_path}/embeddings-v0.4/lm-pt-backward.pt",
# Pubmed
"pubmed-forward": f"{aws_path}/embeddings-v0.4.1/pubmed-2015-fw-lm.pt",
"pubmed-backward": f"{aws_path}/embeddings-v0.4.1/pubmed-2015-bw-lm.pt",
# Slovenian
"sl-forward": f"{aws_path}/embeddings-stefan-it/lm-sl-opus-large-forward-v0.1.pt",
"sl-backward": f"{aws_path}/embeddings-stefan-it/lm-sl-opus-large-backward-v0.1.pt",
"sl-v0-forward": f"{aws_path}/embeddings-v0.3/lm-sl-large-forward-v0.1.pt",
"sl-v0-backward": f"{aws_path}/embeddings-v0.3/lm-sl-large-backward-v0.1.pt",
# Swedish
"sv-forward": f"{aws_path}/embeddings-stefan-it/lm-sv-opus-large-forward-v0.1.pt",
"sv-backward": f"{aws_path}/embeddings-stefan-it/lm-sv-opus-large-backward-v0.1.pt",
"sv-v0-forward": f"{aws_path}/embeddings-v0.4/lm-sv-large-forward-v0.1.pt",
"sv-v0-backward": f"{aws_path}/embeddings-v0.4/lm-sv-large-backward-v0.1.pt",
# Tamil
"ta-forward": f"{aws_path}/embeddings-stefan-it/lm-ta-opus-large-forward-v0.1.pt",
"ta-backward": f"{aws_path}/embeddings-stefan-it/lm-ta-opus-large-backward-v0.1.pt",
# CLEF HIPE Shared task
"de-impresso-hipe-v1-forward": f"{clef_hipe_path}/de-hipe-flair-v1-forward/best-lm.pt",
"de-impresso-hipe-v1-backward": f"{clef_hipe_path}/de-hipe-flair-v1-backward/best-lm.pt",
"en-impresso-hipe-v1-forward": f"{clef_hipe_path}/en-flair-v1-forward/best-lm.pt",
"en-impresso-hipe-v1-backward": f"{clef_hipe_path}/en-flair-v1-backward/best-lm.pt",
"fr-impresso-hipe-v1-forward": f"{clef_hipe_path}/fr-hipe-flair-v1-forward/best-lm.pt",
"fr-impresso-hipe-v1-backward": f"{clef_hipe_path}/fr-hipe-flair-v1-backward/best-lm.pt",
}
if type(model) == str:
# load model if in pretrained model map
if model.lower() in self.PRETRAINED_MODEL_ARCHIVE_MAP:
base_path = self.PRETRAINED_MODEL_ARCHIVE_MAP[model.lower()]
# Fix for CLEF HIPE models (avoid overwriting best-lm.pt in cache_dir)
if "impresso-hipe" in model.lower():
cache_dir = cache_dir / model.lower()
model = cached_path(base_path, cache_dir=cache_dir)
elif replace_with_language_code(model) in self.PRETRAINED_MODEL_ARCHIVE_MAP:
base_path = self.PRETRAINED_MODEL_ARCHIVE_MAP[
replace_with_language_code(model)
]
model = cached_path(base_path, cache_dir=cache_dir)
elif not Path(model).exists():
raise ValueError(
f'The given model "{model}" is not available or is not a valid path.'
)
from flair.models import LanguageModel
if type(model) == LanguageModel:
self.lm: LanguageModel = model
self.name = f"Task-LSTM-{self.lm.hidden_size}-{self.lm.nlayers}-{self.lm.is_forward_lm}"
else:
self.lm: LanguageModel = LanguageModel.load_language_model(model)
self.name = str(model)
# embeddings are static if we don't do finetuning
self.fine_tune = fine_tune
self.static_embeddings = not fine_tune
self.is_forward_lm: bool = self.lm.is_forward_lm
self.with_whitespace: bool = with_whitespace
self.tokenized_lm: bool = tokenized_lm
self.chars_per_chunk: int = chars_per_chunk
# embed a dummy sentence to determine embedding_length
dummy_sentence: Sentence = Sentence()
dummy_sentence.add_token(Token("hello"))
embedded_dummy = self.embed(dummy_sentence)
self.__embedding_length: int = len(
embedded_dummy[0].get_token(1).get_embedding()
)
# set to eval mode
self.eval()
def train(self, mode=True):
# make compatible with serialized models (TODO: remove)
if "fine_tune" not in self.__dict__:
self.fine_tune = False
if "chars_per_chunk" not in self.__dict__:
self.chars_per_chunk = 512
if not self.fine_tune:
pass
else:
super(FlairEmbeddings, self).train(mode)
@property
def embedding_length(self) -> int:
return self.__embedding_length
def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
# make compatible with serialized models (TODO: remove)
if "with_whitespace" not in self.__dict__:
self.with_whitespace = True
if "tokenized_lm" not in self.__dict__:
self.tokenized_lm = True
# gradients are enable if fine-tuning is enabled
gradient_context = torch.enable_grad() if self.fine_tune else torch.no_grad()
with gradient_context:
# if this is not possible, use LM to generate embedding. First, get text sentences
text_sentences = [sentence.to_tokenized_string() for sentence in sentences] if self.tokenized_lm \
else [sentence.to_plain_string() for sentence in sentences]
start_marker = self.lm.document_delimiter if "document_delimiter" in self.lm.__dict__ else '\n'
end_marker = " "
# get hidden states from language model
all_hidden_states_in_lm = self.lm.get_representation(
text_sentences, start_marker, end_marker, self.chars_per_chunk
)
if not self.fine_tune:
all_hidden_states_in_lm = all_hidden_states_in_lm.detach()
# take first or last hidden states from language model as word representation
for i, sentence in enumerate(sentences):
sentence_text = sentence.to_tokenized_string() if self.tokenized_lm else sentence.to_plain_string()
offset_forward: int = len(start_marker)
offset_backward: int = len(sentence_text) + len(start_marker)
for token in sentence.tokens:
offset_forward += len(token.text)
if self.is_forward_lm:
offset_with_whitespace = offset_forward
offset_without_whitespace = offset_forward - 1
else:
offset_with_whitespace = offset_backward
offset_without_whitespace = offset_backward - 1
# offset mode that extracts at whitespace after last character
if self.with_whitespace:
embedding = all_hidden_states_in_lm[offset_with_whitespace, i, :]
# offset mode that extracts at last character
else:
embedding = all_hidden_states_in_lm[offset_without_whitespace, i, :]
if self.tokenized_lm or token.whitespace_after:
offset_forward += 1
offset_backward -= 1
offset_backward -= len(token.text)
# only clone if optimization mode is 'gpu'
if flair.embedding_storage_mode == "gpu":
embedding = embedding.clone()
token.set_embedding(self.name, embedding)
del all_hidden_states_in_lm
return sentences
def __str__(self):
return self.name
class PooledFlairEmbeddings(TokenEmbeddings):
def __init__(
self,
contextual_embeddings: Union[str, FlairEmbeddings],
pooling: str = "min",
only_capitalized: bool = False,
**kwargs,
):
super().__init__()
# use the character language model embeddings as basis
if type(contextual_embeddings) is str:
self.context_embeddings: FlairEmbeddings = FlairEmbeddings(
contextual_embeddings, **kwargs
)
else:
self.context_embeddings: FlairEmbeddings = contextual_embeddings
# length is twice the original character LM embedding length
self.embedding_length = self.context_embeddings.embedding_length * 2
self.name = self.context_embeddings.name + "-context"
# these fields are for the embedding memory
self.word_embeddings = {}
self.word_count = {}
# whether to add only capitalized words to memory (faster runtime and lower memory consumption)
self.only_capitalized = only_capitalized
# we re-compute embeddings dynamically at each epoch
self.static_embeddings = False
# set the memory method
self.pooling = pooling
def train(self, mode=True):
super().train(mode=mode)
if mode:
# memory is wiped each time we do a training run
print("train mode resetting embeddings")
self.word_embeddings = {}
self.word_count = {}
def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
self.context_embeddings.embed(sentences)
# if we keep a pooling, it needs to be updated continuously
for sentence in sentences:
for token in sentence.tokens:
# update embedding
local_embedding = token._embeddings[self.context_embeddings.name].cpu()
# check token.text is empty or not
if token.text:
if token.text[0].isupper() or not self.only_capitalized:
if token.text not in self.word_embeddings:
self.word_embeddings[token.text] = local_embedding
self.word_count[token.text] = 1
else:
# set aggregation operation
if self.pooling == "mean":
aggregated_embedding = torch.add(self.word_embeddings[token.text], local_embedding)
elif self.pooling == "fade":
aggregated_embedding = torch.add(self.word_embeddings[token.text], local_embedding)
aggregated_embedding /= 2
elif self.pooling == "max":
aggregated_embedding = torch.max(self.word_embeddings[token.text], local_embedding)
elif self.pooling == "min":
aggregated_embedding = torch.min(self.word_embeddings[token.text], local_embedding)
self.word_embeddings[token.text] = aggregated_embedding
self.word_count[token.text] += 1
# add embeddings after updating
for sentence in sentences:
for token in sentence.tokens:
if token.text in self.word_embeddings:
base = (
self.word_embeddings[token.text] / self.word_count[token.text]
if self.pooling == "mean"
else self.word_embeddings[token.text]
)
else:
base = token._embeddings[self.context_embeddings.name]
token.set_embedding(self.name, base)
return sentences
def embedding_length(self) -> int:
return self.embedding_length
def get_names(self) -> List[str]:
return [self.name, self.context_embeddings.name]
def __setstate__(self, d):
self.__dict__ = d
if flair.device != 'cpu':
for key in self.word_embeddings:
self.word_embeddings[key] = self.word_embeddings[key].cpu()
class TransformerWordEmbeddings(TokenEmbeddings):
def __init__(
self,
model: str = "bert-base-uncased",
layers: str = "-1,-2,-3,-4",
pooling_operation: str = "first",
batch_size: int = 1,
use_scalar_mix: bool = False,
fine_tune: bool = False,
allow_long_sentences: bool = False,
**kwargs
):
"""
Bidirectional transformer embeddings of words from various transformer architectures.
:param model: name of transformer model (see https://huggingface.co/transformers/pretrained_models.html for
options)
:param layers: string indicating which layers to take for embedding (-1 is topmost layer)
:param pooling_operation: how to get from token piece embeddings to token embedding. Either take the first
subtoken ('first'), the last subtoken ('last'), both first and last ('first_last') or a mean over all ('mean')
:param batch_size: How many sentence to push through transformer at once. Set to 1 by default since transformer
models tend to be huge.
:param use_scalar_mix: If True, uses a scalar mix of layers as embedding
:param fine_tune: If True, allows transformers to be fine-tuned during training
"""
super().__init__()
# load tokenizer and transformer model
self.tokenizer = AutoTokenizer.from_pretrained(model, **kwargs)
config = AutoConfig.from_pretrained(model, output_hidden_states=True, **kwargs)
self.model = AutoModel.from_pretrained(model, config=config, **kwargs)
self.allow_long_sentences = allow_long_sentences
if allow_long_sentences:
self.max_subtokens_sequence_length = self.tokenizer.model_max_length
self.stride = self.tokenizer.model_max_length//2
else:
self.max_subtokens_sequence_length = None
self.stride = 0
# model name
self.name = 'transformer-word-' + str(model)
# when initializing, embeddings are in eval mode by default
self.model.eval()
self.model.to(flair.device)
# embedding parameters
if layers == 'all':
# send mini-token through to check how many layers the model has
hidden_states = self.model(torch.tensor([1], device=flair.device).unsqueeze(0))[-1]
self.layer_indexes = [int(x) for x in range(len(hidden_states))]
else:
self.layer_indexes = [int(x) for x in layers.split(",")]
# self.mix = ScalarMix(mixture_size=len(self.layer_indexes), trainable=False)
self.pooling_operation = pooling_operation
self.use_scalar_mix = use_scalar_mix
self.fine_tune = fine_tune
self.static_embeddings = not self.fine_tune
self.batch_size = batch_size
self.special_tokens = []
# check if special tokens exist to circumvent error message
if self.tokenizer._bos_token:
self.special_tokens.append(self.tokenizer.bos_token)
if self.tokenizer._cls_token:
self.special_tokens.append(self.tokenizer.cls_token)
# most models have an intial BOS token, except for XLNet, T5 and GPT2
self.begin_offset = 1
if type(self.tokenizer) == XLNetTokenizer:
self.begin_offset = 0
if type(self.tokenizer) == T5Tokenizer:
self.begin_offset = 0
if type(self.tokenizer) == GPT2Tokenizer:
self.begin_offset = 0
def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
"""Add embeddings to all words in a list of sentences."""
# split into micro batches of size self.batch_size before pushing through transformer
sentence_batches = [sentences[i * self.batch_size:(i + 1) * self.batch_size]
for i in range((len(sentences) + self.batch_size - 1) // self.batch_size)]
# embed each micro-batch
for batch in sentence_batches:
self._add_embeddings_to_sentences(batch)
return sentences
@staticmethod
def _remove_special_markup(text: str):
# remove special markup
text = re.sub('^Ġ', '', text) # RoBERTa models
text = re.sub('^##', '', text) # BERT models
text = re.sub('^▁', '', text) # XLNet models
text = re.sub('</w>$', '', text) # XLM models
return text
def _get_processed_token_text(self, token: Token) -> str:
pieces = self.tokenizer.tokenize(token.text)
token_text = ''
for piece in pieces:
token_text += self._remove_special_markup(piece)
token_text = token_text.lower()
return token_text
def _add_embeddings_to_sentences(self, sentences: List[Sentence]):
"""Match subtokenization to Flair tokenization and extract embeddings from transformers for each token."""
# first, subtokenize each sentence and find out into how many subtokens each token was divided
subtokenized_sentences = []
subtokenized_sentences_token_lengths = []
sentence_parts_lengths = []
# TODO: keep for backwards compatibility, but remove in future
# some pretrained models do not have this property, applying default settings now.
# can be set manually after loading the model.
if not hasattr(self, 'max_subtokens_sequence_length'):
self.max_subtokens_sequence_length = None
self.allow_long_sentences = False
self.stride = 0
non_empty_sentences = []
empty_sentences = []
for sentence in sentences:
tokenized_string = sentence.to_tokenized_string()
# method 1: subtokenize sentence
# subtokenized_sentence = self.tokenizer.encode(tokenized_string, add_special_tokens=True)
# method 2:
# transformer specific tokenization
subtokenized_sentence = self.tokenizer.tokenize(tokenized_string)
if len(subtokenized_sentence) == 0:
empty_sentences.append(sentence)
continue
else:
non_empty_sentences.append(sentence)
token_subtoken_lengths = self.reconstruct_tokens_from_subtokens(sentence, subtokenized_sentence)
subtokenized_sentences_token_lengths.append(token_subtoken_lengths)
subtoken_ids_sentence = self.tokenizer.convert_tokens_to_ids(subtokenized_sentence)
nr_sentence_parts = 0
while subtoken_ids_sentence:
nr_sentence_parts += 1
encoded_inputs = self.tokenizer.prepare_for_model(subtoken_ids_sentence,
max_length=self.max_subtokens_sequence_length,
stride=self.stride,
return_overflowing_tokens=self.allow_long_sentences)
subtoken_ids_split_sentence = encoded_inputs['input_ids']
subtokenized_sentences.append(torch.tensor(subtoken_ids_split_sentence, dtype=torch.long))
if 'overflowing_tokens' in encoded_inputs:
subtoken_ids_sentence = encoded_inputs['overflowing_tokens']
else:
subtoken_ids_sentence = None
sentence_parts_lengths.append(nr_sentence_parts)
# empty sentences get zero embeddings
for sentence in empty_sentences:
for token in sentence:
token.set_embedding(self.name, torch.zeros(self.embedding_length))
# only embed non-empty sentences and if there is at least one
sentences = non_empty_sentences
if len(sentences) == 0: return
# find longest sentence in batch
longest_sequence_in_batch: int = len(max(subtokenized_sentences, key=len))
total_sentence_parts = sum(sentence_parts_lengths)
# initialize batch tensors and mask
input_ids = torch.zeros(
[total_sentence_parts, longest_sequence_in_batch],
dtype=torch.long,
device=flair.device,
)
mask = torch.zeros(
[total_sentence_parts, longest_sequence_in_batch],
dtype=torch.long,
device=flair.device,
)
for s_id, sentence in enumerate(subtokenized_sentences):
sequence_length = len(sentence)
input_ids[s_id][:sequence_length] = sentence
mask[s_id][:sequence_length] = torch.ones(sequence_length)