In [9]:
from ast import literal_eval
import numpy as np
import pandas as pd
import time
import os
import math
from collections import Counter

In [10]:
def extract_spans_para(seq, seq_type, task, dataset, data_type, self_con=False):
    quads = []
    if seq_type == 'pred':
        if(len(seq.split('therefore, the quadruplets are:'))>1):
            quadruplets = seq.split('therefore, the quadruplets are:')[1].strip()
            sents = [q.strip() for q in quadruplets.split('[SSEP]')]
            for s in sents:
                    # food quality is bad because pizza is over cooked.
                    try:
                        index_ac = s.index("[AC]")
                        index_sp = s.index("[SP]")
                        index_at = s.index("[AT]")
                        index_ot = s.index("[OT]")

                        combined_list = [index_ac, index_sp, index_at, index_ot]
                        arg_index_list = list(np.argsort(combined_list))  # .tolist()

                        result = []
                        for i in range(len(combined_list)):
                            start = combined_list[i] + 4
                            sort_index = arg_index_list.index(i)
                            if sort_index < 3:
                                next_ = arg_index_list[sort_index + 1]
                                re = s[start: combined_list[next_]]
                            else:
                                re = s[start:]
                            result.append(re.strip())

                        ac, sp, at, ot = result             
                    except ValueError:
                        try:
                            print(f'In {seq_type} seq, cannot decode: {s}')
                            pass
                        except UnicodeEncodeError:
                            print(f'In {seq_type} seq, a string cannot be decoded')
                            pass
                        at, ac, sp, ot = '', '', '', ''
                    if [at, ac, sp, ot] not in quads:
                        if self_con:
                            quads.append((at, ac, sp, ot))
                        else:
                            quads.append([at, ac, sp, ot])
        else:
            if self_con:
                quads.append(('','','',''))
            else:
                quads.append(['','','',''])
    elif seq_type == 'gold':
        dataset = dataset.split('_')[0]
        with open(f'original_data/{task}/{dataset}/{data_type}.txt', 'r', encoding='UTF-8') as f:
            lines = f.readlines()
            for line in lines:
                line = line.lower()
                q = line.split('####')[1].strip()
                qq = literal_eval(q)
                quads.append(qq)
    return quads

In [11]:
def compute_f1_scores(pred_pt, gold_pt, data_type, verbose=True):
    """
    Function to compute F1 scores with pred and gold quads
    The input needs to be already processed
    """
    # number of true postive, gold standard, predictions
    n_tp, n_gold, n_pred = 0, 0, 0

    if data_type == "test":
        for i in range(len(pred_pt)):
            n_gold += len(gold_pt[i])
            n_pred += len(pred_pt[i])

            for t in pred_pt[i]:
                if t in gold_pt[i]:
                    n_tp += 1

    if verbose and data_type == "test":
        print(
            f"number of gold spans: {n_gold}, predicted spans: {n_pred}, hit: {n_tp}"
        )

    precision = float(n_tp) / float(n_pred) if n_pred != 0 else 0
    recall = float(n_tp) / float(n_gold) if n_gold != 0 else 0
    f1 = 2 * precision * recall / (
        precision + recall) if precision != 0 or recall != 0 else 0
    scores = {
        'precision': precision * 100,
        'recall': recall * 100,
        'f1': f1 * 100
    }

    return scores

In [12]:
def self_consistency(path, num_path, threshold, gold_quads):
    with open(path) as f:
        outputs = f.readlines()
    
    length = outputs[::num_path]

    pred_quads = []

    for i in range(len(length)):
        o_idx = i * num_path
        multi_outputs_ = outputs[o_idx:o_idx + num_path]

        multi_outputs = []

        for j in range(len(multi_outputs_)):
            multi_outputs.extend(
                extract_spans_para(multi_outputs_[j], 'pred', 'asqp', 'rest16', 'test',self_con=True))

        output_quads = []
        counter = dict(Counter(multi_outputs))
        #print(i, counter)
        for quad, count in counter.items():
            if count >= threshold:
                output_quads.append(quad)
        
        output = []
        for q in output_quads:
            at, ac, sp, ot = q
            output.append([at, ac, sp, ot])
        
        pred_quads.append(output)

    # Compute model performance
    print(len(pred_quads), len(gold_quads))
    assert len(pred_quads) == len(gold_quads)
    num_samples = len(gold_quads)

    all_labels, all_preds = [], []

    for i in range(num_samples):
        gold_list = gold_quads[i]
        pred_list = pred_quads[i]

        all_labels.append(gold_list)
        all_preds.append(pred_list)

    #print("\nResults:")
    scores = compute_f1_scores(all_preds, all_labels,'test')
    #print(scores)

    return scores


In [13]:
num_return_sequences = 15
output_dir = 't5-base_asqp_rest16_top5_x16_test'
model_name_or_path = 't5-base'
rst_path = f'{output_dir}/{model_name_or_path}_path{num_return_sequences}_results.txt'


print("\n****** Self-Consistency Evaluation ******")

gold_quads = []
with open(f'./original_data/asqp/rest16/test.txt', 'r') as f:
    lines = f.readlines()
    for line in lines:
        line = line.lower()
        q = line.split('####')[1].strip()
        qq = literal_eval(q)
        gold_quads.append(qq)

log_file_path = os.path.join(output_dir, f"sc_{num_return_sequences}_seqs_score_result.txt")
local_time = time.asctime(time.localtime(time.time()))

with open(log_file_path, "a+") as f:
    f.write(f"============================================================\n")
    f.write(f"{local_time} \nNum Sequences: {num_return_sequences}\n") # Order Top
    f.write(f"============================================================\n")

    threshold = math.ceil(num_return_sequences/2)
    scores = self_consistency(rst_path, num_return_sequences, threshold, gold_quads)
    scores['precision'] = float(scores['precision'])
    scores['recall'] = float(scores['recall'])
    scores['f1'] = float(scores['f1'])
    exp_results = "Precision: {:.2f} Recall: {:.2f} F1 = {:.2f}".format(scores['precision'], scores['recall'], scores['f1'])
    log_str = f"threshold: {threshold} {exp_results}\n\n"
    print(log_str)
    f.write(log_str)

print()
print("****** Finish Self-Consistency Evaluation ******")


****** Self-Consistency Evaluation ******
In pred seq, cannot decode: 
In pred seq, cannot decode: [AT] food [CT] overpriced [AC] food prices [SP] negative
In pred seq, cannot decode: [SP] positive [AT] saag pane
In pred seq, cannot decode: [SP] positive [AT] samosas [OT] yummy [
In pred seq, cannot decode: [OT] generous [SP] positive [
In pred seq, cannot decode: [TA] relaxing [SP] positive [AT] null [AC] ambience general
In pred seq, cannot decode: [OT] categorized as $ $ $ $ $ $
In pred seq, cannot decode: [OT] great [AC] drinks quality [SP] positive [
In pred seq, cannot decode: [OT] perfect [SP] positive [AT] entre (AC] food style_options
In pred seq, cannot decode: [SP] positive [AT] place [OT] have it locked (AC] ambience general
In pred seq, cannot decode: [OT] good [SP] positive [AT] horse mackerel (AC] food quality
In pred seq, cannot decode: [TO] favorite [SP] positive [AT] spot [AC] restaurant general
In pred seq, cannot decode: [ATT] null [OT] sad [AC] restaurant general 