In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!cd "/content/drive/My Drive/dstc11-track5/"

# Baseline/dataset.py

In [3]:
import json
import os
import logging
from collections import defaultdict
from itertools import chain

import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

import torch
from tqdm import tqdm

logger = logging.getLogger(__name__)

SPECIAL_TOKENS = {
    "additional_special_tokens": ["<speaker1>", "<speaker2>", "<knowledge_sep>", "<knowledge_tag>"],
}

In [4]:
#/utils/data
def pad_ids(arrays, padding, max_length=-1):
    if max_length < 0:
        max_length = max(list(map(len, arrays)))

    arrays = [
        array + [padding] * (max_length - len(array))
        for array in arrays
    ]

    return arrays

In [5]:
#/utils/data
# sequences : 각 turn, max_length = 510으로 설정
def truncate_sequences(sequences, max_length): # generation.params : 60
    words_to_cut = sum(list(map(len, sequences))) - max_length
    if words_to_cut <= 0:
        return sequences

    while words_to_cut > len(sequences[0]):
        words_to_cut -= len(sequences[0])
        sequences = sequences[1:]

    sequences[0] = sequences[0][words_to_cut:]
    return sequences

In [6]:
#scripts/dataset_walker
class DatasetWalker(object):
    def __init__(self, dataset, dataroot, labels=False, labels_file=None, incl_knowledge=False):
        path = os.path.join(os.path.abspath(dataroot))
            
        if dataset not in ['train', 'val']:
            raise ValueError('Wrong dataset name: %s' % (dataset))

        logs_file = os.path.join(path, dataset, 'logs.json')
        with open(logs_file, 'r') as f:
            self.logs = json.load(f)

        self.labels = None

        if labels is True:
            if labels_file is None:
                labels_file = os.path.join(path, dataset, 'labels.json')

            with open(labels_file, 'r') as f:
                self.labels = json.load(f)

        self._incl_knowledge = incl_knowledge
        if self._incl_knowledge is True:
            # knowledge_reader 수정
            #self._knowledge = knowledge_reader(dataroot)
            self._knowledge = KnowledgeReader(dataroot)

    def __iter__(self):
        if self.labels is not None:
            for log, label in zip(self.logs, self.labels):
                if self._incl_knowledge is True and label['target'] is True:
                    for idx, snippet in enumerate(label['knowledge']):
                        domain = snippet['domain']
                        entity_id = snippet['entity_id']
                        doc_type = snippet['doc_type']
                        doc_id = snippet['doc_id']

                        if doc_type == 'review':
                            sent_id = snippet['sent_id']                            
                            sent = self._knowledge.get_review_sent(domain, entity_id, doc_id, sent_id)
                            label['knowledge'][idx]['sent'] = sent
                            
                        elif doc_type == 'faq':
                            doc = self._knowledge.get_faq_doc(domain, entity_id, doc_id)
                            question = doc['question']
                            answer = doc['answer']

                            label['knowledge'][idx]['question'] = question
                            label['knowledge'][idx]['answer'] = answer
                
                yield(log, label)
        else:
            for log in self.logs:
                yield(log, None)

    def __len__(self, ):
        return len(self.logs)

