
HuggingFace Transformers를 활용한 토큰 분류 모델 학습

본 노트북에서는 `klue/roberta-base` 모델을 **KLUE** 내 **NLI** 데이터셋을 활용하여 모델을 훈련하는 예제를 다루게 됩니다.


학습 과정 이후에는 간단한 예제 코드를 통해 모델이 어떻게 활용되는지도 함께 알아보도록 할 것입니다.

모든 소스 코드는 [`huggingface-tutorial`](https://huggingface.co/course/chapter7/2)를 참고하였습니다. 

먼저, 노트북을 실행하는데 필요한 라이브러리를 설치합니다. 모델 훈련을 위해서는 `transformers`가, 학습 데이터셋 로드를 위해서는 `datasets` 라이브러리의 설치가 필요합니다. 그 외 모델 성능 검증을 위해 `scipy`, `scikit-learn`을 추가로 설치해주도록 합니다.

In [None]:
#https://towardsdatascience.com/how-to-create-and-train-a-multi-task-transformer-model-18c54a14624
#https://medium.com/@shahrukhx01/multi-task-learning-with-transformers-part-1-multi-prediction-heads-b7001cf014bf

In [None]:
!pip install  evaluate 
!pip install accelerate
# To run the training on TPU, you will need to uncomment the following line:
# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!apt install git-lfs

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 KB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: evaluate
Successfully installed evaluate-0.4.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting accelerate
  Downloading accelerate-0.18.0-py3-none-any.whl (215 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m215.3/215.3 KB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: accelerate
Successfully installed accelerate-0.18.0
Reading package lists... Done
Building dependency tree       
Reading state information... Done
git-lfs is already the newest version (2.9.2-1).
0 upgraded, 0 newly installed, 0 to remove and 23 not upgraded.


In [None]:
!pip install -U transformers datasets scipy scikit-learn overrides

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
#from huggingface_hub import notebook_login

#notebook_login()

## 문장 분류 모델 학습

노트북을 실행하는데 필요한 라이브러리들을 모두 임포트합니다.

In [1]:
import random
import logging
from IPython.display import display, HTML

import numpy as np
import pandas as pd
import datasets
from datasets import load_dataset, load_metric, ClassLabel, Sequence
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, AutoModel, PreTrainedTokenizer

from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import torch
from torch.utils.data import DataLoader, Dataset

import torch.nn as nn
import torch.nn.functional as F

from overrides import overrides



import argparse
import json
import os


ModuleNotFoundError: ignored

학습에 필요한 정보를 변수로 기록합니다.

본 노트북에서는 `klue-roberta-base` 모델을 활용하지만, https://huggingface.co/klue 페이지에서 더 다양한 사전학습 언어 모델을 확인하실 수 있습니다.

학습 태스크로는 `nli`를, 배치 사이즈로는 32를 지정하겠습니다.

In [None]:
model_checkpoint = "klue/bert-base"
batch_size = 64
task = "wos"

이제 HuggingFace `datasets` 라이브러리에 등록된 KLUE 데이터셋 중, NLI 데이터를 내려받습니다.

In [None]:
#['ynat', 'sts', 'nli', 'ner', 're', 'dp', 'mrc', 'wos']
raw_datasets = load_dataset("klue", task)



  0%|          | 0/2 [00:00<?, ?it/s]

다운로드 혹은 로드 후 얻어진 `datasets` 객체를 살펴보면, 훈련 데이터와 검증 데이터가 포함되어 있는 것을 확인할 수 있습니다.

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig()


In [None]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['guid', 'domains', 'dialogue'],
        num_rows: 8000
    })
    validation: Dataset({
        features: ['guid', 'domains', 'dialogue'],
        num_rows: 1000
    })
})

In [None]:
raw_datasets['train']['dialogue'][0]

