In [None]:
import re
import language_check # need to install it
import data_loader

def relex_utterance(utterance, mr, replace_name=False):
    # function for relexicalisation
    
    # creation of a dictionnary : {slot : value of the slot} for each Meaning Representation
    slots = {}
    for slot_value in mr.split(','):
        # extract the slot
        sep_idx = slot_value.find('[')
        slot = slot_value[:sep_idx].strip()
        # extract the value of the slot
        value = slot_value[sep_idx + 1:-1].strip()
        slots[slot] = value
    
    # identify all value placeholders
    matches = re.findall(r'&slot_val_.*?&', utterance)
    
    # replace the value placeholders with the corresponding slot values from the Meaning Representation
    for match in matches:
        slot = match.split('_')
        slot = slot[-1].rstrip('&')
        if slot in list(slots.keys()):
            
            # for more naturalness we will replace the name by 'It' in order to avoid repetitions in the final merged utterance
            if slot == 'name' and replace_name:
                new_val = 'It'
            else:
                new_val = slots[slot]
                
            utterance = utterance.replace(match, new_val)

    # In order to have correct sentences at the end we capitalize the first letter of each sentence 
    
    utterance = utterance[0].upper() + utterance[1:]
    # we search for the end of the sentence : after a '.' (represented by '-PERIOD-')
    sent_end = utterance.find(r'-PERIOD-')
    while sent_end >= 0:
        next_sent_beg = sent_end + 2
        if next_sent_beg < len(utterance):
            utterance = utterance[:next_sent_beg] + utterance[next_sent_beg].upper() + utterance[next_sent_beg + 1:]
        
        # we search for the following '.' (represented by '-PERIOD-')
        sent_end = utterance.find(r'-PERIOD-', next_sent_beg)
    
    # Finally we replace the represented period, the placeholders, by a real '.' to have grammatically correct sentences
    utterance = utterance.replace(r' -PERIOD-', '.')

    return utterance

def merge_utterances(res, mrs, test_groups, nb_var):
    # function for merging partial utterances belonging to the same MR into multi-sentence utterances
    
    final_utterances = []
    merged_utterance = ''
    prev_group = -1

    for sent, cur_group in zip(res, test_groups):
        if cur_group != prev_group:
            if prev_group != -1:
                final_utterances.append(merged_utterance + '.')

            merged_utterance = relex_utterance(sent, mrs[cur_group // nb_var])
            prev_group = cur_group
            
        else:
            # in this case we set the parameter 'replace_name' to True in order to avoid repetitions -> name replaced by 'It'
            # (see the function 'relex' for details)
            merged_utterance += '. ' + relex_utterance(sent, mrs[cur_group // nb_var], replace_name=True)
    
    final_utterances.append(merged_utterance + '.')

    return final_utterances


def combo_print(small_pred, large_pred, num_permutes):
    x = 0
    y = 0
    base = max(int(len(small_pred) * .1), 1)
    new_pred = []
    
    while x < len(small_pred):
        for i in range(0, num_permutes):
            new_pred.append(large_pred[x*num_permutes+i])
            
        new_pred.append('\033[1m' + small_pred[x] + '\033[0m')
        
        x += 1
        
    return new_pred


def depermute_input(mrs, sents, predictions, num_permutes):
    new_mr = []
    new_sent = []
    new_pred = []
    x = 0
    tool = language_check.LanguageTool('en-UK')
    base = max(int(len(predictions) * .1), 1)
    
    while x < len(predictions):
        scores = {}
        for i in range(0, num_permutes):
            scores[x + i] = score_output(mrs[x // num_permutes], sents[x // num_permutes], predictions[x + i], tool)
        
        top_score = max(scores.keys(), key=(lambda key: scores[key]))
        new_mr.append(mrs[top_score // num_permutes])
        new_sent.append(sents[top_score // num_permutes])
        new_pred.append(predictions[top_score])
        x += num_permutes

    return new_mr, new_sent, new_pred


def correct(mrs, pred):
    new_pred = []
    base = max(int(len(pred) * .1),1)
    tool = language_check.LanguageTool('en-UK')
    
    for x, p in enumerate(pred):

        s1, c1 = score_grammar_spelling(mrs[x], p, tool, True)
        s1, c1 = score_known_errors(c1, True)
        new_pred.append(c1)
        
    return new_pred


def score_output(mr, sent, pred, tool=None):
    # score = info − errors where errors = grammar + known_errors
    score = 0
    score += score_info(mr, pred)
    score -= score_grammar(mr, pred, tool)
    score -= score_known_errors(pred)
    return score


def score_info(mr, pred):
    # Informativeness score is estimated by a direct string overlap of slot values in the utterance
    score = 0
    mrs = mr.split(',')
    for slot_value in mrs:
        sep_idx = slot_value.find('[')
        value = slot_value[sep_idx + 1:-1].strip().lower()

        if value in pred.lower():
            score += 1
            
    #normalize score 
    score = score/len(mrs)
    
    return score


def score_grammar(mr, pred, tool=None, correct=False):
    # score according to grammar spelling errors
    pred = data_loader.delex_data([mr], [pred], update_data_source=True, specific_slots=None, split=False)
    
    if tool is None:
        tool = language_check.LanguageTool('en-UK')
        
    # check number of grammar mistakes in the sentence using language_check tool 
    matches = tool.check(pred)
    score = min(len(matches)/len(pred.split()), 1)
    
    if correct:
        x = 0
        while True:
            new_pred = tool.correct(pred)
            if pred == new_pred or x == 5:
                break
            pred = new_pred
            x += 1
        pred = relex_utterance(pred, mr)
        return score, pred
    
    return score


def score_known_errors(pred, correct=False):
    pred_split = pred.split()
    score = 0
    temp_score = 0
    var_to_reduce = []
    prev = None
    for ps in pred_split:
        #accounts for a weird case of like 5 5 5 5 5
        if len(ps) == 1 and ps in ["0","1","2","3","4","5","6","7","8","9"] and (prev == ps or prev is None):
            temp_score += 1
            prev = ps
        else:
            if temp_score > 1:
                score += temp_score
                var_to_reduce.append((prev,temp_score),)
            temp_score = 0
            if ps in ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]:
                prev = ps
                temp_score += 1
    if temp_score > 1:
        score += temp_score
        var_to_reduce.append((prev, temp_score), )
    
    if correct:
        for var in var_to_reduce:
            v, num_v = var
            string_to_kill = " ".join([v]*num_v)
            pred = pred.replace(string_to_kill, " "+v)
        return score/len(pred_split), pred

    return score/len(pred_split)