In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!cd "/content/drive/My Drive/dstc11-track5/"

In [3]:
!pip install nltk==3.6.6
!pip install numpy==1.22.0
!pip install rouge_score==0.1.2
!pip install scikit_learn==1.1.1
!pip install sentencepiece==0.1.96
!pip install strsimpy==0.2.1
!pip install summ_eval==0.892
!pip install tensorboard==2.9.0
!pip install tensorboardX==2.5
!pip install torch==1.13.1
!pip install tqdm==4.62.3
!pip install transformers==4.20.1
!python -m nltk.downloader 'punkt'
!python -m nltk.downloader 'wordnet'

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting nltk==3.6.6
  Downloading nltk-3.6.6-py3-none-any.whl (1.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m20.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: nltk
  Attempting uninstall: nltk
    Found existing installation: nltk 3.8.1
    Uninstalling nltk-3.8.1:
      Successfully uninstalled nltk-3.8.1
Successfully installed nltk-3.6.6
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting numpy==1.22.0
  Downloading numpy-1.22.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.8/16.8 MB[0m [31m70.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.22.4
    Uninstalling numpy-1.22.4:
     

In [4]:
import json
import os
import re
import logging
import random
import copy

from collections import defaultdict
from itertools import chain

import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

import torch
from tqdm import tqdm

logger = logging.getLogger(__name__)

SPECIAL_TOKENS = {
    "additional_special_tokens": ["<speaker1>", "<speaker2>", "<knowledge_sep>", "<knowledge_tag>"],
}

# utils/data

In [5]:
RE_ART = re.compile(r'\b(a|an|the)\b')
RE_PUNC = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')

def remove_articles(_text):
    return RE_ART.sub(' ', _text)

def white_space_fix(_text):
    return ' '.join(_text.split())

def remove_punc(_text):
    return RE_PUNC.sub(' ', _text)  # convert punctuation to spaces

def lower(_text):
    return _text.lower()

def normalize(text):
    """Lower text and remove punctuation, articles and extra whitespace. """
    return white_space_fix(remove_articles(remove_punc(lower(text))))

def pad_ids(arrays, padding, max_length=-1):
    if max_length < 0:
        max_length = max(list(map(len, arrays)))

    arrays = [
        array + [padding] * (max_length - len(array))
        for array in arrays
    ]

    return arrays


def truncate_sequences(sequences, max_length):
    words_to_cut = sum(list(map(len, sequences))) - max_length
    if words_to_cut <= 0:
        return sequences

    while words_to_cut > len(sequences[0]):
        words_to_cut -= len(sequences[0])
        sequences = sequences[1:]

    sequences[0] = sequences[0][words_to_cut:]
    return sequences


In [6]:
def write_selection_preds(dataset_walker, output_file, data_infos, sorted_pred_ids, topk=None, all_preds=None):
    """ Write results of knowledge selection to output_file """
    assert topk is None or topk > 0, f"topK must be set as None or a positive integer, but it is set as {topk}"
    # Flatten the data_infos
    data_infos = [
        {
            "dialog_id": info["dialog_ids"][i],
            "candidate_keys": info["candidate_keys"][i],
        }
        for info in data_infos
        for i in range(len(info["dialog_ids"]))
    ]

    labels = [label for log, label in dataset_walker]
    new_labels = [{"target": False}] * len(dataset_walker)
    # Update the dialogs with selected knowledge
    for info, sorted_pred_id, all_pred in zip(data_infos, sorted_pred_ids, all_preds):
        dialog_id = info["dialog_id"]
        candidate_keys = info["candidate_keys"]

        snippets = []
        sorted_pred_id_iter = sorted_pred_id[:topk] if topk is not None else sorted_pred_id
        for pred_id in sorted_pred_id_iter:
            selected_cand = candidate_keys[pred_id]
            domain, entity_id, doc_id = selected_cand.split("__")
            snippet = {
                "domain": domain,
                "entity_id": "*" if entity_id == "*" else int(entity_id),
                "doc_type": 'review' if '-' in doc_id else 'faq',
                "doc_id": doc_id,
            }
            if snippet['doc_type'] == 'review':
                doc_id, sent_id = doc_id.split('-')
                snippet['doc_id'], snippet['sent_id'] = int(doc_id), int(sent_id)
            else:
                snippet['doc_id'] = int(snippet['doc_id'])
            snippets.append(snippet)

        assert len(candidate_keys) == len(all_pred)
        new_label = {"target": True, "knowledge": snippets}
        label = labels[dialog_id]
        if label is None:
            label = new_label
        else:
            label = label.copy()
            if "response_tokenized" in label:
                label.pop("response_tokenized")
            label.update(new_label)
        new_labels[dialog_id] = label

    if os.path.dirname(output_file) and not os.path.exists(os.path.dirname(output_file)):
        os.makedirs(os.path.dirname(output_file))

    with open(output_file, "w") as jsonfile:
        logger.info("Writing predictions to {}".format(output_file))
        json.dump(new_labels, jsonfile, indent=2)

# scripts

In [7]:
#scripts/dataset_walker
class DatasetWalker(object):
    def __init__(self, dataset, dataroot, labels=False, labels_file=None, incl_knowledge=False):
        path = os.path.join(os.path.abspath(dataroot))
            
        if dataset not in ['train', 'val']:
            raise ValueError('Wrong dataset name: %s' % (dataset))

        logs_file = os.path.join(path, dataset, 'logs.json')
        with open(logs_file, 'r') as f:
            self.logs = json.load(f)

        self.labels = None

        if labels is True:
            if labels_file is None:
                labels_file = os.path.join(path, dataset, 'labels.json')

            with open(labels_file, 'r') as f:
                self.labels = json.load(f)

        self._incl_knowledge = incl_knowledge
        if self._incl_knowledge is True:
            # knowledge_reader 수정
            #self._knowledge = knowledge_reader(dataroot)
            self._knowledge = KnowledgeReader(dataroot)

    def __iter__(self):
        if self.labels is not None:
            for log, label in zip(self.logs, self.labels):
                if self._incl_knowledge is True and label['target'] is True:
                    for idx, snippet in enumerate(label['knowledge']):
                        domain = snippet['domain']
                        entity_id = snippet['entity_id']
                        doc_type = snippet['doc_type']
                        doc_id = snippet['doc_id']

                        if doc_type == 'review':
                            sent_id = snippet['sent_id']                            
                            sent = self._knowledge.get_review_sent(domain, entity_id, doc_id, sent_id)
                            label['knowledge'][idx]['sent'] = sent
                            
                        elif doc_type == 'faq':
                            doc = self._knowledge.get_faq_doc(domain, entity_id, doc_id)
                            question = doc['question']
                            answer = doc['answer']

                            label['knowledge'][idx]['question'] = question
                            label['knowledge'][idx]['answer'] = answer
                
                yield(log, label)
        else:
            for log in self.logs:
                yield(log, None)

    def __len__(self, ):
        return len(self.logs)

In [8]:
#scripts/knowledge_reader
class KnowledgeReader(object):
    def __init__(self, dataroot, knowledge_file='knowledge.json'):
        path = os.path.join(os.path.abspath(dataroot))

        with open(os.path.join(path, knowledge_file), 'r') as f:
            self.knowledge = json.load(f)

    def get_domain_list(self):
        return list(self.knowledge.keys())

    def get_entity_list(self, domain):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name")

        entity_ids = []
        for entity_id in self.knowledge[domain].keys():
            entity_ids.append(int(entity_id))

        result = []
        for entity_id in sorted(entity_ids):
            entity_name = self.knowledge[domain][str(entity_id)]['name']
            result.append({'id': entity_id, 'name': entity_name})

        return result

    def get_entity_name(self, domain, entity_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        result = self.knowledge[domain][str(entity_id)]['name'] or None

        return result

    def get_faq_doc_ids(self, domain, entity_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)
        
        result = []

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        entity_obj = self.knowledge[domain][str(entity_id)]
        for doc_id, doc_obj in entity_obj['faqs'].items():
            result.append(doc_id)

        return result

    def get_faq_doc(self, domain, entity_id, doc_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        entity_name = self.get_entity_name(domain, entity_id)

        if str(doc_id) not in self.knowledge[domain][str(entity_id)]['faqs']:
            raise ValueError("invalid doc id: %s" % str(doc_id))

        doc_obj = self.knowledge[domain][str(entity_id)]['faqs'][str(doc_id)]
        result = {'domain': domain, 'entity_id': entity_id, 'entity_name': entity_name, 'doc_id': doc_id, 'question': doc_obj['question'], 'answer': doc_obj['answer']}

        return result

    def get_review_doc_ids(self, domain, entity_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        result = []
        
        entity_obj = self.knowledge[domain][str(entity_id)]
        for doc_id, doc_obj in entity_obj['reviews'].items():
            result.append(doc_id)

        return result

    def get_review_doc(self, domain, entity_id, doc_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        entity_name = self.get_entity_name(domain, entity_id)

        if str(doc_id) not in self.knowledge[domain][str(entity_id)]['reviews']:
            raise ValueError("invalid doc id: %s" % str(doc_id))
        
        doc_obj = self.knowledge[domain][str(entity_id)]['reviews'][str(doc_id)]
        
        result = {'domain': domain, 'entity_id': entity_id, 'entity_name': entity_name, 'doc_id': doc_id, 'sentences': doc_obj['sentences']}
        if 'traveler_type' in doc_obj:
            result['traveler_type'] = doc_obj['traveler_type']
        
        if 'dishes' in doc_obj:
            result['dishes'] = doc_obj['dishes']

        if 'drinks' in doc_obj:
            result['drinks'] = doc_obj['drinks']

        return result
    
    def get_review_sent(self, domain, entity_id, doc_id, sent_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))
        
        if str(doc_id) not in self.knowledge[domain][str(entity_id)]['reviews']:
            raise ValueError("invalid doc id: %s" % str(doc_id))

        if str(sent_id) not in self.knowledge[domain][str(entity_id)]['reviews'][str(doc_id)]['sentences']:
            raise ValueError("invalid sentence id: %s" % str(sent_id))

        result = self.knowledge[domain][str(entity_id)]['reviews'][str(doc_id)]['sentences'][str(sent_id)]

        return result

# dataset

In [9]:
task = "selection"
dataroot = '/content/drive/MyDrive/dstc11-track5/data'
negative_sample_method = 'oracle'
knowledge_file = 'knowledge.json'
debug = 0
knowledge_max_tokens = 256
history_max_tokens = 256 
history_max_utterances = 1000000
n_candidates = 2

class BaseDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, split_type, labels=True, labels_file=None):
        self.dataroot = dataroot
        self.tokenizer = tokenizer
        self.split_type = split_type
        self.task = task
        self.negative_sample_method = negative_sample_method

        self.cls = self.tokenizer.cls_token_id
        self.sep = self.tokenizer.sep_token_id
        self.bos = self.tokenizer.bos_token_id
        self.eos = self.tokenizer.eos_token_id
        self.pad = self.tokenizer.pad_token_id
        self.SPECIAL_TOKENS = SPECIAL_TOKENS

        self.speaker1, self.speaker2, self.knowledge_sep, self.knowledge_tag = self.tokenizer.convert_tokens_to_ids(
            self.SPECIAL_TOKENS["additional_special_tokens"]
        )
        self.knowledge_sep_token = self.SPECIAL_TOKENS["additional_special_tokens"][2]
        self.dataset_walker = DatasetWalker(split_type, labels=labels, dataroot=self.dataroot, labels_file=labels_file)
        self.dialogs = self._prepare_conversations()
        self.knowledge_reader = KnowledgeReader(self.dataroot, knowledge_file)
        self.snippets = self._prepare_knowledge()
        self._create_examples()

        self.debug = debug
        self.knowledge_max_tokens = knowledge_max_tokens
        self.history_max_utterances = history_max_utterances
        self.history_max_tokens = history_max_tokens
        self.n_candidates = n_candidates


    def _prepare_conversations(self):
        """ Tokenize and encode the dialog data """
        logger.info("Tokenize and encode the dialog data")
        tokenized_dialogs = []
        for i, (log, label) in enumerate(tqdm(self.dataset_walker, disable=False, desc='tokenizing...')):
            dialog = {}
            dialog["id"] = i
            dialog["log"] = log
            if label is not None:
                if "response" in label:
                    label["response_tokenized"] = self.tokenizer.convert_tokens_to_ids(
                        self.tokenizer.tokenize(label["response"])
                    )
            dialog["label"] = label
            tokenized_dialogs.append(dialog)
        return tokenized_dialogs

    def _prepare_knowledge(self):
        """ Tokenize and encode the knowledge snippets """
        self.knowledge_docs = self._get_snippet_list()

        tokenized_snippets = defaultdict(dict)
        for snippet_id, snippet in enumerate(self.knowledge_docs):
            key = "{}__{}__{}".format(snippet["domain"], str(snippet["entity_id"]) or "", snippet["doc_id"])
            knowledge = self._knowledge_to_string(snippet["doc"], name=snippet["entity_name"] or "")

            tokenized_knowledge = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(knowledge))
            tokenized_snippets[key]['token_ids'] = tokenized_knowledge[:256] # knowledge_max_tokens : 256
        return tokenized_snippets

    def _get_snippet_list(self):
        """ Get all knowledge snippets in the dataset """
        result = []
        for domain in self.knowledge_reader.get_domain_list():
            for entity_id in self.knowledge_reader.knowledge[domain].keys():
                for review_doc_id in self.knowledge_reader.get_review_doc_ids(domain, entity_id):
                    review_doc = self.knowledge_reader.get_review_doc(domain, entity_id, review_doc_id)
                    for review_sent_id, review_sent in review_doc['sentences'].items():
                        result.append(
                            {'domain': domain, 'entity_id': entity_id, 'entity_name': review_doc['entity_name'],
                             'doc_id': f"{review_doc_id}-{review_sent_id}",
                             'doc': {'body': review_sent}})
                for faq_doc_id in self.knowledge_reader.get_faq_doc_ids(domain, entity_id):
                    faq_doc = self.knowledge_reader.get_faq_doc(domain, entity_id, faq_doc_id)
                    result.append({'domain': domain, 'entity_id': entity_id, 'entity_name': faq_doc['entity_name'],
                                   'doc_id': faq_doc_id,
                                   'doc': {'body': f"{faq_doc['question']} {faq_doc['answer']}"}})
        return result

    def _knowledge_to_string(self, doc, name=""):
        """ Convert a knowledge snippet to a string """
        doc_body = f"{name.title()}: {doc['body']}"
        return doc_body

    def _create_examples(self):
        """ Creating examples for model training and evaluation """
        logger.info("Creating examples")
        self.examples = []
        token_len, truncated_len = [], []
        for dialog in tqdm(self.dialogs, disable=False, desc='creating examples'):
            #if self.debug > 0 and len(self.examples) >= self.debug:
            #    break
            dialog_id = dialog["id"]
            label = dialog["label"]

            dialog = dialog["log"]
            if label is None:
                # This will only happen when running knowledge-seeking turn detection on test data
                # So we create dummy target here
                label = {"target": False}

            target = label["target"]

            if not target and self.task != "detection":
                # we only care about non-knowledge-seeking turns in turn detection task
                continue

            # Turn Embedding 수정하기!!!
            history = [
                self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(turn["text"]))
                for turn in dialog
            ]
            token_len.append(len(history))
            
            gt_resp = label.get("response", "")
            tokenized_gt_resp = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(gt_resp))

            # apply history threshold at an utterance-level (a large value can be used to nullify its effect)
            truncated_history = history[-1000000:] #history_max_utterances : 1000000
            #**************************************
            #**************************************
            #**************************************

            # perform token-level truncation of history from the left 
            truncated_history = truncate_sequences(truncated_history, 256) #history_max_tokens : 512로 수정하였음
            truncated_len.append(len(truncated_history))

            if target:
                knowledge_keys = []
                knowledge_candidates = defaultdict(lambda: 0)
                used_knowledge = []
                knowledge_prefix_visited = set()

                if "knowledge" not in label:
                    raise ValueError("Please run entity matching before running knowledge selection")

                label_knowledge = label["knowledge"]

                for knowledge in label_knowledge:
                    if not (self.task == 'selection' and eval_only):
                        if knowledge['doc_type'] == 'review':
                            knowledge_key = f"{knowledge['domain']}__{knowledge['entity_id']}__{knowledge['doc_id']}-{knowledge['sent_id']}"
                        else:
                            knowledge_key = f"{knowledge['domain']}__{knowledge['entity_id']}__{knowledge['doc_id']}"

                    # find snippets with same entity as candidates
                    prefix = "{}__{}".format(knowledge["domain"], knowledge["entity_id"])
                    if prefix not in knowledge_prefix_visited:
                        knowledge_prefix_visited.add(prefix)
                        _knowledge_candidates = [
                            cand
                            for cand in self.snippets.keys()
                            if "__".join(cand.split("__")[:-1]) == prefix
                        ]

                        for _knowledge_cand_idx, _knowledge_cand in enumerate(_knowledge_candidates):
                            knowledge_candidates[_knowledge_cand] = 1
                    if self.split_type == "train" and self.negative_sample_method == "oracle":
                        # if there's not enough candidates during training, we just skip this example
                        if len(knowledge_candidates) < 2 or len(knowledge_candidates) <= len(label["knowledge"]): #n_candidates : 2
                            logger.info("Not enough candidates. Skip this example...")
                            continue

                    if not (self.task == 'selection' and eval_only):
                        used_knowledge.append(
                            self.snippets[knowledge_key]['token_ids'][:256]) # knowledge_max_tokens : 256
                        knowledge_keys.append(knowledge_key)
                knowledge_candidates = [k for k, v in knowledge_candidates.items()]

            else:
                knowledge_candidates = None
                used_knowledge = []
                knowledge_keys = []

            self.examples.append({
                "history": truncated_history,
                "knowledge": used_knowledge,
                "knowledge_keys": knowledge_keys,
                "candidates": knowledge_candidates,
                "response": tokenized_gt_resp,
                "response_text": gt_resp,
                "label": label,
                "knowledge_seeking": target,
                "dialog_id": dialog_id
            })
        print(max(token_len), len(token_len))
        print(token_len)

        print(max(truncated_len), len(truncated_len))
        print(truncated_len)

    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        return len(self.examples)

In [10]:
class KnowledgeSelectionDataset(BaseDataset):
    def __init__(self, tokenizer, split_type, labels=True, labels_file=None):
        super(KnowledgeSelectionDataset, self).__init__(tokenizer, split_type, labels, labels_file)

        if self.negative_sample_method not in ["all", "mix", "oracle"]:
            # Negative sampling method for knowledge selection
            # all: use all knowledge snippets of all entities as candidates
            # oracle: use all knowledge snippets of oracle entities as candidates
            # mix: use oracle candidates & equally sized candidates sampled from other entities
            raise ValueError(
                "negative_sample_method must be all, mix, or oracle, got %s" % self.negative_sample_method)

    def _knowledge_to_string(self, doc, name=""):
        """ convert a knowlege snippet to a string """
        join_str = " %s " % self.knowledge_sep_token
        doc_body = doc['body']
        knowledge_string = join_str.join([name.title(), doc_body])
        return knowledge_string

    def __getitem__(self, index):
        example = self.examples[index]

        this_inst = {
            "dialog_id": example["dialog_id"],
            "input_ids": [],
            "token_type_ids": []
        }

        if self.split_type != "train":
            # if eval_all_snippets is set, we use all snippets as candidates with no sampling
            if self.eval_all_snippets:
                candidates = list(self.snippets.keys())
            else:
                candidates = example["candidates"]
        else:
            if self.build_input_from_segmentsnegative_sample_method == "all":
                candidates = list(self.snippets.keys())
            elif self.negative_sample_method == "mix":
                candidates = example["candidates"] + random.sample(list(self.snippets.keys()),
                                                                   k=len(example["candidates"]))
            elif self.negative_sample_method == "oracle":
                candidates = example["candidates"]
            else:  # although we have already checked for this, still adding this here to be sure
                raise ValueError(
                    "negative_sample_method must be all, mix, or oracle, got %s" % self.negative_sample_method)

        candidate_keys = candidates
        this_inst["candidate_keys"] = candidate_keys
        candidates = [self.snippets[cand_key]['token_ids'] for cand_key in candidates]

        if self.split_type == "train":
            candidates = self._shrink_label_cands(example["knowledge"], candidates)

        label_idx = [candidates.index(knowledge) for knowledge in example["knowledge"]]

        this_inst["label_idx"] = label_idx
        for cand in candidates:
            instance, _ = self.build_input_from_segments(
                cand,
                example["history"]
            )
            this_inst["input_ids"].append(instance["input_ids"])
            this_inst["token_type_ids"].append(instance["token_type_ids"])

        return this_inst

    def build_input_from_segments(self, knowledge, history):
        """ Build a sequence of input from 2 segments: knowledge and history"""
        instance = {}

        sequence = [[self.cls]] + history
        sequence_with_speaker = [
            [self.speaker1 if (len(sequence) - i) % 2 == 0 else self.speaker2] + s
            for i, s in enumerate(sequence[1:])
        ]
        sequence_with_speaker = list(chain(*sequence_with_speaker))

        sequence0 = [self.cls] + sequence_with_speaker + [self.sep]
        sequence1 = knowledge + [self.sep]

        if 'roberta' in str(type(self.tokenizer)):
            sequence0 += [self.sep]
        instance["input_ids"] = sequence0 + sequence1
        instance["token_type_ids"] = [0 for _ in sequence0] + [1 for _ in sequence1]
        return instance, sequence

    def _shrink_label_cands(self, label, candidates):
        """ remove positive knowledge snippets from the candidates """
        shrunk_label_cands = candidates.copy()
        for l in label:
            if l in shrunk_label_cands:
                shrunk_label_cands.remove(l)
        sample_size = min(len(label), len(shrunk_label_cands))
        shrunk_label_cands = random.sample(shrunk_label_cands, k=sample_size)

        shrunk_label_cands.extend(label)
        random.shuffle(shrunk_label_cands)
        return shrunk_label_cands

    def collate_fn(self, batch):
        input_ids = [ids for ins in batch for ids in ins["input_ids"]]
        token_type_ids = [ids for ins in batch for ids in ins["token_type_ids"]]
        label_idx = [1 if i in ins['label_idx'] else 0 for ins in batch for i in range(len(ins['input_ids']))]
        data_info = {
            "dialog_ids": [ins["dialog_id"] for ins in batch],
            "candidate_keys": [ins["candidate_keys"] for ins in batch]
        }

        input_ids = torch.tensor(pad_ids(input_ids, self.pad))
        attention_mask = 1 - (input_ids == self.pad).int()
        token_type_ids = torch.tensor(pad_ids(token_type_ids, 0))
        label_idx = torch.tensor(label_idx)
        return input_ids, token_type_ids, attention_mask, label_idx, data_info


In [11]:
def set_seed(num):
    random.seed(num)
    np.random.seed(num)
    torch.manual_seed(num)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(num)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# Set seed
set_seed(42)

cuda:0


# utils/model

In [12]:
max_candidates_per_forward_train = 4
max_candidates_per_forward_eval = 16

def run_batch_selection_train(model, batch, **kwargs):
    """ Run batch knowledge selection during training time """
    candidates_per_forward = max_candidates_per_forward_train
    batch = tuple(input_tensor.to(device) for input_tensor in batch if isinstance(input_tensor, torch.Tensor))
    input_ids, token_type_ids, attention_mask, labels = batch
    for index in range(0, input_ids.size(0), candidates_per_forward):
        model_outputs = model(
            input_ids=input_ids[index:index + candidates_per_forward],
            token_type_ids=None if model.base_model_prefix in ['roberta'] else
                           token_type_ids[index:index + candidates_per_forward],
            attention_mask=attention_mask[index:index + candidates_per_forward],
            labels=labels[index:index + candidates_per_forward],
        )
        loss, logits = model_outputs[0], model_outputs[1]
        yield loss, logits, None


def run_batch_selection_eval(model, batch, **kwargs):
    """ Run batch knowledge selection during evaluation time """
    # return: loss, logits, labels
    candidates_per_forward = max_candidates_per_forward_eval
    batch = tuple(input_tensor.to(device) for input_tensor in batch if isinstance(input_tensor, torch.Tensor))
    input_ids, token_type_ids, attention_mask, labels = batch
    original_labels = copy.deepcopy(labels)

    all_logits = []
    eval_loss = 0
    for index in range(0, input_ids.size(0), candidates_per_forward):
        model_outputs = model(
            input_ids=input_ids[index:index + candidates_per_forward],
            token_type_ids=None if model.base_model_prefix in ['roberta'] else
                           token_type_ids[index:index + candidates_per_forward],
            attention_mask=attention_mask[index:index + candidates_per_forward],
            labels=labels[index:index + candidates_per_forward]
        )
        eval_loss += model_outputs.loss * len(input_ids[index:index + candidates_per_forward])
        logits = model_outputs.logits
        all_logits.append(logits.detach())
    all_logits = torch.cat(all_logits, dim=0)
    return eval_loss, all_logits, original_labels


# PolyEncoder Model


In [14]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertPreTrainedModel, BertModel

class PolyEncoder(BertPreTrainedModel):
    # *input : tuple 형태, **kwargs : dictation 형태
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.bert = kwargs['bert']
        self.poly_m = kwargs['poly_m']
        self.poly_code_embeddings = nn.Embedding(self.poly_m, config.hidden_size)
        # https://github.com/facebookresearch/ParlAI/blob/master/parlai/agents/transformer/polyencoder.py#L355
        torch.nn.init.normal_(self.poly_code_embeddings.weight, config.hidden_size ** -0.5)

    def dot_attention(self, q, k, v):
        # q: [bs, poly_m, dim] or [bs, res_cnt, dim]
        # k=v: [bs, length, dim] or [bs, poly_m, dim]
        attn_weights = torch.matmul(q, k.transpose(2, 1)) # [bs, poly_m, length]
        attn_weights = F.softmax(attn_weights, -1)
        output = torch.matmul(attn_weights, v) # [bs, poly_m, dim]
        return output

    def forward(self, context_input_ids, context_input_masks,
                            responses_input_ids, responses_input_masks, labels=None):
        # during training, only select the first response
        # we are using other instances in a batch as negative examples
        if labels is not None:
            responses_input_ids = responses_input_ids[:, 0, :].unsqueeze(1)
            responses_input_masks = responses_input_masks[:, 0, :].unsqueeze(1)
        batch_size, res_cnt, seq_length = responses_input_ids.shape # res_cnt is 1 during training

        # context encoder
        ctx_out = self.bert(context_input_ids, context_input_masks)[0]  # [bs, length, dim]
        poly_code_ids = torch.arange(self.poly_m, dtype=torch.long).to(context_input_ids.device)
        poly_code_ids = poly_code_ids.unsqueeze(0).expand(batch_size, self.poly_m)
        poly_codes = self.poly_code_embeddings(poly_code_ids) # [bs, poly_m, dim]
        embs = self.dot_attention(poly_codes, ctx_out, ctx_out) # [bs, poly_m, dim]

        # response encoder
        responses_input_ids = responses_input_ids.view(-1, seq_length)
        responses_input_masks = responses_input_masks.view(-1, seq_length)
        cand_emb = self.bert(responses_input_ids, responses_input_masks)[0][:,0,:] # [bs, dim]
        cand_emb = cand_emb.view(batch_size, res_cnt, -1) # [bs, res_cnt, dim]

        # merge
        if labels is not None:
            # we are recycling responses for faster training
            # we repeat responses for batch_size times to simulate test phase
            # so that every context is paired with batch_size responses
            cand_emb = cand_emb.permute(1, 0, 2) # [1, bs, dim]
            cand_emb = cand_emb.expand(batch_size, batch_size, cand_emb.shape[2]) # [bs, bs, dim]
            ctx_emb = self.dot_attention(cand_emb, embs, embs).squeeze() # [bs, bs, dim]
            dot_product = (ctx_emb*cand_emb).sum(-1) # [bs, bs]
            mask = torch.eye(batch_size).to(context_input_ids.device) # [bs, bs]
            loss = F.log_softmax(dot_product, dim=-1) * mask
            loss = (-loss.sum(dim=1)).mean()
            return loss
        else:
            ctx_emb = self.dot_attention(cand_emb, embs, embs) # [bs, res_cnt, dim]
            dot_product = (ctx_emb*cand_emb).sum(-1)
            return dot_product

In [None]:
from transformers import BertModel, BertConfig, BertTokenizer, BertTokenizerFast
from transformers.optimization import AdamW, get_linear_schedule_with_warmup

dataset_class, model_class, run_batch_fn_train, run_batch_fn_eval = KnowledgeSelectionDataset, AutoModelForSequenceClassification, run_batch_selection_train, run_batch_selection_eval

eval_only = False
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens(SPECIAL_TOKENS)
tokenizer.model_max_length = min(1024, tokenizer.model_max_length)
print(tokenizer.model_max_length)

# Main

In [15]:
import argparse
import logging
import os
import random
import json

from typing import Dict, Tuple
from argparse import Namespace

import numpy as np
import torch
from sklearn.metrics import recall_score, precision_score, average_precision_score, classification_report, f1_score

from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm, trange
from transformers import (
    AutoConfig,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    get_linear_schedule_with_warmup,
    BartForConditionalGeneration,
    AutoModelForSequenceClassification,
)

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter

------------------------------------

In [None]:
# deverta-v3-base - max : 1024
dataset_class, model_class, run_batch_fn_train, run_batch_fn_eval = KnowledgeSelectionDataset, AutoModelForSequenceClassification, run_batch_selection_train, run_batch_selection_eval

eval_only = False
model_name = "microsoft/deberta-v3-base"

config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens(SPECIAL_TOKENS)
tokenizer.model_max_length = min(1024, tokenizer.model_max_length)
print(tokenizer.model_max_length)

model = model_class.from_pretrained(model_name, config=config)
model.resize_token_embeddings(len(tokenizer))
model.to(device)

output_file = 'selection_result'
output_dir = '/content/drive/MyDrive/dstc11-track5/output'

train_dataset = KnowledgeSelectionDataset(tokenizer, split_type="train")
eval_dataset = KnowledgeSelectionDataset(tokenizer, split_type="val")  # main difference is during evaluation, val need to go through all snippets

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


1024


Some weights of the model checkpoint at microsoft/deberta-v3-base were not used when initializing DebertaV2ForSequenceClassification: ['mask_predictions.classifier.bias', 'lm_predictions.lm_head.LayerNorm.weight', 'mask_predictions.LayerNorm.weight', 'lm_predictions.lm_head.bias', 'mask_predictions.classifier.weight', 'mask_predictions.dense.bias', 'lm_predictions.lm_head.dense.weight', 'lm_predictions.lm_head.dense.bias', 'mask_predictions.dense.weight', 'mask_predictions.LayerNorm.bias', 'lm_predictions.lm_head.LayerNorm.bias']
- This IS expected if you are initializing DebertaV2ForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaV2ForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a

45 14768
[5, 9, 7, 13, 7, 3, 5, 11, 5, 5, 21, 3, 9, 5, 7, 15, 7, 15, 11, 5, 11, 3, 3, 13, 13, 5, 7, 7, 11, 15, 5, 17, 3, 11, 7, 5, 3, 5, 3, 7, 7, 11, 5, 9, 11, 3, 11, 9, 13, 7, 19, 11, 7, 7, 5, 13, 3, 5, 13, 11, 9, 15, 5, 11, 13, 11, 7, 11, 3, 15, 3, 7, 11, 7, 17, 5, 5, 5, 9, 7, 7, 9, 3, 9, 9, 7, 5, 9, 7, 9, 9, 5, 7, 5, 3, 3, 5, 5, 21, 9, 13, 19, 7, 11, 5, 9, 9, 3, 9, 7, 3, 15, 7, 5, 7, 9, 3, 15, 13, 17, 7, 11, 7, 15, 19, 15, 9, 17, 5, 15, 11, 7, 5, 11, 9, 9, 5, 7, 19, 13, 5, 15, 3, 5, 7, 15, 15, 9, 11, 9, 3, 3, 7, 7, 9, 5, 3, 11, 7, 13, 3, 7, 5, 9, 17, 7, 3, 9, 5, 7, 7, 9, 13, 13, 5, 9, 7, 5, 7, 7, 13, 7, 3, 7, 11, 5, 11, 7, 5, 7, 9, 3, 13, 17, 5, 3, 9, 9, 11, 17, 9, 3, 7, 3, 17, 7, 13, 7, 5, 5, 9, 5, 3, 5, 5, 9, 3, 3, 19, 13, 7, 7, 5, 7, 5, 9, 3, 7, 5, 11, 3, 7, 3, 9, 7, 13, 19, 5, 19, 7, 3, 11, 11, 7, 3, 7, 7, 11, 3, 9, 11, 11, 7, 5, 9, 7, 5, 3, 7, 7, 3, 13, 3, 11, 11, 9, 3, 5, 11, 7, 13, 9, 9, 3, 5, 15, 15, 11, 9, 7, 7, 7, 9, 7, 9, 7, 19, 3, 13, 11, 3, 5, 5, 13, 5, 11, 3, 15, 3, 17

tokenizing...: 100%|██████████| 4173/4173 [00:00<00:00, 11396.38it/s]
creating examples: 100%|██████████| 4173/4173 [00:16<00:00, 248.42it/s]

27 2129
[5, 11, 5, 3, 7, 3, 7, 5, 11, 7, 7, 5, 11, 13, 11, 11, 3, 9, 7, 13, 11, 9, 9, 13, 15, 15, 15, 9, 3, 7, 9, 3, 7, 7, 17, 7, 3, 5, 7, 11, 19, 5, 9, 13, 3, 3, 15, 17, 11, 5, 7, 9, 13, 19, 7, 5, 11, 7, 13, 13, 9, 7, 5, 5, 3, 15, 11, 9, 19, 3, 11, 9, 9, 9, 11, 3, 13, 7, 3, 5, 11, 11, 15, 3, 13, 11, 11, 3, 3, 3, 5, 11, 11, 11, 3, 5, 7, 7, 9, 9, 9, 7, 11, 11, 3, 11, 13, 13, 3, 13, 7, 7, 7, 13, 3, 5, 11, 9, 5, 15, 21, 15, 17, 9, 13, 3, 9, 3, 9, 3, 9, 7, 9, 3, 3, 9, 9, 9, 7, 9, 11, 17, 7, 7, 5, 3, 5, 5, 7, 3, 11, 9, 9, 5, 5, 3, 13, 5, 11, 15, 11, 5, 7, 3, 3, 5, 13, 7, 7, 5, 13, 11, 3, 7, 7, 7, 7, 15, 3, 5, 15, 5, 7, 3, 13, 11, 5, 11, 9, 9, 13, 11, 15, 9, 13, 3, 9, 9, 3, 5, 5, 7, 15, 17, 17, 9, 9, 9, 13, 7, 11, 11, 7, 7, 5, 13, 15, 11, 9, 3, 5, 11, 7, 11, 3, 9, 5, 9, 3, 17, 23, 3, 5, 5, 5, 13, 3, 5, 5, 9, 5, 15, 17, 5, 11, 5, 15, 5, 11, 9, 15, 7, 15, 11, 5, 5, 13, 7, 9, 21, 5, 5, 11, 11, 11, 7, 9, 15, 5, 3, 7, 3, 5, 7, 11, 7, 7, 11, 9, 7, 7, 11, 5, 9, 5, 9, 3, 5, 11, 5, 5, 7, 5, 3, 11, 7,




In [None]:
def get_cls_report(y_true, y_pred):
    """ Get the report of precision, recall, and f1-score for a classification output """
    return {"precision": precision_score(y_true, y_pred, average=None, zero_division=0)[1],
            "recall": recall_score(y_true, y_pred, average=None, zero_division=0)[1],
            "f1-score": f1_score(y_true, y_pred, average=None, zero_division=0)[1]}

In [None]:
def get_eval_performance(eval_output_dir, eval_loss, all_preds, all_labels, desc):
    """ Get evaluation performance when the gold labels are available """
    if task == "selection":
        def get_eval_performance_selection(all_preds, all_labels, threshold):
            report_list, average_precison_list = [], []
            all_preds_micro, all_labels_micro = [], []
            for preds, labels in zip(all_preds, all_labels):
                preds = preds.reshape(-1)
                preds_labels = np.zeros_like(preds)
                preds_labels[preds > threshold] = 1
                average_precison_list.append(average_precision_score(labels, preds))
                all_preds_micro.extend(preds_labels)
                all_labels_micro.extend(labels)
                report_list.append(get_cls_report(labels, preds_labels))

            cls_report = get_cls_report(all_labels_micro, all_preds_micro)
            result = {"loss": eval_loss, "mean_ave_prec": np.mean(average_precison_list),
                      "micro_prec": cls_report['precision'],
                      "micro_recall": cls_report['recall'],
                      "micro_f1": cls_report['f1-score'],
                      "macro_prec": np.mean([report['precision'] for report in report_list]),
                      "macro_recall": np.mean([report['recall'] for report in report_list]),
                      "macro_f1": np.mean([report['f1-score'] for report in report_list]),
                      "val_measure": -1 * np.mean([report['f1-score'] for report in report_list]),
                      }
            return result

        best_result, best_threshold = {"val_measure": 0}, None
        for threshold in list(np.linspace(-5, 5, num=41)):
            result = get_eval_performance_selection(all_preds, all_labels, threshold)
            if result['val_measure'] < best_result['val_measure']:
                best_result, best_threshold = result, threshold
        best_result.update({"threshold": best_threshold})
        eval_threshold = best_threshold
        result = best_result

    else:
        raise ValueError("args.task not in ['generation', 'selection', 'detection'], got %s" % task)

    logger.info(str(result))

    output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
    with open(output_eval_file, "a") as writer:
        logger.info("***** Eval results %s *****" % desc)
        writer.write("***** Eval results %s *****\n" % desc)
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

    return result


In [None]:
def evaluate(eval_dataset, model: PreTrainedModel, run_batch_fn, desc="") -> Dict:
    """ Model evaluation for knowledge seeking turn detection and knowledge selection
        Report evaluation results if gold labels are available
    """
    eval_output_dir = output_dir
    os.makedirs(eval_output_dir, exist_ok=True)

    # eval_batch_size for selection must be 1 to handle different number of candidates
    eval_batch_size = 1

    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset,
        sampler=eval_sampler,
        batch_size=eval_batch_size,
        collate_fn=eval_dataset.collate_fn
    )

    eval_loss = 0.0
    nb_eval_steps = 0
    model.eval()
    data_infos = []
    all_preds = []
    all_labels = []
    for batch in tqdm(eval_dataloader, desc="Evaluating", disable=False):
        with torch.no_grad():
            loss, logits, labels = run_batch_fn(model, batch)
            if task in ["selection", "detection"]:
                data_infos.append(batch[-1])
                all_preds.append((logits[:, 1] - logits[:, 0]).detach().cpu().numpy())
                all_labels.append(labels.detach().cpu().numpy())
            eval_loss += loss.mean().item()
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps

    if task == "selection":
        if output_file:
            eval_threshold = 0
            sorted_pred_ids = [np.argsort(logits.squeeze())[::-1][:(logits > eval_threshold).sum()] for logits in all_preds]
            write_selection_preds(eval_dataset.dataset_walker, output_file, data_infos, sorted_pred_ids,
                                  all_preds=all_preds)
    else:
        raise ValueError("args.task not in ['generation', 'selection', 'detection'], got %s" % task)

    if not eval_only:
        return get_eval_performance(eval_output_dir, eval_loss, all_preds, all_labels, desc)


In [None]:
def train(train_dataset, eval_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer,
          run_batch_fn_train, run_batch_fn_eval) -> Tuple[int, float]:
    """ Model training and evaluation """
    exp_name = ''
    log_dir = os.path.join("runs", exp_name) if exp_name else None
    tb_writer = SummaryWriter(log_dir)
    output_dir = log_dir

    train_batch_size = 4

    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(
        train_dataset,
        # shuffle=True,
        sampler=train_sampler,
        batch_size=train_batch_size,
        collate_fn=train_dataset.collate_fn
    )

    gradient_accumulation_steps = 16
    num_train_epochs = 3
    learning_rate = 3e-5
    adam_epsilon = 1e-8
    warmup_steps = 500
    max_grad_norm = 1.0

    t_total = len(train_dataloader) // gradient_accumulation_steps * num_train_epochs
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, eps=adam_epsilon)
    if 0 < warmup_steps < 1:
        warmup_steps = int(warmup_steps * t_total)

    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
    )

    # Train!
    global_step = 0
    model.zero_grad()
    train_iterator = trange(
        0, int(num_train_epochs), desc="Epoch", disable=False
    )
    set_seed(42)  # for reproducibility
    val_loss = float('inf')

    for _ in train_iterator:
        local_steps = 0  # update step
        tr_loss = 0.0
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=False)
        step = 0  # backward step
        total_log_loss = 0
        for _, batch in enumerate(epoch_iterator):
            model.train()
            for loss, _, _ in run_batch_fn_train(model, batch, global_step=global_step):
                step += 1

                total_log_loss += loss.item()

                if gradient_accumulation_steps > 1:
                    loss = loss / gradient_accumulation_steps

                loss.backward()
                tr_loss += loss.item()

                if (step + 1) % gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    global_step += 1
                    local_steps += 1
                    epoch_iterator.set_postfix(Loss=tr_loss / local_steps)
                    total_log_loss = 0

        results = evaluate(eval_dataset, model, run_batch_fn_eval, desc=str(global_step))


        for key, value in results.items():
            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
        tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
        tb_writer.add_scalar("loss", tr_loss / local_steps, global_step)

        if results['val_measure'] < val_loss:
            logger.info(f"Find a smaller val loss measure {results['val_measure']}")
            val_loss = results['val_measure']
            # Save model checkpoint
            #save_model(output_dir, model, tokenizer)
        else:
            logger.info(f"The val loss measure {results['val_measure']} is larger than "
                        f"the smallest val loss {val_loss}, continue to train ... ")

    tb_writer.flush()
    tb_writer.close()

    return global_step, tr_loss / local_steps

