# 数据处理模块



**目录：**
1. 标签类别收集

2. 单词词汇表收集

3. 训练样本读取

3. 样本转化为模型可读的特征

---


In [1]:
import argparse

import os
import copy
import json
import logging

import numpy as np
from copy import deepcopy
from collections import Counter
from collections import OrderedDict
from ordered_set import OrderedSet


import torch
from torch.utils.data import TensorDataset, RandomSampler, DataLoader
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

logger = logging.getLogger(__name__)

## 1 训练数据形式回顾


### 四个数据集

<img src="./datasets.png"  width="200" height="300" align="left" />

### 数据集格式

<img src="./数据集格式.png"  width="300" height="300" align="left" />

## 2 词汇表定义

In [2]:
# orderedset
#    是一种可变的数据结构，它是列表和集合的混合体。 
#    它记住条目的顺序，
#    每个条目都有一个索引号 可以查到。

letters = OrderedSet('abracadabra')
print(letters)
print('r' in letters)
print(letters.index('r'))
print(letters[2])

letters.add('x')
print(letters)

OrderedSet(['a', 'b', 'r', 'c', 'd'])
True
2
r
OrderedSet(['a', 'b', 'r', 'c', 'd', 'x'])


In [3]:
class Alphabet(object):
    """
    Storage and serialization a set of elements.
    """

    def __init__(self, name, if_use_pad, if_use_unk):

        self.__name = name
        self.__if_use_pad = if_use_pad
        self.__if_use_unk = if_use_unk

        self.__index2instance = OrderedSet()  # orderset[3]
        self.__instance2index = OrderedDict()  

        # Counter Object record the frequency
        # of element occurs in raw text.
        self.__counter = Counter()

        if if_use_pad:
            self.__sign_pad = "<PAD>"
            self.add_instance(self.__sign_pad)
        if if_use_unk:
            self.__sign_unk = "<UNK>"
            self.add_instance(self.__sign_unk)

    @property
    def name(self):
        return self.__name

    def add_instance(self, instance, multi_intent=False):
        """ Add instances to alphabet.

        1, We support any iterative data structure which
        contains elements of str type.  支持一次添加多个instances

        2, We will count added instances that will influence
        the serialization of unknown instance.

        :param instance: is given instance or a list of it.
        """

        if isinstance(instance, (list, tuple)):
            for element in instance:
                self.add_instance(element, multi_intent=multi_intent)   # 递归
            return

        # We only support elements of str type.
        assert isinstance(instance, str)
        if multi_intent and '#' in instance:   # 针对多个意图标签的场景：e.g., RateBook#SearchScreeningEvent
            for element in instance.split('#'):
                self.add_instance(element, multi_intent=multi_intent)
            return
        
        # count the frequency of instances.
        self.__counter[instance] += 1

        if instance not in self.__index2instance:
            self.__instance2index[instance] = len(self.__index2instance)
            self.__index2instance.append(instance)

    def get_index(self, instance, multi_intent=False):
        """ Serialize given instance and return.

        For unknown words, the return index of alphabet
        depends on variable self.__use_unk:

            1, If True, then return the index of "<UNK>";
            2, If False, then return the index of the
            element that hold max frequency in training data.

        :param instance: is given instance or a list of it.
        :return: is the serialization of query instance.
        """
        
        # 使用递归的写法，支持 instance列表 取索引
        if isinstance(instance, (list, tuple)):
            return [self.get_index(elem, multi_intent=multi_intent) for elem in instance]

        assert isinstance(instance, str)
        if multi_intent and '#' in instance:  # 针对多个意图标签的场景
            return [self.get_index(element, multi_intent=multi_intent) for element in instance.split('#')]

        try:
            return self.__instance2index[instance]
        except KeyError:
            if self.__if_use_unk:
                return self.__instance2index[self.__sign_unk]
            else:
                max_freq_item = self.__counter.most_common(1)[0][0]   # counter.most_common(k): [(x1,freq1), (..., ,,,)]
                return self.__instance2index[max_freq_item]

    def get_instance(self, index):
        """ Get corresponding instance of query index.

        if index is invalid, then throws exception.

        :param index: is query index, possibly iterable.
        :return: is corresponding instance.
        """

        if isinstance(index, list):
            return [self.get_instance(elem) for elem in index]

        return self.__index2instance[index]

    def save_content(self, dir_path):
        """ Save the content of alphabet to files.

        There are two kinds of saved files:
            1, The first is a list file, elements are
            sorted by the frequency of occurrence.  # 根据频率排序

            2, The second is a dictionary file, elements
            are sorted by it serialized index.      # 与其索引一起存入文件

        :param dir_path: is the directory path to save object.
        """

        # Check if dir_path exists.
        if not os.path.exists(dir_path):
            os.mkdir(dir_path)

        list_path = os.path.join(dir_path, self.__name + "_list.txt")
        with open(list_path, 'w', encoding="utf8") as fw:
            for element, frequency in self.__counter.most_common():
                fw.write(element + '\t' + str(frequency) + '\n')

        dict_path = os.path.join(dir_path, self.__name + "_dict.txt")
        with open(dict_path, 'w', encoding="utf8") as fw:
            for index, element in enumerate(self.__index2instance):
                fw.write(element + '\t' + str(index) + '\n')

    def __len__(self):
        return len(self.__index2instance)

    def __str__(self):
        return 'Alphabet {} contains about {} words: \n\t{}'.format(self.name, len(self), self.__index2instance)


