In [1]:
import sys
sys.path.append('..')

In [2]:
import json
from tqdm import tqdm

from transformers import BertTokenizer
from data_utils import get_examples_from_dialogues, convert_state_dict, load_dataset
from data_utils import OntologyDSTFeature, DSTPreprocessor, _truncate_seq_pair

## Data Loading 

In [3]:
train_data_file = "/opt/ml/repo/taepd/input/data/train_dataset/train_dials.json"
slot_meta = json.load(open("/opt/ml/repo/taepd/input/data/train_dataset/slot_meta.json"))
ontology = json.load(open("/opt/ml/repo/taepd/input/data/train_dataset/ontology.json"))
train_data, dev_data, dev_labels = load_dataset(train_data_file)

In [4]:
train_examples = get_examples_from_dialogues(data=train_data,
                                             user_first=True,
                                             dialogue_level=True)

dev_examples = get_examples_from_dialogues(data=dev_data,
                                           user_first=True,
                                           dialogue_level=True)

100%|██████████| 6301/6301 [00:00<00:00, 9124.00it/s]
100%|██████████| 699/699 [00:00<00:00, 15943.24it/s]


In [5]:
len(train_data)

6301

In [6]:
max_turn = max([len(e['dialogue']) for e in train_data])
tokenizer = BertTokenizer.from_pretrained('dsksd/bert-ko-small-minimal')

## TODO-1: SUMBT Preprocessor 정의 

Ontology-based DST model인 SUMBT의 InputFeature를 만들기 위한 Preprocessor를 정의해야 합니다. <br>

1. `_convert_examples_to_features` 함수의 빈칸을 매워 완성하세요.
2. `recover_state` 함수의 빈칸을 매워 완성하세요.

In [7]:
class SUMBTPreprocessor(DSTPreprocessor):
    def __init__(
        self,
        slot_meta,
        src_tokenizer,
        trg_tokenizer=None,
        ontology=None,
        max_seq_length=64,
        max_turn_length=14,
    ):
        self.slot_meta = slot_meta
        self.src_tokenizer = src_tokenizer
        self.trg_tokenizer = trg_tokenizer if trg_tokenizer else src_tokenizer
        self.ontology = ontology
        self.max_seq_length = max_seq_length
        self.max_turn_length = max_turn_length

    def _convert_example_to_feature(self, example):
        guid = example[0].guid.rsplit("-", 1)[0]  # dialogue_idx 'ex) snowy-hat-8324:관광_식당_11(-0)' 
        turns = []
        token_types = []
        labels = []
        num_turn = None
        for turn in example[: self.max_turn_length]:
            assert len(turn.current_turn) == 2
            uttrs = []
            for segment_idx, uttr in enumerate(turn.current_turn):
                token = self.src_tokenizer.encode(uttr, add_special_tokens=False)
                uttrs.append(token)

            _truncate_seq_pair(uttrs[0], uttrs[1], self.max_seq_length - 3)
            tokens = (
                [self.src_tokenizer.cls_token_id]
                + uttrs[0]
                + [self.src_tokenizer.sep_token_id]
                + uttrs[1]
                + [self.src_tokenizer.sep_token_id]
            )
            token_type = [0] * (len(uttrs[0]) + 2) + [1] * (len(uttrs[1]) + 1)
            if len(tokens) < self.max_seq_length:
                gap = self.max_seq_length - len(tokens)
                tokens.extend([self.src_tokenizer.pad_token_id] * gap)
                token_type.extend([0] * gap)
            turns.append(tokens)
            token_types.append(token_type)
            label = []
            if turn.label:
                slot_dict = convert_state_dict(turn.label)
            else:
                slot_dict = {}
            for slot_type in self.slot_meta:
                value = slot_dict.get(slot_type, "none")
                # TODO
                # raise Exception('label_idx를 ontology에서 꺼내오는 코드를 작성하세요!')
                # Your code here!
                label_idx = ontology[slot_type].index(value)  
                label.append(label_idx)
            labels.append(label)
        num_turn = len(turns)
        if len(turns) < self.max_turn_length:
            gap = self.max_turn_length - len(turns)
            for _ in range(gap):
                dummy_turn = [self.src_tokenizer.pad_token_id] * self.max_seq_length
                turns.append(dummy_turn)
                token_types.append(dummy_turn)
                dummy_label = [-1] * len(self.slot_meta)
                labels.append(dummy_label)
        return OntologyDSTFeature(
            guid=guid,
            input_ids=turns,
            segment_ids=token_types,
            num_turn=num_turn,
            target_ids=labels,
        )

    def convert_examples_to_features(self, examples):
        return list(map(self._convert_example_to_feature, examples))

    def recover_state(self, pred_slots, num_turn):
        states = []
        # TODO
        # raise Exception('SUMBT의 아웃풋을 prediction form으로 바꾸는 코드를 작성하세요!')
        # Your code here!
        for pred_slot in pred_slots[:num_turn]:
            state = []
            for s, p in zip(self.slot_meta, pred_slot):
                v = ontology[s][p]
                if v != 'none':
                    state.append(f'{s}-{v}')
                states.append(state)
        return states

    def collate_fn(self, batch):
        guids = [b.guid for b in batch]
        input_ids = torch.LongTensor([b.input_ids for b in batch])
        segment_ids = torch.LongTensor([b.segment_ids for b in batch])
        input_masks = input_ids.ne(self.src_tokenizer.pad_token_id)
        target_ids = torch.LongTensor([b.target_ids for b in batch])
        num_turns = [b.num_turn for b in batch]
        return input_ids, segment_ids, input_masks, target_ids, num_turns, guids