In [None]:
# train 돌리는 코드 

# entity matching

In [None]:
import argparse
import os, json
import re
from difflib import SequenceMatcher as SM
from nltk.util import ngrams
from nltk.tokenize import word_tokenize
import string
from multiprocessing import Pool, cpu_count

from tqdm import tqdm

In [None]:
def fuzzy_extract(qs, ls, threshold):
    """ Match an entity name (qs) with an utterance (ls) using fuzzy matching """
    qs_length = len(qs.split())
    max_sim_val = 0
    max_sim_string = u""

    for ngram in ngrams(ls.split(), qs_length + int(.2 * qs_length)):
        ls_ngram = u" ".join(ngram)
        similarity = SM(None, ls_ngram, qs).ratio()
        if similarity > max_sim_val:
            max_sim_val = similarity
            max_sim_string = ls_ngram

    if max_sim_val > threshold:
        return max_sim_string, max_sim_val
    else:
        return None, max_sim_val


def check_substring_exist(qs, ls):
    """ Check if an entity mention (qs) is in an utterance (ls) """
    if len(qs.split()) > 1:
        return qs in ls
    else:
        return qs in word_tokenize(ls)


def entity_matching(entity, log):
    """ Match a single entity with a dialogue history """
    result = None
    max_fuzzy_score = 0

    for turn_id, obj in enumerate(log):
        flag = False

        utter = obj['text'].lower()
        if check_substring_exist(entity, utter):
            flag = True

        entity_names = all_entity_names_norm.get(entity, []) + [entity]
        for entity_name in entity_names:
            if check_substring_exist(entity_name, utter):
                flag = True

        # if substring exist, fuzzy_match_score = 1, otherwise fuzzy_match_score < 1
        for entity_name in entity_names:
            fuzzy_match_res, fuzzy_match_score = fuzzy_extract(entity_name, utter, 0.95)
            max_fuzzy_score = max(max_fuzzy_score, fuzzy_match_score)
            if fuzzy_match_res is not None:
                flag = True

        if flag is True:
            result = turn_id

    return result, max_fuzzy_score


