In [1]:
from os import listdir
from random import shuffle
import time
import tqdm
import json
import hashlib
import pickle
import re
import pandas as pd
from collections import defaultdict
from os.path import exists
from IPython.display import clear_output

import gc
import nltk
from nltk import word_tokenize, sent_tokenize, pos_tag
from nltk.tokenize.treebank import TreebankWordDetokenizer

In [2]:
BASE_DIR = '' # Working dir
DATA_DIR = f'{BASE_DIR}data/'
MODELS_DIR = f'{BASE_DIR}models/'
RESULTS_DIR = f'{BASE_DIR}results/'
MODEL_DATA_DIR = f'{DATA_DIR}/'

In [3]:
DEVICE = 'cuda:0'

In [4]:
from transformers import T5ForConditionalGeneration
from transformers import T5Tokenizer
import torch

device = torch.device(DEVICE)

IMPROVE_TOKEN = "improve_english: "
IMPROVE_TOKEN_MULTI = "improve_english"

# Load models

In [5]:
class TokenizerWrapper:
        
    def __init__(self, tokenizer, config={}):
        self.device = DEVICE
        self.tokenizer = tokenizer
        self.config = config
        self.vocab = tokenizer.get_vocab().keys()
        
    def encode_sentence(self, sent):
        sent_special, map_tokens = self.add_special_tokens(sent)
        inputs_ids = self.tokenizer.encode(sent_special, return_tensors='pt').to(device=device)

        return inputs_ids, map_tokens

    def get_special_sentences(self, sentences):
        list_map_tokens = []
        special_sentences = []
        for sent in sentences:
            sent_special, map_tokens = self.add_special_tokens(sent)
            special_sentences.append(sent_special)
            list_map_tokens.append(map_tokens)
        
        return special_sentences, list_map_tokens
    
    def encode_batch_inference(self, sentences):        
        special_sentences, map_tokens = self.get_special_sentences(sentences)
        encodings_input = self.tokenizer(special_sentences, truncation=True, padding='longest', return_tensors='pt').to(self.device)
        
        return encodings_input, map_tokens
    
    def decode_batch(self, outputs, map_tokens, original_sentences):        
        return [self.decode_sentence(output, map_token, original) for output, map_token, original in zip(outputs, map_tokens, original_sentences) ]
    
    def encode_batch_training(self, sentences):        
        special_sentences, map_tokens = self.get_special_sentences(sentences)
        encodings_input = self.tokenizer(special_sentences, truncation=True, padding='longest')
        
        return encodings_input, map_tokens
    
    def decode_sentence(self, output, map_tokens, original=None):
        sentence = self.tokenizer.decode(output, skip_special_tokens=True)
        return self.replace_special_tokens(sentence, map_tokens, original)
    
    def in_vocab(self, char):
        return char in self.vocab 

    def has_oov(self, sent):
        chars = list(set([char for char in sent if char != ' ']))
        for char in chars:
            if not in_vocab(char, self.vocab):
                return True
        return False

    def add_special_tokens(self, sent):
        start_index = next((int(item.replace('[', '').replace(']','')) for item in sorted(re.findall(r'(\[[0-9]+\])',sent), reverse=True)), 0)
        keep_to_tokens = {}

        def map_char_to_token(char):
            nonlocal start_index
            start_index += 1
            return f'[KEEP{start_index}]'

        def convert_token(token):    
            nonlocal keep_to_tokens
            chars = list(token)
            for char in chars:
                if not self.in_vocab(char):            
                    mapped_token = map_char_to_token(char)
                    last_char = chars[-1]
                    if last_char == ',' or last_char == '.':
                        keep_to_tokens[mapped_token] = token[:-1]
                        return mapped_token + last_char
                    else:
                        keep_to_tokens[mapped_token] = token         
                        return mapped_token

            return token

        updated_sent = ' '.join([convert_token(token) for token in sent.split()])    
                
        def keep_parenthesis(sent):
            nonlocal start_index

            while re.search('(\(.*?\))+', sent):
                start_index += 1
                mapped_token = f'[KEEP{start_index}]'
                match = str(re.search('(\(.*?\))+', sent)[0])
                sent = re.sub('(\(.*?\))+', mapped_token, sent, 1)     
                keep_to_tokens[mapped_token] = match
            return sent

        if self.config.get('keep_parentheses'):
            updated_sent = keep_parenthesis(updated_sent)

        return updated_sent, keep_to_tokens

    def replace_special_tokens(self, sentence, keep_to_tokens, original):
        for key in keep_to_tokens:
            if key not in sentence and original:
                return original
            sentence = sentence.replace(key, keep_to_tokens[key])
        return sentence

