In [None]:
import os 
import sys 
sys.path.append("../")
import re 
import json 
import numpy as np 
import pandas as pd 
from collections import defaultdict, Counter
import torch 
import datasets
import nltk  # Here to have a nice missing dependency error message early on
import datasets
from datasets import load_dataset

import evaluate
import transformers
from filelock import FileLock
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    BertTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    set_seed,
)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, is_offline_mode, send_example_telemetry
from transformers.utils.versions import require_version

from tqdm import tqdm
from sacrebleu.metrics import BLEU, CHRF, TER
from bert_score import BERTScorer
from torch.utils.data import Dataset, DataLoader

from utils import recover_sentence
from model_multi_director import BartDASC
from director_data_collator import DataCollatorForMultiBinaryDirector, DataCollatorForDynamicMaskMultiBinaryDirector
from my_configs import BartMultiBinaryDecompDirectorConfig

In [None]:
torch.manual_seed(0)

In [None]:
test_file = "./datasets/dulemon/test.csv"
data_files = {}
data_files["test"] = test_file
extension = test_file.split(".")[-1]
raw_datasets = load_dataset(
    extension,
    data_files=data_files,
    use_auth_token=None,
)

In [None]:
emotion_list = ['anger',
 'disgust',
 'fear',
 'happiness',
 'like',
 'none',
 'sadness',
 'surprise']

# 0/1/2: neutral/positive/negative
emotion2big_class = [2, 2, 2, 1, 1, 0, 2, 0]

In [None]:
gt_df = pd.read_csv("./datasets/dulemon/test.csv")
gt_texts = gt_df["target"].values
gender_labels = gt_df["Gender"].values
emotion_labels = gt_df["Emotion"].values
question_labels = gt_df["Question"].values

In [None]:
model_path = "../train_dulemon_outputs/bart_dasc1"
config = BartMultiBinaryDecompDirectorConfig.from_pretrained(model_path)
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BartDASC.from_pretrained(model_path).to("cuda:0")

In [None]:
input_column = "history"
target_column = "target"
input_truncation_side = "left"
max_source_length = 384
max_target_length = 128
padding = "max_length"
ignore_pad_token_for_loss = True
column_names = raw_datasets["test"].column_names
control_columns = ['Gender', 'Emotion', 'Question']
num_control_pre_aspect = [3, 8, 2]
num_controls = sum(num_control_pre_aspect)
fp16 = True
class_control_mapping = json.load(open("../aux_files/control_token_mapping_3aspect.json", encoding="utf-8"))
used_control_mappings = {}
extra_control_tokens = set(["[speaker1]", "[speaker2]"])
for col in control_columns:
    assert col in class_control_mapping
    # sth like {0: "", 1: "[positive]", 2: "[negative]"}
    for k in class_control_mapping[col]:
        if k != "" and k.count("[") == 1:  # do not add composed tokens like "[-positive][-negative]"
            extra_control_tokens.add(k)
    mapping0 = {v: k for k, v in class_control_mapping[col].items()}
    used_control_mappings[col] = mapping0


In [None]:
def preprocess_function(examples):
    # remove pairs where at least one record is None

    inputs, targets, controls = [], [], []
    for i in range(len(examples[input_column])):
        if examples[input_column][i] and examples[target_column][i]:
            inputs.append(examples[input_column][i])
            targets.append(examples[target_column][i])
            curr_control = [0] * num_controls
            offset = 0
            for control_column, num_control in zip(control_columns, num_control_pre_aspect):
                control_label = int(examples[control_column][i])
                curr_control[offset+control_label] = 1
                offset += num_control
            controls.append(curr_control)
            
    tokenizer.truncation_side = input_truncation_side
    model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)

    # Tokenize targets with the `text_target` keyword argument
    # labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
    tokenizer.truncation_side = 'right'
    labels = tokenizer(targets, max_length=max_target_length, padding=padding, truncation=True)

    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
    # padding in the loss.
    if padding == "max_length" and ignore_pad_token_for_loss:
        labels["input_ids"] = [
            [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
        ]

    model_inputs["labels"] = labels["input_ids"]
    model_inputs["controls"] = controls
    return model_inputs


In [None]:
predict_dataset = raw_datasets["test"]
predict_dataset = predict_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=4,
    remove_columns=column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on prediction dataset",
)

In [None]:
# Data collator
label_pad_token_id = -100 if ignore_pad_token_for_loss else tokenizer.pad_token_id
data_collator = DataCollatorForMultiBinaryDirector(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
    pad_to_multiple_of=8 if fp16 else None,
)

In [None]:
data_loader = DataLoader(predict_dataset, batch_size=32, num_workers=4, collate_fn=data_collator, pin_memory=True)

In [None]:
out_dir = os.path.join(model_path, "infer_sampling/")
os.makedirs(out_dir, exist_ok=True)

In [None]:
model.clf_infer_weight = -1  # no prior weighting, use +1 as weight
all_outputs = []
for batch_id, batch in enumerate(tqdm(data_loader)):
    input_ids = batch["input_ids"].to(model.device)
    controls = batch["controls"].to(model.device)
    outputs = model.generate(input_ids, controls=controls, do_sample=True, max_length=128, top_p=0.5, num_return_sequences=1)
    outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    all_outputs.extend(outputs)
all_outputs = list(map(recover_sentence, all_outputs))
with open(out_dir+"topp_0.5.txt", "w", encoding="utf-8") as f:
    f.write("\n".join(all_outputs))