def run_entity_matching(args):
    """ Run entity matching for a single instance """
    idx_, (log, label) = args
    if label['target'] is False:
        return None, None

    matching_res_ls = set()
    entity_scores = []
    for entity_tup in all_entity_names:
        entity_domain, entity_id, entity_name = entity_tup
        match_res, match_score = entity_matching(entity_name.lower(), log)
        entity_scores.append((entity_tup, match_score))
        if match_res is not None:
            matching_res_ls.add((entity_domain, entity_id, entity_name, match_res))
    matching_res_ls = sorted(list(matching_res_ls), key=lambda x: x[-1])

    result = []
    if len(matching_res_ls) > 0:
        latest_turn_w_entity = matching_res_ls[-1][-1]
        for entity_domain, entity_id, entity_name, turn_id in matching_res_ls:
            if turn_id == latest_turn_w_entity:
                result.append({'domain': entity_domain, 'entity_id': int(entity_id), 'entity_name': entity_name})
    else:
        entity_scores = sorted(entity_scores, key=lambda x: -x[1])
        entity_tup, match_score = entity_scores[0]
        entity_domain, entity_id, entity_name = entity_tup
        result.append({'domain': entity_domain, 'entity_id': int(entity_id), 'entity_name': entity_name})

    pred_entity_set = set([str(r['entity_id']) for r in result])
    return result, pred_entity_set

