In [None]:
!pip install simpletransformers
!pip install spacy==3.2.1
!python -m spacy download pt_core_news_lg
!pip install seqval

In [None]:
import requests
from tqdm import tqdm
from zipfile import ZipFile
import os 
import shutil

In [None]:
BASE_MODEL_DIR = './drive/Shareddrives/MEC - Correção textual/PLN/Notebooks/Pontuação/models'
MODEL_NAME = 'bert-portuguese-tedtalk2012'
zipfile = ZipFile(os.path.join(BASE_MODEL_DIR, f'{MODEL_NAME}.zip'))
zipfile.extractall(path=MODEL_NAME)

In [None]:
from simpletransformers.ner import NERModel, NERArgs
import torch

def get_model(model_path,
              model_type="bert",
              labels=None,
              max_seq_length=512):
    model_args = NERArgs()

    if labels is not None:
        model_args.labels_list = labels
    else:
        model_args.labels_list = ["O", "COMMA", "PERIOD", "QUESTION"]
    model_args.silent = True
    model_args.max_seq_length = max_seq_length
    return NERModel(
        model_type,
        model_path,
        args=model_args,
        use_cuda=torch.cuda.is_available()
    )

In [None]:
import json 
filename = '/content/drive/Shareddrives/MEC - Correção textual/PLN/Notebooks/Pontuação/Dataset/student_entities.json'
student_entities = json.load(open(filename))


In [None]:
import nltk 
nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [None]:
from nltk.tokenize import word_tokenize, wordpunct_tokenize

wordpunct_tokenize("Ela vai dormir.")

['Ela', 'vai', 'dormir', '.']

In [None]:
import re
def split_paragraphs(text):
    paragraphs = text.split('\n')
    return paragraphs

In [None]:
import string
def remove_punctuation(text):
    text = ' '.join(word for word in word_tokenize(text)
                    if word not in string.punctuation)
    return text

In [None]:

anno_files = '/content/drive/Shareddrives/MEC - Correção textual/PLN/Notebooks/Pontuação/Dataset/annotator1_entities.json'
annotator = json.load(open(anno_files))
annotator[0]['labels'][:3]

['O', 'O', 'I-PERIOD']

In [None]:
def merge_dicts(dict_args):
    """
    Given any number of dictionaries, shallow copy and merge into a new dict,
    precedence goes to key-value pairs in latter dictionaries.
    """
    result = dict_args[0]
    for dictionary in dict_args:
        result.update(dictionary)
    return result

In [None]:
import traceback
def get_labels(text, pred_dict):
    labels = []
    try:
        ## Tokenização do BERT tá diferente daque é feita aqui
        for word in word_tokenize(text):
            if word not in string.punctuation:
                if pred_dict[word] == "QUESTION":
                    label = "I-PERIOD"
                elif pred_dict[word] == "COMMA":
                    label = "I-COMMA"
                elif pred_dict[word] == "PERIOD":
                    label = "I-PERIOD"
                else:
                    label = "O"
                labels.append(label)
    except KeyError:
        print("KeyError", pred_dict)
        print(traceback.format_exc())
        print(text)
        
    return labels

In [None]:
def preprocess_text(text):
    """ 
    Preprocess text for prediction
    :param text: text to preprocess
    :return:  list of preprocessed text
    """
    paragraphs = split_paragraphs(text)

    return list(map(lambda x: remove_punctuation(x), paragraphs))

In [None]:
from itertools import chain

In [None]:
def predict(test_text: str, model):
    texts = preprocess_text(test_text)
    prediction_list, raw_outputs = model.predict(texts)
    pred_dict = merge_dicts(list(chain(*prediction_list)))

    return get_labels(test_text, pred_dict)

In [None]:
def text2labels(sentence):
    """
    Convert text to labels
    :param sentence: text to convert
    :return:  list of labels
    """
    tokens = word_tokenize(sentence.lower())

    labels = []
    for i, token in enumerate(tokens):
        try:
            if token not in string.punctuation:
                labels.append('O')
            elif token in ['.', '?', '!', ';']:
                labels[-1] = 'I-PERIOD'
            elif token == ',':
                labels[-1] = 'I-COMMA'

        except IndexError:
            raise ValueError(f"Sentence can't start with punctuation {token}")
    return labels

In [None]:
anno_file = "/content/drive/Shareddrives/MEC - Correção textual/PLN/Notebooks/Pontuação/Dataset/annotator2_entities.json"
annotator_entities = json.load(open(anno_file, "r"))
MODEL_PATH = "./bert-portuguese-tedtalk2012"
model = get_model(MODEL_PATH, model_type="bert", max_seq_length=512)
annotator_entities[:1]

In [None]:
from seqeval import metrics
import pandas as pd
from tqdm import tqdm
bert_labels = []
true_labels = []

for item in tqdm(annotator_entities, total=len(annotator_entities)):

    text_id = item["text_id"]

    ann_text = item["text"]

    bert_label = predict(ann_text, model)
    true_label = text2labels(item["text"])
   
    true_labels.append(true_label)
    bert_labels.append(bert_label)

report = metrics.classification_report(true_labels, bert_labels, output_dict=True)
df = pd.DataFrame.from_dict(report, orient='index')
df

100%|██████████| 149/149 [00:35<00:00,  4.21it/s]


Unnamed: 0,precision,recall,f1-score,support
COMMA,0.08978,0.302251,0.138439,311
PERIOD,0.543031,0.544275,0.543652,1310
micro avg,0.341949,0.497841,0.405426,1621
macro avg,0.316406,0.423263,0.341046,1621
weighted avg,0.456072,0.497841,0.465909,1621


In [None]:
df.to_csv(f"results_anotador2.csv")