In [7]:
#scripts/knowledge_reader
class KnowledgeReader(object):
    def __init__(self, dataroot, knowledge_file='knowledge.json'):
        path = os.path.join(os.path.abspath(dataroot))

        with open(os.path.join(path, knowledge_file), 'r') as f:
            self.knowledge = json.load(f)

    def get_domain_list(self):
        return list(self.knowledge.keys())

    def get_entity_list(self, domain):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name")

        entity_ids = []
        for entity_id in self.knowledge[domain].keys():
            entity_ids.append(int(entity_id))

        result = []
        for entity_id in sorted(entity_ids):
            entity_name = self.knowledge[domain][str(entity_id)]['name']
            result.append({'id': entity_id, 'name': entity_name})

        return result

    def get_entity_name(self, domain, entity_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        result = self.knowledge[domain][str(entity_id)]['name'] or None

        return result

    def get_faq_doc_ids(self, domain, entity_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)
        
        result = []

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        entity_obj = self.knowledge[domain][str(entity_id)]
        for doc_id, doc_obj in entity_obj['faqs'].items():
            result.append(doc_id)

        return result

    def get_faq_doc(self, domain, entity_id, doc_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        entity_name = self.get_entity_name(domain, entity_id)

        if str(doc_id) not in self.knowledge[domain][str(entity_id)]['faqs']:
            raise ValueError("invalid doc id: %s" % str(doc_id))

        doc_obj = self.knowledge[domain][str(entity_id)]['faqs'][str(doc_id)]
        result = {'domain': domain, 'entity_id': entity_id, 'entity_name': entity_name, 'doc_id': doc_id, 'question': doc_obj['question'], 'answer': doc_obj['answer']}

        return result

    def get_review_doc_ids(self, domain, entity_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        result = []
        
        entity_obj = self.knowledge[domain][str(entity_id)]
        for doc_id, doc_obj in entity_obj['reviews'].items():
            result.append(doc_id)

        return result

    def get_review_doc(self, domain, entity_id, doc_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        entity_name = self.get_entity_name(domain, entity_id)

        if str(doc_id) not in self.knowledge[domain][str(entity_id)]['reviews']:
            raise ValueError("invalid doc id: %s" % str(doc_id))
        
        doc_obj = self.knowledge[domain][str(entity_id)]['reviews'][str(doc_id)]
        
        result = {'domain': domain, 'entity_id': entity_id, 'entity_name': entity_name, 'doc_id': doc_id, 'sentences': doc_obj['sentences']}
        if 'traveler_type' in doc_obj:
            result['traveler_type'] = doc_obj['traveler_type']
        
        if 'dishes' in doc_obj:
            result['dishes'] = doc_obj['dishes']

        if 'drinks' in doc_obj:
            result['drinks'] = doc_obj['drinks']

        return result
    
    def get_review_sent(self, domain, entity_id, doc_id, sent_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))
        
        if str(doc_id) not in self.knowledge[domain][str(entity_id)]['reviews']:
            raise ValueError("invalid doc id: %s" % str(doc_id))

        if str(sent_id) not in self.knowledge[domain][str(entity_id)]['reviews'][str(doc_id)]['sentences']:
            raise ValueError("invalid sentence id: %s" % str(sent_id))

        result = self.knowledge[domain][str(entity_id)]['reviews'][str(doc_id)]['sentences'][str(sent_id)]

        return result

# DSTC9 불러오기

In [8]:
import csv

data_path = "/content/drive/My Drive/dstc11-track5/trans/"

with open(data_path + 'knowledge.csv', 'r', newline = '') as f :
  reader = csv.reader(f)
  for row in reader :
    dstc9_knowledge = row
print(dstc9_knowledge)

with open(data_path + 'history.csv', 'r', newline = '') as f :
  reader = csv.reader(f)
  for row in reader :
    dstc9_history = row
print(dstc9_history)

with open(data_path + 'response.csv', 'r', newline = '') as f :
  reader = csv.reader(f)
  for row in reader :
    dstc9_response = row
print(dstc9_response)

['The Orchard Garden Hotel does not serve breakfast.', 'The check-in time at Grant Hotel begins at 3 p.m.', "That's right. Fitzvilly has a gluten-free option.", 'Grant Hotel check-in time is from 3 p.m.', 'Yes, Fittsville has a gluten-free option.', 'For breakfast, guests can choose from a carte or continental option.', 'Off-road parking is available at Archway House.', 'Yes, the reservation is complete.', 'El Shaddai does not provide daily housekeeping services.', 'You can choose between aracart or continental for breakfast.', 'Breakfast is a light intercontinental meal served between 7:30 and 10:30', 'Smoking is not allowed at the Andrews Hotel.', 'El Shaddai does not provide daily housekeeping services.', 'City Stop Restaurant does not offer live music.', 'Extra beds are not available at the Layne Hotel.', 'Yes, you can use the high chair.', 'Bicycle parking provided', 'There is no live music in Tortellino', 'Yes, there is a lift in San Francisco, St. Regis.', 'Houses in India are a

In [9]:
task = "generation"
dataroot = 'data'
negative_sample_method = 'oracle'
knowledge_file = 'knowledge.json'
debug = 0
knowledge_max_tokens = 512
history_max_tokens = 512
history_max_utterances = 1000000
n_candidates = 2

class BaseDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, split_type, labels=True, labels_file=None):
        self.dataroot = dataroot
        self.tokenizer = tokenizer
        self.split_type = split_type
        self.task = task
        self.negative_sample_method = negative_sample_method

        self.cls = self.tokenizer.cls_token_id
        self.sep = self.tokenizer.sep_token_id
        self.bos = self.tokenizer.bos_token_id
        self.eos = self.tokenizer.eos_token_id
        self.pad = self.tokenizer.pad_token_id
        self.SPECIAL_TOKENS = SPECIAL_TOKENS

        self.speaker1, self.speaker2, self.knowledge_sep, self.knowledge_tag = self.tokenizer.convert_tokens_to_ids(
            self.SPECIAL_TOKENS["additional_special_tokens"]
        )
        self.knowledge_sep_token = self.SPECIAL_TOKENS["additional_special_tokens"][2]
        self.dataset_walker = DatasetWalker(split_type, labels=labels, dataroot=self.dataroot, labels_file=labels_file)
        self.dialogs = self._prepare_conversations()
        self.knowledge_reader = KnowledgeReader(self.dataroot, knowledge_file)
        self.snippets = self._prepare_knowledge()
        self._create_examples()

        self.debug = debug
        self.knowledge_max_tokens = knowledge_max_tokens
        self.history_max_utterances = history_max_utterances
        self.history_max_tokens = history_max_tokens
        self.n_candidates = n_candidates


    def _prepare_conversations(self):
        """ Tokenize and encode the dialog data """
        logger.info("Tokenize and encode the dialog data")
        tokenized_dialogs = []
        for i, (log, label) in enumerate(tqdm(self.dataset_walker, disable=False, desc='tokenizing...')):
            dialog = {}
            dialog["id"] = i
            dialog["log"] = log
            if label is not None:
                if "response" in label:
                    label["response_tokenized"] = self.tokenizer.convert_tokens_to_ids(
                        self.tokenizer.tokenize(label["response"])
                    )
            dialog["label"] = label
            tokenized_dialogs.append(dialog)
        return tokenized_dialogs

    def _prepare_knowledge(self):
        """ Tokenize and encode the knowledge snippets """
        self.knowledge_docs = self._get_snippet_list()

        tokenized_snippets = defaultdict(dict)
        for snippet_id, snippet in enumerate(self.knowledge_docs):
            key = "{}__{}__{}".format(snippet["domain"], str(snippet["entity_id"]) or "", snippet["doc_id"])
            knowledge = self._knowledge_to_string(snippet["doc"], name=snippet["entity_name"] or "")

            tokenized_knowledge = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(knowledge))
            tokenized_snippets[key]['token_ids'] = tokenized_knowledge[:512] # knowledge_max_tokens : 256
        return tokenized_snippets

    def _get_snippet_list(self):
        """ Get all knowledge snippets in the dataset """
        result = []
        for domain in self.knowledge_reader.get_domain_list():
            for entity_id in self.knowledge_reader.knowledge[domain].keys():
                for review_doc_id in self.knowledge_reader.get_review_doc_ids(domain, entity_id):
                    review_doc = self.knowledge_reader.get_review_doc(domain, entity_id, review_doc_id)
                    for review_sent_id, review_sent in review_doc['sentences'].items():
                        result.append(
                            {'domain': domain, 'entity_id': entity_id, 'entity_name': review_doc['entity_name'],
                             'doc_id': f"{review_doc_id}-{review_sent_id}",
                             'doc': {'body': review_sent}})
                for faq_doc_id in self.knowledge_reader.get_faq_doc_ids(domain, entity_id):
                    faq_doc = self.knowledge_reader.get_faq_doc(domain, entity_id, faq_doc_id)
                    result.append({'domain': domain, 'entity_id': entity_id, 'entity_name': faq_doc['entity_name'],
                                   'doc_id': faq_doc_id,
                                   'doc': {'body': f"{faq_doc['question']} {faq_doc['answer']}"}})
        return result

    def _knowledge_to_string(self, doc, name=""):
        """ Convert a knowledge snippet to a string """
        doc_body = f"{name.title()}: {doc['body']}"
        return doc_body

    def _create_examples(self):
        """ Creating examples for model training and evaluation """
        logger.info("Creating examples")
        self.examples = []
        for dialog in tqdm(self.dialogs, disable=False, desc='creating examples'):
            #if self.debug > 0 and len(self.examples) >= self.debug:
            #    break
            dialog_id = dialog["id"]
            label = dialog["label"]

            dialog = dialog["log"]
            if label is None:
                # This will only happen when running knowledge-seeking turn detection on test data
                # So we create dummy target here
                label = {"target": False}

            target = label["target"]

            if not target and self.task != "detection":
                # we only care about non-knowledge-seeking turns in turn detection task
                continue

            # Turn Embedding 수정하기!!!
            history = [
                self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(turn["text"]))
                for turn in dialog
            ]
            
            gt_resp = label.get("response", "")
            tokenized_gt_resp = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(gt_resp))

            # apply history threshold at an utterance-level (a large value can be used to nullify its effect)
            truncated_history = history[-1000000:] #history_max_utterances : 1000000

            # perform token-level truncation of history from the left 
            truncated_history = truncate_sequences(truncated_history, 512) #history_max_tokens : 256

            if target:
                knowledge_keys = []
                knowledge_candidates = defaultdict(lambda: 0)
                used_knowledge = []
                knowledge_prefix_visited = set()

                if "knowledge" not in label:
                    raise ValueError("Please run entity matching before running knowledge selection")

                label_knowledge = label["knowledge"]

                for knowledge in label_knowledge:
                    if not (self.task == 'selection' and self.eval_only):
                        if knowledge['doc_type'] == 'review':
                            knowledge_key = f"{knowledge['domain']}__{knowledge['entity_id']}__{knowledge['doc_id']}-{knowledge['sent_id']}"
                        else:
                            knowledge_key = f"{knowledge['domain']}__{knowledge['entity_id']}__{knowledge['doc_id']}"

                    # find snippets with same entity as candidates
                    prefix = "{}__{}".format(knowledge["domain"], knowledge["entity_id"])
                    if prefix not in knowledge_prefix_visited:
                        knowledge_prefix_visited.add(prefix)
                        _knowledge_candidates = [
                            cand
                            for cand in self.snippets.keys()
                            if "__".join(cand.split("__")[:-1]) == prefix
                        ]

                        for _knowledge_cand_idx, _knowledge_cand in enumerate(_knowledge_candidates):
                            knowledge_candidates[_knowledge_cand] = 1
                    if self.split_type == "train" and self.negative_sample_method == "oracle":
                        # if there's not enough candidates during training, we just skip this example
                        if len(knowledge_candidates) < 2 or len(knowledge_candidates) <= len(label["knowledge"]): #n_candidates : 2
                            logger.info("Not enough candidates. Skip this example...")
                            continue

                    if not (self.task == 'selection' and self.eval_only):
                        used_knowledge.append(
                            self.snippets[knowledge_key]['token_ids'][:512]) # knowledge_max_tokens : 256
                        knowledge_keys.append(knowledge_key)
                knowledge_candidates = [k for k, v in knowledge_candidates.items()]

            else:
                knowledge_candidates = None
                used_knowledge = []
                knowledge_keys = []

            self.examples.append({
                "history": truncated_history,
                "knowledge": used_knowledge,
                "response": tokenized_gt_resp,
                "response_text": gt_resp,
                "dialog_id" : dialog_id
            })
        if self.split_type == 'train' : 
            for i in tqdm(range(len(dstc9_knowledge)), disable=False, desc='creating examples'):
                # Turn Embedding 수정하기!!!
                history = [
                    self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(turn))
                    for turn in dstc9_history[i]
                ]

                # apply history threshold at an utterance-level (a large value can be used to nullify its effect)
                truncated_history = history[-1000000:] #history_max_utterances : 1000000

                # perform token-level truncation of history from the left 
                truncated_history = truncate_sequences(truncated_history, 512) #history_max_tokens : 256
                
                gt_resp = dstc9_response[i]
                tokenized_gt_resp = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(gt_resp))

                tokenized_knowledge = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(dstc9_knowledge[i]))
                used_knowledge= [tokenized_knowledge[:512]]

                self.examples.append({
                    "history": truncated_history,
                    "knowledge": used_knowledge,
                    "response": tokenized_gt_resp,
                    "response_text": gt_resp,
                    "dialog_id" : i
                })


    def __getitem__(self, index):
        raise NotImplementedError

    def __len__(self):
        return len(self.examples)