In [None]:
# read data
knowledge_file = os.path.join(dataroot, 'knowledge.json')
logs_file = os.path.join(dataroot + '/train', 'logs.json') # data_eval : test set에 대함
labels_file = "/content/drive/MyDrive/dstc11-track5/data/train/labels.json"

with open(logs_file, 'r') as f:
    logs = json.load(f)
with open(knowledge_file, 'r') as f:
    knowledges = json.load(f)
with open(labels_file, 'r') as f:
    labels = json.load(f)

# load entities and normalized entity mentions
all_entity_names = []
for domain, domain_dict in knowledges.items():
    if domain in ['train', 'taxi']:
        continue
    for doc_id, docs in domain_dict.items():
        all_entity_names.append((domain, doc_id, docs['name']))

norm_dict = "/content/drive/MyDrive/dstc11-track5/baseline/resources/entity_mapping.json"
with open(norm_dict, 'r') as fr:
    all_entity_names_norm = json.load(fr)

# match entities in parallel
results = []
pred_entity_sets = []
with Pool(processes=cpu_count()) as p:
    with tqdm(total=len(logs), desc='entity matching') as pbar:
        for result, pred_entity_set in p.imap(run_entity_matching, enumerate(zip(logs, labels))):
            if result is not None:
                results.append(result)
                pred_entity_sets.append(pred_entity_set)
            else:
                results.append(None)
            pbar.update()