In [None]:
## 3 数据加载与处理

In [4]:
class DatasetManager(object):

    def __init__(self, args):

        # Instantiate alphabet objects.
        self.__word_alphabet = Alphabet('word', if_use_pad=True, if_use_unk=True)
        self.__slot_alphabet = Alphabet('slot', if_use_pad=False, if_use_unk=False)
        self.__intent_alphabet = Alphabet('intent', if_use_pad=False, if_use_unk=False)

        # Record the raw text of dataset.
        self.__text_word_data = {}
        self.__text_slot_data = {}
        self.__text_intent_data = {}

        # Record the serialization of dataset.
        self.__digit_word_data = {}
        self.__digit_slot_data = {}
        self.__digit_intent_data = {}

        self.__args = args

    @property    # 将类的方法变为属性；被property修饰的方法只有一个参数，self；必须有返回值；
    def test_sentence(self):
        return deepcopy(self.__text_word_data['test'])
    
    @property
    def train_digit_word_data(self):
        return deepcopy(self.__digit_word_data['train'])

    @property
    def word_alphabet(self):
        return deepcopy(self.__word_alphabet)

    @property
    def slot_alphabet(self):
        return deepcopy(self.__slot_alphabet)

    @property
    def intent_alphabet(self):
        return deepcopy(self.__intent_alphabet)

    @property
    def num_epoch(self):
        return self.__args.num_epoch

    @property
    def batch_size(self):
        return self.__args.batch_size

    @property
    def learning_rate(self):
        return self.__args.learning_rate

    @property
    def l2_penalty(self):
        return self.__args.l2_penalty

    @property
    def save_dir(self):
        return self.__args.save_dir

    @property
    def slot_forcing_rate(self):
        return self.__args.slot_forcing_rate
    
    # 读取数据
    @staticmethod
    def __read_file(file_path):
        """ 
        Read data file of given path.

        :param file_path: path of data file.
        :return: list of sentence, list of slot and list of intent.
        """

        texts, slots, intents = [], [], []
        text, slot = [], []

        with open(file_path, 'r', encoding="utf8") as fr:
            for line in fr.readlines():
                items = line.strip().split()

                if len(items) == 1:   # 表示：到了一个标签行
                    texts.append(text)
                    slots.append(slot)
                    if "/" not in items[0]:
                        intents.append(items)
                    else:
                        print(items)
                        new = items[0].split("/")
                        intents.append([new[1]])

                    # clear buffer lists.
                    text, slot = [], []

                elif len(items) == 2:
                    text.append(items[0].strip())
                    slot.append(items[1].strip())

        return texts, slots, intents
    
    
    def add_file(self, file_path, data_name, if_train_file):
        text, slot, intent = self.__read_file(file_path)

        if if_train_file:
            self.__word_alphabet.add_instance(text)
            self.__slot_alphabet.add_instance(slot)
            self.__intent_alphabet.add_instance(intent, multi_intent=True)

        # Record the raw text of dataset.
        self.__text_word_data[data_name] = text
        self.__text_slot_data[data_name] = slot
        self.__text_intent_data[data_name] = intent

        # Serialize raw text and stored it.
        self.__digit_word_data[data_name] = self.__word_alphabet.get_index(text)
        if if_train_file:
            self.__digit_slot_data[data_name] = self.__slot_alphabet.get_index(slot)
            self.__digit_intent_data[data_name] = self.__intent_alphabet.get_index(intent, multi_intent=True)
    
    def quick_build(self):
        """
        Convenient function to instantiate a dataset object.
        """

        train_path = os.path.join(self.__args.data_dir, 'train.txt')
        dev_path = os.path.join(self.__args.data_dir, 'dev.txt')
        test_path = os.path.join(self.__args.data_dir, 'test.txt')
        
        # add_file: 读入数据，并做初步处理
        self.add_file(train_path, 'train', if_train_file=True)   
        self.add_file(dev_path, 'dev', if_train_file=False)
        self.add_file(test_path, 'test', if_train_file=False)

        # Check if save path exists.
        if not os.path.exists(self.save_dir):
            os.mkdir(self.save_dir)

        alphabet_dir = os.path.join(self.__args.save_dir, "alphabet")
        self.__word_alphabet.save_content(alphabet_dir)
        self.__slot_alphabet.save_content(alphabet_dir)
        self.__intent_alphabet.save_content(alphabet_dir)


