In [45]:
import json
from collections import Counter

from utils.data2seq import Dial2seq

In [46]:
topical_sequencer = Dial2seq('data/topical_chat_annotated_v3.json', 3)
daily_sequencer = Dial2seq('data/daily_dialogue_annotated_v3.json', 3)

In [47]:
len(daily_sequencer.data), len(topical_sequencer.data)

(12376, 8628)

In [48]:
class SequencePreprocessor():
    """
    preprocesses sequences
    to filter only those that are relevant for the task

    params:
    stoplist_labels: MIDAS labels to ignore
    seq_validator: None or similar to one of seq_validation classes or
    similar
    """

    def __init__(self, 
                 stoplist_labels: list = ['misc', 'anaphor', 
                                          # 'film', 'song', "literary_work",
                                          # 'association_football_club', 'championship'
                                         ],
                 seq_validator=None):
        self.stoplist_labels = stoplist_labels
        self.seq_validator = seq_validator

    def transform(self, sequences: list) -> list:
        """ extract only necessary data from sequences """
        seqs = list()

        for seq in sequences:
            if self.seq_validator and not self.seq_validator.is_valid(seq[-1]):
                # validate final utterance if necessary
                continue
            sample = self.__get_dict_entry(self.__shape_output(seq))
            seqs.append(sample)

        return seqs


    def __shape_output(self, seq: list) -> list:
        """ shapes sequence in order to keep only the necessary data """
        output = list()

        # preprocess context
        for ut in seq[:-1]:
            midas_labels, midas_vectors = self.__get_midas(ut['midas'])
            output.append((
                ut['text'], midas_labels, midas_vectors, ut['entities']))

        # preprocess target: only the first sentence of
        # the last utterance in the sequence
        midas_labels, midas_vectors = self.__get_midas(seq[-1]['midas'])
        midas_labels, midas_vectors = midas_labels[0:1], midas_vectors[0:1]
        sentence = seq[-1]['text'][0].lower()
        entities = seq[-1]['entities'][0]

        if entities:
            # filter out labels from stoplist
            entities = [e for e in entities if e['label'] not in self.stoplist_labels]
            # pre-sort them -> longest first to prevent mess with overlapping entities
            entities = sorted(entities, key=lambda x: len(x['text']), reverse=True)

        ## replace entities with their labels
        for ent in entities:
            sentence = sentence.replace(ent['text'].lower(), ent['label'].upper())

        output.append(
            (sentence, midas_labels[0], entities))

        return output


    def __get_dict_entry(self, seq) -> dict:
        """ creates a proper dict entry to dump into a file """
        entry = dict()
        entry['previous_text'] = [s[0] for s in seq[:-1]]
        entry['previous_midas'] = [s[1] for s in seq[:-1]]
        entry['midas_vectors'] = [s[2] for s in seq[:-1]]
        entry['previous_entities'] = [s[-1] for s in seq[:-1]]
        entry['predict'] = {}
        entry['predict']['text'] = seq[-1][0]
        entry['predict']['midas'] = seq[-1][1]
        entry['predict']['entities'] = seq[-1][2]

        return entry


    def __get_midas(self, midas_labels: list) -> tuple:
        """
        extracts midas labels with max value per each sentence in an utterance
        and return a midas vector per each sentence
        """
        labels = []
        vectors = []

        for sentence_labels in midas_labels:
            labels.append(max(sentence_labels, key=sentence_labels.get))
            vectors.append(list(sentence_labels.values()))

        return labels, vectors

In [49]:
from abc import ABC, abstractmethod

class Validator(ABC):

    @abstractmethod
    def is_valid(self, seq: dict):
        pass


class OneEntity(Validator):

    def __init__(self, stoplist: list = ['misc', 'anaphor', 'film', 'song', "literary_work"]):
        self.stoplist = stoplist

    def is_valid(self, seq:dict) -> bool:
        """
        checks if the first sentence of the sequence has
        one annotated entity and it is not in the stoplist
        """
        if len(seq['entities'][0]) != 1:
            return False

        return seq['entities'][0][0]['label'] not in self.stoplist

## Daily Dialogue

In [50]:
daily = daily_sequencer.transform()

In [51]:
len(daily)

61087

In [52]:
daily_preproc = SequencePreprocessor()

In [53]:
daily_dataset = daily_preproc.transform(daily)
len(daily_dataset)

61087

In [54]:
with open('data/daily_dataset_v3.json', 'w', encoding='utf-8') as f:
    json.dump(daily_dataset, f, ensure_ascii=False, indent=2)

In [55]:
daily_preproc = SequencePreprocessor(seq_validator=OneEntity())
daily_dataset = daily_preproc.transform(daily)
len(daily_dataset)

3894

In [56]:
with open('data/single_entity_daily_dataset_v3.json', 'w', encoding='utf-8') as f:
    json.dump(daily_dataset, f, ensure_ascii=False, indent=2)

## Topical Chat

In [57]:
topical = topical_sequencer.transform()

In [58]:
len(topical)

162494

In [59]:
topical_preproc = SequencePreprocessor()

In [60]:
topical_dataset = topical_preproc.transform(topical)
len(topical_dataset)

162494

In [61]:
with open('data/topical_dataset_v3.json', 'w', encoding='utf-8') as f:
    json.dump(topical_dataset, f, ensure_ascii=False, indent=2)

In [62]:
topical_preproc = SequencePreprocessor(seq_validator=OneEntity())
topical_dataset = topical_preproc.transform(topical)
len(topical_dataset)

7800

In [63]:
with open('data/single_entity_topical_dataset_v3.json', 'w', encoding='utf-8') as f:
    json.dump(topical_dataset, f, ensure_ascii=False, indent=2)