[{'role': 'user',
  'text': '서울 중앙에 있는 박물관을 찾아주세요',
  'state': ['관광-종류-박물관', '관광-지역-서울 중앙']},
 {'role': 'sys',
  'text': '안녕하세요. 문화역서울 284은 어떠신가요? 평점도 4점으로 방문객들에게 좋은 평가를 받고 있습니다.',
  'state': []},
 {'role': 'user',
  'text': '좋네요 거기 평점은 말해주셨구 전화번호가 어떻게되나요?',
  'state': ['관광-종류-박물관', '관광-지역-서울 중앙', '관광-이름-문화역서울 284']},
 {'role': 'sys', 'text': '전화번호는 983880764입니다. 더 필요하신 게 있으실까요?', 'state': []},
 {'role': 'user',
  'text': '네 관광지와 같은 지역의 한식당을 가고싶은데요 야외석이 있어야되요',
  'state': ['관광-종류-박물관',
   '관광-지역-서울 중앙',
   '관광-이름-문화역서울 284',
   '식당-지역-서울 중앙',
   '식당-종류-한식당',
   '식당-야외석 유무-yes']},
 {'role': 'sys', 'text': '생각하고 계신 가격대가 있으신가요?', 'state': []},
 {'role': 'user',
  'text': '음.. 저렴한 가격대에 있나요?',
  'state': ['관광-종류-박물관',
   '관광-지역-서울 중앙',
   '관광-이름-문화역서울 284',
   '식당-가격대-저렴',
   '식당-지역-서울 중앙',
   '식당-종류-한식당',
   '식당-야외석 유무-yes']},
 {'role': 'sys', 'text': '죄송하지만 저렴한 가격대에는 없으시네요.', 'state': []},
 {'role': 'user',
  'text': '그럼 비싼 가격대로 다시 찾아주세요',
  'state': ['관광-종류-박물관',
   '관광-지역-서울 중앙',
   '관광

In [None]:
raw_datasets['train']['domains'][:3]

[['관광', '식당'], ['관광'], ['식당', '관광', '지하철']]

In [None]:
#특정 value 출현시 (ex: yes , no 출현시)
#목적 : 결국 매 turn마다 gate id  예측하고  value 에 있는 단어를 정답으로 분류 
gating_id = []
gating2id = {"none": 0, "dontcare": 1, "ptr": 2, "yes": 3, "no": 4}
gating_id.append(gating2id.get("박물관", gating2id['ptr']))
print(gating_id)

[2]


In [None]:
#example 에는 매 turn 마다 dialoge id , turn_id으로 저장시킴 

#state = (dom, slot, value) or ([dom-slot : slot_meta or domain_slot], value)
#context 와 현재 text 를 활용하여 value를 예측하는 일 ㅁ
# value 가 단어이고 이걸 vocabulary size word embedding matrix 활용하여 logit 으로 만들고 예측하도록함 
raw_datasets['train']['dialogue'][:3] 

[[{'role': 'user',
   'text': '서울 중앙에 있는 박물관을 찾아주세요',
   'state': ['관광-종류-박물관', '관광-지역-서울 중앙']},
  {'role': 'sys',
   'text': '안녕하세요. 문화역서울 284은 어떠신가요? 평점도 4점으로 방문객들에게 좋은 평가를 받고 있습니다.',
   'state': []},
  {'role': 'user',
   'text': '좋네요 거기 평점은 말해주셨구 전화번호가 어떻게되나요?',
   'state': ['관광-종류-박물관', '관광-지역-서울 중앙', '관광-이름-문화역서울 284']},
  {'role': 'sys', 'text': '전화번호는 983880764입니다. 더 필요하신 게 있으실까요?', 'state': []},
  {'role': 'user',
   'text': '네 관광지와 같은 지역의 한식당을 가고싶은데요 야외석이 있어야되요',
   'state': ['관광-종류-박물관',
    '관광-지역-서울 중앙',
    '관광-이름-문화역서울 284',
    '식당-지역-서울 중앙',
    '식당-종류-한식당',
    '식당-야외석 유무-yes']},
  {'role': 'sys', 'text': '생각하고 계신 가격대가 있으신가요?', 'state': []},
  {'role': 'user',
   'text': '음.. 저렴한 가격대에 있나요?',
   'state': ['관광-종류-박물관',
    '관광-지역-서울 중앙',
    '관광-이름-문화역서울 284',
    '식당-가격대-저렴',
    '식당-지역-서울 중앙',
    '식당-종류-한식당',
    '식당-야외석 유무-yes']},
  {'role': 'sys', 'text': '죄송하지만 저렴한 가격대에는 없으시네요.', 'state': []},
  {'role': 'user',
   'text': '그럼 비싼 가격대로 다시 찾아주세요',
   'state': ['관광-종

각 예시 데이터는 아래와 같이 두 개의 문장과 두 문장의 추론 관계를 라벨로 지니고 있습니다.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True) 

In [None]:
import dataclasses
from dataclasses import dataclass


In [None]:
@dataclass
class WoSInputExample:
    guid: str
    context_turns: List[str]
    current_turn: List[str]
    label: Optional[List[str]] = None

    def to_dict(self) -> Dict[str, Any]:
        return dataclasses.asdict(self)

    def to_json_string(self) -> str:
        """Serializes this instance to a JSON string."""

        return json.dumps(self.to_dict(), indent=2) + "\n"

@dataclass
class WoSInputFeature:
    guid: str
    input_id: List[int]
    segment_id: List[int]
    gating_id: List[int]
    target_ids: Optional[Union[List[int], List[List[int]]]]
    label: Optional[List[str]] = None


In [None]:

class WoSDataset:
    def __init__(self, features: List[Union[torch.Tensor, str, List[str]]]) -> None:
        self.features = features
        self.length = len(self.features)

    def __len__(self) -> int:
        return self.length

    def __getitem__(self, idx: int) -> Union[torch.Tensor, str, List[str]]:
        return self.features[idx]

class WoSProcessor:

    origin_train_file_name = "wos-v1.1_train.json"
    origin_dev_file_name = "wos-v1.1_dev.json"
    origin_test_file_name = "wos-v1.1_test.json"


    def __init__(self, max_seq_len: str, ontology_path: str, tokenizer: PreTrainedTokenizer) -> None:

        self.tokenizer = tokenizer
        self.slot_meta: List[str] = []
        self.gating2id = {"none": 0, "dontcare": 1, "ptr": 2, "yes": 3, "no": 4}
        self.id2gating = {v: k for k, v in self.gating2id.items()}
        self.ontology_path = None # json 파일 정제해서 적용해보기 
        self.max_seq_length = max_seq_len
        self.truncate = False 

    def _create_dataset(self, file_path: str, dataset_type: str) -> Dataset:

        # Read ontology file if exists and store the slots
        if self.ontology_path:
            _, self.slot_meta = self.build_slot_from_ontology(self.ontology_path)
            print(self.slot_meta)
        # Extract slots from a given dialogue and merge with ontology slots
        with open(file_path, "r", encoding="utf-8") as dial_file:
            dials = json.load(dial_file)
        slot_from_dials = self.build_slot_meta(dials)
        self.slot_meta = self.merge_slot_meta(slot_from_dials)

        examples = self._create_examples(file_path, dataset_type)
        features = self._convert_features(examples, dataset_type)

        """
        input_ids = torch.LongTensor(self.pad_ids([f.input_id for f in features], self.tokenizer.pad_token_id))
        segment_ids = torch.LongTensor(self.pad_ids([f.segment_id for f in features], self.tokenizer.pad_token_id))
        input_masks = input_ids.ne(self.tokenizer.pad_token_id)
        gating_ids = torch.LongTensor([f.gating_id for f in features])
        target_ids = self.pad_id_of_matrix(
            [torch.LongTensor(f.target_ids) for f in features], self.tokenizer.pad_token_id
        )
        return TensorDataset(input_ids, segment_ids, input_masks, gating_ids, target_ids, guids)
        """

        return WoSDataset(features)

    def _create_examples(self,  file_path: str, dataset_type: str) -> List[WoSInputExample]:
        examples = []
        with open(file_path, "r", encoding="utf-8") as f:
            data = json.load(f)
            for dialogue in data:
                dialogue_examples = self.get_examples_from_dialogue(dialogue)
                examples.extend(dialogue_examples)
        return examples

    def _convert_features(self, examples: List[WoSInputExample], dataset_type: str) -> List[WoSInputFeature]:
        features = []
        for example in examples:
            feature = self._convert_example_to_feature(example, dataset_type)
            if feature:
                features.append(feature)

        for feature in features[:3]:
            logger.info("*** Example ***")
            logger.info("input_id: %s" % feature.input_id)
            logger.info("gating_id: %s" % feature.gating_id)
            logger.info("target_ids: %s" % feature.target_ids)
            logger.info("label: %s" % feature.label)
 

        return features


    def _convert_example_to_feature(self, example: WoSInputExample, dataset_type: str) -> WoSInputFeature:
        dialogue_context = example.context_turns + [self.tokenizer.sep_token] + example.current_turn

        input_id = self.tokenizer.convert_tokens_to_ids(dialogue_context)
        len_input_id = len(input_id)
        if len_input_id > self.max_seq_length - 2:
            if dataset_type == "train" and not self.truncate:
                """Skip this training data which is longer than max_seq_length"""
                logger.info(
                    f"[{dataset_type}] Skip the context [{example.guid}] "
                    f"since the length of dialogue exceeds {self.max_seq_length - 2} < {len_input_id}"
                )
                return None  # type: ignore[return-value]
            else:
                input_id = input_id[len_input_id - (self.max_seq_length - 2) :]
                logger.info(
                    f"[{dataset_type}] Truncate the context [{example.guid}] "
                    f"since the length of dialogue exceeds {self.max_seq_length - 2} < {len_input_id}"
                )
        input_id = [self.tokenizer.cls_token_id] + input_id + [self.tokenizer.sep_token_id]
        segment_id = [0] * len(input_id)

        target_ids = []
        gating_id = []
        state = self.convert_state_dict(example.label)
        for slot in self.slot_meta:
            value = state.get(slot, "none")
            target_id = self.tokenizer.encode(value, add_special_tokens=False)
            len_target_id = len(target_id)
            if len_target_id > self.max_seq_length - 1:
                if dataset_type == "train" and not self.truncate:
                    """Skip this training data which is longer than max_seq_length"""
                    logger.info(
                        f"[{dataset_type}] Skip the slot [{value}] "
                        f"since the length of slot exceeds {self.max_seq_length - 1} < {len_target_id}"
                    )
                    return None  # type: ignore[return-value]
                else:
                    target_id = target_id[len_target_id - (self.max_seq_length - 1) :]
                    logger.info(
                        f"[{dataset_type}] Truncate the slot [{value}] "
                        f"since the length of slot exceeds {self.max_seq_length - 1} < {len_target_id}"
                    )
            target_id = target_id + [self.tokenizer.sep_token_id]
            target_ids.append(target_id)
            gating_id.append(self.gating2id.get(value, self.gating2id["ptr"]))
        target_ids = self.pad_ids(target_ids, self.tokenizer.pad_token_id)

        return WoSInputFeature(example.guid, input_id, segment_id, gating_id, target_ids, example.label)

    @staticmethod
    def pad_ids(arrays: List[List[int]], pad_idx: int, max_length: int = -1) -> List[List[int]]:
        if max_length < 0:
            max_length = max(list(map(len, arrays)))

        arrays = [array + [pad_idx] * (max_length - len(array)) for array in arrays]
        return arrays

    @staticmethod
    def pad_id_of_matrix(arrays: torch.Tensor, pad_idx: int, max_length: int = -1, left: bool = False) -> torch.Tensor:
        if max_length < 0:
            max_length = max([array.size(-1) for array in arrays])

        new_arrays = []
        for i, array in enumerate(arrays):
            n, length = array.size()
            pad = torch.zeros(n, (max_length - length))
            pad[
                :,
                :,
            ] = pad_idx
            pad = pad.long()
            m = torch.cat([array, pad], -1)
            new_arrays.append(m.unsqueeze(0))

        return torch.cat(new_arrays, 0)

    def get_examples_from_dialogue(self, dialogue: Dict[str, List[Dict]]) -> List[WoSInputExample]:
        dialogue_id = dialogue["guid"]
        examples = []
        history: List[str] = []
        d_idx = 0
        for idx, turn in enumerate(dialogue["dialogue"]):
            if turn["role"] != "user":
                continue

            if idx:
                sys_utter = dialogue["dialogue"][idx - 1]["text"]
            else:
                sys_utter = ""

            sys_utter = self.tokenizer.tokenize(sys_utter)
            user_utter = self.tokenizer.tokenize(turn["text"])
            state = turn["state"]
            context = deepcopy(history)
            examples.append(
                WoSInputExample(
                    guid=f"{dialogue_id}-{d_idx}",
                    context_turns=context,
                    current_turn=sys_utter + [self.tokenizer.sep_token] + user_utter,
                    label=state,
                )
            )
            if history:
                history.append(self.tokenizer.sep_token)
            history.extend(sys_utter)
            history.append(self.tokenizer.sep_token)
            history.extend(user_utter)
            d_idx += 1
        return examples

    def merge_slot_meta(self, slot_from_dial: List[str]) -> List[str]:
        exist_slot_set = set(self.slot_meta)
        for slot in slot_from_dial:
            exist_slot_set.add(slot)
        return sorted(list(exist_slot_set))

    @staticmethod
    def build_slot_from_ontology(file_path: str) -> Tuple[List[str], List[str]]:
        """Read ontology file: expected format is `DOMAIN-SLOT`"""

        domains = []
        slots = []
        with open(file_path, "r", encoding="utf-8") as ontology_file:
            for line in ontology_file:
                domain_slot = line.split("-")
                print(domain_slot)
                assert len(domain_slot) == 2
                domains.append(domain_slot[0])
                slots.append(line)
        return domains, slots

    def build_slot_meta(self, data: List[Dict[str, List[dict]]]) -> List[str]:
        slot_meta = []
        for dialog in data:
            for turn in dialog["dialogue"]:
                if not turn.get("state"):
                    continue
                for dom_slot_value in turn["state"]:
                    domain_slot, _ = self.split_slot(dom_slot_value, get_domain_slot=True)
                    if domain_slot not in slot_meta:
                        slot_meta.append(domain_slot)
        return sorted(slot_meta)

    @staticmethod
    def split_slot(dom_slot_value: str, get_domain_slot: bool = False) -> Tuple[str, ...]:
        try:
            dom, slot, value = dom_slot_value.split("-")
        except ValueError:
            tempo = dom_slot_value.split("-")
            if len(tempo) < 2:
                return dom_slot_value, dom_slot_value, dom_slot_value
            dom, slot = tempo[0], tempo[1]
            value = dom_slot_value.replace("%s-%s-" % (dom, slot), "").strip()

        if get_domain_slot:
            return "%s-%s" % (dom, slot), value
        return dom, slot, value

    def recover_state(self, gate_list: List[int], gen_list: List[List[int]]) -> List[str]:
        assert len(gate_list) == len(self.slot_meta)
        assert len(gen_list) == len(self.slot_meta)

        recovered = []
        for slot, gate, value in zip(self.slot_meta, gate_list, gen_list):
            if self.id2gating[gate] == "none":
                continue
            elif self.id2gating[gate] == "dontcare":
                recovered.append("%s-%s" % (slot, "dontcare"))
                continue
            elif self.id2gating[gate] == "yes":
                recovered.append("%s-%s" % (slot, "yes"))
                continue
            elif self.id2gating[gate] == "no":
                recovered.append("%s-%s" % (slot, "no"))
                continue
            elif self.id2gating[gate] == "ptr":
                # Append a token until special tokens appear
                token_id_list = []
                for id_ in value:
                    if id_ in self.tokenizer.all_special_ids:
                        break
                    token_id_list.append(id_)
                value = self.tokenizer.decode(token_id_list, skip_special_tokens=True)
                # This is a basic post-processing for generative DST models based on wordpiece (using punctuation split)
                value = value.replace(" : ", ":").replace(" , ", ", ").replace("##", "")
            else:
                raise ValueError(f"{self.id2gating[gate]} do not support. [none|dontcare|ptr|yes|no]")

            if value == "none":  # type: ignore[comparison-overlap]
                continue

            recovered.append("%s-%s" % (slot, value))
        return recovered

    def convert_state_dict(self, state: Sequence[str]) -> Dict[str, str]:
        dic = {}
        for slot in state:
            s, v = self.split_slot(slot, get_domain_slot=True)
            dic[s] = v
        return dic

    def collate_fn(
        self, batch: List[WoSInputFeature]
    ) -> Tuple[
        torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[str], List[Optional[List[str]]]
    ]:
        input_ids = torch.LongTensor(self.pad_ids([b.input_id for b in batch], self.tokenizer.pad_token_id))
        segment_ids = torch.LongTensor(self.pad_ids([b.segment_id for b in batch], self.tokenizer.pad_token_type_id))
        input_masks = input_ids.ne(self.tokenizer.pad_token_id)

        gating_ids = torch.LongTensor([b.gating_id for b in batch])
        target_ids = self.pad_id_of_matrix([torch.LongTensor(b.target_ids) for b in batch], self.tokenizer.pad_token_id)
        guids = [b.guid for b in batch]
        labels = [b.label for b in batch]
        return input_ids, segment_ids, input_masks, gating_ids, target_ids, guids, labels

In [None]:
dataprocessor = WoSProcessor(1024, '/content/ontology.json',tokenizer)


In [None]:
train_dataset = dataprocessor._create_dataset('/content/wos-v1.1_train.json','train')

[1;30;43m스트리밍 출력 내용이 길어서 마지막 5000줄이 삭제되었습니다.[0m
INFO:root:[train] Skip the context [wos-v1_train_01644-7] since the length of dialogue exceeds 126 < 238
INFO:root:[train] Skip the context [wos-v1_train_01644-8] since the length of dialogue exceeds 126 < 266
INFO:root:[train] Skip the context [wos-v1_train_01645-4] since the length of dialogue exceeds 126 < 146
INFO:root:[train] Skip the context [wos-v1_train_01645-5] since the length of dialogue exceeds 126 < 188
INFO:root:[train] Skip the context [wos-v1_train_01646-3] since the length of dialogue exceeds 126 < 138
INFO:root:[train] Skip the context [wos-v1_train_01646-4] since the length of dialogue exceeds 126 < 188
INFO:root:[train] Skip the context [wos-v1_train_01646-5] since the length of dialogue exceeds 126 < 205
INFO:root:[train] Skip the context [wos-v1_train_01646-6] since the length of dialogue exceeds 126 < 228
INFO:root:[train] Skip the context [wos-v1_train_01646-7] since the length of dialogue exceeds 126 < 256
INFO:

In [None]:
validation_dataset = dataprocessor._create_dataset('/content/wos-v1.1_dev.json','validation')

INFO:root:*** Example ***
INFO:root:input_id: [2, 3, 3, 11404, 4887, 4142, 2170, 3643, 22171, 2318, 1513, 13964, 18114, 2052, 2379, 16539, 16, 7862, 2138, 3922, 2223, 5971, 18, 3]
INFO:root:gating_id: [0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
INFO:root:target_ids: [[19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [11404, 4887, 4142, 3], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], [19665, 10933, 3, 0], 

In [None]:
class SlotGenerator(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        hidden_size: int,
        dropout: float,
        slot_meta: List[str],
        gating2id: Dict[str, int],
        pad_idx: int = 0,
        parallel_decoding: bool = True,
    ) -> None:
        super(SlotGenerator, self).__init__()
        self.hidden_size = hidden_size
        self.pad_idx = pad_idx
        self.slot_meta = slot_meta
        self.vocab_size = vocab_size
        self.embed = nn.Embedding(vocab_size, self.hidden_size, padding_idx=pad_idx)  # shared with encoder

        self.gru = nn.GRU(self.hidden_size, self.hidden_size, 1, dropout=dropout, batch_first=True)

        # receive gate info from processor
        self.gating2id = gating2id  # {"none": 0, "dontcare": 1, "ptr": 2, "yes":3, "no": 4}
        self.num_gates = len(self.gating2id.keys())

        self.dropout = nn.Dropout(dropout)
        self.w_gen = nn.Linear(self.hidden_size * 3, 1)
        self.sigmoid = nn.Sigmoid()
        self.w_gate = nn.Linear(self.hidden_size, self.num_gates)

        self.slot_embed_idx: List[List[int]] = []
        self.parallel_decoding = parallel_decoding

    def set_slot_idx(self, slot_vocab_idx: List[List[int]]) -> None:
        whole = []
        max_length = max(map(len, slot_vocab_idx))
        for idx in slot_vocab_idx:
            if len(idx) < max_length:
                gap = max_length - len(idx)
                idx.extend([self.pad_idx] * gap)
            whole.append(idx)
        self.slot_embed_idx = whole

    def forward(
        self,
        input_ids: torch.Tensor,
        encoder_output: torch.Tensor,
        hidden: torch.Tensor,
        input_masks: torch.Tensor,
        max_len: int,
        teacher: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        input_masks = input_masks.ne(1)
        # J, slot_meta : key : [domain, slot] ex> LongTensor([1,2])
        # J,2
        batch_size = encoder_output.size(0)
        slot = torch.LongTensor(self.slot_embed_idx).to(input_ids.device)
        # slot_embedding
        slot_e = torch.sum(self.embed(slot), 1)  # J, d
        J = slot_e.size(0)

        if self.parallel_decoding:
            all_point_outputs = torch.zeros(batch_size, J, max_len, self.vocab_size).to(input_ids.device)
            all_gate_outputs = torch.zeros(batch_size, J, self.num_gates).to(input_ids.device)

            w = slot_e.repeat(batch_size, 1).unsqueeze(1)
            hidden = hidden.repeat_interleave(J, dim=1)
            encoder_output = encoder_output.repeat_interleave(J, dim=0)
            input_ids = input_ids.repeat_interleave(J, dim=0)
            input_masks = input_masks.repeat_interleave(J, dim=0)
            num_decoding = 1

        else:
            # Seperate Decoding
            all_point_outputs = torch.zeros(J, batch_size, max_len, self.vocab_size).to(input_ids.device)
            all_gate_outputs = torch.zeros(J, batch_size, self.num_gates).to(input_ids.device)
            num_decoding = J

        for j in range(num_decoding):

            if not self.parallel_decoding:
                w = slot_e[j].expand(batch_size, 1, self.hidden_size)

            for k in range(max_len):
                w = self.dropout(w)
                _, hidden = self.gru(w, hidden)  # 1,B,D

                # B,T,D * B,D,1 => B,T
                attn_e = torch.bmm(encoder_output, hidden.permute(1, 2, 0))  # B,T,1
                attn_e = attn_e.squeeze(-1).masked_fill(input_masks, -1e9)
                attn_history = F.softmax(attn_e, -1)  # B,T

                # B,D * D,V => B,V
                attn_v = torch.matmul(hidden.squeeze(0), self.embed.weight.transpose(0, 1))  # B,V
                attn_vocab = F.softmax(attn_v, -1)

                # B,1,T * B,T,D => B,1,D
                context = torch.bmm(attn_history.unsqueeze(1), encoder_output)  # B,1,D
                p_gen = self.sigmoid(self.w_gen(torch.cat([w, hidden.transpose(0, 1), context], -1)))  # B,1
                p_gen = p_gen.squeeze(-1)

                p_context_ptr = torch.zeros_like(attn_vocab).to(input_ids.device)
                p_context_ptr.scatter_add_(1, input_ids, attn_history)  # copy B,V
                # attn_vocab : hidden --> vocab linear mapping 
                # p_context_ptr : hidden --> T_enc --> 해당 timestep 에 있는 attention weight 을 해당 input_id 위치해있는 Vocabulary logit 에 더해줌 
                #                 value 와 연관이 있는 input token 들을 강조하여 발화 text에 직접적으로 연관된 value를 예측하도록 함 
                # attn_vocab , p_context_ptr 의 비율을 p_gen으로 조절 
                p_final = p_gen * attn_vocab + (1 - p_gen) * p_context_ptr  # B,V 
                _, w_idx = p_final.max(-1)

                if teacher is not None:
                    if self.parallel_decoding:
                        w = self.embed(teacher[:, :, k]).reshape(batch_size * J, 1, -1)
                    else:
                        w = self.embed(teacher[:, j, k]).unsqueeze(1)
                else:
                    w = self.embed(w_idx).unsqueeze(1)  # B,1,D

                if k == 0:
                    gated_logit = self.w_gate(context.squeeze(1))  # B,5
                    if self.parallel_decoding:
                        all_gate_outputs = gated_logit.view(batch_size, J, self.num_gates)
                    else:
                        _, gated = gated_logit.max(1)  # maybe `-1` would be more clear
                        all_gate_outputs[j] = gated_logit

                if self.parallel_decoding:
                    all_point_outputs[:, :, k, :] = p_final.view(batch_size, J, self.vocab_size)
                else:
                    all_point_outputs[j, :, k, :] = p_final

        if not self.parallel_decoding:
            all_point_outputs = all_point_outputs.transpose(0, 1)
            all_gate_outputs = all_gate_outputs.transpose(0, 1)

        return all_point_outputs, all_gate_outputs



NameError: ignored

In [None]:
class DSTTransformer(BaseTransformer):

    mode = Mode.DialogueStateTracking
    REQUIRE_ADDITIONAL_POOLER_LAYER_BY_TYPE = ["electra"]

    def __init__(self, hparams: Union[Dict[str, Any], argparse.Namespace], metrics: dict = {}) -> None:
        if type(hparams) == dict:
            hparams = argparse.Namespace(**hparams)

        super().__init__(
            hparams,
            num_labels=None,
            mode=self.mode,
            model_type=AutoModel,
            metrics=metrics,
        )
        self.processor = hparams.processor
        hparams.processor = None

        self.teacher_forcing = self.hparams.teacher_forcing
        self.parallel_decoding = self.hparams.parallel_decoding

        self.slot_meta = self.processor.slot_meta
        self.slot_vocab = [
            self.processor.tokenizer.encode(slot.replace("-", " "), add_special_tokens=False) for slot in self.slot_meta
        ]
        # refer the vars (encoder_config, encoder) in super class (BaseTransformer)
        self.encoder_config = self.config
        self.encoder = self.model

        if self._is_require_pooler_layer():
            from transformers.modeling_bert import BertPooler

            self.encoder_pooler_layer = BertPooler(self.encoder_config)

        self.decoder = SlotGenerator(
            self.encoder_config.vocab_size,
            self.encoder_config.hidden_size,
            self.encoder_config.hidden_dropout_prob,
            self.slot_meta,
            self.processor.gating2id,
            self.processor.tokenizer.pad_token_id,
            self.parallel_decoding,
        )

        self.decoder.set_slot_idx(self.slot_vocab)
        self.tie_weight()

        self.loss_gen = masked_cross_entropy_for_value
        self.loss_gate = nn.CrossEntropyLoss()

        self.metrics = nn.ModuleDict(metrics)

    def tie_weight(self) -> None:
        """Share the embedding layer for both encoder and decoder"""
        self.decoder.embed.weight = self.encoder.embeddings.word_embeddings.weight

    def forward(
        self,
        input_ids: torch.Tensor,
        token_type_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        max_len: int = 10,
        teacher: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        # TODO: Need to be refactored before code release
        outputs_dict = self.encoder(
            input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, return_dict=True
        )
        encoder_outputs = outputs_dict["last_hidden_state"]
        if "pooler_output" in outputs_dict.keys():
            pooler_output = outputs_dict["pooler_output"]
        else:
            pooler_output = self.encoder_pooler_layer(encoder_outputs)

        all_point_outputs, all_gate_outputs = self.decoder(
            input_ids, encoder_outputs, pooler_output.unsqueeze(0), attention_mask, max_len, teacher
        )

        return all_point_outputs, all_gate_outputs

    def training_step(self, batch: Sequence[torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]:
        input_ids, segment_ids, input_masks, gating_ids, target_ids, _, _ = batch

        if self.teacher_forcing > 0.0 and random.random() < self.teacher_forcing:
            tf = target_ids
        else:
            tf = None

        all_point_outputs, all_gate_outputs = self(input_ids, segment_ids, input_masks, target_ids.size(-1), tf)
        loss_gen = self.loss_gen(
            all_point_outputs.contiguous(), target_ids.contiguous().view(-1), self.tokenizer.pad_token_id
        )
        loss_gate = self.loss_gate(
            all_gate_outputs.contiguous().view(-1, len(self.processor.gating2id.keys())),
            gating_ids.contiguous().view(-1),
        )
        loss = loss_gen + loss_gate

        self.log("train/loss", loss)
        self.log("train/loss_gen", loss_gen)
        self.log("train/loss_gate", loss_gate)

        return {"loss": loss}


In [None]:
def masked_cross_entropy_for_value(logits: torch.Tensor, target: torch.Tensor, pad_idx: int = 0) -> torch.Tensor:
    mask = target.ne(pad_idx)
    logits_flat = logits.view(-1, logits.size(-1))
    log_probs_flat = torch.log(logits_flat)
    target_flat = target.view(-1, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
    losses = losses_flat.view(*target.size())
    losses = losses * mask.float()
    loss = losses.sum() / (mask.sum().float())
    return loss