In [10]:
class ResponseGenerationDataset(BaseDataset):
    def __init__(self, tokenizer, split_type, labels=True, labels_file=None):
        super(ResponseGenerationDataset, self).__init__(tokenizer, split_type, labels, labels_file)

    def __getitem__(self, index):
        example = self.examples[index]
        instance, _ = self.build_input_from_segments(
            example["knowledge"],
            example["history"],
            example["response"]
        )
        return instance

    def build_input_from_segments(self, knowledge, history, response):
        """ Build a sequence of input from 3 segments: knowledge, history and last reply """
        instance = {}
        knowledge = [[self.knowledge_sep] + k for k in knowledge]
        knowledge = [w for k in knowledge for w in k]

        # 3: special tokens; len(history): special speaker tokens
        entire_input_len = self.tokenizer.model_max_length - 3

        entire_knowledge_len, entire_history_len = len(knowledge), len(list(chain(*history)))
        max_history_len = int((entire_history_len * entire_input_len) / (entire_knowledge_len + entire_history_len))
        max_history_len = min(entire_history_len + len(history), max(max_history_len, 512))
        max_knowledge_len = entire_input_len - max_history_len  # - len(history)

        if max_knowledge_len < entire_knowledge_len:
            logger.warning(
                f"Knowledge too long! Have been truncated from {entire_knowledge_len} to {max_knowledge_len}")
            knowledge = knowledge[:max_knowledge_len]
        if max_history_len < entire_history_len:
            logger.warning(f"History too long! Have been truncated from {entire_history_len} to {max_history_len}")

        sequence = [knowledge] + history + [response]
        sequence_with_speaker = [
            [self.speaker1 if (len(sequence) - i) % 2 == 0 else self.speaker2] + s
            for i, s in enumerate(sequence[1:])
        ]  # speaker 2 (user)
        history = list(chain(*sequence_with_speaker[:-1]))[:max_history_len]
        sequence = [[self.bos]] + [sequence[0]] + [[self.knowledge_tag]] + [history] + [[self.eos]]
        instance["input_ids"] = list(chain(*sequence))
        instance["lm_labels"] = [self.bos] + sequence_with_speaker[-1] + [self.eos]
        return instance, sequence

    def collate_fn(self, batch):
        input_ids = [ins["input_ids"] for ins in batch]
        lm_labels = [ins["lm_labels"] for ins in batch]

        input_ids = torch.tensor(pad_ids(input_ids, self.pad))
        attention_mask = 1 - (input_ids == self.pad).int()
        lm_labels = torch.tensor(pad_ids(lm_labels, -100))

        return input_ids, attention_mask, lm_labels

