In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [12]:
!cd "/content/drive/MyDrive/Knowledge-grounded Task-oriented Dialogue system/Knowledge-grounded-ToD"

In [91]:
import json
import os
import logging
from collections import defaultdict
from itertools import chain

import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

import torch
from tqdm import tqdm

logger = logging.getLogger(__name__)

SPECIAL_TOKENS = {
    "additional_special_tokens": ["<speaker1>", "<speaker2>", "<knowledge_sep>", "<knowledge_tag>"],
}
SPECIAL_TOKENS_VALUES = ["<bos>", "<eos>", "<pad>", "<speaker1>", "<speaker2>", "<knowledge_sep>", "<knowledge_tag>"]

task = "generation"
dataroot = '/content/drive/MyDrive/Knowledge-grounded Task-oriented Dialogue system/Knowledge-grounded-ToD/data'

In [92]:
class DSTC_DatasetWalker(object):
    def __init__(self, dataset, dataroot, labels=False, labels_file=None, incl_knowledge = False):
        path = os.path.join(os.path.abspath(dataroot))
            
        if dataset not in ['train', 'val', 'test']:
            raise ValueError('Wrong dataset name: %s' % (dataset))

        logs_file = os.path.join(path, dataset, 'logs.json')
        with open(logs_file, 'r') as f:
            self.logs = json.load(f)

        self.labels = None

        if labels is True:
            if labels_file is None:
                labels_file = os.path.join(path, dataset, 'labels.json')

            with open(labels_file, 'r') as f:
                self.labels = json.load(f)

        self._incl_knowledge = incl_knowledge
        if self._incl_knowledge is True:
            # knowledge_reader 수정
            #self._knowledge = knowledge_reader(dataroot)
            self._knowledge = DSTC_KnowledgeReader(dataroot)

    def __iter__(self):
        if self.labels is not None:
            for log, label in zip(self.logs, self.labels):
                if self._incl_knowledge is True and label['target'] is True:
                    for idx, snippet in enumerate(label['knowledge']):
                        domain = snippet['domain']
                        entity_id = snippet['entity_id']
                        doc_type = snippet['doc_type']
                        doc_id = snippet['doc_id']

                        if doc_type == 'review':
                            sent_id = snippet['sent_id']                            
                            sent = self._knowledge.get_review_sent(domain, entity_id, doc_id, sent_id)
                            label['knowledge'][idx]['sent'] = sent
                            
                        elif doc_type == 'faq':
                            doc = self._knowledge.get_faq_doc(domain, entity_id, doc_id)
                            question = doc['question']
                            answer = doc['answer']

                            label['knowledge'][idx]['question'] = question
                            label['knowledge'][idx]['answer'] = answer
                
                yield(log, label)
        else:
            for log in self.logs:
                yield(log, None)

    def __len__(self, ):
        return len(self.logs)

