Main resource: https://github.com/flairNLP/flair/blob/master/resources/docs/TUTORIAL_7_TRAINING_A_MODEL.md

In [1]:
import flair
import numpy as np
import pandas as pd
import torch
from torch.optim.adam import Adam

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
flair.device = device

print(flair.device)

cuda


In [2]:
print(torch.__version__)

1.7.1+cu110


In [3]:
from flair.data import Corpus
from flair.datasets import CSVClassificationCorpus
from flair.embeddings import WordEmbeddings, FlairEmbeddings, StackedEmbeddings, TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer

data_folder = '../data/corpus_10042021'
column_name_map = {1: "text", 2: "label_topic"}

# 1. get the corpus
corpus: Corpus = CSVClassificationCorpus(data_folder,
                                         column_name_map,
                                         skip_header=True) 

# 2. create the label dictionary
label_dict = corpus.make_label_dictionary()

print(label_dict)

2021-04-18 16:30:48,905 Reading data from ..\data\corpus_10042021
2021-04-18 16:30:48,953 Train: ..\data\corpus_10042021\train.csv
2021-04-18 16:30:48,953 Dev: ..\data\corpus_10042021\dev.csv
2021-04-18 16:30:48,954 Test: ..\data\corpus_10042021\test.csv
2021-04-18 16:30:48,989 Computing label dictionary. Progress:


100%|███████████████████████████████████████| 976/976 [00:01<00:00, 863.33it/s]

2021-04-18 16:31:08,157 [b'High', b'Low', b'Medium']
Dictionary with 3 tags: High, Low, Medium





In [7]:
# 4. initialize document embedding by passing list of word embeddings
document_embeddings = TransformerDocumentEmbeddings('bert-base-uncased', fine_tune=True)

# 5. create the text classifier
classifier = TextClassifier(document_embeddings, label_dictionary=label_dict)

# 6. initialize the text classifier trainer
trainer = ModelTrainer(classifier, corpus, optimizer=Adam)

# 7. start the training
trainer.train('./flair/transformers',
              learning_rate=3e-5, # use very small learning rate
              mini_batch_size=4,
              mini_batch_chunk_size=2, # optionally set this if transformer is too much for your machine
              max_epochs=10) # terminate after 10 epochs)

Downloading:   0%|          | 0.00/433 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

2021-04-18 16:48:34,774 ----------------------------------------------------------------------------------------------------
2021-04-18 16:48:34,777 Model: "TextClassifier(
  (document_embeddings): TransformerDocumentEmbeddings(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
               

{'test_score': 0.422,
 'dev_score_history': [0.4259,
  0.4259,
  0.4074,
  0.4074,
  0.5093,
  0.4352,
  0.4722,
  0.4352,
  0.4722,
  0.4352],
 'train_loss_history': [1.106940983764587,
  1.0478954512265422,
  0.7803794679254473,
  0.3641364173836085,
  0.1156075713953646,
  0.04744413032194318,
  0.012784025995174052,
  0.019352154888300978,
  0.0008776853299023337,
  0.002259977813497576],
 'dev_loss_history': [1.0516793727874756,
  1.17844557762146,
  1.6175380945205688,
  2.8741862773895264,
  3.242818593978882,
  4.196915626525879,
  4.7205705642700195,
  4.628772258758545,
  5.626125335693359,
  5.3988423347473145]}