## Convert_Examples_to_Features 

In [8]:
processor = SUMBTPreprocessor(slot_meta,
                              tokenizer,
                              ontology=ontology,  # predefined ontology
                              max_seq_length=64,  # 각 turn마다 최대 길이
                              max_turn_length=max_turn)  # 각 dialogue의 최대 turn 길이
train_features = processor.convert_examples_to_features(train_examples)
dev_features = processor.convert_examples_to_features(dev_examples)

In [9]:
print(len(train_features))  # 대화 level의 features
print(len(dev_features))

6301
699


In [14]:
train_examples[0][0].guid.rsplit("-", 1)[0] 
train_examples[0][0].guid 


DSTInputExample(guid='snowy-hat-8324:관광_식당_11-1', context_turns=['', '서울 중앙에 있는 박물관을 찾아주세요'], current_turn=['좋네요 거기 평점은 말해주셨구 전화번호가 어떻게되나요?', '안녕하세요. 문화역서울 284은 어떠신가요? 평점도 4점으로 방문객들에게 좋은 평가를 받고 있습니다.'], label=['관광-종류-박물관', '관광-지역-서울 중앙', '관광-이름-문화역서울 284'])

In [24]:
current_turn = train_examples[0][1].current_turn
uttrs = []
for segment_idx, uttr in enumerate(current_turn):
                token = tokenizer.encode(uttr, add_special_tokens=False)
                uttrs.append(token)
#                 print(token)
#                 print(uttrs)
_truncate_seq_pair(uttrs[0], uttrs[1], 64 - 3)
tokens = (
    [tokenizer.cls_token_id]
    + uttrs[0]
    + [tokenizer.sep_token_id]
    + uttrs[1]
    + [tokenizer.sep_token_id]
)

In [26]:
token_type = [0] * (len(uttrs[0]) + 2) + [1] * (len(uttrs[1]) + 1)

In [30]:
train_examples[0][1].label

['관광-종류-박물관', '관광-지역-서울 중앙', '관광-이름-문화역서울 284']

In [33]:
for key, value in ontology.items():
    print(key)
    print(value)
    break

관광-경치 좋은
['none', 'dontcare', 'yes', 'no']


In [34]:
type(slot_meta)

list

In [37]:
ontology['관광-경치 좋은'].index('yes')

2