In [6]:
import math

def chunks(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]
        
def compare_improve(sentence, model, improve_token, num_outputs=1):
    print("BEFORE:", sentence)
    outputs = generate(model, improve_token + sentence, num_outputs)
    print("\nAFTER :\n-", '\n- '.join([sent.strip() for sent in outputs]))
    
    return outputs
    
def generate(model, start, num_outputs=1, length=100):
    inputs_ids = tokenizer.encode(start, return_tensors='pt').to(device=device)
    with torch.no_grad():
        outputs = model.generate(inputs_ids, 
                                max_length=length, num_beams=5, 
                                num_return_sequences=num_outputs,
                                early_stopping=False)    
        return [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
            
def generate_batch(model, tokenizer, sentences, num_outputs=1, length=150, use_wrapper=True):        
    if use_wrapper:
        encodings_input, map_tokens = tokenizer.encode_batch_inference(sentences)
    else:
        encodings_input = tokenizer(sentences, truncation=True, padding='longest', return_tensors='pt').to(device)

    with torch.no_grad():
        outputs = model.generate(encodings_input['input_ids'], 
                                max_length=length, num_beams=5, 
                                num_return_sequences=num_outputs,
                                early_stopping=False)   

    if use_wrapper:
        return tokenizer.decode_batch(outputs, map_tokens, sentences)
    else:
        return [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]


def process_sentences(model, sentences, token_start=IMPROVE_TOKEN):
    parsed_sentences = []
    for sent in tqdm.tqdm(sentences):
        
        outputs = generate(model, token_start + sent, 1)

        if (outputs[0] != sent):
            parsed_sentences.append({'orig':sent, 'improved':outputs})
        
    return parsed_sentences

def process_sentences_batch(model, tokenizer, sentences, token_start=IMPROVE_TOKEN, batch_size=32, output=None):
    parsed_sentences = []
    
    batches = chunks(sentences, batch_size)
    total = math.ceil(len(sentences)/batch_size)
    
    counter = 0
    for batch in tqdm.tqdm(batches, total=total):
        outputs = generate_batch(model, tokenizer, [token_start + sent for sent in batch], 1)
    
        joined_outputs = list(zip(batch, outputs))

        for sent in joined_outputs:
            if sent[0] != sent[1]:
                parsed_sentences.append({'orig':sent[0], 'improved':[sent[1]]})
        counter += 1       
        if counter % 100 == 0 and output:
            with open(output,'w') as f:
                f.write(json.dumps(parsed_sentences, indent=4))
        
    return parsed_sentences

# Multi lingual

In [10]:
import os

def process_files(files, model, tokenizer, processed_folder, limit, token_start=IMPROVE_TOKEN, suffix=None, input_dir=None, batch_size=32):
    for file in files:
        print ('Currrent file', file, 'token', token_start, 'limit', limit, 'suffix', suffix)
        with open(DATA_DIR + f'{input_dir}/{file}','r') as f:
            sentences_to_process = json.loads(f.read())
        
        base_file = file.split('.')[0]
        if suffix:
            output_file = f'{processed_folder}{base_file}-{limit}-{suffix}.json'
        else:
            output_file = f'{processed_folder}{base_file}-{limit}.json'
            
        if os.path.exists(output_file):
            print('Skipping already processed', file)
            continue
            
        parsed_sentences = process_sentences_batch(model, tokenizer, sentences_to_process[:limit], token_start, batch_size=batch_size, output=output_file)
    
        with open(output_file,'w') as f:
            f.write(json.dumps(parsed_sentences, indent=4))

# Process

In [13]:
models_config = {
    't5sm-l1aware-multi-s260-v1': {
        'size': 't5-small',
        'languages': ['pt','es'],
    },
    't5lg-l1aware-multi-s260-v1': {
        'size': 't5-base',
        'languages': ['pt','es'],
    },        
}


In [None]:
import pathlib

TOKENIZER_CONFIG = {
    'keep_parentheses': False
}

MAX_SENTENCES = 20000
BATCH_SIZE= 16

models_list = [
    't5sm-l1aware-multi-s260-v1',
    't5lg-l1aware-multi-s260-v1',
]

to_process_files = [
    'brace-v1.json',
    'lace-v1.json',
]

language_files = {
    'brace-v1.json': 'pt', 
    'lace-v1.json': 'es', 
}

def process_multi_models(models_list):    
    for MODEL_VERSION in models_list:
        # Clean up 
        torch.cuda.empty_cache()  
            
        print ('Loading model', MODEL_VERSION)
        if DEVICE == 'cpu':
            model = T5ForConditionalGeneration.from_pretrained(MODELS_DIR+MODEL_VERSION).cpu().eval()
        else:
            model = T5ForConditionalGeneration.from_pretrained(MODELS_DIR+MODEL_VERSION).cuda(DEVICE).eval()
        
        print ('Model loaded', MODEL_VERSION)
        
        model_config = models_config[MODEL_VERSION]
        MULTI_LANGUAGES = model_config['languages']
        
        print('Using tokenizer', model_config['size'])
        tokenizer = T5Tokenizer.from_pretrained(model_config['size'])
        tokenizer_wrapper= TokenizerWrapper(tokenizer, TOKENIZER_CONFIG)        

        processed_folder = f'{RESULTS_DIR}{MODEL_VERSION}/processed/'
        pathlib.Path(processed_folder).mkdir(parents=True, exist_ok=True)

        language = MODEL_VERSION.split('-')[2]
        if language == 'multi':
            for language in MULTI_LANGUAGES:
                token_start = f'{IMPROVE_TOKEN_MULTI} {language}: '
                suffix = f'token-{language}'  
    
                # Only process same language
                filtered_files = []
                for file in to_process_files:
                    if language_files[file] != language:
                        continue
                    filtered_files.append(file)
                
                try:
                    process_files(filtered_files, model, tokenizer_wrapper, processed_folder, MAX_SENTENCES, token_start=token_start,suffix=suffix, input_dir=input_dir, batch_size=BATCH_SIZE)    
                except Exception as e:
                    print (e)
                    pass
        elif language == 'all':
            token_start = f'{IMPROVE_TOKEN}'     
            try:
                process_files(to_process_files, model, tokenizer_wrapper, processed_folder, MAX_SENTENCES, token_start=token_start, input_dir=input_dir, batch_size=BATCH_SIZE)    
            except Exception as e:
                print (e)
                pass

        else:
            token_start = f'{IMPROVE_TOKEN_MULTI} {language}: '     
            try:
                process_files(to_process_files, model, processed_folder, MAX_SENTENCES, token_start=token_start, input_dir=input_dir, batch_size=BATCH_SIZE)    
            except Exception as e:
                print (e)
                pass
            
        # Clean up
        del model, tokenizer, tokenizer_wrapper
        gc.collect()
        time.sleep(10)
        torch.cuda.empty_cache() 
        
process_multi_models(models_list)