<a href="https://colab.research.google.com/github/fberanizo/spelling-correction/blob/master/BEA2019_BERT_PT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

![alt text](https://)# Grammatical Error Correction: Modelo BERT PT

**Nome: Fabio Beranizo Lopes**<br>
**Nome: Luiz Pita Almeida**

Usaremos o modelo BERT pré-treinado e o dataset  BrWaC (Brazilian Web as Corpus) <br>

Métrica de avaliação: F0.5-score <br>
https://www.cl.cam.ac.uk/research/nl/bea2019st/#eval

O método de correção aplicado foi sugerido pelos docentes:<br>
> dado uma frase e um palavra nesta frase a ser corrigida ou não, iremos
> mascarar a palavra, rodar o BERT ou T5, e prever as top-10 palavras 
> alternativas usando mask language modeling. Se a palavra original estiver 
> entre as top previstas, não sugerir correção. Caso contrário, usar edit 
> distance para ver qual é a palavra mais próxima, e sugerí-la ao usuário.

Passos:

1. Geram-se tuplas: `(input_ids, lm_labels, original, corrected)`
2. Aplica-se modelo BERT para prever top-10 palavras.<br>
   Caso a palavra original esteja no Top 10 do modelo, é classificada como correta.<br>
   Senão, a palavra é classificada como incorreta.

**Obs: os notebooks contém excertos de códigos dos colegas de turma.**<br>
**Obrigado Diedre, Gabriela, Leard, Lucas e Israel.**


In [None]:
import torch

print(f"Current GPU: {torch.cuda.get_device_name(0)}")

# don't even start if it's not a P100 GPU
# if torch.cuda.get_device_name(0) != "Tesla P100-PCIE-16GB":
#     import os
#     os.kill(os.getpid(), 9)

Current GPU: Tesla P100-PCIE-16GB


In [None]:
#@title Configurações gerais
experiment_name = "first"
model_name = "bert-base-pt-cased"  #@param ["bert-base-pt-cased", "bert-large-pt-cased"] {type:"string"}
batch_size = 25  #@param {type:"integer"}
accumulate_grad_batches = 10  #@param {type:"integer"}
source_max_length = 50  #@param {type:"integer"}
target_max_length = 100  #@param {type:"integer"}
learning_rate = 5e-3  #@param {type:"number"}
decode_mode = "topk"  #@param ["greedy", "nucleus", "topk", "beam"] {type:"string"}
k = 10  #@param {type:"integer"}

## Instala dependências

- PyTorch Lightning
- Hugginface Transformers
- ERRANT (ERRor ANnotation Toolkit)

In [None]:
try:
    import pytorch_lightning
    import transformers
except ImportError as e:
    # can't import modules, then install
    !pip install --quiet pytorch-lightning
    !pip install --quiet transformers
    !pip install --quiet errant==2.0.0
    !python -m spacy download en
    # kill kernel (necessary for tqdm)
    import os
    os.kill(os.getpid(), 9)

[K     |████████████████████████████████| 296kB 2.9MB/s 
[K     |████████████████████████████████| 276kB 7.9MB/s 
[K     |████████████████████████████████| 829kB 9.0MB/s 
[?25h  Building wheel for PyYAML (setup.py) ... [?25l[?25hdone
  Building wheel for future (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 675kB 2.9MB/s 
[K     |████████████████████████████████| 3.8MB 13.6MB/s 
[K     |████████████████████████████████| 1.1MB 33.1MB/s 
[K     |████████████████████████████████| 890kB 38.4MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 501kB 2.8MB/s 
[K     |████████████████████████████████| 3.4MB 13.9MB/s 
[K     |████████████████████████████████| 81kB 9.3MB/s 
[K     |████████████████████████████████| 931kB 33.3MB/s 
[K     |████████████████████████████████| 1.4MB 41.0MB/s 
[K     |████████████████████████████████| 184kB 44.1MB/s 
[K     |████████████████████████████████| 1

In [None]:
# Importar todos os pacotes de uma só vez para evitar duplicados ao longo do notebook.
import datetime
import errant
import gzip
import json
import nvidia_smi
import os
import pickle
import psutil
import pytorch_lightning as pl
import random
import spacy
import sys
import tarfile
import torch
import torch.nn.functional as F

from argparse import Namespace
from google.colab import drive

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning import Trainer

from transformers import T5ForConditionalGeneration, T5Model
from transformers import T5Tokenizer
from transformers import BertTokenizer, BertForMaskedLM
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch import nn

from typing import Dict
from typing import List
from typing import Tuple

# Portuguese corpus for vocab
import nltk
from nltk.corpus import floresta
nltk.download('floresta')
floresta_vocab = floresta.words()

# Leard decoding solution
import html
import unicodedata

nlp = spacy.load("en")
annotator = errant.load("en", nlp)


def hardware_stats():
    """
    Returns a dict containing some hardware related stats
    """
    res = nvidia_smi.nvmlDeviceGetUtilizationRates(handle)
    return {"cpu": f"{str(psutil.cpu_percent())}%",
            "mem": f"{str(psutil.virtual_memory().percent)}%",
            "gpu": f"{str(res.gpu)}%",
            "gpu_mem": f"{str(res.memory)}%"}

## Define random seeds

Importante: Fix seeds so we can replicate results

In [None]:
import random

seed = 0
random.seed(seed)
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)

DICA para modelos reais: Um modelo otimizado deve manter o uso de GPU próximo a 100% durante o treino.
Vamos utilizar a bilioteca abaixo para monitorar isso. Note que no modelo simples utilizado aqui o uso não vai chegar a 100%.

In [None]:
print(f"Pytorch Lightning Version: {pl.__version__}")
nvidia_smi.nvmlInit()
handle = nvidia_smi.nvmlDeviceGetHandleByIndex(0)
print(f"Device name: {nvidia_smi.nvmlDeviceGetName(handle)}")

def gpu_usage():
    global handle
    return f"{str(nvidia_smi.nvmlDeviceGetUtilizationRates(handle).gpu)}%"

## Mapeia Google Drive

Iremos salvar os checkpoints (pesos do modelo) no google drive, para que possamos continuar o treino de onde paramos.

In [None]:
drive.mount("/content/drive")
base_path = "/content/drive/My Drive/PF-Correcao/bea2019st"
base_path = "/content/bea2019st"
os.environ["BASE_PATH"] = base_path

In [None]:
# ime vocab https://www.ime.usp.br/~pf/dicios/index.html
ime_words_path = "/content/drive/My Drive/PF-Correcao/br-utf8.txt"
with open(ime_words_path) as file: # Use file to refer to the file object
   ime_vocab = file.read().splitlines()

print(ime_vocab[:10])
print(len(ime_vocab))

## ERRANT Scorer

Comando para avaliação que compara um arquivo M2 "hipótese" contra um arquivo M2 "referência".<br>

### **Example**
**Original**: This are gramamtical sentence .<br>
**Corrected**: This is a grammatical sentence .<br>
**Output M2**:<br>
S This are gramamtical sentence .<br>
A 1 2|||R:VERB:SVA|||is|||REQUIRED|||-NONE-|||0<br>
A 2 2|||M:DET|||a|||REQUIRED|||-NONE-|||0<br>
A 2 3|||R:SPELL|||grammatical|||REQUIRED|||-NONE-|||0<br>
A -1 -1|||noop|||-NONE-|||REQUIRED|||-NONE-|||1<br>

In M2 format, a line preceded by S denotes an original sentence while a line preceded by A indicates an edit annotation. Each edit line consists of the start and end token offset of the edit, the error type, and the tokenized correction string. The next two fields are included for historical reasons (see the CoNLL-2014 shared task) while the last field is the annotator id.

In [None]:
%%writefile ref.m2
S This are gramamtical sentence .
A 1 2|||R:VERB:SVA|||is|||REQUIRED|||-NONE-|||0
A 2 2|||M:DET|||a|||REQUIRED|||-NONE-|||0
A 2 3|||R:SPELL|||grammatical|||REQUIRED|||-NONE-|||0
A -1 -1|||noop|||-NONE-|||REQUIRED|||-NONE-|||1

In [None]:
%%writefile hyp.m2
S This are gramamtical sentence .
A 1 2|||R:VERB:SVA|||is|||REQUIRED|||-NONE-|||0
A 2 2|||M:DET|||a|||REQUIRED|||-NONE-|||0
A 2 3|||R:SPELL|||grammatical|||REQUIRED|||-NONE-|||0

In [None]:
original = 'This are gramamtical sentence .'
correct = 'This is a grammatical sentence .'
annotator = errant.load('en')
orig = annotator.parse(original)
cor = annotator.parse(correct)
edits = annotator.annotate(orig, cor, lev=False, merging='rules')

with open('hyp.m2', 'w') as f:
    f.write(f'S {original}\n')
    for edit in edits:
        f.write(f'{edit.to_m2()}\n')

In [None]:
!cat hyp.m2

In [None]:
!errant_compare -hyp hyp.m2 -ref ref.m2

# predição

In [None]:
# Input text
text = '[CLS] Eu gosto [MASK] brigadeiro [SEP]'
tokenizer = BertTokenizer.from_pretrained("neuralmind/bert-base-portuguese-cased")
model = BertForMaskedLM.from_pretrained("neuralmind/bert-base-portuguese-cased")
model.eval()

# Print the original sentence.
print(' Original: ', text)
# Print the sentence split into tokens.
print('Tokenized: ', tokenizer.tokenize(text))
# Print the sentence mapped to token ids.
print('Token IDs: ', tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)))

In [None]:
text = 'Eu [MASK] de brigadeiro.'
encoded = tokenizer.encode_plus(text, add_special_tokens=True, return_tensors='pt')
input_ids = encoded['input_ids']

# Generating 20 sequences with maximum length set to 5
outputs = model.generate(input_ids=input_ids, num_beams=200,
                         num_return_sequences=20, max_length=5)
results = list(map(tokenizer.decode, outputs))

print(input_ids)
print(outputs)
results

In [None]:
text = 'Eu [MASK] de brigadeiro.'
encoded = tokenizer.encode_plus(text, add_special_tokens=True, return_tensors='pt')
input_ids = encoded['input_ids']

# Generating 20 sequences with maximum length set to 5
outputs = model.generate(input_ids=input_ids, num_beams=200,
                         num_return_sequences=20, max_length=5)

mask_idx = text.index('[MASK]')
_result_prefix = text[:mask_idx]
_result_suffix = text[mask_idx+6:] 

def _filter(output, end_token='[SEP]'):
    # The first token is <unk> (inidex at 0) and the second token is <extra_id_0> (indexed at 32099)
    _txt = tokenizer.decode(output[1:], skip_special_tokens=False, clean_up_tokenization_spaces=False)
    print(_txt)
    if end_token in _txt:
        _end_token_index = _txt.index(end_token)
        return _result_prefix + _txt[:_end_token_index] + _result_suffix
    else:
        return _result_prefix + _txt + _result_suffix

results = list(map(_filter, outputs))

print(input_ids)
print(outputs)
results

In [None]:
text = "Eu [MASK] de brigadeiro."
input_ids = tokenizer.encode(text, return_tensors='pt')
mask_pos = input_ids.tolist()[0].index(103)
k=10
outputs = model(input_ids=input_ids)
_, outputs = torch.topk(outputs[0], k, sorted=True)
results = list(map(tokenizer.decode, outputs.squeeze()))
print(input_ids)
print(outputs.squeeze())
results

In [None]:
text = "Eu amo João [MASK]."
input_ids = tokenizer.encode(text, return_tensors='pt')
mask_pos = input_ids.tolist()[0].index(103)
k=10
outputs = model(input_ids=input_ids)
_, outputs = torch.topk(outputs[0], k, sorted=True)
results = list(map(tokenizer.decode, outputs.squeeze()))
print(input_ids)
print(outputs.squeeze())
results[mask_pos]

In [None]:
seq = "Eu gosto de brigadeiro."
k = 50
seq_tokens_id = tokenizer.encode(seq)
for idx in range(len(seq_tokens_id)):
  inputs_id = seq_tokens_id.copy()
  if inputs_id[idx] == tokenizer.cls_token_id or inputs_id[idx] == tokenizer.sep_token_id:
    continue
  inputs_id[idx] = tokenizer.mask_token_id 
  print(tokenizer.decode(inputs_id))
  inputs_id = torch.LongTensor(input_ids)
  outputs = model(input_ids=input_ids)
  _, outputs = torch.topk(outputs[0], k, sorted=True)
  results = list(map(tokenizer.decode, outputs.squeeze()))[idx]
  print(results)
  print()



In [None]:
import re
seq = "Eu estou gostando de brigadeiro."
seq_splitted = re.findall(r"[\w']+|[.,!?;]", seq)
k = 5
inputs_id = tokenizer.encode(seq_splitted, return_tensors='pt')

for idx in range(inputs_id.shape[1]):
  input_seq = inputs_id.clone()
  input_seq[0][idx] = tokenizer.mask_token_id
  #input_seq = ' '.join(input_seq)
  print(input_seq)
  
  print(tokenizer.decode(inputs_id[0]))
  #for token in inputs_id[0].tolist():
  #  print(token, tokenizer.decode(token))
  outputs = model(input_ids=input_ids)
  _, outputs = torch.topk(outputs[0], k, sorted=True)
  results = list(map(tokenizer.decode, outputs.squeeze()))
  print(len(results), len(inputs_id[0]))
  print(idx, results[idx])
  print()


In [None]:
import re
seq = "Eu estou gostando de brigadeiro."
seq_splitted = re.findall(r"[\w']+|[.,!?;]", seq)
k = 10
inputs_id = tokenizer.encode(seq_splitted, return_tensors='pt')
inputs_id = inputs_id.repeat(len(seq_splitted), 1)
mask_pos = []
for idx, tensor in enumerate(inputs_id):
  tensor[idx+1] = tokenizer.mask_token_id
  mask_pos.append(idx+1)
print(inputs_id)
model.eval()
outputs = model(input_ids=inputs_id)
_, outputs = torch.topk(outputs[0], k, sorted=True)

for idx in mask_pos:
  print(inputs_id[idx-1])
  results = list(map(tokenizer.decode, outputs[idx-1].squeeze()))[idx]
  print(results)


In [None]:
text = "Eu gosto de brigadeiro."
splitted = text.split()
for idx, word in enumerate(splitted):
  splitted[idx] = '[MASK]'
  new_text = ' '.join(splitted)
  print(new_text)
  input_ids = tokenizer.encode(new_text, return_tensors='pt')
  mask_pos = input_ids.tolist()[0].index(103)
  k=10
  outputs = model(input_ids=input_ids)
  _, outputs = torch.topk(outputs[0], k, sorted=True)
  results = list(map(tokenizer.decode, outputs.squeeze()))
  print(results[mask_pos])
  splitted[idx] = word

In [None]:
text = "Eu gsto de brigadeiro."
splitted = text.split()
for idx, word in enumerate(splitted):
  splitted[idx] = '[MASK]'
  new_text = ' '.join(splitted)
  print(new_text)
  input_ids = tokenizer.encode(new_text, return_tensors='pt')
  mask_pos = input_ids.tolist()[0].index(103)
  k=10
  outputs = model(input_ids=input_ids)
  _, outputs = torch.topk(outputs[0], k, sorted=True)
  results = list(map(tokenizer.decode, outputs.squeeze()))
  print(results[mask_pos])
  splitted[idx] = word

# comparação

In [None]:
!pip3 install fuzzywuzzy[speedup] 

In [None]:
from fuzzywuzzy import fuzz, process, StringMatcher
fuzz.SequenceMatcher = StringMatcher.StringMatcher
#import difflib
#fuzz.SequenceMatcher = difflib.SequenceMatcher

# Correção
word = 'gsto'
pred = ['gosto', 'gostava', 'amo', 'sou', 'preciso', 'precisava', 'vou', 'iria', 'morro', 'ia']

metric_value_default = [fuzz.ratio(word, pred_word) for pred_word in pred]

print(metric_value_default)
print(process.extract(word, pred, limit=4))
print(process.extractOne(word, pred))

# Sem correção
word = 'Eu'
pred = ['Mas', 'E', 'Já', 'Não', 'Agora', 'Hoje', 'Ainda', 'Então', '[UNK]']

metric_value_default = [fuzz.ratio(word, pred_word) for pred_word in pred]

print(metric_value_default)
print(process.extract(word, pred, limit=4))
print(process.extractOne(word, pred))

# correção
word = 'a'
pred = ['à', 'as', 'ai', 'á', 'ã', 'ia']

metric_value_default = [fuzz.ratio(word, pred_word) for pred_word in pred]

print(metric_value_default)

print(process.extract(word, pred, limit=4))
print(process.extractOne(word, pred))
print('scorer:', process.extractOne(word, pred, scorer=fuzz.ratio))


print(pred[metric_value_default.index(max(metric_value_default))])
matches_id = [i for i,x in enumerate(metric_value_default) if x==max(metric_value_default)]
print(matches_id)
matches = [pred[x] for x in matches_id]
print(matches)

# modelo

In [None]:
class bert_pt_corrector(nn.Module):
  '''
  Bert inference model
  '''
  def __init__(self, k_top_predictions=10, levenshtein_threshold=85,
               vocab=ime_vocab):
    super().__init__()

    # Inicializa modelo e tokenizador
    self.tokenizer = BertTokenizer.from_pretrained("neuralmind/bert-base-portuguese-cased")
    self.bert = BertForMaskedLM.from_pretrained("neuralmind/bert-base-portuguese-cased")
    for param in self.bert.bert.parameters():
      param.requires_grad = False

    # Set number of predictions to look
    self.k = k_top_predictions

    # Set threshold for Levenshtein comparison
    self.threshold = levenshtein_threshold

    self.vocab = vocab
    
    #self.punctuation = ['.', ',', ':', ';', '?', '!']
  
  def forward(self, sequence):
    suggestion = []
    splitted = sequence.split()
    #print(splitted)
    for idx, word in enumerate(splitted):
      #print(word)
      n_encoder = len(self.tokenizer.encode(word))
      #print(n_encoder)
      new_value = n_encoder * ['[MASK]']
      new_value = ' '.join(new_value)
      #print(new_value)
      splitted[idx] = new_value
      new_text = ' '.join(splitted)
      print(new_text)
      #print(new_text)
      #print(new_text)
      input_ids = self.tokenizer.encode(new_text, return_tensors='pt')
      mask_pos_init = input_ids.tolist()[0].index(103)
      mask_pos_end = input_ids.tolist()[0][::-1].index(103)
      splitted[idx] = word
      print(input_ids)

      outputs = self.bert(input_ids=input_ids)
      print('bert_out', len(outputs), type(outputs), outputs)

      _, outputs = torch.topk(outputs[0], self.k, sorted=True)
      print('topk_out', outputs.shape, type(outputs))

      results = list(map(self.tokenizer.decode, outputs.squeeze()))
      #print(mask_pos, results[mask_pos])
      print('rst_len', len(results))
      for idx, rst in enumerate(results):
        print(input_ids[0][idx])
        print(rst)
      prediction = results[mask_pos_init].split()
      #prediction = [word for result in results for word in result.split()]
      #prediction = list(set(prediction))
      #print(prediction)

      only_word = re.sub(r'[.,!?;]', '', word)
      if only_word in prediction:
        suggestion.append(word)
        #print('Only copy:', word)
      else:
        near_word = self.get_near_word(only_word, prediction)
        if near_word is None:
          suggestion.append(word)
          #print('No suggestion, word copied:', word)
        else:
          word_split = re.findall(r"[\w']+|[.,!?;]", word)
          word_split[word_split==only_word] = near_word
          word_suggestion = ''.join(word_split)
          suggestion.append(word_suggestion)
    
    suggestion = ' '.join(suggestion)
    return suggestion

  def get_near_word(self, word, prediction):
    #prediction = list(set(prediction).intersection(self.vocab)) # muito lento
    #comparison_values = [fuzz.ratio(word, pred_word) for pred_word in prediction]
    #max_value = max(comparison_values)
    #print('W', word)
    #print('P', prediction)
    #print('C', comparison_values)
    near_word, max_value = process.extractOne(word, prediction)
    #print('S', near_word, max_value)
    if max_value < self.threshold:
      return None

    #matches_id = [i for i,x in enumerate(comparison_values)
    #              if x==max(comparison_values)]
    #near_words = [prediction[x] for x in matches_id]
    #print('S', near_words[0], max_value)
    return near_word#near_words[0]

In [None]:
# Testing model

sequence = "Capitu, apesar daqueles olhos que o diabo lhe deu... \
Você já reparou nos olhos dela? São assim de cigana oblíqua e dissimulada."

sequence = "Capitu, brigadeiro Pota."

spell_checker = bert_pt_corrector()
spell_checker.eval()
spell_checker(sequence)

In [None]:
# Testing model

sequence = "Capitu, apesar daqueles olhos q o diabo lhe deu... \
Vc ja reparou nos olhos dela? Sao assim de cigana obliqua e dissimulada."

spell_checker = bert_pt_corrector(k_top_predictions=50)
spell_checker(sequence)

In [None]:
# Testing model

sequence = "Capitu, apesar daqueles olhos q o diabo lhe deu... \
Vc ja reparou nos olhos dela? Sao assim de cigana obliqua e dissimulada."

spell_checker = bert_pt_corrector(k_top_predictions=500, levenshtein_threshold=50)
spell_checker(sequence)

In [None]:
# Testing model

sequence = "Capitu, apesar daqueles olhos q o diabo lhe deu... \
Vc ja reparou nos olhos dela? Sao assim de cigana obliqua e dissimulada."

spell_checker = bert_pt_corrector(k_top_predictions=1000, levenshtein_threshold=50)
spell_checker(sequence)

In [None]:
'Xq' in ime_vocab#['geral', 'iguais', 'chegou', 'gostou', 'poderes', 'qualidades', 'chamado', 'mal', 'logo', 'desconte', 'doou', 'forneceu', 'viu', 'óculos', 'assistiu', 'trinta', 'ato', 'horror', 'justamente', 'lançados', 'cio', 'nascidos', 'lixo', 'da', 'presos', 'tempo', 'nele', 'conversa', 'formou', 'marcou', 'Cristo', 'lentes', 'companheira', 'cal', 'buracos', 'inter', 'dourado', 'conta', 'devido', 'incluir', 'papel', 'citou', 'animais', 'criou', 'figuras', 'confirmou', 'serem', 'centros', 'metade', 'sentimentos', 'sintomas', 'firme', 'simples', 'estava', 'traços', 'sistema', 'imagem', 'partidos', 'qual', 'coro', 'metal', 'leitores', 'negra', 'vale', 'politicamente', 'cinzas', 'luz', 'voltados', 'repara', 'tatu', 'por', 'dei', 'São', 'time', 'exemplos', 'firma', 'imediata', 'o', 'peças', 'este', 'verdade', 'sorte', 'descendentes', 'existe', 'nas', 'fundos', 'para', 'com', 'conhecer', 'tele', 'falsa', 'continua', 'vendo', 'nessa', 'pela', 'membros', 'leitor', 'apareceu', 'prometeu', 'pôde', 'provavelmente', 'vil', 'aquelas', 'outrora', 'mostra', 'pesados', 'tais', 'coisas', 'conselhos', 'montes', 'mulher', 'espirituais', 'pastor', 'reserva', 'si', 'assistir', 'me', 'desconhecidos', 'fato', 'sinais', 'chance', 'sombra', 'porte', 'primeiros', 'muita', 'tinha', 'baixos', 'estiver', 'foram', 'diferentemente', 'fechados', 'lado', 'melhores', 'santos', 'aos', 'via', 'lados', 'tirou', 'modos', 'suas', 'cães', 'possui', 'li', 'carro', 'propósitos', 'ainda', 'desses', 'conto', 'lógica', 'selvagens', 'mecânica', 'flor', 'cinza', 'feitas', 'frios', 'sue', 'conhece', 'reais', 'exames', 'muito', 'comer', 'dessas', 'formas', 'através', 'são', 'perdida', 'dois', 'sobre', 'foco', 'partes', 'treze', 'tu', 'das', 'abertos', 'caminho', 'determinou', 'encontrou', 'verdadeiros', 'dourada', 'recebe', 'estranho', 'colocou', 'propósito', 'fruto', 'lei', 'poucos', 'fria', 'dado', 'ajudou', 'judeus', 'perguntou', 'cu', 'recebeu', 'rosa', 'capas', 'femininos', 'morta', 'compostos', 'pôs', 'desistiu', 'moda', 'ata', 'assassinos', 'subiu', 'oferecer', 'simplesmente', 'própria', 'está', 'bem', 'internos', 'mortais', 'movimentos', 'sou', 'porque', 'homem', 'pediram', 'levou', 'infantis', 'raça', 'garotos', 'estar', 'rendeu', 'dada', 'Sul', 'claro', 'lido', 'entre', 'naquele', 'vis', 'pernas', 'restos', 'diante', 'tudo', 'mata', 'hem', 'nesse', 'ao', 'separou', 'orai', 'antigos', 'anal', 'demo', 'vistos', 'comentários', 'feitos', 'sei', 'meus', 'aqui', 'esforços', 'bar', 'diz', 'for', 'general', 'negou', 'duma', 'quadros', 'revista', 'semelhança', 'ser', 'apesar', 'apresentou', 'conselheiros', 'demônio', 'havia', 'escritos', 'mel', 'conhecidos', 'onde', 'completamente', 'reduzida', 'bom', 'deram', 'deste', 'falta', 'meia', 'nossa', 'quem', 'foca', 'poderosos', 'fãs', 'Maria', 'matou', 'gigantes', 'nenhum', 'nos', 'vive', 'acha', 'bel', 'entregou', 'cura', 'comida', 'donos', 'padrões', 'proporcionar', 'Jorge', 'pá', 'magos', 'tira', 'acaba', 'ma', 'poderia', 'fim', 'mundo', 'do', 'aquele', 'olhos', 'pediu', 'ganhos', 'estudos', 'eva', 'feminina', 'jeito', 'lo', 'local', 'comprou', 'embora', 'normalmente', 'profissional', 'neta', 'concedido', 'certos', 'fotos', 'sequer', 'usou', 'últimos', 'cama', 'sente', 'daqueles', 'eram', 'intelectual', 'pode', 'oferece', 'certamente', 'assuntos', 'Daniel', 'tratando', 'cuidados', 'eu', 'mãos', 'cedeu', 'ris', 'sabia', 'sendo', 'dias', 'pois', 'pares', 'africana', 'jantar', 'cheio', 'originalmente', 'junta', 'tratados', 'ovos', 'gala', 'deles', 'pesa', 'hoje', 'ares', 'isso', 'Eva', 'descont', 'em', 'especial', 'pura', 'ruim', 'conceder', 'aquilo', 'papa', 'maca', 'fora', 'nosso', 'parece', 'maus', 'amor', 'hábitos', 'deixou', 'forma', 'se', 'deu', 'céu', 'voltou', 'ele', 'igualmente', 'atos', 'orientada', 'Marta', 'pegar', 'sai', 'sempre', 'três', 'ultra', 'bastante', 'externos', 'mostrou', 'próprios', 'Grande', 'milhões', 'nem', 'salvo', 'virtude', 'mesma', 'esteve', 'perdão', 'dano', 'longe', 'desse', 'enviou', 'pro', 'chama', 'desde', 'comandante', 'cai', 'acerca', 'brilhantes', 'hum', 'auto', 'lava', 'pele', 'rebeldes', 'pelo', 'grandes', 'acima', 'dentro', 'funda', 'ora', 'aí', 'pensar', 'atriz', 'primeira', 'bons', 'contida', 'trabalhou', 'indica', 'falou', 'padre', 'raio', 'estilo', 'sexuais', 'foto', 'dos', 'precisa', 'anjo', 'queria', 'pensamentos', 'os', 'ate', 'ricos', 'tia', 'ti', 'tanta', 'florestas', 'prestou', 'citados', 'nobres', 'ganhou', 'detalhes', 'aparência', 'cri', 'mo', 'informa', 'novos', 'campos', 'há', 'desenhos', 'também', 'pau', 'achou', 'fama', 'salva', 'agente', 'como', 'come', 'conversas', 'respectivos', 'reuniu', 'histórias', 'nós', 'principalmente', 'dessa', 'cabelos', 'alvos', 'colocar', 'lote', 'sejamos', 'tantos', 'aplicou', 'força', 'grãos', 'divino', 'pena', 'tão', 'humanos', 'lhes', 'dá', 'objetiva', 'casos', 'chega', 'per', 'criados', 'estes', 'poder', 'gosta', 'falar', 'biológica', 'ficou', 'enfim', 'mar', 'foi', 'lá', 'naquela', 'as', 'EUA', 'independente', 'outra', 'relato', 'apenas', 'toda', 'uns', 'cortes', 'deputados', 'declarada', 'compro', 'escolheu', 'essas', 'namorada', 'agora', 'sub', 'nada', 'dirigiu', 'dedicada', 'deve', 'mais', 'benefícios', 'pra', 'ego', 'fracos', 'mas', 'tal', 'arredores', 'coma', 'sua', 'contra', 'dela', 'falava', 'bit', 'aqueles', 'comentar', 'pessoas', 'lembra', 'repor', 'pobre', 'ater', 'emitiu', 'pecado', 'sido', 'tinham', 'aonde', 'costume', 'diferente', 'passados', 'gosto', 'mentais', 'der', 'olho', 'delas', 'fuma', 'aba', 'concedeu', 'perfeita', 'umas', 'visões', 'palha', 'quanto', 'nela', 'quase', 'remanescente', 'investi', 'super', 'vários', 'despeito', 'depois', 'já', 'antes', 'cano', 'davam', 'pergunta', 'gente', 'você', 'brasileiros', 'problemas', 'norte', 'veio', 'eles', 'médicos', 'quando', 'azuis', 'ver', 'alta', 'pedir', 'visto', 'cuida', 'observou', 'tiveram', 'especiais', 'ali', 'amigos', 'fica', 'então', 'romana', 'responsável', 'esposa', 'só', 'destes', 'maneira', 'tons', 'visitantes', 'continuando', 'debaixo', 'acusado', 'casa', 'privilégios', 'campanha', 'tipos', 'violentos', 'beleza', 'tirar', 'critica', 'a', 'pobres', 'vestido', 'impostos', 'age', 'não', 'pais', 'seu', 'outros', 'presentes', 'carros', 'bonita', 'detectar', 'nova', 'dou', 'acabou', 'locais', 'um', 'diferentes', 'deixar', 'antiga', 'sol', 'algo', 'baixa', 'qualquer', 'notar', 'lua', 'voltando', 'continuou', 'era', 'doentes', 'vermelho', 'vete', 'natureza', 'acredita', 'realmente', 'frases', 'cheia', 'rosto', 'poderosa', 'mestre', 'cor', 'fie', 'hora', 'faz', 'padrão', 'cristã', 'determinados', 'boca', 'disco', 'ler', 'filmes', 'entende', 'moça', 'amiga', 'questiona', 'semana', 'ano', 'dia', 'tomou', 'vista', 'típico', 'exatamente', 'parou', 'manda', 'muitas', 'sapatos', 'preparou', 'vento', 'nenhuma', 'amarela', 'cria', 'índices', 'termos', 'tomar', 'hospital', 'naturais', 'viram', 'aula', 'vide', 'companheiros', 'ganha', 'Fátima', 'dava', 'rei', 'espanhola', 'algum', 'mau', 'políticos', 'esquerda', 'destina', 'longos', 'anos', 'coisa', 'vida', 'espírito', 'corre', 'livros', 'essa', 'mandou', 'sabendo', 'mim', 'roupas', 'piso', 'Norte', 'personagem', 'pouca', 'roma', 'toneladas', 'falando', 'alguém', 'disto', 'inferno', 'invés', 'minha', 'negros', 'saiu', 'aspectos', 'ter', 'parecido', 'santo', 'sim', 'vir', 'mono', 'apresentação', 'Mar', 'conseguiu', 'objetos', 'estranha', 'esse', 'jovem', 'vem', 'quentes', 'retratos', 'som', 'sonhos', 'acusados', 'bando', 'fez', 'exceção', 'feita', 'deixa', 'grande', 'olhar', 'outro', 'pedaços', 'pesco', 'pontos', 'mudou', 'povo', 'quer', 'legal', 'meu', 'reclama', 'quis', 'no', 'tê', 'pés', 'religiosos', 'ama', 'acesso', 'nó', 'assim', 'sofre', 'extremamente', 'serviu', 'longo', 'anda', 'pequenos', 'manchas', 'mos', 'ofereceu', 'mor', 'disse', 'par', 'efeitos', 'cara', 'demais', 'daria', 'jamais', 'azul', 'esta', 'chamados', 'tomaram', 'visão', 'familiares', 'holandeses', 'mira', 'capaz', 'secos', 'simpatia', 'destino', 'dar', 'esses', 'te', 'erros', 'verde', 'trocou', 'feito', 'bela', 'gostava', 'cuidado', 'numa', 'totalmente', 'votos', 'atitudes', 'teria', 'ponto', 'todo', 'trabalhos', 'boa', 'garota', 'muitos', 'sabe', 'tipo', 'AC', 'históricos', 'suava', 'fios', 'branca', 'dão', 'comparou', 'daquela', 'gostaria', 'Armando', 'las', 'sofreu', 'vestidos', 'dali', 'moral', 'nunca', 'quais', 'primeiro', 'superior', 'rapaz', 'juntou', 'vermelha', 'direitos', 'haver', 'inteira', 'todas', 'ela', 'motivos', 'senhora', 'mortos', 'ex', 'acorda', 'atribuiu', 'de', 'vales', 'Ana', 'seguinte', 'negativos', 'legais', 'brancos', 'imagens', 'notou', 'cabo', 'Chico', 'centenas', 'dando', 'encontrada', 'obras', 'vaga', 'abaixo', 'fala', 'taxa', 'alma', 'radia', 'aquela', 'características', 'deus', 'pelos', 'sentidos', 'sentido', 'instrumentos', 'argentina', 'promete', 'recursos', 'latina', 'atender', 'nossos', 'rapa', 'até', 'teve', 'vizinha', 'lunar', 'irmã', 'divina', 'cada', 'tem', 'ar', 'fraca', 'resultados', 'compra', 'ente', 'sobe', 'talo', 'somente', 'Brasil', 'vítima', 'textos', 'vinda', 'espécie', 'dados', 'talentos', 'graças', 'ara', 'daquele', 'chamada', 'lugares', 'palavras', 'programas', 'positivos', 'houve', 'mesmos', 'd', 'ficam', 'ouvidos', 'ria', 'parte', 'típicos', 'verdadeira', 'contou', 'causou', 'grega', 'planos', 'alguma', 'noite', 'ódio', 'vez', 'extra', 'visual', 'seus', 'juventude', 'caminhos', 'graça', 'bateu', 'dupla', 'canais', 'culpa', 'reta', 'etc', 'verdes', 'obteve', 'abriu', 'vai', 'tempos', 'pensou', 'pagou', 'papéis', 'anti', 'dele', 'nestes', 'mente', 'temos', 'é', 'dentes', 'próprio', 'altos', 'vi', 'santa', 'voltar', 'gerou', 'suspeita', 'garantiu', 'carne', 'alguns', 'capazes', 'igual', 'chamou', 'porém', 'dona', 'mil', 'velhos', 'ou', 'enormes', 'corpo', 'na', 'oposta', 'perceber', 'seres', 'mínimos', 'eterna', 'num', 'vermelhos', 'meio', 'multi', 'demorou', 'compartilha', 'desta', 'americana', 'pessoa', 'dor', 'à', 'tanto', 'la', 'reconhecer', 'irmãos', 'mesmo', 'pegou', 'ataques', 'lhe', 'uma', 'nomes', 'cobra', 'perdidos', 'homens', 'que', 'percebeu', 'milhares', 'tiver', 'valores', 'paga', 'aparentemente', 'tala', 'facilidade', 'lançou', 'menina', 'sozinha', 'todos', 'sem', 'organizada', 'estabeleceu', 'filhos', 'comunistas', 'disso', 'além', 'juntos', 'e', 'Jesus', 'órgãos', 'publica']
#ime_vocab

In [None]:
s = '?en!tão..'
print(s.encode().decode('ascii', 'ignore'))
print(re.sub(r'[^A-Za-z0-9]+', '', s))

import unicodedata
unicodedata.normalize('NFD', s).encode('ascii', 'ignore').decode()

punctuation = ['.', ',', ':', ';', '?', '!']
print(re.sub(r"[.,!?;]", '', s))

?en!to..
ento
então


# Datasets

- REGRA in email
- ParaCraw
- [Others](https://lionbridge.ai/datasets/best-portuguese-language-datasets-for-machine-learning/)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


#### BlogsetBR

In [None]:
# Download others
!mkdir /content/blogsetBR
!time wget -nc --progress=dot:giga -P /content/blogsetBR \
    -i "http://www.inf.pucrs.br/linatural/blogs/blogset-br.csv.gz"

--2020-06-22 01:54:18--  http://www.inf.pucrs.br/linatural/blogs/blogset-br.csv.gz
Resolving www.inf.pucrs.br (www.inf.pucrs.br)... 104.18.21.134, 104.18.20.134, 2606:4700::6812:1586, ...
Connecting to www.inf.pucrs.br (www.inf.pucrs.br)|104.18.21.134|:80... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://www.inf.pucrs.br/linatural/blogs/blogset-br.csv.gz [following]
--2020-06-22 01:54:18--  https://www.inf.pucrs.br/linatural/blogs/blogset-br.csv.gz
Connecting to www.inf.pucrs.br (www.inf.pucrs.br)|104.18.21.134|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5015788144 (4.7G) [application/x-gzip]
Saving to: ‘/content/blogsetBR/blogset-br.csv.gz’

     0K ........ ........ ........ ........  0% 7.44M 10m38s
 32768K ........ ........ ........ ........  1% 9.75M 9m19s
 65536K ........ ........ ........ ........  2% 9.75M 8m50s
 98304K ........ ........ ........ ........  2% 8.69M 8m49s
131072K ........ ........ ........

UnicodeDecodeError: ignored

In [None]:
import pandas as pd
posts = pd.read_csv('/content/blogsetBR/blogset-br.csv.gz', compression='gzip',
                    nrows=20, header=True)
posts.head()

TypeError: ignored

### Brazilian Portuguese Literature Corpus

In [None]:
!pip3 install --quiet kaggle
!ls -a ~
!cp "/content/drive/My Drive/kaggle.json" ~/.kaggle
!ls -a ~/.kaggle
!kaggle datasets download --unzip \
  -d rtatman/brazilian-portuguese-literature-corpus

.   .bash_history  .cache   .gsutil   .jupyter	.keras	.node-gyp  .profile
..  .bashrc	   .config  .ipython  .kaggle	.local	.npm	   .wget-hsts
.  ..  kaggle.json
Downloading brazilian-portuguese-literature-corpus.zip to /content
 29% 5.00M/17.5M [00:00<00:00, 22.6MB/s]
100% 17.5M/17.5M [00:00<00:00, 58.2MB/s]


In [None]:
import os
path = "/content/Brazilian_Portugese_Corpus"
br_litteture_corpus = []
for root, dirs, files in os.walk(path):
  for file in files:
    if file.endswith(".txt"):
      path_file = os.path.join(root, file)
      with open(path_file, encoding="latin1") as txt_file: # Use file to refer to the file object
        br_litteture_corpus.append(txt_file.read().splitlines())

In [None]:
br_litteture = [seq.split() for filecorpus in br_litteture_corpus for seq in filecorpus]
print('numero de sentenças:', len(br_litteture))
print(br_litteture[0])
words = [word for seq in br_litteture for word in seq]
print(words[0])
print('numero de palavras:', len(words))
non_norm_corpus_vocab = list(set(words))
print('numero de palavra unicas', len(non_norm_corpus_vocab))

print('\n...normalizando...\n')
norm_br_litteture = [seq.lower().split() for filecorpus in br_litteture_corpus for seq in filecorpus]
print(norm_br_litteture[0])
print('numero de sentenças:', len(norm_br_litteture))
norm_words = [word for seq in norm_br_litteture for word in seq]
print(norm_words[0])
print('numero de palavras:', len(norm_words))
norm_corpus_vocab = list(set(norm_words))
print('numero de palavra unicas', len(norm_corpus_vocab))

numero de sentenças: 886562
['Uma', 'Lágrima', 'de', 'Mulher']
Uma
numero de palavras: 7594582
numero de palavra unicas 233317

...normalizando...

['uma', 'lágrima', 'de', 'mulher']
numero de sentenças: 886562
uma
numero de palavras: 7594582
numero de palavra unicas 214783


## Classe Dataset
Gerenciamento dos dados, e um pequeno teste.
No getitem é aplicada a correção de codificação.

In [None]:
class BrazilianLiterature(Dataset):
    """
    Loads data from preprocessed file and manages them.
    """
    VALID_MODES = ["train", "validation", "test"]
    TOKENIZER = T5Tokenizer.from_pretrained(hparams["model_name"],
                                            cache_dir=base_path)
    def __init__(self, mode: str, seq_len: int):
        """
        mode: One of train, validation or test 
        seq_len: limit to returned encoded tokens
        """
        super().__init__()
        assert mode in BEA2019.VALID_MODES

        self.mode = mode
        self.seq_len = seq_len

        file_name = os.path.join(base_path, f"{mode}.pkl")
        if not os.path.isfile(file_name):
            print("Pre-processed files not found, preparing data.")
            self.prepare_data()
        
        with open(file_name, "rb") as preprocessed_file:
            self.data = pickle.load(preprocessed_file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i: int):
        """
        Unpacks line from data and applies T5 encoding if necessary.

        returns: input_ids, lm_labels, original, corrected
        """
        input, target, original, corrected = self.data[i]

        # From Leard's solution
        # target = unicodedata.normalize("NFD", target).encode("latin-1", "xmlcharrefreplace").decode("latin-1")

        eos_token = BEA2019.TOKENIZER.eos_token

        inputs = BEA2019.TOKENIZER.encode_plus(text=f"{input}",
                                               max_length=self.seq_len,
                                               pad_to_max_length=True,
                                               return_tensors="pt")

        lm_labels = BEA2019.TOKENIZER.encode_plus(text=f"{target}{eos_token}",
                                                  max_length=self.seq_len,
                                                  pad_to_max_length=True,
                                                  return_tensors="pt")
        return inputs["input_ids"].squeeze(), lm_labels["input_ids"].squeeze(), original, corrected

    def get_dataloader(self, batch_size: int, shuffle: bool):
        return DataLoader(self, batch_size=batch_size, shuffle=shuffle, 
                          num_workers=4)

    @staticmethod
    def load_text_tuples(path, member):
        """
        Load tuples from original files: text_original, text_input, text_target, text_corrected.
        """
        text_tuples = []
        with tarfile.open(path) as tar:
            f = tar.extractfile(member)
            for line in f:
                data = json.loads(line)
                if len(data["text"]) > 200:
                    continue
                # list of edits, one for each annotators
                for annotator_id, edits in data["edits"]:
                    # edit: [annotator_id, [[char_start_offset, char_end_offset, correction], ...]]
                    text_input = ""
                    text_target = ""
                    text_original = data["text"]
                    text_corrected = ""
                    prev_char_end_offset = 0

                    for idx, (char_start_offset, char_end_offset, correction) in enumerate(edits):
                        # a slice of unchanged original text 
                        text_unchanged = text_original[prev_char_end_offset:char_start_offset]

                        # a mask token
                        mask = f"<extra_id_{idx+1}>"

                        # the input for T5 model
                        text_input = f"{text_input} {text_unchanged} {mask}"

                        # the output for T5 model
                        if correction is None:
                            correction = ""
                        text_target = f"{text_target} {mask} {correction}"

                        text_corrected = f"{text_corrected} {text_unchanged}{correction}"

                        prev_char_end_offset = char_end_offset

                    text_unchanged = text_original[prev_char_end_offset:]
                    mask = f"<extra_id_{idx+1}>"

                    text_input = f"{text_input} {text_unchanged}".lstrip()
                    text_target = f"{text_target} {mask}".lstrip()
                    text_corrected = f"{text_corrected} {text_unchanged}".lstrip()
                    text_tuples.append((text_input, text_target, text_original, text_corrected))

        return text_tuples

    @staticmethod
    def prepare_data(train_size=10, val_size=3):
        """
        Performs everything needed to get the data ready.
        Addition of Eos token and encoding is performed in runtime.
        """
        if not os.path.isfile("wi+locness_v2.1.bea19.tar.gz"):    
            !wget -nc https://www.cl.cam.ac.uk/research/nl/bea2019st/data/wi+locness_v2.1.bea19.tar.gz -P "$BASE_PATH"
            # !wget -nc https://www.cl.cam.ac.uk/research/nl/bea2019st/data/ABCN.bea19.dev.orig -P "$BASE_PATH"
            # !wget -nc https://www.cl.cam.ac.uk/research/nl/bea2019st/data/ABCN.bea19.test.orig -P "$BASE_PATH"

        data = {}
        train_val_data = BEA2019.load_text_tuples(os.path.join(base_path, "wi+locness_v2.1.bea19.tar.gz"), "wi+locness/json/A.train.json")
        test_data = BEA2019.load_text_tuples(os.path.join(base_path, "wi+locness_v2.1.bea19.tar.gz"), "wi+locness/json/A.dev.json")

        random.shuffle(train_val_data)

        train_data = train_val_data[:train_size]
        val_data = train_val_data[train_size:train_size + val_size]

        for mode, data in zip(BEA2019.VALID_MODES, [train_data, val_data, test_data]):
            file_name = os.path.join(base_path, f"{mode}.pkl")
            with open(file_name, "wb") as pkl_file:
                pickle.dump(data, pkl_file)
            print(f"Pre-processed data saved as {file_name}.")


datasets = {m: BEA2019(mode=m, seq_len=hparams["seq_len"]) for m in BEA2019.VALID_MODES}

# Testing datasets
for mode, dataset in datasets.items():
    print(f"\n{mode} dataset length: {len(dataset)}\n")
    print("Random sample")
    input_ids, lm_labels, original, corrected = random.choice(dataset)
    print("input_ids\n", input_ids, end="\n\n")
    print("lm_labels\n", lm_labels, end="\n\n")
    print("original\n", original, end="\n\n")
    print("corrected\n", corrected, end="\n\n")

## Dataloaders

Verificação se dataloaders estão funcionando corretamente.

In [None]:
shuffle = {"train": True, "validation": False, "test": False}
debug_dataloaders = {mode: datasets[mode].get_dataloader(batch_size=hparams["batch_size"], 
                                                         shuffle=shuffle[mode])
                     for mode in BEA2019.VALID_MODES}

# Testing dataloaders
for mode, dataloader in debug_dataloaders.items():
    print("{} number of batches: {}".format(mode, len(dataloader)))
    batch = next(iter(dataloader))

## Lightning Module

Aqui a classe principal do PyTorch Lightning é definida.


In [None]:
class T5Corrector(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()

        self.hparams = hparams
        self.t5 = T5ForConditionalGeneration.from_pretrained(hparams.model_name,
                                                             cache_dir=hparams.base_path)
        self.tokenizer = BEA2019.TOKENIZER

    def forward(self, x):
        """
        Inspirado pela solução do Lucas
        """
        input_ids, lm_labels, original, corrected = x

        if self.training:
            return self.t5(input_ids=input_ids,
                           lm_labels=lm_labels)[0]
        else:   
            # or use transformers library decoding
            if self.hparams.decode_mode == "greedy":
                return self.t5.generate(input_ids=input_ids,
                                        lm_labels=lm_labels), original, corrected
            elif self.hparams.decode_mode == "beam":
                return self.t5.generate(input_ids=input_ids,
                                        lm_labels=lm_labels,
                                        num_beams=self.hparams.nbeams, 
                                        num_return_sequences=10,
                                        early_stopping=True), original, corrected
            elif self.hparams.decode_mode == "topk":
                return self.t5.generate(input_ids=input_ids,
                                        lm_labels=lm_labels,
                                        num_return_sequences=10,
                                        top_k=self.hparams.k, 
                                        do_sample=True), original, corrected
            elif self.hparams.decode_mode == "nucleus":
                return self.t5.generate(input_ids=input_ids,
                                        lm_labels=lm_labels,
                                        top_k=0, 
                                        top_p=self.hparams.p,
                                        do_sample=True), original, corrected

    def training_step(self, batch, batch_idx):
        loss = self(batch)

        return {"loss": loss, "log": {"loss": loss}, "progress_bar": hardware_stats()}

    def validation_step(self, batch, batch_idx):
        pred_tokens, original, corrected = self(batch)

        # List of string predictions, one per item of mini_batch
        pred = [html.unescape(self.tokenizer.decode(pred)) for pred in pred_tokens]
    
        return {"pred": pred, "corrected": corrected, "progress_bar":  hardware_stats()}

    def test_step(self, batch, batch_idx):
        pred_tokens, original, corrected = self(batch)

        # List of string predictions, one per item of mini_batch
        pred = [html.unescape(self.tokenizer.decode(pred)) for pred in pred_tokens]

        return {"pred": pred, "corrected": corrected, "progress_bar":  hardware_stats()}

    def training_epoch_end(self, outputs):
        avg_loss = torch.stack([x["loss"] for x in outputs]).mean()

        return {"log": {"train_loss": avg_loss}} 

    def validation_epoch_end(self, outputs):
        """
        Inspirado pela solução do Israel
        """
        trues = sum([list(x["corrected"]) for x in outputs], [])
        preds = sum([list(x["pred"]) for x in outputs], [])

        n = random.choice(range(len(trues)))
        print(f"\nTarget: {trues[n]}\nPrediction: {preds[n]}\n")

        # annotator = errant.load("en")
        # orig = annotator.parse(original[0])
        # cor = annotator.parse(pred)
        # print("orig", orig)
        # print("cor", cor)
        # annotator.annotate(orig, cor, lev=False, merging="rules")

        f_score = 0.5 #TODO
        val_dict = {"f_score": f_score}
    
        return {"f_score": f_score, "log": val_dict, "progress_bar": val_dict}

    def test_epoch_end(self, outputs):
        trues = sum([list(x["corrected"]) for x in outputs], [])
        preds = sum([list(x["pred"]) for x in outputs], [])

        n = random.choice(range(len(trues)))
        print(f"\nTarget: {trues[n]}\nPrediction: {preds[n]}\n")
        f_score = 0.5 # TODO
        test_dict = {"test_f_score": f_score}
    
        return {"test_f_score": f_score, "log": test_dict, "progress_bar": test_dict}

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.hparams.lr)    

    def train_dataloader(self):
        if self.hparams.overfit_pct > 0:
            logging.info("Disabling train shuffle due to overfit_pct.")
            shuffle = False
        else:
            shuffle = True
        dataset = BEA2019("train", seq_len=self.hparams.seq_len)
        return dataset.get_dataloader(batch_size=self.hparams.batch_size, shuffle=shuffle)

    def val_dataloader(self):
        dataset = BEA2019("validation", seq_len=self.hparams.test_seq_len)
        return dataset.get_dataloader(batch_size=self.hparams.batch_size, shuffle=False)

    def test_dataloader(self):
        dataset = BEA2019("test", seq_len=self.hparams.test_seq_len)
        return dataset.get_dataloader(batch_size=self.hparams.batch_size, shuffle=False)

## Preparação e Testes

In [None]:
hparams = {"name": experiment_name, "base_path": base_path,
           "model_name": model_name, "seq_len": 200, "test_seq_len": 300,
           "lr": learning_rate, "batch_size": batch_size, "batch_accum": accumulate_grad_batches,
           "max_epochs": 3,
           "overfit_pct": 0, "debug": 0,
           "decode_mode": decode_mode}


for key, parameter in hparams.items():
    print("{}: {}".format(key, parameter))

In [None]:
# Instantiate model
model = T5Corrector(Namespace(**hparams))

# Folder/path management, for logs and checkpoints
tensorboard_path = "logs"
experiment_name = hparams["name"]
model_folder = os.path.join(tensorboard_path, experiment_name)
os.makedirs(model_folder, exist_ok=True)
ckpt_path = os.path.join(model_folder, "-{epoch}")

# Callback initialization
checkpoint_callback = ModelCheckpoint(prefix=experiment_name, 
                                      filepath=ckpt_path, 
                                      mode="max")
logger = TensorBoardLogger(tensorboard_path, experiment_name)

# PL Trainer initialization
trainer = Trainer(gpus=1,
                  checkpoint_callback=checkpoint_callback, 
                  early_stop_callback=False,
                  logger=logger,
                  accumulate_grad_batches=hparams["batch_accum"],
                  max_epochs=hparams["max_epochs"], 
                  fast_dev_run=bool(hparams["debug"]), 
                  overfit_pct=hparams["overfit_pct"],
                  progress_bar_refresh_rate=10,
                  profiler=True)

In [None]:
trainer.test(model)