In [7]:
from pathlib import Path


#spaCy imports
import spacy
from tqdm import tqdm # loading bar
from spacy import displacy
from spacy.scorer import Scorer
import random
from spacy.gold import GoldParse


#flair imports
from flair.data import Corpus
from flair.datasets import CSVClassificationCorpus, ColumnCorpus
from flair.embeddings import WordEmbeddings, FlairEmbeddings, DocumentRNNEmbeddings,DocumentLSTMEmbeddings, BertEmbeddings, StackedEmbeddings, TokenEmbeddings
from flair.trainers import ModelTrainer
from flair.models import SequenceTagger
from typing import List
from flair.data import Sentence



class Model(object):
    def __init__(self, model_format, model_name, training_data, nb_iter, out_dir):
        self.model_name = model_name
        self.nb_iter = nb_iter
        self.training_data = training_data
        self.out_dir = out_dir
        self.is_ready = False
        self.model_format=model_format.lower()


    def convert_format(self):



        #spaCy
        if self.model_format == "spacy_format" :
            json_file=self.training_data.get_file()
            data=[]
            for obj in json_file :
                entities = []
                for e in obj['entities'] :
                    entities.append(tuple(e))
                data.append((obj['text'], {'entities' : entities}))
            self.training_data=data


        #flair
        elif self.model_format == "bio_format":
            file = open("./data/format.txt", "w")
            pos_beg = 0
            pos_end = 1
            annot_beg = 0
            annot_end = 1
            label = None
            matches = None
            json_file=self.training_data.get_file()
            for obj in json_file:
                pattern = re.compile(r"\w'|\w+|[^\w\s]")
                matches = pattern.finditer(obj['text'])

                if not matches:
                    print("error")

                for m in matches:
                    file.write(m.group())
                    file.write(" ")
                    pos_beg = m.span()[0]
                    pos_end = m.span()[1]

                    for e in obj['entities']:
                        annot_beg = e[0]
                        annot_end = e[1]
                        label = e[2]
                        #At the beginning of the entity's position
                        if pos_beg == annot_beg:
                            file.write("B-"+label)
                            
                        #between the entity's position
                        elif pos_beg > annot_beg and pos_end <= annot_end:
                            file.write("I-"+label)
                        else:
                            file.write("O")

                    file.write("\n")
                file.write("\n")

            file.close()
            self.training_data = "./data/format.txt"

    








class SpacyModel(Model):
    def __init__(self, model_format, model_name, training_data, nb_iter, out_dir,model):
        Model.__init__(self,model_format, model_name, training_data, nb_iter, out_dir)
        self.visuals = []
        if model is not None:
            self.nlp = spacy.load(model)
            print("Loaded model '%s'" % model)
        else:
            self.nlp = spacy.blank('fr')
            print("Created new model")
            
        if 'ner' not in self.nlp.pipe_names:
            self.ner = self.nlp.create_pipe('ner')
            self.nlp.add_pipe(self.ner)
        else:
            self.ner = self.nlp.get_pipe('ner')
        labels = [ label for label in training_data.get_labels()]
        for l in labels :
            self.ner.add_label(l)
        if model is None:
            self.optimizer = self.nlp.begin_training()
        else:
            self.optimizer = self.nlp.entity.create_optimizer()

    def get_visuals(self):
        return self.visuals 
            
    def train(self):
        self.convert_format()
        other_pipes = [pipe for pipe in self.nlp.pipe_names if pipe != 'ner']
        with self.nlp.disable_pipes(*other_pipes):
            for itn in range(self.nb_iter):
                random.shuffle(self.training_data)
                losses = {}
                for text, annotations in tqdm(self.training_data):
                    self.nlp.update([text], [annotations], sgd=self.optimizer, drop=0.35,
                        losses=losses)
                print(losses)
        self.is_ready = True
                
    def test(self, test_data):
        scorer = Scorer()
        for sents, ents in test_data:
            doc_gold = self.nlp.make_doc(sents)
            gold = GoldParse(doc_gold, entities=ents['entities'])
            pred_value = self.nlp(sents)
            visual = displacy.render(pred_value, style="ent")
            visual = visual.replace("\n\n","\n")
            self.visuals.append(visual)
            scorer.score(pred_value, gold)
            print(scorer.scores)
        return scorer.scores
   


    
    def save(self):
        if self.out_dir is not None:
            self.out_dir = Path(self.out_dir)
        if not self.out_dir.exists():
            self.out_dir.mkdir()
        self.nlp.to_disk(self.out_dir)
        print("Modele saved in :", self.out_dir)

                
class FlairModel(Model):
    """Class to train/test a model using flair"""
    
    def __init__(self,model_format, model_name, training_data, nb_iter=10, lr=0.1, batch=32, mode='cpu', out_dir=None):
        Model.__init__(self,model_format, model_name, training_data, nb_iter, out_dir)
        self.model_name = './'+ model_name
        self.learning_rate=lr
        self.batch_size=batch
        self.mode=mode
        self.training_data = training_data
        


    def train(self):
        self.convert_format()
       
        corpus: Corpus = ColumnCorpus(".", {0: 'text', 1: 'ner'},
                                      train_file=self.training_data
                                      )
        tag_dictionary = corpus.make_tag_dictionary(tag_type='ner')
        embedding_types: List[TokenEmbeddings] = [

            WordEmbeddings('fr'),
            FlairEmbeddings('fr-forward'),
            FlairEmbeddings('fr-backward'),
        ]

        embeddings: StackedEmbeddings = StackedEmbeddings(embeddings=embedding_types)
        tagger: SequenceTagger = SequenceTagger(hidden_size=256,
                                                embeddings=embeddings,
                                                tag_dictionary=tag_dictionary,
                                                tag_type='ner',
                                                use_crf=True)
        self.trainer = ModelTrainer(tagger, corpus)
        self.trainer.train(self.model_name,learning_rate=self.learning_rate,mini_batch_size=self.batch_size, max_epochs=self.nb_iter,embeddings_storage_mode=self.mode)




    def test(self, test_data):
        model = SequenceTagger.load(self.model_name+'/best-model.pt')

        corpus: Corpus = ColumnCorpus(".", {0: 'text', 1: 'ner'},
                                        train_file=None,
                                        test_file=test_data
                                      )
        result, eval_loss = model.evaluate(corpus.test)
        # permet de retourner un dictionnaire de la même forme que celui fourni pas spaCy
        res = result.detailed_results
        res = res.split('\n')[:][9:-4]
        res =' '.join(str(res).split())
        res = res.replace("[\'",'')
        res = res.replace("\']",'')
        res = res.replace("', '",'')
        res = res.split()
        scores = {}
        for i in range(0,len(res),5):
            scores[res[i]] = {}
            scores[res[i]]['p'] = res[i+1]
            scores[res[i]]['r'] = res[i+2]
            scores[res[i]]['f'] = res[i+3]
        return scores
        



