### util functions

In [1]:
import json
from tqdm import tqdm

from utils.helper import *
from utils.labels_tags import *
from utils.correction_ratio import *

from utils.helper import nlp
from utils.labels_tags import label_map

### compute metric code

In [2]:
import nltk
import evaluate
import numpy as np

from bart_score import BARTScorer

from statistics import mean

rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")
bart_scorer = BARTScorer(device='cuda:0', checkpoint='facebook/bart-large-cnn')

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels

def compute_metrics_inference(preds, targets):
    # Some simple post-processing
    preds, targets = postprocess_text(preds, targets)

    bart_score_p = mean(bart_scorer.score(targets, preds, batch_size=4))
    bart_score_r = mean(bart_scorer.score(preds, targets, batch_size=4))
    bart_score_f = 0.5 * (bart_score_p + bart_score_r)
    bart_score_result = {
        "precision": bart_score_p,
        "recall": bart_score_r,
        "f_score": bart_score_f,
    }

    rouge_result = rouge.compute(predictions=preds, references=targets, use_stemmer=True)
    rouge_result = {key: value * 100 for key, value in rouge_result.items()}
    rouge_result = {k: round(v, 4) for k, v in rouge_result.items()}

    bs_result = bertscore.compute(predictions=preds, references=targets, lang="en")
    bs_result = {key: np.mean(value) * 100 for key, value in bs_result.items() if key != "hashcode"}
    bs_result = {k: round(v, 4) for k, v in bs_result.items()}

    return {
        "bart_score": bart_score_result,
        "rouge": rouge_result,
        "bert-score": bs_result,
    }

def compute_metrics_test(preds, targets, dataset):
    # Some simple post-processing
    preds, targets = postprocess_text(preds, targets)

    num_dist_hard = 0
    num_cor_hard = 0
    num_dist_soft = 0
    num_cor_soft = 0
    for idx, d in enumerate(dataset):
        cor_hard, dist_hard = compute_correction_ratio(d, preds[idx])
        num_cor_hard += cor_hard
        num_dist_hard += dist_hard
        cor_soft, dist_soft = compute_correction_ratio(d, preds[idx], soft=True)
        num_cor_soft += cor_soft
        num_dist_soft += dist_soft

    corr_result = {
        "corr_score_hard": num_cor_hard / num_dist_hard,
        "corr_score_soft": num_cor_soft / num_dist_soft,
    }

    bart_score_p = mean(bart_scorer.score(targets, preds, batch_size=4))
    bart_score_r = mean(bart_scorer.score(preds, targets, batch_size=4))
    bart_score_f = 0.5 * (bart_score_p + bart_score_r)
    bart_score_result = {
        "precision": bart_score_p,
        "recall": bart_score_r,
        "f_score": bart_score_f,
    }

    rouge_result = rouge.compute(predictions=preds, references=targets, use_stemmer=True)
    rouge_result = {key: value * 100 for key, value in rouge_result.items()}
    rouge_result = {k: round(v, 4) for k, v in rouge_result.items()}

    bs_result = bertscore.compute(predictions=preds, references=targets, lang="en")
    bs_result = {key: np.mean(value) * 100 for key, value in bs_result.items() if key != "hashcode"}
    bs_result = {k: round(v, 4) for k, v in bs_result.items()}

    return {
        "correction_ratio": corr_result,
        "bart_score": bart_score_result,
        "rouge": rouge_result,
        "bert-score": bs_result,
    }

  from .autonotebook import tqdm as notebook_tqdm
Downloading builder script: 6.27kB [00:00, 5.34MB/s]
Downloading builder script: 7.95kB [00:00, 7.37MB/s]
Downloading: 899kB [00:00, 1.73MB/s]
Downloading: 456kB [00:00, 1.27MB/s]
Downloading: 1.36MB [00:00, 2.44MB/s]
Downloading: 1.58kB [00:00, 1.81MB/s]


### initialize joint bert model and funct

In [3]:
import sys
import argparse

sys.argv=['']
parser = argparse.ArgumentParser()