In [7]:
# 先构建参数

# parser = argparse.ArgumentParser()

# 实际使用应该是命令行传入的参数，不过我这里直接赋值传入
# parser.add_argument("--task", default=None, required=True, type=str, help="The name of the task to train")
# parser.add_argument("--data_dir", default="./data", type=str, help="The input data dir")
# parser.add_argument("--intent_label_file", default="intent_label.txt", type=str, help="Intent Label file")
# parser.add_argument("--slot_label_file", default="slot_label.txt", type=str, help="Slot Label file")

# args = parser.parse_args()

class Args():
    task =  None
    

args = Args()
args.data_dir = "../data/MixATIS_clean"
args.save_dir = "../save/MixATIS_clean"

# Instantiate a dataset object.
dataset = DatasetManager(args)
dataset.quick_build()
# dataset.show_summary()

print(len(dataset.test_sentence))
print(len(dataset.train_digit_word_data))
print(dataset.train_digit_word_data[10])

828
13162
[39, 95, 42, 109, 110, 5, 11, 12, 92, 62, 53, 86, 56, 79, 27, 111, 66, 88, 112, 68, 50, 113, 9, 102, 44, 45, 114, 95, 9, 115, 24, 64]


## 转化为特征

In [None]:
class DatasetManager(object):

    def __init__(self, args):

        # Instantiate alphabet objects.
        self.__word_alphabet = Alphabet('word', if_use_pad=True, if_use_unk=True)
        self.__slot_alphabet = Alphabet('slot', if_use_pad=False, if_use_unk=False)
        self.__intent_alphabet = Alphabet('intent', if_use_pad=False, if_use_unk=False)

        # Record the raw text of dataset.
        self.__text_word_data = {}
        self.__text_slot_data = {}
        self.__text_intent_data = {}

        # Record the serialization of dataset.
        self.__digit_word_data = {}
        self.__digit_slot_data = {}
        self.__digit_intent_data = {}

        self.__args = args

    @property    # 将类的方法变为属性；被property修饰的方法只有一个参数，self；必须有返回值；
    def test_sentence(self):
        return deepcopy(self.__text_word_data['test'])
    
    @property
    def train_digit_word_data(self):
        return deepcopy(self.__digit_word_data['train'])

    @property
    def word_alphabet(self):
        return deepcopy(self.__word_alphabet)

    @property
    def slot_alphabet(self):
        return deepcopy(self.__slot_alphabet)

    @property
    def intent_alphabet(self):
        return deepcopy(self.__intent_alphabet)

    @property
    def num_epoch(self):
        return self.__args.num_epoch

    @property
    def batch_size(self):
        return self.__args.batch_size

    @property
    def learning_rate(self):
        return self.__args.learning_rate

    @property
    def l2_penalty(self):
        return self.__args.l2_penalty

    @property
    def save_dir(self):
        return self.__args.save_dir

    @property
    def slot_forcing_rate(self):
        return self.__args.slot_forcing_rate
    
    # 读取数据
    @staticmethod
    def __read_file(file_path):
        """ 
        Read data file of given path.

        :param file_path: path of data file.
        :return: list of sentence, list of slot and list of intent.
        """

        texts, slots, intents = [], [], []
        text, slot = [], []

        with open(file_path, 'r', encoding="utf8") as fr:
            for line in fr.readlines():
                items = line.strip().split()

                if len(items) == 1:   # 表示：到了一个标签行
                    texts.append(text)
                    slots.append(slot)
                    if "/" not in items[0]:
                        intents.append(items)
                    else:
                        print(items)
                        new = items[0].split("/")
                        intents.append([new[1]])

                    # clear buffer lists.
                    text, slot = [], []

                elif len(items) == 2:
                    text.append(items[0].strip())
                    slot.append(items[1].strip())

        return texts, slots, intents
    
    
    def add_file(self, file_path, data_name, if_train_file):
        text, slot, intent = self.__read_file(file_path)

        if if_train_file:
            self.__word_alphabet.add_instance(text)
            self.__slot_alphabet.add_instance(slot)
            self.__intent_alphabet.add_instance(intent, multi_intent=True)

        # Record the raw text of dataset.
        self.__text_word_data[data_name] = text
        self.__text_slot_data[data_name] = slot
        self.__text_intent_data[data_name] = intent

        # Serialize raw text and stored it.
        self.__digit_word_data[data_name] = self.__word_alphabet.get_index(text)
        if if_train_file:
            self.__digit_slot_data[data_name] = self.__slot_alphabet.get_index(slot)
            self.__digit_intent_data[data_name] = self.__intent_alphabet.get_index(intent, multi_intent=True)
    
    def quick_build(self):
        """
        Convenient function to instantiate a dataset object.
        """

        train_path = os.path.join(self.__args.data_dir, 'train.txt')
        dev_path = os.path.join(self.__args.data_dir, 'dev.txt')
        test_path = os.path.join(self.__args.data_dir, 'test.txt')
        
        # add_file: 读入数据，并做初步处理
        self.add_file(train_path, 'train', if_train_file=True)   
        self.add_file(dev_path, 'dev', if_train_file=False)
        self.add_file(test_path, 'test', if_train_file=False)

        # Check if save path exists.
        if not os.path.exists(self.save_dir):
            os.mkdir(self.save_dir)

        alphabet_dir = os.path.join(self.__args.save_dir, "alphabet")
        self.__word_alphabet.save_content(alphabet_dir)
        self.__slot_alphabet.save_content(alphabet_dir)
        self.__intent_alphabet.save_content(alphabet_dir)
    
    def batch_delivery(self, data_name, batch_size=None, is_digital=True, shuffle=True):
        if batch_size is None:
            batch_size = self.batch_size

        if is_digital:
            text = self.__digit_word_data[data_name]
            slot = self.__digit_slot_data[data_name]
            intent = self.__digit_intent_data[data_name]
        else:
            text = self.__text_word_data[data_name]
            slot = self.__text_slot_data[data_name]
            intent = self.__text_intent_data[data_name]
        dataset = TorchDataset(text, slot, intent)

        return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=self.__collate_fn)
    
    @staticmethod
    def add_padding(texts, items=None, digital=True):
        len_list = [len(text) for text in texts]
        max_len = max(len_list)

        # Get sorted index of len_list.
        sorted_index = np.argsort(len_list)[::-1]  # 按照长度倒序排列

        trans_texts, seq_lens, trans_items = [], [], None
        if items is not None:
            trans_items = [[] for _ in range(0, len(items))]

        for index in sorted_index:
            seq_lens.append(deepcopy(len_list[index]))
            trans_texts.append(deepcopy(texts[index]))
            if digital:
                trans_texts[-1].extend([0] * (max_len - len_list[index]))
            else:
                trans_texts[-1].extend(['<PAD>'] * (max_len - len_list[index]))

            # This required specific if padding after sorting.
            if items is not None:
                for item, (o_item, required) in zip(trans_items, items):
                    item.append(deepcopy(o_item[index]))
                    if required:
                        if digital:
                            item[-1].extend([0] * (max_len - len_list[index]))
                        else:
                            item[-1].extend(['<PAD>'] * (max_len - len_list[index]))

        if items is not None:
            return trans_texts, trans_items, seq_lens
        else:
            return trans_texts, seq_lens
    
    @staticmethod
    def __collate_fn(batch):
        """
        helper function to instantiate a DataLoader Object.
        """

        n_entity = len(batch[0])
        modified_batch = [[] for _ in range(0, n_entity)]

        for idx in range(0, len(batch)):
            for jdx in range(0, n_entity):
                modified_batch[jdx].append(batch[idx][jdx])

        return modified_batch