In [2]:
""" Dataset : Parent class - checks and stores the data structure of an annotated JSON file.
    TrainData : Child class - specific to training dataset: allows to store metadata.
"""

import json
import re
import hashlib

#### Dataset

class Dataset(object):
    def __init__(self, title):
        self.title = title
        self.file = []
    
    def get_title(self):
        return self.title
    
    def get_file(self):
        return self.file

    def filter_json(self, json_file):
        """keeps only the text elements and entities of the JSON file"""
        file = []
        for o in json_file:
            try:
                text = o['text']
                try:
                    entities = o['entities']
                    if not text or not entities : 
                        return

                except:
                    return
            except:
                return
            obj = {'text' : text, 'entities' : entities}
            file.append(obj)
        self.file = file
        return True
    
    
    def is_correct(self):
        """checks if the content of the file is correct"""

        r_str = "((\"[^\"]+\")|(\'[^\']+\'))"
        r_entity = "\[\d+,\s*\d+,\s*" + r_str + "\]"
        for obj in self.file:
            entity = obj['entities']
            if not entity :
                return False
            for e in entity:
                if not re.fullmatch(r_entity, str(e)) or e[0] >= e[1]:
                    return False
        return True


#### TrainData

class TrainData(Dataset):
    def __init__(self, title):
        Dataset.__init__(self, title)
        self.hash = ""
        self.nb_entities = 0
        self.labels = []
        
    def __str__(self):
        """print(obj)"""
        return 'the dataset "'+ self.title +'" has '+ str(self.nb_entities) +' entities.'
    
    def get_nb_entities(self):
        return self.nb_entities
    
    def get_labels(self):
        return self.labels
    
    def get_hash(self):
        return self.hash
    
    def metadata(self):
        """completes the object properties to create metadata"""       
        dic = {}
        nb_entities = 0
        for obj in self.file:
            self.nb_entities += len(obj['entities'])
            for e in obj['entities']:
                dic.setdefault(e[2], 0)
                dic[e[2]] += 1
        self.labels = {k: v for k, v in sorted(dic.items(), key=lambda item: item[1],reverse = True)}        
            
        #MD5 hash - encoded data in hexadecimal format.
        self.hash = hashlib.md5(str(self.file).encode()).hexdigest()
        return True

In [3]:
with open("./data/cheval_annotated.json", "r") as file :
    content = file.read()
    content = json.loads(content)
train = TrainData("data")
train.filter_json(content)
train.is_correct()
train.metadata()


True

In [4]:
flair_model = FlairModel(model_format = "bio_format", model_name ="test", training_data = train,nb_iter=1, lr=0.1, batch=32, mode='cpu')

In [5]:
flair_model.train()

2020-08-12 15:03:41,309 Reading data from .
2020-08-12 15:03:41,313 Train: data/format.txt
2020-08-12 15:03:41,315 Dev: None
2020-08-12 15:03:41,316 Test: None
2020-08-12 15:03:50,202 ----------------------------------------------------------------------------------------------------
2020-08-12 15:03:50,204 Model: "SequenceTagger(
  (embeddings): StackedEmbeddings(
    (list_embedding_0): WordEmbeddings('fr')
    (list_embedding_1): FlairEmbeddings(
      (lm): LanguageModel(
        (drop): Dropout(p=0.5, inplace=False)
        (encoder): Embedding(275, 100)
        (rnn): LSTM(100, 1024)
        (decoder): Linear(in_features=1024, out_features=275, bias=True)
      )
    )
    (list_embedding_2): FlairEmbeddings(
      (lm): LanguageModel(
        (drop): Dropout(p=0.5, inplace=False)
        (encoder): Embedding(275, 100)
        (rnn): LSTM(100, 1024)
        (decoder): Linear(in_features=1024, out_features=275, bias=True)
      )
    )
  )
  (word_dropout): WordDropout(p=0.05)
  (

In [None]:
test_text = [("On dit qu\'un cheval est calme",{
            'entities': [(13, 19, 'ANIMAL')]
            }),
            ("Un cheval endormi n\'est pas nécessairement un cheval calme",{
             'entities': [(3, 9, 'ANIMAL'),(46,51, 'ANIMAL')]   
            }),
            ("souhaitez vous apprendre à monter à cheval?",{
             'entities' : [(36,41,'ANIMAL')]
            }),
            ("Pour moi les chevaux sont les meilleurs animaux après les chats",{
             'entities' : [(13,20,'ANIMAL'),(58,63, 'ANIMAL')]
            })
           ]

score=flair_model.test(test_text) 

2020-08-12 15:08:37,413 loading file ./test/best-model.pt