In [11]:
class ResponseGenerationEvalDataset(ResponseGenerationDataset):
    def __init__(self, tokenizer, split_type, labels=True, labels_file=None):
        super(ResponseGenerationEvalDataset, self).__init__(tokenizer, split_type, labels, labels_file)

    def __getitem__(self, index):
        example = self.examples[index]
        return example

    def collate_fn(self, batch):
        return batch

# Baseline/generate.py - library 불러오기

In [12]:
!pip install nltk==3.6.6
!pip install numpy==1.22.0
!pip install rouge_score==0.1.2
!pip install scikit_learn==1.1.1
!pip install sentencepiece==0.1.96
!pip install strsimpy==0.2.1
!pip install summ_eval==0.892
!pip install tensorboard==2.9.0
!pip install tensorboardX==2.5
!pip install torch==1.13.1
!pip install tqdm==4.62.3
!pip install transformers==4.20.1
!python -m nltk.downloader 'punkt'
!python -m nltk.downloader 'wordnet'

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting nltk==3.6.6
  Downloading nltk-3.6.6-py3-none-any.whl (1.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m31.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: nltk
  Attempting uninstall: nltk
    Found existing installation: nltk 3.8.1
    Uninstalling nltk-3.8.1:
      Successfully uninstalled nltk-3.8.1
Successfully installed nltk-3.6.6
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting numpy==1.22.0
  Downloading numpy-1.22.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.8/16.8 MB[0m [31m65.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.22.4
    Uninstalling numpy-1.22.4:
     

In [13]:
#utils/argument
def update_additional_params(params, args):
    """ Update params if they are provided in cli """
    if args.get("dataroot"):
        params["dataset_args"]["dataroot"] = args["dataroot"]

    if args.get("knowledge_file"):
        params["dataset_args"]["knowledge_file"] = args["knowledge_file"]

    if args.get("model_name_or_path"):
        params["model_name_or_path"] = args["model_name_or_path"]
    if args.get("task"):
        params["task"] = args["task"]

    if args.get("negative_sample_method", ""):
        params["dataset_args"]["negative_sample_method"] = args["negative_sample_method"]

    if args.get("eval_all_snippets", False):
        params["dataset_args"]["eval_all_snippets"] = args["eval_all_snippets"]

    if args.get("learning_rate", False):
        params["learning_rate"] = args["learning_rate"]

    for key in ["history_max_tokens", "knowledge_max_tokens"]:
        if args.get(key, -1) > -1:
            params["dataset_args"][key] = args[key]

In [14]:
#utils/model
def run_batch_generation_train(model, batch, **kwargs):
    """ Run batch generation during training time """
    batch = tuple(input_tensor.to(device) for input_tensor in batch[:4])
    input_ids, attention_mask, lm_labels = batch
    model_outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=lm_labels)
    loss = model_outputs[0]
    lm_logits = model_outputs[1]
    yield loss, lm_logits, torch.tensor([])


def run_batch_generation_eval(model, batch, **kwargs):
    """ Run batch generation during evaluation time """
    batch = tuple(input_tensor.to(device) for input_tensor in batch[:4])
    input_ids, attention_mask, lm_labels = batch
    model_outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=lm_labels)
    loss = model_outputs[0]
    lm_logits = model_outputs[1]
    return loss, lm_logits, torch.tensor([])


def run_batch_generation_sample(model, tokenizer, batch, dataset):
    """ Run batch generation during test time
        Responses are decoded using beam search + sampling
    """
    current_output = []

    example = batch[0]
    knowledge, history = example["knowledge"], example["history"]
    response_text = example["response_text"]
    dialog_id = example["dialog_id"]

    instance, sequence = dataset.build_input_from_segments(
        knowledge, history, current_output
    )

    input_ids = torch.tensor(instance["input_ids"], device=device).unsqueeze(0)
    current_output = model.generate(input_ids=input_ids, num_beams=5,
                                    min_length=5, max_length=60,
                                    do_sample=True, num_return_sequences=1)

    return current_output, response_text, dialog_id

In [15]:
import re

RE_ART = re.compile(r'\b(a|an|the)\b')
RE_PUNC = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')


def remove_articles(_text):
    return RE_ART.sub(' ', _text)


def white_space_fix(_text):
    return ' '.join(_text.split())


def remove_punc(_text):
    return RE_PUNC.sub(' ', _text)  # convert punctuation to spaces


def lower(_text):
    return _text.lower()


def normalize(text):
    """Lower text and remove punctuation, articles and extra whitespace. """
    return white_space_fix(remove_articles(remove_punc(lower(text))))


In [16]:
#utils/metrics
import numpy as np

from nltk import bigrams as get_bigrams
from nltk import trigrams as get_trigrams
from nltk import word_tokenize, ngrams
from collections import Counter

from rouge_score import rouge_scorer
from summ_eval.bleu_metric import BleuMetric
from summ_eval.meteor_metric import MeteorMetric

def get_fourgrams(sequence, **kwargs):
    """
    Return the 4-grams generated from a sequence of items, as an iterator.

    :param sequence: the source data to be converted into 4-grams
    :type sequence: sequence or iter
    :rtype: iter(tuple)
    """

    for item in ngrams(sequence, 4, **kwargs):
        yield item


class Metric:
    def __init__(self):
        self.is_single = True
        self.reset()

    def reset(self):
        pass

    def update(self, output):
        raise NotImplementedError()

    def compute(self):
        raise NotImplementedError()


class DataCacheMetric(Metric):
    def __init__(self):
        self.refs = []
        self.preds = []
        super(DataCacheMetric, self).__init__()

    def reset(self):
        self.refs = []
        self.preds = []

    def update(self, output):
        hypothesis, reference = output
        assert isinstance(hypothesis, str)
        assert isinstance(reference, str)
        self.preds.append(hypothesis)
        self.refs.append(reference)

    def compute(self):
        return len(self.preds)

    def name(self):
        return "Data Count"


class UnigramMetric(Metric):
    def __init__(self, metric):
        self._score = None
        self._count = None
        if metric.lower() not in ["recall", "precision"]:
            raise ValueError("mertic should be either 'recall' or 'precision', got %s" % metric)
        self.metric = metric.lower()
        super(UnigramMetric, self).__init__()

    def reset(self):
        self._score = 0
        self._count = 0
        super(UnigramMetric, self).reset()

    def update(self, output):
        # hypothesis and reference are assumed to be actual sequences of tokens
        hypothesis, reference = output

        hyp_tokens = normalize(hypothesis).split()
        ref_tokens = normalize(reference).split()

        common = Counter(ref_tokens) & Counter(hyp_tokens)
        num_same = sum(common.values())

        if num_same == 0:
            score = 0
        else:
            if self.metric == "precision":
                score = 1.0 * num_same / len(hyp_tokens)
            else:
                assert self.metric == "recall"
                score = 1.0 * num_same / len(ref_tokens)

        self._score += score
        self._count += 1

    def compute(self):
        if self._count == 0:
            raise ValueError("Unigram metrics must have at least one example before it can be computed!")
        return self._score / self._count

    def name(self):
        return "Unigram{:s}".format(self.metric.capitalize())


class NGramDiversity(Metric):
    def __init__(self, n=1):
        self._n = n
        self._diversity = None
        self._count = None

        if self._n not in [1, 2, 3, 4]:
            raise ValueError("NGramDiversity only supports n=1 (unigrams), n=2 (bigrams),"
                             "n=3 (trigrams) and n=4 (4-grams)!")

        self.ngram_func = {
            1: lambda x: x,
            2: get_bigrams,
            3: get_trigrams,
            4: get_fourgrams
        }[self._n]

        super(NGramDiversity, self).__init__()

    def reset(self):
        self._diversity = 0
        self._count = 0
        super(NGramDiversity, self).reset()

    def update(self, output):
        hypothesis, _ = output

        if hypothesis is None:
            diversity = 0
        else:
            diversity = 0
            output_tokens = word_tokenize(hypothesis)
            denominator = float(len(output_tokens))

            if denominator != 0.0:
                ngrams = set(list(self.ngram_func(output_tokens)))
                diversity = len(ngrams) / denominator

        self._diversity += diversity
        self._count += 1

    def compute(self):
        if self._count == 0:
            raise ValueError("NGramDiversity must consume at least one example before it can be computed!")
        return self._diversity / self._count

    def name(self):
        return "{:d}GramDiversity".format(self._n)


class CorpusNGramDiversity(Metric):
    def __init__(self, n=1):
        self._n = n

        self._ngrams = None
        self._token_count = None

        if self._n not in [1, 2, 3, 4]:
            raise ValueError("CorpusNGramDiversity only supports n=1 (unigrams), n=2 (bigrams),"
                             "n=3 (trigrams) and n=4 (4-grams)!")
        self.ngram_func = {
            1: lambda x: x,
            2: get_bigrams,
            3: get_trigrams,
            4: get_fourgrams
        }[self._n]

        super(CorpusNGramDiversity, self).__init__()

    def reset(self):
        self._ngrams = set()
        self._token_count = 0
        super(CorpusNGramDiversity, self).reset()

    def update(self, output):
        hypothesis, _ = output
        if isinstance(hypothesis, str) and hypothesis:
            output_tokens = word_tokenize(hypothesis)

            ngrams = list(self.ngram_func(output_tokens))
            self._ngrams.update(ngrams)
            self._token_count += len(output_tokens)

    def compute(self):
        if self._token_count == 0:
            raise ValueError("CorpusNGramDiversity must consume at least one example before it can be computed!")

        return len(self._ngrams) / self._token_count

    def name(self):
        return "Corpus{:d}GramDiversity".format(self._n)


class LENGTH(DataCacheMetric):
    def __init__(self):
        self._len = []
        super(LENGTH, self).__init__()

    def reset(self):
        self._len = []

    def update(self, output):
        hypothesis, _ = output
        self._len.append(len(hypothesis.split()))

    def compute(self):
        if len(self._len) == 0:
            raise ValueError("LENGTH must have at least one example before it can be computed!")
        return sum(self._len) / len(self._len)

    def name(self):
        return "LENGTH"


class BLEU(DataCacheMetric):
    def __init__(self):
        super(BLEU, self).__init__()

    def compute(self):
        if len(self.preds) == 0:
            raise ValueError("BLEU-1 must have at least one example before it can be computed!")

        metric = BleuMetric()
        score = metric.evaluate_batch(self.preds, self.refs)
        return score['bleu']

    def name(self):
        return "BLEU"


class METEOR(DataCacheMetric):
    def __init__(self):
        super(METEOR, self).__init__()

    def compute(self):
        if len(self.preds) == 0:
            raise ValueError("METEOR must have at least one example before it can be computed!")
            
        metric = MeteorMetric()
        score = metric.evaluate_batch(self.preds, self.refs)
        return score['meteor'] #* 100

    def name(self):
        return "METEOR"


class ROUGE(Metric):
    def __init__(self):
        self.rouge_type = ['rouge1', 'rouge2', 'rougeL', "rougeLsum"]
        self.scorer = rouge_scorer.RougeScorer(self.rouge_type, use_stemmer=True)
        self._rouge = None
        self._count = None
        super(ROUGE, self).__init__()
        self.is_single = False

    def reset(self):
        self._rouge = []
        self._count = 0
        super(ROUGE, self).reset()

    def update(self, output):
        hypothesis, reference = output
        rouge = self.scorer.score(reference, hypothesis)

        _rouge = [rouge[_rouge_type].fmeasure * 100 for _rouge_type in self.rouge_type]
        self._rouge.append(_rouge)
        self._count += 1

    def compute(self):
        if self._count == 0:
            raise ValueError("ROUGE-L must have at least one example before it can be computed!")
        return np.array(self._rouge).mean(axis=0).tolist()

    def name(self):
        return self.rouge_type


Downloading the meteor jar


In [17]:
#utils/data 
def write_generation_preds(dataset_walker, output_file, dialog_ids, responses):
    """ Write results of response generation to output_file """
    print('start write generation preds')
    labels = [label for log, label in dataset_walker]
    new_labels = [{"target": False}] * len(dataset_walker)
    # Update the dialogs with detection result
    for dialog_id, response in zip(dialog_ids, responses):
        label = labels[dialog_id]
        new_label = {"target": True, "response": response}

        if label is None:
            label = new_label
        else:
            label = label.copy()
            label.update(new_label)
            if "response_tokenized" in label:
                label.pop("response_tokenized")
        new_labels[dialog_id] = label

    print('new_labels')
    print(new_labels)

    if os.path.dirname(output_file) and not os.path.exists(os.path.dirname(output_file)):
        os.makedirs(os.path.dirname(output_file))

    with open(output_file, "w") as jsonfile:
        logger.info("Writing predictions to {}".format(output_file))
        json.dump(new_labels, jsonfile, indent=2)
    print('finish write_generation_preds')

# Baseline/generate.py - 실행

In [18]:
import argparse
import logging
import os
import random
import json

from typing import Dict
from argparse import Namespace

import numpy as np
import torch
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer, PreTrainedModel

try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter

In [19]:
def set_seed(num):
    random.seed(num)
    np.random.seed(num)
    torch.cuda.manual_seed(num)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(num)

In [20]:
def generation_evaluate(eval_dataset, model, tokenizer, desc="") -> Dict:
    """ Generate responses and report the eval performance if references are available """
    os.makedirs(eval_output_dir, exist_ok=True)
    
    eval_batch_size = 4 #per_gpu_eval_batch_size:4

    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset,
        sampler=eval_sampler,
        batch_size=1,  # only support batch_size=1 for sampling right now
        collate_fn=eval_dataset.collate_fn
    )

    metrics = [
        DataCacheMetric(),
        UnigramMetric(metric="recall"),
        UnigramMetric(metric="precision"),
        NGramDiversity(n=1),
        NGramDiversity(n=2),
        NGramDiversity(n=3),
        NGramDiversity(n=4),
        CorpusNGramDiversity(n=1),
        CorpusNGramDiversity(n=2),
        CorpusNGramDiversity(n=3),
        CorpusNGramDiversity(n=4),
        BLEU(),
        ROUGE(),
        METEOR(),
    ]

    all_output_texts = []
    dialog_ids = []
    do_evaluate = False
    model.eval()

    run_batch_generation_func = run_batch_generation_sample

    for batch in tqdm(eval_dataloader, desc="Evaluating", disable=False):
        with torch.no_grad():
            sampled_output_ids, ground_truth, dialog_id = run_batch_generation_func(model, tokenizer, batch, eval_dataset)
            sampled_output_text = [tokenizer.decode(_sampled_output_ids, skip_special_tokens=True) for
                                   _sampled_output_ids in sampled_output_ids]
            if len(sampled_output_text) == 1:
                all_output_texts.append(sampled_output_text[0])
            else:
                all_output_texts.append(sampled_output_text)
            dialog_ids.append(dialog_id)
        if ground_truth.strip() != "":
            do_evaluate = True
            for metric in metrics:
                metric.update((sampled_output_text[0], ground_truth))

    # 안들어가짐!!!!????
    print(output_file)
    if output_file:
      write_generation_preds(eval_dataset.dataset_walker, output_file, dialog_ids, all_output_texts)

    print('after write generation preds in evaluate!')
    result = dict()
    if do_evaluate:
        output_eval_file = os.path.join(eval_output_dir, f"eval_results_{task}.txt")
        with open(output_eval_file, "a") as writer:
            logger.info("***** Eval results %s *****" % desc)
            writer.write("***** Eval results %s *****\n" % desc)
            for metric in metrics:
                name = metric.name()
                score = metric.compute()
                if name == "METEOR" :
                  print('here!! Meteor score', str(score))

                if metric.is_single:
                    result[name] = score
                    logger.info("  %s = %s", name, str(score))
                    writer.write("%s = %s\n" % (name, str(score)))
                    print('metric is single')
                    print(name, str(score))
                else:
                    for _name, _score in zip(name, score):
                        result[_name] = _score
                        logger.info("  %s = %s", _name, str(_score))
                        writer.write("%s = %s\n" % (_name, str(_score)))
                        print('else : metric is single')
                        print(_name, str(_score))

    return result

