In [1]:
import flair
import torch
from flair.data import Corpus
from flair.datasets import ColumnCorpus
from flair.embeddings import TokenEmbeddings, WordEmbeddings, StackedEmbeddings, FlairEmbeddings
from typing import List
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer

In [2]:
# 1. get the corpus
corpus: Corpus = flair.datasets.ColumnCorpus('dataset/Ontonotes-conll-formatted/',
                                             column_format={0: 'text', 1: 'pos', 2: 'upos'},)

# 2. what tag do we want to predict?
tag_type = 'pos'

# 3. make the tag dictionary from the corpus
tag_dictionary = corpus.make_tag_dictionary(tag_type=tag_type)

# initialize embeddings
embedding_types: List[TokenEmbeddings] = [
    WordEmbeddings('crawl'),
    FlairEmbeddings('news-forward'),
    FlairEmbeddings('news-backward'),
]

embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types)


2020-11-02 23:56:21,875 Reading data from dataset/Ontonotes-conll-formatted
2020-11-02 23:56:21,876 Train: dataset/Ontonotes-conll-formatted/train.english.v4_gold_conll
2020-11-02 23:56:21,876 Dev: dataset/Ontonotes-conll-formatted/dev.english.v4_gold_conll
2020-11-02 23:56:21,876 Test: dataset/Ontonotes-conll-formatted/test.english.v4_gold_conll


In [None]:
# initialize sequence tagger
tagger: SequenceTagger = SequenceTagger(hidden_size=256,
                                        embeddings=embeddings,
                                        tag_dictionary=tag_dictionary,
                                        tag_type=tag_type)

# initialize trainer

trainer: ModelTrainer = ModelTrainer(tagger, corpus)

trainer.train('models/taggers/ontonotes-pos',
              learning_rate=0.1,
              train_with_dev=True,  
              # it's a big dataset so maybe set embeddings_storage_mode to 'none' (embeddings are not kept in memory)
              embeddings_storage_mode='none', 
              checkpoint=True
             )