# write the matched entities in output file
for label, result in zip(labels, results):
    if label['target'] is False:
        assert result is None
    else:
        label['knowledge'] = result

with open(output_dir + '/selection_train_em.json', 'w') as fw:
    json.dump(labels, fw, indent=4)

In [None]:
# read data
knowledge_file = os.path.join(dataroot, 'knowledge.json')
logs_file = os.path.join(dataroot + '/val', 'logs.json') # data_eval : test set에 대함
labels_file = "/content/drive/MyDrive/dstc11-track5/data/val/labels.json"

with open(logs_file, 'r') as f:
    logs = json.load(f)
with open(knowledge_file, 'r') as f:
    knowledges = json.load(f)
with open(labels_file, 'r') as f:
    labels = json.load(f)

# load entities and normalized entity mentions
all_entity_names = []
for domain, domain_dict in knowledges.items():
    if domain in ['train', 'taxi']:
        continue
    for doc_id, docs in domain_dict.items():
        all_entity_names.append((domain, doc_id, docs['name']))

norm_dict = "/content/drive/MyDrive/dstc11-track5/baseline/resources/entity_mapping.json"
with open(norm_dict, 'r') as fr:
    all_entity_names_norm = json.load(fr)

# match entities in parallel
results = []
pred_entity_sets = []
with Pool(processes=cpu_count()) as p:
    with tqdm(total=len(logs), desc='entity matching') as pbar:
        for result, pred_entity_set in p.imap(run_entity_matching, enumerate(zip(logs, labels))):
            if result is not None:
                results.append(result)
                pred_entity_sets.append(pred_entity_set)
            else:
                results.append(None)
            pbar.update()

# write the matched entities in output file
for label, result in zip(labels, results):
    if label['target'] is False:
        assert result is None
    else:
        label['knowledge'] = result

with open(output_dir + '/selection_val_em.json', 'w') as fw:
    json.dump(labels, fw, indent=4)

entity matching: 100%|██████████| 4173/4173 [41:52<00:00,  1.66it/s]