In [21]:
def evaluate(eval_dataset, model: PreTrainedModel, run_batch_fn, desc="") -> Dict:
    """ Model evaluation for knowledge seeking turn detection and knowledge selection
        Report evaluation results if gold labels are available
    """
    eval_output_dir = output_dir
    os.makedirs(eval_output_dir, exist_ok=True)

    # eval_batch_size for selection must be 1 to handle different number of candidates
    eval_batch_size = 4

    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset,
        sampler=eval_sampler,
        batch_size=eval_batch_size,
        collate_fn=eval_dataset.collate_fn
    )

    eval_loss = 0.0
    nb_eval_steps = 0
    model.eval()
    data_infos = []
    all_preds = []
    all_labels = []
    
    epoch_iterator = tqdm(eval_dataloader, desc = "Evaluate", disable = False)
    for _, batch in enumerate(epoch_iterator):
        with torch.no_grad():
            loss, logits, labels = run_batch_fn(model, batch)
            eval_loss += loss.mean().item()
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps

    if not eval_only:
        return get_eval_performance(eval_output_dir, eval_loss, all_preds, all_labels, desc)


def get_eval_performance(eval_output_dir, eval_loss, all_preds, all_labels, desc):
    """ Get evaluation performance when the gold labels are available """
    if task == "generation":
        perplexity = torch.exp(torch.tensor(eval_loss))
        result = {"perplexity": perplexity, "loss": eval_loss, "val_measure": eval_loss}

    else:
        raise ValueError("task not in ['generation', 'selection', 'detection'], got %s" % task)

    logger.info(str(result))

    output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
    with open(output_eval_file, "a") as writer:
        logger.info("***** Eval results %s *****" % desc)
        writer.write("***** Eval results %s *****\n" % desc)
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))

    return result