2020-10-30 23:03:50,204 ----------------------------------------------------------------------------------------------------
2020-10-30 23:03:50,205 Model: "SequenceTagger(
  (embeddings): StackedEmbeddings(
    (list_embedding_0): WordEmbeddings('crawl')
    (list_embedding_1): FlairEmbeddings(
      (lm): LanguageModel(
        (drop): Dropout(p=0.05, inplace=False)
        (encoder): Embedding(300, 100)
        (rnn): LSTM(100, 2048)
        (decoder): Linear(in_features=2048, out_features=300, bias=True)
      )
    )
    (list_embedding_2): FlairEmbeddings(
      (lm): LanguageModel(
        (drop): Dropout(p=0.05, inplace=False)
        (encoder): Embedding(300, 100)
        (rnn): LSTM(100, 2048)
        (decoder): Linear(in_features=2048, out_features=300, bias=True)
      )
    )
  )
  (word_dropout): WordDropout(p=0.05)
  (locked_dropout): LockedDropout(p=0.5)
  (embedding2nn): Linear(in_features=4396, out_features=4396, bias=True)
  (rnn): LSTM(4396, 256, batch_first=True, b

2020-10-31 01:18:17,793 epoch 5 - iter 530/2650 - loss 2.35167194 - samples/sec: 44.87 - lr: 0.100000
2020-10-31 01:21:28,802 epoch 5 - iter 795/2650 - loss 2.33863143 - samples/sec: 44.40 - lr: 0.100000
2020-10-31 01:24:40,857 epoch 5 - iter 1060/2650 - loss 2.35425532 - samples/sec: 44.16 - lr: 0.100000
2020-10-31 01:27:57,429 epoch 5 - iter 1325/2650 - loss 2.35761821 - samples/sec: 43.14 - lr: 0.100000
2020-10-31 01:31:10,355 epoch 5 - iter 1590/2650 - loss 2.35989253 - samples/sec: 43.96 - lr: 0.100000
2020-10-31 01:34:25,896 epoch 5 - iter 1855/2650 - loss 2.36756338 - samples/sec: 43.37 - lr: 0.100000
2020-10-31 01:37:29,418 epoch 5 - iter 2120/2650 - loss 2.35625317 - samples/sec: 46.21 - lr: 0.100000
2020-10-31 01:40:41,531 epoch 5 - iter 2385/2650 - loss 2.35640405 - samples/sec: 44.14 - lr: 0.100000
2020-10-31 01:43:52,239 epoch 5 - iter 2650/2650 - loss 2.36594388 - samples/sec: 44.47 - lr: 0.100000
2020-10-31 01:43:52,239 ---------------------------------------------------

2020-10-31 04:27:19,537 epoch 11 - iter 265/2650 - loss 2.26305812 - samples/sec: 44.01 - lr: 0.100000
2020-10-31 04:30:28,710 epoch 11 - iter 530/2650 - loss 2.27788934 - samples/sec: 44.83 - lr: 0.100000
2020-10-31 04:33:39,076 epoch 11 - iter 795/2650 - loss 2.26249915 - samples/sec: 44.55 - lr: 0.100000
2020-10-31 04:36:51,666 epoch 11 - iter 1060/2650 - loss 2.28012602 - samples/sec: 44.03 - lr: 0.100000
2020-10-31 04:39:58,374 epoch 11 - iter 1325/2650 - loss 2.28141983 - samples/sec: 45.42 - lr: 0.100000
2020-10-31 04:43:12,667 epoch 11 - iter 1590/2650 - loss 2.28381920 - samples/sec: 43.65 - lr: 0.100000
2020-10-31 04:46:28,370 epoch 11 - iter 1855/2650 - loss 2.29653991 - samples/sec: 43.33 - lr: 0.100000
2020-10-31 04:49:37,874 epoch 11 - iter 2120/2650 - loss 2.29520733 - samples/sec: 44.75 - lr: 0.100000
2020-10-31 04:52:48,988 epoch 11 - iter 2385/2650 - loss 2.30810944 - samples/sec: 44.37 - lr: 0.100000
2020-10-31 04:55:58,438 epoch 11 - iter 2650/2650 - loss 2.30928524

2020-10-31 07:35:52,659 EPOCH 16 done: loss 2.2818 - lr 0.1000000
2020-10-31 07:35:52,659 BAD EPOCHS (no improvement): 1
2020-10-31 07:35:58,514 ----------------------------------------------------------------------------------------------------
2020-10-31 07:39:14,177 epoch 17 - iter 265/2650 - loss 2.27007309 - samples/sec: 43.34 - lr: 0.100000
2020-10-31 07:42:28,428 epoch 17 - iter 530/2650 - loss 2.30910558 - samples/sec: 43.66 - lr: 0.100000
2020-10-31 07:45:36,090 epoch 17 - iter 795/2650 - loss 2.30589203 - samples/sec: 45.19 - lr: 0.100000
2020-10-31 07:48:49,156 epoch 17 - iter 1060/2650 - loss 2.31545552 - samples/sec: 43.93 - lr: 0.100000
2020-10-31 07:52:00,116 epoch 17 - iter 1325/2650 - loss 2.30370109 - samples/sec: 44.41 - lr: 0.100000
2020-10-31 07:55:11,710 epoch 17 - iter 1590/2650 - loss 2.30119457 - samples/sec: 44.26 - lr: 0.100000
2020-10-31 07:58:22,781 epoch 17 - iter 1855/2650 - loss 2.29584629 - samples/sec: 44.38 - lr: 0.100000
2020-10-31 08:01:33,607 epoch

2020-10-31 10:47:59,747 epoch 22 - iter 2650/2650 - loss 1.49423668 - samples/sec: 45.22 - lr: 0.050000
2020-10-31 10:47:59,748 ----------------------------------------------------------------------------------------------------
2020-10-31 10:47:59,748 EPOCH 22 done: loss 1.4942 - lr 0.0500000
2020-10-31 10:47:59,748 BAD EPOCHS (no improvement): 0
2020-10-31 10:48:05,711 ----------------------------------------------------------------------------------------------------
2020-10-31 10:51:22,445 epoch 23 - iter 265/2650 - loss 1.51317963 - samples/sec: 43.11 - lr: 0.050000
2020-10-31 10:54:33,920 epoch 23 - iter 530/2650 - loss 1.50244934 - samples/sec: 44.29 - lr: 0.050000
2020-10-31 10:57:50,135 epoch 23 - iter 795/2650 - loss 1.49869032 - samples/sec: 43.22 - lr: 0.050000
2020-10-31 11:00:58,325 epoch 23 - iter 1060/2650 - loss 1.50031695 - samples/sec: 45.06 - lr: 0.050000
2020-10-31 11:04:12,196 epoch 23 - iter 1325/2650 - loss 1.50300252 - samples/sec: 43.74 - lr: 0.050000
2020-10-

2020-10-31 13:54:00,812 epoch 28 - iter 2120/2650 - loss 1.49233444 - samples/sec: 44.98 - lr: 0.050000
2020-10-31 13:57:13,517 epoch 28 - iter 2385/2650 - loss 1.49492994 - samples/sec: 44.01 - lr: 0.050000
2020-10-31 14:00:25,715 epoch 28 - iter 2650/2650 - loss 1.49607755 - samples/sec: 44.12 - lr: 0.050000
2020-10-31 14:00:25,716 ----------------------------------------------------------------------------------------------------
2020-10-31 14:00:25,717 EPOCH 28 done: loss 1.4961 - lr 0.0500000
2020-10-31 14:00:25,717 BAD EPOCHS (no improvement): 3
2020-10-31 14:00:31,620 ----------------------------------------------------------------------------------------------------
2020-10-31 14:03:44,242 epoch 29 - iter 265/2650 - loss 1.49367757 - samples/sec: 44.03 - lr: 0.050000
2020-10-31 14:06:55,573 epoch 29 - iter 530/2650 - loss 1.51075876 - samples/sec: 44.32 - lr: 0.050000
2020-10-31 14:10:03,376 epoch 29 - iter 795/2650 - loss 1.49596605 - samples/sec: 45.16 - lr: 0.050000
2020-10-

In [3]:
from pathlib import Path

checkpoint = 'models/taggers/ontonotes-pos/checkpoint.pt'
trainer = ModelTrainer.load_checkpoint(checkpoint, corpus)
trainer.train('models/taggers/ontonotes-pos',
              learning_rate=0.1,
              train_with_dev=True,  
              # it's a big dataset so maybe set embeddings_storage_mode to 'none' (embeddings are not kept in memory)
              embeddings_storage_mode='none', 
              checkpoint=True
             )

2020-11-02 23:56:55,166 ----------------------------------------------------------------------------------------------------
2020-11-02 23:56:55,167 Model: "SequenceTagger(
  (embeddings): StackedEmbeddings(
    (list_embedding_0): WordEmbeddings('crawl')
    (list_embedding_1): FlairEmbeddings(
      (lm): LanguageModel(
        (drop): Dropout(p=0.05, inplace=False)
        (encoder): Embedding(300, 100)
        (rnn): LSTM(100, 2048)
        (decoder): Linear(in_features=2048, out_features=300, bias=True)
      )
    )
    (list_embedding_2): FlairEmbeddings(
      (lm): LanguageModel(
        (drop): Dropout(p=0.05, inplace=False)
        (encoder): Embedding(300, 100)
        (rnn): LSTM(100, 2048)
        (decoder): Linear(in_features=2048, out_features=300, bias=True)
      )
    )
  )
  (word_dropout): WordDropout(p=0.05)
  (locked_dropout): LockedDropout(p=0.5)
  (embedding2nn): Linear(in_features=4396, out_features=4396, bias=True)
  (rnn): LSTM(4396, 256, batch_first=True, b

2020-11-03 02:08:09,089 epoch 81 - iter 265/2650 - loss 2.27419177 - samples/sec: 42.15 - lr: 0.100000
2020-11-03 02:11:19,789 epoch 81 - iter 530/2650 - loss 2.25483339 - samples/sec: 44.47 - lr: 0.100000
2020-11-03 02:14:29,717 epoch 81 - iter 795/2650 - loss 2.25330724 - samples/sec: 44.65 - lr: 0.100000
2020-11-03 02:17:36,554 epoch 81 - iter 1060/2650 - loss 2.25417044 - samples/sec: 45.39 - lr: 0.100000
2020-11-03 02:20:42,059 epoch 81 - iter 1325/2650 - loss 2.24131235 - samples/sec: 45.72 - lr: 0.100000
2020-11-03 02:23:49,028 epoch 81 - iter 1590/2650 - loss 2.24311438 - samples/sec: 45.36 - lr: 0.100000
2020-11-03 02:26:52,600 epoch 81 - iter 1855/2650 - loss 2.23395515 - samples/sec: 46.20 - lr: 0.100000
2020-11-03 02:30:07,260 epoch 81 - iter 2120/2650 - loss 2.23587232 - samples/sec: 43.57 - lr: 0.100000
2020-11-03 02:33:23,926 epoch 81 - iter 2385/2650 - loss 2.23530808 - samples/sec: 43.12 - lr: 0.100000
2020-11-03 02:36:30,165 epoch 81 - iter 2650/2650 - loss 2.23188309

2020-11-03 05:16:15,179 EPOCH 86 done: loss 1.4498 - lr 0.0500000
2020-11-03 05:16:15,179 BAD EPOCHS (no improvement): 0
2020-11-03 05:16:21,208 ----------------------------------------------------------------------------------------------------
2020-11-03 05:19:33,265 epoch 87 - iter 265/2650 - loss 1.46422534 - samples/sec: 44.16 - lr: 0.050000
2020-11-03 05:22:51,847 epoch 87 - iter 530/2650 - loss 1.46983711 - samples/sec: 42.71 - lr: 0.050000
2020-11-03 05:26:03,974 epoch 87 - iter 795/2650 - loss 1.46093412 - samples/sec: 44.14 - lr: 0.050000
2020-11-03 05:29:09,791 epoch 87 - iter 1060/2650 - loss 1.46080808 - samples/sec: 45.64 - lr: 0.050000
2020-11-03 05:32:16,404 epoch 87 - iter 1325/2650 - loss 1.44980201 - samples/sec: 45.44 - lr: 0.050000
2020-11-03 05:35:35,667 epoch 87 - iter 1590/2650 - loss 1.45737809 - samples/sec: 42.56 - lr: 0.050000
2020-11-03 05:38:51,199 epoch 87 - iter 1855/2650 - loss 1.45427178 - samples/sec: 43.37 - lr: 0.050000
2020-11-03 05:42:02,313 epoch

2020-11-03 08:29:04,811 epoch 92 - iter 2650/2650 - loss 1.45273500 - samples/sec: 43.14 - lr: 0.050000
2020-11-03 08:29:04,811 ----------------------------------------------------------------------------------------------------
2020-11-03 08:29:04,812 EPOCH 92 done: loss 1.4527 - lr 0.0500000
2020-11-03 08:29:04,812 BAD EPOCHS (no improvement): 1
2020-11-03 08:29:10,798 ----------------------------------------------------------------------------------------------------
2020-11-03 08:32:22,647 epoch 93 - iter 265/2650 - loss 1.42957536 - samples/sec: 44.20 - lr: 0.050000
2020-11-03 08:35:29,369 epoch 93 - iter 530/2650 - loss 1.44931692 - samples/sec: 45.42 - lr: 0.050000
2020-11-03 08:38:40,718 epoch 93 - iter 795/2650 - loss 1.45013816 - samples/sec: 44.32 - lr: 0.050000
2020-11-03 08:41:57,609 epoch 93 - iter 1060/2650 - loss 1.44894879 - samples/sec: 43.07 - lr: 0.050000
2020-11-03 08:45:09,085 epoch 93 - iter 1325/2650 - loss 1.44910877 - samples/sec: 44.29 - lr: 0.050000
2020-11-

2020-11-03 11:33:41,490 epoch 98 - iter 1855/2650 - loss 1.16749422 - samples/sec: 44.34 - lr: 0.025000
2020-11-03 11:36:59,448 epoch 98 - iter 2120/2650 - loss 1.16489176 - samples/sec: 42.84 - lr: 0.025000
2020-11-03 11:40:14,392 epoch 98 - iter 2385/2650 - loss 1.16614878 - samples/sec: 43.50 - lr: 0.025000
2020-11-03 11:43:24,993 epoch 98 - iter 2650/2650 - loss 1.16703117 - samples/sec: 44.49 - lr: 0.025000
2020-11-03 11:43:24,993 ----------------------------------------------------------------------------------------------------
2020-11-03 11:43:24,994 EPOCH 98 done: loss 1.1670 - lr 0.0250000
2020-11-03 11:43:24,994 BAD EPOCHS (no improvement): 1
2020-11-03 11:43:31,005 ----------------------------------------------------------------------------------------------------
2020-11-03 11:46:51,165 epoch 99 - iter 265/2650 - loss 1.18048681 - samples/sec: 42.37 - lr: 0.025000
2020-11-03 11:50:02,650 epoch 99 - iter 530/2650 - loss 1.17623984 - samples/sec: 44.29 - lr: 0.025000
2020-11

{'test_score': 0.5935,
 'dev_score_history': [],
 'train_loss_history': [2.220548223684419,
  2.2322360849380494,
  2.212665776992744,
  2.2258442742532156,
  2.2318830942432837,
  2.234760799419205,
  2.2260387383879356,
  1.4582174939479469,
  1.4571732783879874,
  1.44979367571057,
  1.4502403410313265,
  1.4491260002356656,
  1.4633362090812538,
  1.4604455889171024,
  1.4482243834351594,
  1.452734996822645,
  1.4595477888044321,
  1.4526241859062663,
  1.4515750320452565,
  1.1691802895743892,
  1.165818441543939,
  1.1670311651589735,
  1.1688919344375719,
  1.1667681279159943],
 'dev_loss_history': []}