Main resource: https://github.com/flairNLP/flair/blob/master/resources/docs/TUTORIAL_7_TRAINING_A_MODEL.md

In [1]:
import flair
import numpy as np
import pandas as pd
import torch
from torch.optim.adam import Adam

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
flair.device = device

print(flair.device)

cuda


In [2]:
print(torch.__version__)

1.7.1+cu110


In [3]:
from flair.data import Corpus
from flair.datasets import CSVClassificationCorpus
from flair.embeddings import WordEmbeddings, FlairEmbeddings, StackedEmbeddings, TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer

data_folder = '../data/corpuslow_13042021'
column_name_map = {1: "text", 2: "label_topic"}

# 1. get the corpus
corpus: Corpus = CSVClassificationCorpus(data_folder,
                                         column_name_map,
                                         skip_header=True) 

# 2. create the label dictionary
label_dict = corpus.make_label_dictionary()

print(label_dict)

2021-04-18 17:45:47,591 Reading data from ..\data\corpuslow_13042021
2021-04-18 17:45:47,592 Train: ..\data\corpuslow_13042021\train.csv
2021-04-18 17:45:47,592 Dev: ..\data\corpuslow_13042021\dev.csv
2021-04-18 17:45:47,593 Test: ..\data\corpuslow_13042021\test.csv
2021-04-18 17:45:47,627 Computing label dictionary. Progress:


100%|███████████████████████████████████████| 976/976 [00:01<00:00, 864.44it/s]

2021-04-18 17:46:06,960 [b'0', b'1']
Dictionary with 2 tags: 0, 1





In [None]:
# 4. initialize document embedding by passing list of word embeddings
document_embeddings = TransformerDocumentEmbeddings('bert-base-uncased', fine_tune=True)

# 5. create the text classifier
classifier = TextClassifier(document_embeddings, label_dictionary=label_dict)

# 6. initialize the text classifier trainer
trainer = ModelTrainer(classifier, corpus, optimizer=Adam)

# 7. start the training
trainer.train('./flair/transformers_low',
              learning_rate=3e-5, # use very small learning rate
              mini_batch_size=16,
              mini_batch_chunk_size=4, # optionally set this if transformer is too much for your machine
              max_epochs=10) # terminate after 5 epochs)

2021-04-18 17:46:16,851 ----------------------------------------------------------------------------------------------------
2021-04-18 17:46:16,854 Model: "TextClassifier(
  (document_embeddings): TransformerDocumentEmbeddings(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
               