In [93]:
#scripts/knowledge_reader
class DSTC_KnowledgeReader(object):
    def __init__(self, dataroot, knowledge_file='knowledge.json'):
        path = os.path.join(os.path.abspath(dataroot))

        with open(os.path.join(path, knowledge_file), 'r') as f:
            self.knowledge = json.load(f)

    def get_domain_list(self):
        return list(self.knowledge.keys())

    def get_entity_list(self, domain):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name")

        entity_ids = []
        for entity_id in self.knowledge[domain].keys():
            entity_ids.append(int(entity_id))

        result = []
        for entity_id in sorted(entity_ids):
            entity_name = self.knowledge[domain][str(entity_id)]['name']
            result.append({'id': entity_id, 'name': entity_name})

        return result

    def get_entity_name(self, domain, entity_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        result = self.knowledge[domain][str(entity_id)]['name'] or None

        return result

    def get_faq_doc_ids(self, domain, entity_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)
        
        result = []

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        entity_obj = self.knowledge[domain][str(entity_id)]
        for doc_id, doc_obj in entity_obj['faqs'].items():
            result.append(doc_id)

        return result

    def get_faq_doc(self, domain, entity_id, doc_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        entity_name = self.get_entity_name(domain, entity_id)

        if str(doc_id) not in self.knowledge[domain][str(entity_id)]['faqs']:
            raise ValueError("invalid doc id: %s" % str(doc_id))

        doc_obj = self.knowledge[domain][str(entity_id)]['faqs'][str(doc_id)]
        result = {'domain': domain, 'entity_id': entity_id, 'entity_name': entity_name, 'doc_id': doc_id, 'question': doc_obj['question'], 'answer': doc_obj['answer']}

        return result

    def get_review_doc_ids(self, domain, entity_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        result = []
        
        entity_obj = self.knowledge[domain][str(entity_id)]
        for doc_id, doc_obj in entity_obj['reviews'].items():
            result.append(doc_id)

        return result

    def get_review_doc(self, domain, entity_id, doc_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))

        entity_name = self.get_entity_name(domain, entity_id)

        if str(doc_id) not in self.knowledge[domain][str(entity_id)]['reviews']:
            raise ValueError("invalid doc id: %s" % str(doc_id))
        
        doc_obj = self.knowledge[domain][str(entity_id)]['reviews'][str(doc_id)]
        
        result = {'domain': domain, 'entity_id': entity_id, 'entity_name': entity_name, 'doc_id': doc_id, 'sentences': doc_obj['sentences']}
        if 'traveler_type' in doc_obj:
            result['traveler_type'] = doc_obj['traveler_type']
        
        if 'dishes' in doc_obj:
            result['dishes'] = doc_obj['dishes']

        if 'drinks' in doc_obj:
            result['drinks'] = doc_obj['drinks']

        return result
    
    def get_review_sent(self, domain, entity_id, doc_id, sent_id):
        if domain not in self.get_domain_list():
            raise ValueError("invalid domain name: %s" % domain)

        if str(entity_id) not in self.knowledge[domain]:
            raise ValueError("invalid entity id: %s" % str(entity_id))
        
        if str(doc_id) not in self.knowledge[domain][str(entity_id)]['reviews']:
            raise ValueError("invalid doc id: %s" % str(doc_id))

        if str(sent_id) not in self.knowledge[domain][str(entity_id)]['reviews'][str(doc_id)]['sentences']:
            raise ValueError("invalid sentence id: %s" % str(sent_id))

        result = self.knowledge[domain][str(entity_id)]['reviews'][str(doc_id)]['sentences'][str(sent_id)]

        return result

In [152]:
class DSTC_BaseDataset(torch.utils.data.Dataset): 
  global dstc
  df = pd.DataFrame({"query" : 'aaa', "positive_passages": 'bbb'}, index = [0])
  dstc = pd.DataFrame(df, columns = ["query", "positive_passages"])

  global document
  doc = pd.DataFrame({"doc_id" : 'aaa', "text" : 'bbb'}, index = [0])
  document = pd.DataFrame(doc, columns = ["doc_id", "text"])

  def __init__(self, split_type, labels=True, labels_file = None) :
    
    self.dataroot = dataroot
    self.split_type = split_type

    self.dataset_walker = DSTC_DatasetWalker(split_type, labels=labels, dataroot = self.dataroot, labels_file=labels_file)
    self.dialogs = self._prepare_conversations()
    self.knowledge_reader = DSTC_KnowledgeReader(self.dataroot, "knowledge.json")
    self.knowledge, self.snippets = self._prepare_knowledge() #knowledge_key baseline_dataset.py에서 추가

    self._create_examples()
    self.dstc = dstc

  def _prepare_conversations(self):
    tokenized_dialogs = []
    for i, (log, label) in enumerate(tqdm(self.dataset_walker)) :
      dialog = {}
      dialog["id"] = i
      dialog["log"] = log
      dialog["label"] = label
      tokenized_dialogs.append(dialog)

    return tokenized_dialogs

  def _knowledge_to_string(self, doc, name = ""):
        """ Convert a knowledge snippet to a string """
        doc_body = f"{name.title()}: {doc['body']}"
        return doc_body

  def _prepare_knowledge(self): 
    self.knowledge_docs = self._get_snippet_list()
    
    snippets = dict()
    for snippet_id, snippet in enumerate(self.knowledge_docs) :
        key = "{}__{}__{}".format(snippet["domain"], str(snippet["entity_id"]) or "", snippet["doc_id"])
        knowledge = self._knowledge_to_string(snippet["doc"], name = snippet["entity_name"] or "")

        snippets[key] = knowledge

    return knowledge, snippets #, tokenized_snippets ## 해결!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!1


  """Corpus 정의할 때 사용"""
  def _get_snippet_list(self):
        """ Get all knowledge snippets in the dataset """
        result = []
        keys = []
        texts = []

        for domain in self.knowledge_reader.get_domain_list():
            for entity_id in self.knowledge_reader.knowledge[domain].keys():
                for review_doc_id in self.knowledge_reader.get_review_doc_ids(domain, entity_id):
                    review_doc = self.knowledge_reader.get_review_doc(domain, entity_id, review_doc_id)
                    for review_sent_id, review_sent in review_doc['sentences'].items():
                        result.append(
                            {'domain': domain, 'entity_id': entity_id, 'entity_name': review_doc['entity_name'],
                             'doc_id': f"{review_doc_id}-{review_sent_id}",
                             'doc': {'body': review_sent}}) 
                for faq_doc_id in self.knowledge_reader.get_faq_doc_ids(domain, entity_id):
                    faq_doc = self.knowledge_reader.get_faq_doc(domain, entity_id, faq_doc_id)
                    result.append({'domain': domain, 'entity_id': entity_id, 'entity_name': faq_doc['entity_name'],
                                   'doc_id': faq_doc_id,
                                   'doc': {'body': f"{faq_doc['question']} {faq_doc['answer']}"}})

        return result


  def _create_examples(self):
    self.examples = []
    idx = 0
    id = 0 
    for dialog in tqdm(self.dialogs):
      dialog_id = dialog["id"]
      label = dialog["label"]
      dialog = dialog["log"]
      if label is None:
        label = {"target": False}

      target = label["target"]

      if not target and task != "detection":
        continue
            

      if target:
        knowledge_keys = []
        used_knowledge = []

        #if "knowledge" not in label:
        #  label["knowledge"] = [self.knowledge_docs[0]]

        label_knowledge = label["knowledge"]
        history = [turn["text"] for turn in dialog]

        for knowledge in label_knowledge:
            if knowledge['doc_type'] == 'review':
                knowledge_key = f"{knowledge['domain']}__{knowledge['entity_id']}__{knowledge['doc_id']}-{knowledge['sent_id']}"
            else:
                knowledge_key = f"{knowledge['domain']}__{knowledge['entity_id']}__{knowledge['doc_id']}"

            used_knowledge.append(self.snippets[knowledge_key]) # knowledge_max_tokens : 256
            knowledge_key = knowledge_key.replace('hotel', '0')
            knowledge_key = knowledge_key.replace('restaurant', '1')

            knowledge_keys.append(knowledge_key)
      
      else:
        used_knowledge = []

      self.examples.append({
        "history": history,
        "doc_id" : knowledge_keys,
        "knowledge": used_knowledge,
        #"candidates": knowledge_candidates,
        #"response": tokenized_gt_resp,
        #"response_text": gt_resp,
        #"label": label,
        #"knowledge_seeking": target,
        #"dialog_id": dialog_id
      })

      passages = []
      for i in range(len(knowledge_keys)):
        passage = {"doc_id": knowledge_keys[i], "text" : used_knowledge[i]}
        passages.append(passage)

      dstc.loc[idx] = [history, passages]
      idx += 1

    #dstc9 = dstc9.append(self.examples, ignore_index = True)
    #print(dstc9)
    #return self.examples

    for domain in self.knowledge_reader.get_domain_list():
        for entity_id in self.knowledge_reader.knowledge[domain].keys():
            for review_doc_id in self.knowledge_reader.get_review_doc_ids(domain, entity_id):
                review_doc = self.knowledge_reader.get_review_doc(domain, entity_id, review_doc_id)
                for review_sent_id, review_sent in review_doc['sentences'].items():
                    key = f"{domain}__{entity_id}__{review_doc_id}-{review_sent_id}"
                    know_key = key.replace('', '')
                    know_key = know_key.replace('restaurant', '1')
                    know_key = know_key.replace('hotel', '0')
                    
                    document.loc[id] = [know_key, self.snippets[key]]
                    id += 1  
            for faq_doc_id in self.knowledge_reader.get_faq_doc_ids(domain, entity_id):
                faq_doc = self.knowledge_reader.get_faq_doc(domain, entity_id, faq_doc_id)
                key = f"{domain}__{entity_id}__{faq_doc_id}"
                know_key = key.replace('', '')
                know_key = know_key.replace('restaurant', '1')
                know_key = know_key.replace('hotel', '0')
                
                document.loc[id] = [know_key, self.snippets[key]]
                id += 1
    
  def build_input_from_segments(self, knowledge, history, response, with_eos=True):
    """ Build a sequence of input from 3 segments: knowledge, history and last reply """
    instance = {}

    sequence = [[self.bos] + knowledge] + history + [response + ([self.eos] if with_eos else [])]
    print(sequence)
    sequence_with_speaker = [
      [self.speaker1 if (len(sequence) - i) % 2 == 0 else self.speaker2] + s
      for i, s in enumerate(sequence[1:])
    ]
    sequence = [sequence[0]] + sequence_with_speaker
    instance["input_ids"] = list(chain(*sequence))
    instance["token_type_ids"] = [self.speaker2 if i % 2 else self.speaker1 for i, s in enumerate(sequence) for _ in s]
    instance["mc_token_ids"] = len(instance["input_ids"]) - 1
    instance["lm_labels"] = ([-100] * sum(len(s) for s in sequence[:-1])) + [-100] + sequence[-1][1:]

    return instance, sequence

  def __getitem__(self, index):
    raise NotImplementedError
    
  def __len__(self):
    return len(self.examples)

In [153]:
class DSTC_ResponseGenerationDataset(DSTC_BaseDataset):
    def __init__(self, split_type, labels=True, labels_file=None):
        super(DSTC_ResponseGenerationDataset, self).__init__(split_type, labels, labels_file)

    def __getitem__(self, index):
        example = self.examples[index]
        instance, _ = self.build_input_from_segments(
            example["knowledge"],
            example["history"],
            #example["response_text"]
        )
        print(instance)
        return instance

class DSTC_ResponseGenerationEvalDataset(DSTC_BaseDataset):
    def __init__(self, split_type, labels=True, labels_file=None):
        super(DSTC_ResponseGenerationEvalDataset, self).__init__(split_type, labels, labels_file)

    def __getitem__(self, index):
        example = self.examples[index]
        return example

    def collate_fn(self, batch):
        return batch

In [154]:
dstc_dataset = DSTC_ResponseGenerationDataset(split_type = 'val')

100%|██████████| 4173/4173 [00:00<00:00, 740891.91it/s]
100%|██████████| 4173/4173 [00:02<00:00, 1514.61it/s]


In [155]:
len(document)

10882

In [157]:
document.head()

Unnamed: 0,doc_id,text
0,0__0__0-0,A And B Guest House: I was really happy with m...
1,0__0__0-1,"A And B Guest House: I stayed on my own, and I..."
2,0__0__0-2,A And B Guest House: I also thought that my ro...
3,0__0__1-0,A And B Guest House: My husband was pleased to...
4,0__0__1-1,A And B Guest House: We thought it was a bit n...


In [135]:
document.text[0]

'A And B Guest House: I was really happy with my recent stay at A and B Guest House.'

In [158]:
import json

dstc.to_json('corpus.json')