parser.add_argument(
    "--seed", type=int, default=1234, help="random seed for initialization"
)
parser.add_argument(
    "--train_batch_size", default=8, type=int, help="Batch size for training."
)
parser.add_argument(
    "--eval_batch_size", default=16, type=int, help="Batch size for evaluation."
)
parser.add_argument(
    "--max_seq_len",
    default=512,
    type=int,
    help="The maximum total input sequence length after tokenization.",
)
parser.add_argument(
    "--learning_rate",
    default=2e-5,
    type=float,
    help="The initial learning rate for Adam.",
)
parser.add_argument(
    "--num_train_epochs",
    default=10.0,
    type=float,
    help="Total number of training epochs to perform.",
)
parser.add_argument(
    "--weight_decay",
    default=0.01,
    type=float,
    help="Weight decay if we apply some.",
)
parser.add_argument(
    "--gradient_accumulation_steps",
    type=int,
    default=1,
    help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
    "--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer."
)
parser.add_argument(
    "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
)
parser.add_argument(
    "--max_steps",
    default=-1,
    type=int,
    help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument(
    "--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps."
)
parser.add_argument(
    "--dropout_rate",
    default=0.1,
    type=float,
    help="Dropout for fully-connected layers",
)

parser.add_argument(
    "--logging_steps", type=int, default=500, help="Log every X updates steps."
)
parser.add_argument(
    "--save_steps",
    type=int,
    default=500,
    help="Save checkpoint every X updates steps.",
)

parser.add_argument(
    "--do_train", action="store_true", help="Whether to run training."
)
parser.add_argument(
    "--do_eval", action="store_true", help="Whether to run eval on the dev set."
)
parser.add_argument(
    "--do_predict",
    action="store_true",
    help="Whether to run predict on the test set.",
)
parser.add_argument(
    "--no_cuda", action="store_true", help="Avoid using CUDA when available"
)

parser.add_argument(
    "--ignore_index",
    default=-100,
    type=int,
    help="Specifies a target value that is ignored and does not contribute to the input gradient",
)

parser.add_argument(
    "--slot_loss_coef",
    type=float,
    default=1.0,
    help="Coefficient for the slot loss.",
)

# CRF option
parser.add_argument("--use_crf", action="store_true", help="Whether to use CRF")
parser.add_argument(
    "--slot_pad_label",
    default="PAD",
    type=str,
    help="Pad token for slot label pad (to be ignore when calculate loss)",
)

parser.add_argument(
    "--model_name_or_path",
    default="bert-large-uncased",
    type=str,
    help="Model for training",
)
parser.add_argument(
    "--data_dir", default="./data/dialogsum/", type=str, help="Input data dir"
)
parser.add_argument(
    "--label_type", default="labels_bio", type=str, help="Label type"
)

args = parser.parse_args()

In [4]:
import torch
import numpy as np

def align_labels_with_tokens(labels, word_ids, context_len):
    new_labels = []
    current_word = None
    for i in range(len(word_ids)):
        if i < context_len + 2:
            new_labels.append(-100)
        else:
            if word_ids[i] != current_word:
                # Start of a new word!
                current_word = word_ids[i]
                label = -100 if word_ids[i] is None else labels[word_ids[i]]
                new_labels.append(label)
            else:
                # Special token or same word as prev. token
                new_labels.append(-100)

    return new_labels

def joint_model_inference(dialog, draft):
    tokenized_inputs = joint_tokenizer(
        dialog,
        draft,
        is_split_into_words=True,
        truncation=True,
        max_length=512,
    )
    dummy_labels = [0] * len(draft)
    word_ids = tokenized_inputs.word_ids()
    context_len = len(joint_tokenizer.tokenize(dialog, is_split_into_words=True))
    tokenized_inputs["labels"] = align_labels_with_tokens(dummy_labels, word_ids, context_len)
    tokenized_inputs["hallucination_label_ids"] = 0

    joint_model.eval()
    inputs = {
        "input_ids": torch.tensor([tokenized_inputs["input_ids"]], dtype=torch.long),
        "attention_mask": torch.tensor([tokenized_inputs["attention_mask"]], dtype=torch.long),
        "token_type_ids": torch.tensor([tokenized_inputs["token_type_ids"]], dtype=torch.long),
        "intent_label_ids": torch.tensor([tokenized_inputs["hallucination_label_ids"]], dtype=torch.long),
        "slot_labels_ids": torch.tensor([tokenized_inputs["labels"]], dtype=torch.long),
    }

    outputs = joint_model(**inputs)
    _, (intent_logits, slot_logits) = outputs[:2]

    intent_preds = intent_logits.detach().cpu().numpy()
    intent_preds = np.argmax(intent_preds, axis=1)

    slot_preds = slot_logits.detach().cpu().numpy()
    slot_preds = np.argmax(slot_preds, axis=2)
    out_slot_labels_ids = inputs["slot_labels_ids"].detach().cpu().numpy()
    slot_label_map = {i: label for i, label in enumerate(label_map[args.label_type]["tags"])}
    out_slot_label_list = [[] for _ in range(out_slot_labels_ids.shape[0])]
    slot_preds_list = [[] for _ in range(out_slot_labels_ids.shape[0])]

    for i in range(out_slot_labels_ids.shape[0]):
        for j in range(out_slot_labels_ids.shape[1]):
            if out_slot_labels_ids[i, j] != args.ignore_index:
                out_slot_label_list[i].append(
                    slot_label_map[out_slot_labels_ids[i][j]]
                )
                slot_preds_list[i].append(slot_label_map[slot_preds[i][j]])
    
    return intent_preds[0], slot_preds_list[0]

### corrector pipeline inference

In [5]:
def get_identifier_prediction(identifier, dialog, draft_tokens):
    encoded = cased_tokenizer(draft_tokens, is_split_into_words=True)
    word_ids = encoded.word_ids()
    word_ids = word_ids[1:-1]
    missing_nums = find_missing_numbers(word_ids)

    if "joint" in identifier:
        hallucinated_pred, pred_labels = joint_model_inference(dialog, draft_tokens)
        if missing_nums:
            for i in missing_nums:
                pred_labels.insert(i, "B-E")
        is_hallucinated = bool(hallucinated_pred) or not all(l == "O" for l in pred_labels)
    else:
        tokenized_inputs = tokenizer(dialog, draft_tokens, is_split_into_words=True, truncation=True, max_length=512, return_tensors="pt")
        with torch.no_grad():
            logits = token_classifier(**tokenized_inputs).logits
        predictions = torch.argmax(logits, dim=2)
        predicted_token_class = [token_classifier.config.id2label[t.item()] for t in predictions[0]]

        word_ids = tokenized_inputs.word_ids()
        first_idx = (tokenized_inputs.token_type_ids[0] == 1).nonzero(as_tuple=True)[0][0]
        word_ids = word_ids[first_idx:-1]
        raw_preds = predicted_token_class[first_idx:-1]

        pred_map = map_predictions(raw_preds, word_ids)
        if missing_nums:
            for i in missing_nums:
                if i in pred_map.keys():
                    pred_map[i].add("B-E")
                else:
                    pred_map[i] = {"B-E"}

        pred_labels = align_pred_labels(pred_map, draft_tokens)
        is_hallucinated = not all(l == "O" for l in pred_labels)
    
    return pred_labels, is_hallucinated

In [20]:
def evaluate_corrector_pipeline(dataset, mode="full", identifier="joint", supervision="tag", model_type="proposed", iterative=False, max_iter=5, prompt=None, draft_summaries=None):
    assert mode in ["full", "distort_only"]
    assert identifier in ["joint", "joint_comb", "token_cls", "token_cls_comb", "bart", "ideal"]
    assert supervision in ["tag", "list"]
    assert model_type in ["baseline", "proposed"]

    TP = 0
    TN = 0
    FP = 0
    FN = 0
    count_corrected = 0
    
    corrected_summaries = []
    if mode == "distort_only" and draft_summaries is None:
        dataset = [d for d in dataset if not all(l == 0 for l in d[label_type])]

    if draft_summaries is not None:
        targets = [d["summary"] for d in dataset]
    else:
        targets = [d["ref_summaries"] for d in dataset]
    
    for idx, d in tqdm(enumerate(dataset)):
        if draft_summaries is not None:
            assert len(draft_summaries) == len(dataset)
            assert mode == "full"
            
            dialog = d["dialogue"]
            dialog_tokens = [token.text for token in nlp(dialog)]
            draft = draft_summaries[idx]
            raw_draft = tokenize_summary(draft)
            while draft != raw_draft:
                draft = raw_draft
                raw_draft = tokenize_summary(draft)
            draft = raw_draft
            draft_tokens = draft.split()
        else:
            dialog = " ".join(d["dialogues"])
            dialog_tokens = d["dialogues"]
            draft = " ".join(d["distorted_summaries"])
            draft_tokens = d["distorted_summaries"]

        iteration = 1
        while iteration <= max_iter:
            if identifier == "ideal" and draft_summaries is None:
                pred_labels = d[f"tag_{label_type}"]
                is_hallucinated = not all(l == "O" for l in pred_labels)
                if iteration > 1:
                    pred_labels = ["O"] * len(draft_tokens)
                    is_hallucinated = False
            elif identifier == "bart":
                pred_labels = ["O"] * len(draft_tokens)
                identifier_input = f"{dialog} </s></s> {draft}"
                input_ids = bart_identifier_tokenizer(identifier_input, max_length=1024, truncation=True, return_tensors='pt')['input_ids']
                summary_ids = bart_identifier.generate(
                            input_ids.to(torch.device("cuda:0")),
                            max_length=128,
                            min_length=5,
                            num_beams=6, 
                            no_repeat_ngram_size=3)
                identifier_result = bart_identifier_tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
                if supervision == "tag":
                    tagged_summary = identifier_result
                    if "<" in tagged_summary and ">" in tagged_summary:
                        is_hallucinated = True
                    else:
                        is_hallucinated = False
                elif supervision == "list":
                    word_list = identifier_result
                    is_hallucinated = any(char.isalpha() for char in word_list)
            else:
                pred_labels, is_hallucinated = get_identifier_prediction(identifier, dialog_tokens, draft_tokens)
            
            if draft_summaries is None:
                assert len(pred_labels) == len(draft_tokens)

            uncorrected_summary = cased_tokenizer.decode(cased_tokenizer(draft)["input_ids"][1:-1])

            if not is_hallucinated:
                if iteration == 1:
                    if draft_summaries is None:
                        if all(l == 0 for l in d[label_type]):
                            TN += 1
                        else:
                            FN += 1
                        corrected_summaries.append(uncorrected_summary)
                    else:
                        corrected_summaries.append(draft_summaries[idx])
                else:
                    corrected_summaries.append(uncorrected_summary)
                break

            else:
                label2id = label_map[label_type]["label2id"]
                pred_labels = [label2id[label] for label in pred_labels]

                if model_type == "baseline":
                    input_str = f"{dialog} </s></s> {draft}"
                elif model_type == "proposed":
                    if supervision == "tag":
                        if "bio" in label_type:
                            corrector_input = add_tags_to_summary_bio(draft_tokens, pred_labels, label_type)
                        else:
                            corrector_input = add_tags_to_summary(draft_tokens, pred_labels, label_type)
                        
                        if identifier == "bart":
                            corrector_input = tagged_summary
                            
                        input_str = f"{dialog} </s></s> {corrector_input}"
                    else:
                        if identifier != "bart":
                            word_list = get_hallucinated_word_list(pred_labels, draft_tokens, label_type)
                        if prompt:
                            input_str = f"{prompt} </s></s> Word List: {word_list} </s></s> Draft Summary: {draft} </s></s> Dialogue Context: {dialog}"
                        else:
                            raise ValueError("List Supervision need prompt")

                input_ids = corrector_tokenizer(input_str, max_length=1024, truncation=True, return_tensors='pt')['input_ids']
                summary_ids = corrector.generate(
                            input_ids.to(torch.device("cuda:0")),
                            max_length=128,
                            min_length=5,
                            num_beams=6, 
                            no_repeat_ngram_size=3)
                corrected_summary = corrector_tokenizer.decode(summary_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
                is_same = compare_strings(corrected_summary, uncorrected_summary)

                if iteration == 1:
                    if is_same:
                        if draft_summaries is None:
                            if all(l == 0 for l in d[label_type]):
                                TN += 1
                            else:
                                FN += 1
                    else:
                        if draft_summaries is None:
                            if not all(l == 0 for l in d[label_type]):
                                TP += 1
                            else:
                                FP += 1
                        else:
                            count_corrected += 1

                if not iterative or iteration == max_iter or is_same:
                    if iteration == 1 and is_same:
                        if draft_summaries is None:
                            corrected_summaries.append(uncorrected_summary)
                        else:
                            corrected_summaries.append(draft_summaries[idx])
                    else:
                        corrected_summaries.append(corrected_summary)
                    break
                else:
                    draft = corrected_summary.strip()
                    raw_draft = tokenize_summary(draft)
                    while draft != raw_draft:
                        draft = raw_draft
                        raw_draft = tokenize_summary(draft)
                    draft = raw_draft
                    draft_tokens = draft.split()
                    iteration += 1

    output_filename = f"test_{identifier}_{mode}_{supervision}_predict"
    if iterative:
        output_filename += f"_iter{max_iter}"
    
    if draft_summaries is not None:
        output_filename = "inference_" + output_filename

    if draft_summaries is None:
        result = compute_metrics_test(corrected_summaries, targets, dataset)
        hallucination_pred_acc = (TP + TN) / (TP + TN + FP + FN)
        tpr = TP / (TP + FN)
        tnr = TN / (TN + FP)
        result["pred_acc"] = hallucination_pred_acc
        result["TPR"] = tpr
        result["TNR"] = tnr
        result["balanced_acc"] = (tpr + tnr) / 2
        result["TP"] = TP
        result["TN"] = TN
        result["FP"] = FP
        result["FN"] = FN
    else:
        result = compute_metrics_inference(corrected_summaries, targets)
        result["corrected_summary"] = count_corrected

    with open(f"{corrector_dir}/{output_filename}.txt", 'w') as f:
        f.write("\n".join(corrected_summaries))

    with open(f"{corrector_dir}/{output_filename}.json", 'w') as f:
        json.dump(result, f)

### Run Inference

In [21]:
from datasets import load_from_disk

data_dir = "./data/dialogsum"
dataset = load_from_disk(data_dir)

with open("./model/dialogsum/vanilla/bart-base/test_predict.txt") as f:
    draft_summaries = f.readlines()

In [22]:
from train_jointbert import JointBERT
from transformers import AutoModelForSeq2SeqLM, AutoModelForTokenClassification, BertConfig, AutoTokenizer

label_type = "labels_sep_bio"
args.label_type = label_type
model_type = label_type

cased_checkpoint = "bert-base-cased"

tokencls_dir = f"./model/dialogsum/span-predictor/bert-large-uncased-x0.5-alpha0.5-lfp0.3-insertion-{label_type}"
joint_dir = f"./model/dialogsum/joint-predictor/bert-large-uncased-x0.5-alpha0.5-lfp0.3-insertion-{label_type}"
bart_identifier_dir = f"./model/dialogsum/identifier/bart-large-x0.5-alpha0.5-lfp0.3-insertion-identifier-sep"
corrector_dir = f"./model/dialogsum/baseline/baseline-bart-large-x0.5-alpha0.5-lfp0.3-insertion-all"

args.model_dir = joint_dir

joint_tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
joint_config = BertConfig.from_pretrained(args.model_dir)
joint_model = JointBERT.from_pretrained(
    args.model_dir,
    config=joint_config,
    args=args,
    intent_label_lst=["clean", "distorted"],
    slot_label_lst=label_map[args.label_type]["tags"],
)

cased_tokenizer = AutoTokenizer.from_pretrained(cased_checkpoint, use_fast=True)
tokenizer = AutoTokenizer.from_pretrained(tokencls_dir, use_fast=True)
token_classifier = AutoModelForTokenClassification.from_pretrained(tokencls_dir)

bart_identifier_tokenizer = AutoTokenizer.from_pretrained(bart_identifier_dir)
bart_identifier = AutoModelForSeq2SeqLM.from_pretrained(bart_identifier_dir).to(torch.device("cuda:0"))

corrector_tokenizer = AutoTokenizer.from_pretrained(corrector_dir)
corrector = AutoModelForSeq2SeqLM.from_pretrained(corrector_dir).to(torch.device("cuda:0"))

In [23]:
list_prompt = "Given a dialogue context, a draft summary, a list of potential factually incorrect words/spans, produce a faithful final summary based on the dialogue context."

evaluate_corrector_pipeline(dataset["test"], mode="full", identifier="joint", supervision="tag", model_type="baseline", iterative=False, prompt=None, draft_summaries=draft_summaries)

261it [01:07,  4.04it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (995 > 512). Running this sequence through the model will result in indexing errors
1500it [06:32,  3.82it/s]


06/23/2023 22:19:40 - INFO - absl - Using default tokenizer.