In [22]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# Set seed
set_seed(42)

cuda:0


In [23]:
#from transformers import LongformerModel, BertForQuestionAnswering
from transformers import LEDTokenizer, LEDForConditionalGeneration, AutoConfig

MAX_SEQUENCE_LENGTH = 2048 #4096

model_name = "allenai/led-base-16384"
#model_name = "hyesunyun/update-summarization-bart-large-longformer"

config = AutoConfig.from_pretrained(model_name)
tokenizer = LEDTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens(SPECIAL_TOKENS)
print(tokenizer.model_max_length)
tokenizer.model_max_length = min(2048, tokenizer.model_max_length)
print(tokenizer.model_max_length)

#model = LEDForConditionalGeneration.from_pretrained(model_name, ignore_mismatched_sizes = True, config = config)
model = torch.load("/content/drive/MyDrive/dstc11-track5/output2model.pt")

model.resize_token_embeddings(len(tokenizer))
model.to(device)

Downloading:   0%|          | 0.00/1.07k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/772 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/27.0 [00:00<?, ?B/s]

16384
2048


LEDForConditionalGeneration(
  (led): LEDModel(
    (shared): Embedding(50269, 768)
    (encoder): LEDEncoder(
      (embed_tokens): Embedding(50269, 768)
      (embed_positions): LEDLearnedPositionalEmbedding(16384, 768)
      (layers): ModuleList(
        (0-5): 6 x LEDEncoderLayer(
          (self_attn): LEDEncoderAttention(
            (longformer_self_attn): LEDEncoderSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (query_global): Linear(in_features=768, out_features=768, bias=True)
              (key_global): Linear(in_features=768, out_features=768, bias=True)
              (value_global): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): L

In [24]:
dataset_class, run_batch_fn_train, run_batch_fn_eval = ResponseGenerationDataset, run_batch_generation_train, run_batch_generation_eval

In [25]:
# load datasets and train the model
from signal import signal, SIGPIPE, SIG_DFL  
signal(SIGPIPE,SIG_DFL) 

dataroot = '/content/drive/MyDrive/dstc11-track5/data'
task = "generation"
negative_sample_method = 'oracle'
knowledge_file = 'knowledge.json'
debug = 0
knowledge_max_tokens = 1024
history_max_tokens = 1024

history_max_utterances = 1000000
n_candidates = 2

train_dataset = ResponseGenerationDataset(tokenizer, split_type="train")
eval_dataset = ResponseGenerationDataset(tokenizer, split_type="val")
test_dataset = ResponseGenerationEvalDataset(tokenizer, split_type="val")

tokenizing...: 100%|██████████| 28431/28431 [00:11<00:00, 2553.19it/s]
creating examples: 100%|██████████| 28431/28431 [01:33<00:00, 305.52it/s]
creating examples: 100%|██████████| 1185/1185 [00:35<00:00, 33.53it/s]
tokenizing...: 100%|██████████| 4173/4173 [00:00<00:00, 10354.50it/s]
creating examples: 100%|██████████| 4173/4173 [00:14<00:00, 281.19it/s]
tokenizing...: 100%|██████████| 4173/4173 [00:00<00:00, 10757.55it/s]
creating examples: 100%|██████████| 4173/4173 [00:15<00:00, 276.19it/s]


In [26]:
from sklearn.metrics import recall_score, precision_score, average_precision_score, classification_report, f1_score
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm, trange
from transformers import (
    AutoConfig,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    get_linear_schedule_with_warmup,
    AutoModelForSequenceClassification,
)
from typing import Dict, Tuple


eval_output_dir = '/content/drive/MyDrive/dstc11-track5/output'


def get_eval_performance(eval_output_dir, eval_loss, all_preds, all_labels, desc):
    """ Get evaluation performance when the gold labels are available """
    if task == "generation":
        perplexity = torch.exp(torch.tensor(eval_loss))
        result = {"perplexity": perplexity, "loss": eval_loss, "val_measure": eval_loss}
        print('get_eval_performance')
        print(result)
    else:
        raise ValueError("args.task not in ['generation', 'selection', 'detection'], got %s" % task)

    output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
    with open(output_eval_file, "a") as writer:
        writer.write("***** Eval results %s *****\n" % desc)
        for key in sorted(result.keys()):
            writer.write("%s = %s\n" % (key, str(result[key])))

    return result

In [27]:
def train(train_dataset, eval_dataset, model: PreTrainedModel, tokenizer: PreTrainedTokenizer,
          run_batch_fn_train, run_batch_fn_eval) -> Tuple[int, float]:
    
    """ Model training and evaluation """
    exp_name = ''
    log_dir = os.path.join("runs", exp_name) if exp_name else None
    tb_writer = SummaryWriter(log_dir)
    #output_dir = log_dir

    
    train_batch_size = 4 #per_gpu_train_batch_size : 4

    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(
        train_dataset,
        # shuffle=True,
        sampler=train_sampler,
        batch_size=train_batch_size,
        collate_fn=train_dataset.collate_fn
    )

    #args setting
    gradient_accumulation_steps = 4
    num_train_epochs = 3
    learning_rate = 3e-5
    adam_epsilon = 1e-8
    warmup_steps = 0.2
    max_grad_norm = 1.0


    t_total = len(train_dataloader) // gradient_accumulation_steps * num_train_epochs
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, eps=adam_epsilon)
    if 0 < warmup_steps < 1:
        warmup_steps = int(warmup_steps * t_total)

    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
    )

    # Train!
    global_step = 0
    model.zero_grad()
    train_iterator = trange(
        0, int(num_train_epochs), desc="Epoch", disable=False
    )
    set_seed(42)  # for reproducibility
    val_loss = float('inf')

    for epoch in train_iterator:
        local_steps = 0  # update step
        tr_loss = 0.0
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=False)
        step = 0  # backward step
        total_log_loss = 0
        for _, batch in enumerate(epoch_iterator):
            model.train()
            for loss, _, _ in run_batch_fn_train(model, batch, global_step=global_step):
                step += 1

                total_log_loss += loss.item()

                if gradient_accumulation_steps > 1:
                    loss = loss / gradient_accumulation_steps

                loss.backward()
                tr_loss += loss.item()

                if (step + 1) % gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    global_step += 1
                    local_steps += 1
                    epoch_iterator.set_postfix(Loss=tr_loss / local_steps)
                    total_log_loss = 0

        results = evaluate(eval_dataset, model, run_batch_fn_eval, desc=str(global_step))

        for key, value in results.items():
            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
        tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
        tb_writer.add_scalar("loss", tr_loss / local_steps, global_step)

        if results['val_measure'] < val_loss:
            logger.info(f"Find a smaller val loss measure {results['val_measure']}")
            val_loss = results['val_measure']
            # Save model checkpoint
            #save_model(output_dir, model, tokenizer)
            torch.save(model, output_dir + str(3 + epoch) + 'model.pt')
        else:
            logger.info(f"The val loss measure {results['val_measure']} is larger than "
                        f"the smallest val loss {val_loss}, continue to train ... ")

    tb_writer.flush()
    tb_writer.close()

    return global_step, tr_loss / local_steps

In [28]:
import gc
gc.collect()

0

In [None]:
len(train_dataset)

In [None]:
torch.cuda.empty_cache()
labels_file = None

eval_only = False
output_dir = '/content/drive/MyDrive/dstc11-track5/output'
output_file = 'generation_output'

global_step, tr_loss = train(train_dataset, eval_dataset, model, tokenizer, run_batch_fn_train, run_batch_fn_eval)

In [29]:
result = generation_evaluate(test_dataset, model, tokenizer, desc="test")


Evaluating:   3%|▎         | 56/2129 [00:33<20:49,  1.66it/s]


KeyboardInterrupt: ignored

.

In [None]:
output_dir = '/content/drive/MyDrive/dstc11-track5/output'
output_file = 'generation_output'
result = generation_evaluate(test_dataset, model, tokenizer, desc="test")

Evaluating: 100%|██████████| 2129/2129 [17:35<00:00,  2.02it/s]


generation_output
start write generation preds
new_labels
[{'target': False}, {'target': True, 'knowledge': [{'domain': 'hotel', 'entity_id': 20, 'doc_type': 'review', 'doc_id': 9, 'sent_id': 4}, {'domain': 'hotel', 'entity_id': 20, 'doc_type': 'review', 'doc_id': 6, 'sent_id': 4}, {'domain': 'hotel', 'entity_id': 20, 'doc_type': 'review', 'doc_id': 4, 'sent_id': 2}], 'response': 'According to the reviews I have on hand for that location, the opinions are mixed.  Half found the bathrooms pristine and top notch, while half found them not well cleaned.'}, {'target': True, 'knowledge': [{'domain': 'restaurant', 'entity_id': 19250, 'doc_type': 'review', 'doc_id': 0, 'sent_id': 3}, {'domain': 'restaurant', 'entity_id': 19250, 'doc_type': 'review', 'doc_id': 0, 'sent_id': 4}], 'response': 'The Maharajah Tandoori Restaurant has modern décor that transports you to India. Is there anything else I can help you with?'}, {'target': True, 'knowledge': [{'domain': 'hotel', 'entity_id': 7, 'doc_type'



---

