
HuggingFace Transformers를 활용한 토큰 분류 모델 학습

본 노트북에서는 `klue/roberta-base` 모델을 **KLUE** 내 **NLI** 데이터셋을 활용하여 모델을 훈련하는 예제를 다루게 됩니다.


학습 과정 이후에는 간단한 예제 코드를 통해 모델이 어떻게 활용되는지도 함께 알아보도록 할 것입니다.

모든 소스 코드는 [`huggingface-tutorial`](https://huggingface.co/course/chapter7/2)를 참고하였습니다. 

먼저, 노트북을 실행하는데 필요한 라이브러리를 설치합니다. 모델 훈련을 위해서는 `transformers`가, 학습 데이터셋 로드를 위해서는 `datasets` 라이브러리의 설치가 필요합니다. 그 외 모델 성능 검증을 위해 `scipy`, `scikit-learn`을 추가로 설치해주도록 합니다.

In [None]:
#https://towardsdatascience.com/how-to-create-and-train-a-multi-task-transformer-model-18c54a14624
#https://medium.com/@shahrukhx01/multi-task-learning-with-transformers-part-1-multi-prediction-heads-b7001cf014bf

In [1]:
!pip install  evaluate 
!pip install accelerate
# To run the training on TPU, you will need to uncomment the following line:
# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!apt install git-lfs

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting evaluate
  Downloading evaluate-0.4.0-py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 KB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess
  Downloading multiprocess-0.70.14-py39-none-any.whl (132 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.9/132.9 KB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.0.0
  Downloading datasets-2.10.1-py3-none-any.whl (469 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m469.0/469.0 KB[0m [31m29.6 MB/s[0m eta [36m0:00:00[0m
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting dill
  Downloading dill-0.3.6-py3-none-any.whl (110 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 KB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
Collecting hug

In [2]:
!pip install -U transformers datasets scipy scikit-learn

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.27.3-py3-none-any.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m84.0 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m74.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, transformers
Successfully installed tokenizers-0.13.2 transformers-4.27.3


In [None]:
#from huggingface_hub import notebook_login

#notebook_login()

## 문장 분류 모델 학습

노트북을 실행하는데 필요한 라이브러리들을 모두 임포트합니다.

In [11]:
import random
import logging
from IPython.display import display, HTML

import numpy as np
import pandas as pd
import datasets
from datasets import load_dataset, load_metric, ClassLabel, Sequence
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, AutoModel, PreTrainedTokenizer


학습에 필요한 정보를 변수로 기록합니다.

본 노트북에서는 `klue-roberta-base` 모델을 활용하지만, https://huggingface.co/klue 페이지에서 더 다양한 사전학습 언어 모델을 확인하실 수 있습니다.

학습 태스크로는 `nli`를, 배치 사이즈로는 32를 지정하겠습니다.

In [4]:
model_checkpoint = "klue/bert-base"
batch_size = 64
task = "dp"

이제 HuggingFace `datasets` 라이브러리에 등록된 KLUE 데이터셋 중, NLI 데이터를 내려받습니다.

In [None]:
#['ynat', 'sts', 'nli', 'ner', 're', 'dp', 'mrc', 'wos']
raw_datasets = load_dataset("klue", task)



  0%|          | 0/2 [00:00<?, ?it/s]

다운로드 혹은 로드 후 얻어진 `datasets` 객체를 살펴보면, 훈련 데이터와 검증 데이터가 포함되어 있는 것을 확인할 수 있습니다.

In [5]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig()


In [None]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence', 'index', 'word_form', 'lemma', 'pos', 'head', 'deprel'],
        num_rows: 10000
    })
    validation: Dataset({
        features: ['sentence', 'index', 'word_form', 'lemma', 'pos', 'head', 'deprel'],
        num_rows: 2000
    })
})

각 예시 데이터는 아래와 같이 두 개의 문장과 두 문장의 추론 관계를 라벨로 지니고 있습니다.

In [15]:

def get_dep_labels() -> List[str]:
    """
    label for dependency relations format:
    {structure}_(optional){function}
    """
    dep_labels = [
        "NP",
        "NP_AJT",
        "VP",
        "NP_SBJ",
        "VP_MOD",
        "NP_OBJ",
        "AP",
        "NP_CNJ",
        "NP_MOD",
        "VNP",
        "DP",
        "VP_AJT",
        "VNP_MOD",
        "NP_CMP",
        "VP_SBJ",
        "VP_CMP",
        "VP_OBJ",
        "VNP_CMP",
        "AP_MOD",
        "X_AJT",
        "VP_CNJ",
        "VNP_AJT",
        "IP",
        "X",
        "X_SBJ",
        "VNP_OBJ",
        "VNP_SBJ",
        "X_OBJ",
        "AP_AJT",
        "L",
        "X_MOD",
        "X_CNJ",
        "VNP_CNJ",
        "X_CMP",
        "AP_CMP",
        "AP_SBJ",
        "R",
        "NP_SVJ",
    ]
    return dep_labels


def get_pos_labels() -> List[str]:
    """label for part-of-speech tags"""

    return [
        "NNG",
        "NNP",
        "NNB",
        "NP",
        "NR",
        "VV",
        "VA",
        "VX",
        "VCP",
        "VCN",
        "MMA",
        "MMD",
        "MMN",
        "MAG",
        "MAJ",
        "JC",
        "IC",
        "JKS",
        "JKC",
        "JKG",
        "JKO",
        "JKB",
        "JKV",
        "JKQ",
        "JX",
        "EP",
        "EF",
        "EC",
        "ETN",
        "ETM",
        "XPN",
        "XSN",
        "XSV",
        "XSA",
        "XR",
        "SF",
        "SP",
        "SS",
        "SE",
        "SO",
        "SL",
        "SH",
        "SW",
        "SN",
        "NA",
    ]

데이터셋을 전반적으로 살펴보기 위한 시각화 함수를 다음과 같이 정의합니다.

In [None]:
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."

    picks = []
    
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)

        # 이미 등록된 예제가 뽑힌 경우, 다시 추출
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)

        picks.append(pick)

    # 임의로 추출된 인덱스들로 구성된 데이터 프레임 선언
    df = pd.DataFrame(dataset[picks])

    for column, typ in dataset.features.items():
        # 라벨 클래스를 스트링으로 변환
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])

    display(HTML(df.to_html()))

앞서 정의한 함수를 활용해 훈련 데이터를 살펴보도록 합시다.

이처럼 데이터를 살펴보는 것의 장점으로는 각 라벨에 어떠한 문장들이 해당하는지에 대한 감을 익힐 수 있다는데에 있습니다.


In [None]:
show_random_elements(datasets["train"])

Unnamed: 0,sentence,index,word_form,lemma,pos,head,deprel
0,현재 노조는 외환위기 극복을 위해 61세에서 58세로 단축된 정년을 공무원의 정년과 연동해 다시 연장하기로 단체협약을 4차례 맺었지만 이행되지 않고 있다고 주장하고 있다.,"[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]","[현재, 노조는, 외환위기, 극복을, 위해, 61세에서, 58세로, 단축된, 정년을, 공무원의, 정년과, 연동해, 다시, 연장하기로, 단체협약을, 4차례, 맺었지만, 이행되지, 않고, 있다고, 주장하고, 있다.]","[현재, 노조 는, 외환 위기, 극복 을, 위하 여, 61 세 에서, 58 세 로, 단축 되 ㄴ, 정년 을, 공무원 의, 정년 과, 연동 하 여, 다시, 연장 하 기 로, 단체 협약 을, 4 차례, 맺 었 지만, 이행 되 지, 않 고, 있 다고, 주장 하 고, 있 다 .]","[MAG, NNG+JX, NNG+NNG, NNG+JKO, VV+EC, SN+NNB+JKB, SN+NNB+JKB, NNG+XSV+ETM, NNG+JKO, NNG+JKG, NNG+JKB, NNG+XSV+EC, MAG, NNG+XSV+ETN+JKB, NNG+NNG+JKO, SN+NNG, VV+EP+EC, NNG+XSV+EC, VX+EC, VX+EC, NNG+XSV+EC, VX+EF+SF]","[21, 21, 4, 5, 12, 7, 8, 9, 12, 11, 12, 14, 14, 17, 17, 17, 18, 19, 20, 21, 22, 0]","[AP, NP_SBJ, NP, NP_OBJ, VP, NP_AJT, NP_AJT, VP_MOD, NP_OBJ, NP_MOD, NP_AJT, VP, AP, VP_AJT, NP_OBJ, NP_AJT, VP, VP, VP, VP_CMP, VP, VP]"
1,영어가 익숙하지 않은 분들은 아예 이용이 어려울 수 있는 정도예요.,"[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]","[영어가, 익숙하지, 않은, 분들은, 아예, 이용이, 어려울, 수, 있는, 정도예요.]","[영어 가, 익숙하 지, 않 은, 분 들 은, 아예, 이용 이, 어렵 ㄹ, 수, 있 는, 정도 이 에요 .]","[NNG+JKS, VA+EC, VX+ETM, NNB+XSN+JX, MAG, NNG+JKS, VA+ETM, NNB, VV+ETM, NNG+VCP+EF+SF]","[2, 3, 4, 7, 7, 7, 8, 9, 10, 0]","[NP_SBJ, VP, VP_MOD, NP_SBJ, AP, NP_SBJ, VP_MOD, NP_SBJ, VP_MOD, VNP]"
2,야외활동이 적은 1∼4월 월평균 20마리 정도의 동물이 버려진 것과 비교하면 2∼3배 늘어난 수치다.,"[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]","[야외활동이, 적은, 1∼4월, 월평균, 20마리, 정도의, 동물이, 버려진, 것과, 비교하면, 2∼3배, 늘어난, 수치다.]","[야외 활동 이, 적 은, 1 ∼ 4 월, 월 평균, 20 마리, 정도 의, 동물 이, 버리 어 지 ㄴ, 것 과, 비교 하 면, 2 ∼ 3 배, 늘 어 나 ㄴ, 수치 이 다 .]","[NNG+NNG+JKS, VA+ETM, SN+SO+SN+NNB, NNG+NNG, SN+NNB, NNG+JKG, NNG+JKS, VV+EC+VX+ETM, NNB+JKB, NNG+XSV+EC, SN+SO+SN+NNG, VV+EC+VX+ETM, NNG+VCP+EF+SF]","[2, 3, 8, 5, 6, 7, 8, 9, 10, 13, 12, 13, 0]","[NP_SBJ, VP_MOD, NP_AJT, NP, NP, NP_MOD, NP_SBJ, VP_MOD, NP_AJT, VP, NP_SBJ, VP_MOD, VNP]"
3,아직 지진으로 인한 피해상황은 전해지지 않고 있다.,"[1, 2, 3, 4, 5, 6, 7]","[아직, 지진으로, 인한, 피해상황은, 전해지지, 않고, 있다.]","[아직, 지진 으로, 인하 ㄴ, 피해 상황 은, 전하 여 지 지, 않 고, 있 다 .]","[MAG, NNG+JKB, VV+ETM, NNG+NNG+JX, VV+EC+VX+EC, VX+EC, VX+EF+SF]","[5, 3, 4, 5, 6, 7, 0]","[AP, NP_AJT, VP_MOD, NP_SBJ, VP, VP, VP]"
4,공항공사 노동조합과 용산참사 유족 등이 한국공항공사 신임 사장으로 임명된 김석기 전 서울지방경찰청장의 출근을 저지하기 위해 사옥 앞에 서 있습니다.,"[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]","[공항공사, 노동조합과, 용산참사, 유족, 등이, 한국공항공사, 신임, 사장으로, 임명된, 김석기, 전, 서울지방경찰청장의, 출근을, 저지하기, 위해, 사옥, 앞에, 서, 있습니다.]","[공항공사, 노동 조합 과, 용산 참사, 유족, 등 이, 한국공항공사, 신임, 사장 으로, 임명 되 ㄴ, 김석기, 전, 서울지방경찰청장 의, 출근 을, 저지 하 기, 위하 여, 사옥, 앞 에, 서 어, 있 습니다 .]","[NNP, NNG+NNG+JC, NNP+NNG, NNG, NNB+JKS, NNP, NNG, NNG+JKB, NNG+XSV+ETM, NNP, MMD, NNP+JKG, NNG+JKO, NNG+XSV+ETN, VV+EC, NNG, NNG+JKB, VV+EC, VX+EF+SF]","[2, 4, 4, 5, 15, 8, 8, 9, 12, 12, 12, 13, 14, 15, 18, 17, 18, 19, 0]","[NP, NP_CNJ, NP, NP, NP_SBJ, NP, NP, NP_AJT, VP_MOD, NP, DP, NP_MOD, NP_OBJ, VP_OBJ, VP, NP, NP_AJT, VP, VP]"
5,LG전자는 스마트폰 광고 모델로 '체조요정' 손연재 선수를 선정했다고 16일 밝혔다.,"[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]","[LG전자는, 스마트폰, 광고, 모델로, '체조요정', 손연재, 선수를, 선정했다고, 16일, 밝혔다.]","[LG 전자 는, 스마트 폰, 광고, 모델 로, ' 체조 요정 ', 손연재, 선수 를, 선정 하 였 다고, 16 일, 밝히 었 다 .]","[SL+NNG+JX, NNG+NNG, NNG, NNG+JKB, SS+NNG+NNG+SS, NNP, NNG+JKO, NNG+XSV+EP+EC, SN+NNB, VV+EP+EF+SF]","[10, 4, 4, 8, 7, 7, 8, 10, 10, 0]","[NP_SBJ, NP, NP, NP_AJT, NP, NP, NP_OBJ, VP_AJT, NP_AJT, VP]"
6,일단 런던에서 이 정도 가격에 이런 숙소를 쓸 수 있다는게 너무 좋았어요.,"[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]","[일단, 런던에서, 이, 정도, 가격에, 이런, 숙소를, 쓸, 수, 있다는게, 너무, 좋았어요.]","[일단, 런던 에서, 이, 정도, 가격 에, 이런, 숙소 를, 쓰 ㄹ, 수, 있 다는 것 이, 너무, 좋 았 어요 .]","[MAG, NNP+JKB, MMD, NNG, NNG+JKB, MMD, NNG+JKO, VV+ETM, NNB, VA+ETM+NNB+JKS, MAG, VA+EP+EF+SF]","[12, 8, 4, 5, 8, 7, 8, 9, 10, 12, 12, 0]","[AP, NP_AJT, DP, NP, NP_AJT, DP, NP_OBJ, VP_MOD, NP_SBJ, NP_SBJ, AP, VP]"
7,양측은 판문점 남측 지역 평화의 집에서 추석 계기 이산가족 상봉 행사 등을 논의하는 무박 2일의 적십자 실무접촉을 갖고 이런 내용이 포함된 2개항의 합의서를 채택했다.,"[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]","[양측은, 판문점, 남측, 지역, 평화의, 집에서, 추석, 계기, 이산가족, 상봉, 행사, 등을, 논의하는, 무박, 2일의, 적십자, 실무접촉을, 갖고, 이런, 내용이, 포함된, 2개항의, 합의서를, 채택했다.]","[양측 은, 판문점, 남측, 지역, 평화 의, 집 에서, 추석, 계기, 이산가족, 상봉, 행사, 등 을, 논의 하 는, 무 박, 2 일 의, 적십자, 실무 접촉 을, 갖 고, 이런, 내용 이, 포함 되 ㄴ, 2 개 항 의, 합의서 를, 채택 하 였 다 .]","[NNG+JX, NNP, NNG, NNG, NNG+JKG, NNG+JKB, NNG, NNG, NNG, NNG, NNG, NNB+JKO, NNG+XSV+ETM, XPN+NNG, SN+NNB+JKG, NNG, NNG+NNG+JKO, VV+EC, MMD, NNG+JKS, NNG+XSV+ETM, SN+NNB+NNG+JKG, NNG+JKO, NNG+XSV+EP+EF+SF]","[18, 3, 4, 6, 6, 18, 8, 11, 10, 11, 12, 13, 17, 15, 17, 17, 18, 24, 20, 21, 23, 23, 24, 0]","[NP_SBJ, NP, NP, NP, NP_MOD, NP_AJT, NP, NP, NP, NP, NP, NP_OBJ, VP_MOD, NP, NP_MOD, NP, NP_OBJ, VP, DP, NP_SBJ, VP_MOD, NP_MOD, NP_OBJ, VP]"
8,"위치, 물 수압, 호스트 모두모두 아주 좋습니다.","[1, 2, 3, 4, 5, 6, 7]","[위치,, 물, 수압,, 호스트, 모두모두, 아주, 좋습니다.]","[위치 ,, 물, 수압 ,, 호스트, 모두 모두, 아주, 좋 습니다 .]","[NNG+SP, NNG, NNG+SP, NNG, MAG+MAG, MAG, VA+EF+SF]","[4, 3, 4, 7, 7, 7, 0]","[NP_CNJ, NP, NP_CNJ, NP_SBJ, AP, AP, VP]"
9,가까운 예로 지난달 현 정권 실세비리를 폭로한 이국철 SLS그룹 회장의 항소심을 심리하던 서울고법은 구속만기일에 임박하자 보석 심문을 통해 직권으로 풀어준 적이 있다.,"[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22]","[가까운, 예로, 지난달, 현, 정권, 실세비리를, 폭로한, 이국철, SLS그룹, 회장의, 항소심을, 심리하던, 서울고법은, 구속만기일에, 임박하자, 보석, 심문을, 통해, 직권으로, 풀어준, 적이, 있다.]","[가깝 ㄴ, 예 로, 지나 ㄴ 달, 현, 정권, 실세 비리 를, 폭로 하 ㄴ, 이국철, SLS 그룹, 회장 의, 항소심 을, 심리 하 던, 서울고법 은, 구속 만기일 에, 임박 하 자, 보석, 심문 을, 통하 여, 직권 으로, 풀 어 주 ㄴ, 적 이, 있 다 .]","[VA+ETM, NNG+JKB, VV+ETM+NNG, MMD, NNG, NNG+NNG+JKO, NNG+XSV+ETM, NNP, SL+NNG, NNG+JKG, NNG+JKO, NNG+XSV+ETM, NNP+JX, NNG+NNG+JKB, NNG+XSV+EC, NNG, NNG+JKO, VV+EC, NNG+JKB, VV+EC+VX+ETM, NNB+JKS, VV+EF+SF]","[2, 15, 7, 5, 6, 7, 10, 9, 10, 11, 12, 13, 15, 15, 18, 17, 18, 20, 20, 21, 22, 0]","[VP_MOD, NP_AJT, NP_AJT, DP, NP, NP_OBJ, VP_MOD, NP, NP, NP_MOD, NP_OBJ, VP_MOD, NP_SBJ, NP_AJT, VP, NP, NP_OBJ, VP, NP_AJT, VP_MOD, NP_SBJ, VP]"


훈련 과정 중 모델의 성능을 파악하기 위한 메트릭을 설정합니다.

`datasets` 라이브러리에는 이미 구현된 메트릭을 사용할 수 있는 `load_metric` 함수가 있습니다.


In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True) 

Downloading (…)okenizer_config.json:   0%|          | 0.00/289 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/425 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/248k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/495k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

In [7]:
import argparse
import logging
import os
from typing import Any, List, Optional, Tuple


In [16]:
pos_labels = get_pos_labels()
dep_labels = get_dep_labels()

In [8]:
import torch

In [9]:
from torch.utils.data import DataLoader, TensorDataset


In [12]:

class KlueDPInputExample:
    """A single training/test example for Dependency Parsing in .conllu format
    Args:
        guid : Unique id for the example
        text : string. the original form of sentence
        token_id : token id
        token : 어절
        pos : POS tag(s)
        head : dependency head
        dep : dependency relation
    """

    def __init__(
        self, guid: str, text: str, sent_id: int, token_id: int, token: str, pos: str, head: str, dep: str
    ) -> None:
        self.guid = guid
        self.text = text
        self.sent_id = sent_id
        self.token_id = token_id
        self.token = token
        self.pos = pos
        self.head = head
        self.dep = dep


class KlueDPInputFeatures:
    """A single set of features of data. Property names are the same names as the corresponding inputs to a model.
    Args:
        input_ids: Indices of input sequence tokens in the vocabulary.
        attention_mask: Mask to avoid performing attention on padding token indices.
            Mask values selected in ``[0, 1]``: Usually ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded)
            tokens.
        bpe_head_mask : Mask to mark the head token of bpe in aejeol
        head_ids : head ids for each aejeols on head token index
        dep_ids : dependecy relations for each aejeols on head token index
        pos_ids : pos tag for each aejeols on head token index
    """

    def __init__(
        self,
        guid: str,
        ids: List[int],
        mask: List[int],
        bpe_head_mask: List[int],
        bpe_tail_mask: List[int],
        head_ids: List[int],
        dep_ids: List[int],
        pos_ids: List[int],
    ) -> None:
        self.guid = guid
        self.input_ids = ids
        self.attention_mask = mask
        self.bpe_head_mask = bpe_head_mask
        self.bpe_tail_mask = bpe_tail_mask
        self.head_ids = head_ids
        self.dep_ids = dep_ids
        self.pos_ids = pos_ids


class KlueDPProcessor:

    origin_train_file_name = "klue-dp-v1.1_train.tsv"
    origin_dev_file_name = "klue-dp-v1.1_dev.tsv"
    origin_test_file_name = "klue-dp-v1.1_test.tsv"


    def __init__(self, max_seq_length: int, tokenizer: PreTrainedTokenizer) -> None:
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length

    def _create_examples(self, file_path: str, dataset_type: str) -> List[KlueDPInputExample]:
        sent_id = -1
        examples = []
        with open(file_path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if line == "" or line == "\n" or line == "\t":
                    continue
                if line.startswith("#"):
                    parsed = line.strip().split("\t")
                    if len(parsed) != 2:  # metadata line about dataset
                        continue
                    else:
                        sent_id += 1
                        text = parsed[1].strip()
                        guid = parsed[0].replace("##", "").strip()
                else:
                    token_list = [token.replace("\n", "") for token in line.split("\t")] + ["-", "-"]
                    examples.append(
                        KlueDPInputExample(
                            guid=guid,
                            text=text,
                            sent_id=sent_id,
                            token_id=int(token_list[0]),
                            token=token_list[1],
                            pos=token_list[3],
                            head=token_list[4],
                            dep=token_list[5],
                        )
                    )
        return examples

    def convert_examples_to_features(
        self,
        examples: List[KlueDPInputExample],
        tokenizer: PreTrainedTokenizer,
        max_length: int,
        pos_label_list: List[str],
        dep_label_list: List[str],
    ) -> List[KlueDPInputFeatures]:

        pos_label_map = {label: i for i, label in enumerate(pos_label_list)}
        dep_label_map = {label: i for i, label in enumerate(dep_label_list)}

        SENT_ID = 0

        token_list: List[str] = []
        pos_list: List[str] = []
        head_list: List[int] = []
        dep_list: List[str] = []

        features = []
        for example in examples:
            if SENT_ID != example.sent_id:
                SENT_ID = example.sent_id
                encoded = tokenizer.encode_plus(
                    " ".join(token_list),
                    None,
                    add_special_tokens=True,
                    max_length=max_length,
                    truncation=True,
                    padding="max_length",
                )

                ids, mask = encoded["input_ids"], encoded["attention_mask"]

                bpe_head_mask = [0]
                bpe_tail_mask = [0]
                head_ids = [-100]
                dep_ids = [-100]
                pos_ids = [-100]  # --> CLS token

                for token, head, dep, pos in zip(token_list, head_list, dep_list, pos_list):
                    bpe_len = len(tokenizer.tokenize(token))
                    head_token_mask = [1] + [0] * (bpe_len - 1)
                    tail_token_mask = [0] * (bpe_len - 1) + [1]
                    bpe_head_mask.extend(head_token_mask)
                    bpe_tail_mask.extend(tail_token_mask)

                    head_mask = [head] + [-100] * (bpe_len - 1)
                    head_ids.extend(head_mask)
                    dep_mask = [dep_label_map[dep]] + [-100] * (bpe_len - 1)
                    dep_ids.extend(dep_mask)
                    pos_mask = [pos_label_map[pos]] + [-100] * (bpe_len - 1)
                    pos_ids.extend(pos_mask)

                bpe_head_mask.append(0)
                bpe_tail_mask.append(0)
                head_ids.append(-100)
                dep_ids.append(-100)
                pos_ids.append(-100)  # END token
                if len(bpe_head_mask) > max_length:
                    bpe_head_mask = bpe_head_mask[:max_length]
                    bpe_tail_mask = bpe_tail_mask[:max_length]
                    head_ids = head_ids[:max_length]
                    dep_ids = dep_ids[:max_length]
                    pos_ids = pos_ids[:max_length]

                else:
                    bpe_head_mask.extend([0] * (max_length - len(bpe_head_mask)))  # padding by max_len
                    bpe_tail_mask.extend([0] * (max_length - len(bpe_tail_mask)))  # padding by max_len
                    head_ids.extend([-100] * (max_length - len(head_ids)))  # padding by max_len
                    dep_ids.extend([-100] * (max_length - len(dep_ids)))  # padding by max_len
                    pos_ids.extend([-100] * (max_length - len(pos_ids)))

                feature = KlueDPInputFeatures(
                    guid=example.guid,
                    ids=ids,
                    mask=mask,
                    bpe_head_mask=bpe_head_mask,
                    bpe_tail_mask=bpe_tail_mask,
                    head_ids=head_ids,
                    dep_ids=dep_ids,
                    pos_ids=pos_ids,
                )
                features.append(feature)

                token_list = []
                pos_list = []
                head_list = []
                dep_list = []

            token_list.append(example.token)
            pos_list.append(example.pos.split("+")[-1])  # 맨 뒤 pos정보만 사용
            head_list.append(int(example.head))
            dep_list.append(example.dep)

        encoded = tokenizer.encode_plus(
            " ".join(token_list),
            None,
            add_special_tokens=True,
            max_length=max_length,
            truncation=True,
            padding="max_length",
        )

        ids, mask = encoded["input_ids"], encoded["attention_mask"]

        bpe_head_mask = [0]
        bpe_tail_mask = [0]
        head_ids = [-100]
        dep_ids = [-100]
        pos_ids = [-100]  # --> CLS token

        for token, head, dep, pos in zip(token_list, head_list, dep_list, pos_list):
            bpe_len = len(tokenizer.tokenize(token))
            head_token_mask = [1] + [0] * (bpe_len - 1)
            tail_token_mask = [0] * (bpe_len - 1) + [1]
            bpe_head_mask.extend(head_token_mask)
            bpe_tail_mask.extend(tail_token_mask)

            head_mask = [head] + [-100] * (bpe_len - 1)
            head_ids.extend(head_mask)
            dep_mask = [dep_label_map[dep]] + [-100] * (bpe_len - 1)
            dep_ids.extend(dep_mask)
            pos_mask = [pos_label_map[pos]] + [-100] * (bpe_len - 1)
            pos_ids.extend(pos_mask)

        bpe_head_mask.append(0)
        bpe_tail_mask.append(0)
        head_ids.append(-100)
        dep_ids.append(-100)  # END token
        bpe_head_mask.extend([0] * (max_length - len(bpe_head_mask)))  # padding by max_len
        bpe_tail_mask.extend([0] * (max_length - len(bpe_tail_mask)))  # padding by max_len
        head_ids.extend([-100] * (max_length - len(head_ids)))  # padding by max_len
        dep_ids.extend([-100] * (max_length - len(dep_ids)))  # padding by max_len
        pos_ids.extend([-100] * (max_length - len(pos_ids)))

        feature = KlueDPInputFeatures(
            guid=example.guid,
            ids=ids,
            mask=mask,
            bpe_head_mask=bpe_head_mask,
            bpe_tail_mask=bpe_tail_mask,
            head_ids=head_ids,
            dep_ids=dep_ids,
            pos_ids=pos_ids,
        )
        features.append(feature)

        for feature in features[:3]:
            logger.info("*** Example ***")
            logger.info("input_ids: %s" % feature.input_ids)
            logger.info("attention_mask: %s" % feature.attention_mask)
            logger.info("bpe_head_mask: %s" % feature.bpe_head_mask)
            logger.info("bpe_tail_mask: %s" % feature.bpe_tail_mask)
            logger.info("head_id: %s" % feature.head_ids)
            logger.info("dep_ids: %s" % feature.dep_ids)
            logger.info("pos_ids: %s" % feature.pos_ids)

        return features

    def _convert_features(self, examples: List[KlueDPInputExample]) -> List[KlueDPInputFeatures]:
        return self.convert_examples_to_features(
            examples,
            self.tokenizer,
            max_length=self.max_seq_length,
            dep_label_list=get_dep_labels(),
            pos_label_list=get_pos_labels(),
        )

    def _create_dataset(self, file_path: str, dataset_type: str) -> TensorDataset:
        examples = self._create_examples(file_path, dataset_type)
        features = self._convert_features(examples)

        all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
        all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
        all_bpe_head_mask = torch.tensor([f.bpe_head_mask for f in features], dtype=torch.long)
        all_bpe_tail_mask = torch.tensor([f.bpe_tail_mask for f in features], dtype=torch.long)
        all_head_ids = torch.tensor([f.head_ids for f in features], dtype=torch.long)
        all_dep_ids = torch.tensor([f.dep_ids for f in features], dtype=torch.long)
        all_pos_ids = torch.tensor([f.pos_ids for f in features], dtype=torch.long)

        return TensorDataset(
            all_input_ids,
            all_attention_mask,
            all_bpe_head_mask,
            all_bpe_tail_mask,
            all_head_ids,
            all_dep_ids,
            all_pos_ids,
        )

    def collate_fn(self, batch: List[Tuple]) -> Tuple[torch.Tensor, Any, Any, Any]:
        # 1. set args
        batch_size = len(batch)
        pos_padding_idx = None if self.hparams.no_pos else len(get_pos_labels())
        # 2. build inputs : input_ids, attention_mask, bpe_head_mask, bpe_tail_mask
        batch_input_ids = []
        batch_attention_masks = []
        batch_bpe_head_masks = []
        batch_bpe_tail_masks = []
        for batch_id in range(batch_size):
            (
                input_id,
                attention_mask,
                bpe_head_mask,
                bpe_tail_mask,
                _,
                _,
                _,
            ) = batch[batch_id]
            batch_input_ids.append(input_id)
            batch_attention_masks.append(attention_mask)
            batch_bpe_head_masks.append(bpe_head_mask)
            batch_bpe_tail_masks.append(bpe_tail_mask)
        # 2. build inputs : packing tensors
        # 나는 밥을 먹는다. => [CLS] 나 ##는 밥 ##을 먹 ##는 ##다 . [SEP]
        # input_id : [2, 717, 2259, 1127, 2069, 1059, 2259, 2062, 18, 3, 0, 0, ...]
        # bpe_head_mask : [0, 1, 0, 1, 0, 1, 0, 0, 0, 0, ...] (indicate word start (head) idx)
        input_ids = torch.stack(batch_input_ids)
        attention_masks = torch.stack(batch_attention_masks)
        bpe_head_masks = torch.stack(batch_bpe_head_masks)
        bpe_tail_masks = torch.stack(batch_bpe_tail_masks)
        # 3. token_to_words : set in-batch max_word_length
        max_word_length = max(torch.sum(bpe_head_masks, dim=1)).item()
        # 3. token_to_words : placeholders
        head_ids = torch.zeros(batch_size, max_word_length).long()
        type_ids = torch.zeros(batch_size, max_word_length).long()
        pos_ids = torch.zeros(batch_size, max_word_length + 1).long()
        mask_e = torch.zeros(batch_size, max_word_length + 1).long()
        # 3. token_to_words : head_ids, type_ids, pos_ids, mask_e, mask_d
        for batch_id in range(batch_size):
            (
                _,
                _,
                bpe_head_mask,
                _,
                token_head_ids,
                token_type_ids,
                token_pos_ids,
            ) = batch[batch_id]
            # head_id : [1, 3, 5] (prediction candidates)
            # token_head_ids : [-1, 3, -1, 3, -1, 0, -1, -1, -1, .-1, ...] (ground truth head ids)
            head_id = [i for i, token in enumerate(bpe_head_mask) if token == 1]
            word_length = len(head_id)
            head_id.extend([0] * (max_word_length - word_length))
            head_ids[batch_id] = token_head_ids[head_id]
            type_ids[batch_id] = token_type_ids[head_id]

            pos_ids[batch_id][0] = torch.tensor(pos_padding_idx)
            pos_ids[batch_id][1:] = token_pos_ids[head_id]
            pos_ids[batch_id][int(torch.sum(bpe_head_mask)) + 1 :] = torch.tensor(pos_padding_idx)
            mask_e[batch_id] = torch.LongTensor([1] * (word_length + 1) + [0] * (max_word_length - word_length))
        mask_d = mask_e[:, 1:]
        # 4. pack everything
        masks = (attention_masks, bpe_head_masks, bpe_tail_masks, mask_e, mask_d)
        ids = (head_ids, type_ids, pos_ids)

        return input_ids, masks, ids, max_word_length


In [17]:
dataprocessor = KlueDPProcessor(128, tokenizer)






In [18]:
train_dataset = dataprocessor._create_dataset('/content/klue-dp-v1.1_train.tsv','train')

INFO:root:*** Example ***
INFO:root:input_ids: [2, 4162, 4238, 2069, 1160, 2460, 14834, 6717, 7285, 6664, 27562, 6539, 25286, 2079, 6711, 15351, 8673, 2151, 4895, 16, 10879, 1283, 3781, 2069, 8631, 17807, 2371, 2062, 18, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
INFO:root:attention_mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
INFO:root:bpe_head_mask: [0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 

In [19]:
validation_dataset = dataprocessor._create_dataset('/content/klue-dp-v1.1_dev.tsv','validation')

INFO:root:*** Example ***
INFO:root:input_ids: [2, 11, 47, 3360, 4889, 2195, 115, 13668, 2142, 2052, 27333, 2364, 2079, 8209, 2170, 4998, 2069, 12823, 2062, 18, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
INFO:root:attention_mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
INFO:root:bpe_head_mask: [0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 

In [None]:

class DPTransformer(BaseTransformer):

    mode = Mode.DependencyParsing

    def __init__(self, hparams: Union[Dict[str, Any], argparse.Namespace], metrics: dict = {}) -> None:
        if type(hparams) == dict:
            hparams = argparse.Namespace(**hparams)

        super().__init__(
            hparams,
            num_labels=None,
            mode=self.mode,
            model_type=AutoModel,
            metrics=metrics,
        )

        self.hidden_size = hparams.hidden_size
        self.input_size = self.model.config.hidden_size
        self.arc_space = hparams.arc_space
        self.type_space = hparams.type_space

        self.n_pos_labels = len(get_pos_labels())
        self.n_dp_labels = len(get_dep_labels())

        if hparams.no_pos:
            self.pos_embedding = None
        else:
            self.pos_embedding = nn.Embedding(self.n_pos_labels + 1, hparams.pos_dim)

        enc_dim = self.input_size * 2
        if self.pos_embedding is not None:
            enc_dim += hparams.pos_dim

        self.encoder = nn.LSTM(
            enc_dim,
            self.hidden_size,
            hparams.encoder_layers,
            batch_first=True,
            dropout=0.33,
            bidirectional=True,
        )
        self.decoder = nn.LSTM(
            self.hidden_size, self.hidden_size, hparams.decoder_layers, batch_first=True, dropout=0.33
        )

        self.dropout = nn.Dropout2d(p=0.33)

        self.src_dense = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.hx_dense = nn.Linear(self.hidden_size * 2, self.hidden_size)

        self.arc_c = nn.Linear(self.hidden_size * 2, self.arc_space)
        self.type_c = nn.Linear(self.hidden_size * 2, self.type_space)
        self.arc_h = nn.Linear(self.hidden_size, self.arc_space)
        self.type_h = nn.Linear(self.hidden_size, self.type_space)

        self.attention = BiAttention(self.arc_space, self.arc_space, 1)
        self.bilinear = BiLinear(self.type_space, self.type_space, self.n_dp_labels)

    @overrides
    def forward(
        self,
        bpe_head_mask: torch.Tensor,
        bpe_tail_mask: torch.Tensor,
        pos_ids: torch.Tensor,
        head_ids: torch.Tensor,
        max_word_length: int,
        mask_e: torch.Tensor,
        mask_d: torch.Tensor,
        batch_index: torch.Tensor,
        is_training: bool = True,
        **inputs: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        outputs = self.model(**inputs)
        outputs = outputs[0]
        outputs, sent_len = self.resize_outputs(outputs, bpe_head_mask, bpe_tail_mask, max_word_length)

        if self.pos_embedding is not None:
            pos_outputs = self.pos_embedding(pos_ids)
            pos_outputs = self.dropout(pos_outputs)
            outputs = torch.cat([outputs, pos_outputs], dim=2)

        # encoder
        packed_outputs = pack_padded_sequence(outputs, sent_len, batch_first=True, enforce_sorted=False)
        encoder_outputs, hn = self.encoder(packed_outputs)
        encoder_outputs, outputs_len = pad_packed_sequence(encoder_outputs, batch_first=True)
        encoder_outputs = self.dropout(encoder_outputs.transpose(1, 2)).transpose(1, 2)  # apply dropout for last layer
        hn = self._transform_decoder_init_state(hn)

        # decoder
        src_encoding = F.elu(self.src_dense(encoder_outputs[:, 1:]))
        sent_len = [i - 1 for i in sent_len]
        packed_outputs = pack_padded_sequence(src_encoding, sent_len, batch_first=True, enforce_sorted=False)
        decoder_outputs, _ = self.decoder(packed_outputs, hn)
        decoder_outputs, outputs_len = pad_packed_sequence(decoder_outputs, batch_first=True)
        decoder_outputs = self.dropout(decoder_outputs.transpose(1, 2)).transpose(1, 2)  # apply dropout for last layer

        # compute output for arc and type
        arc_c = F.elu(self.arc_c(encoder_outputs))
        type_c = F.elu(self.type_c(encoder_outputs))

        arc_h = F.elu(self.arc_h(decoder_outputs))
        type_h = F.elu(self.type_h(decoder_outputs))

        out_arc = self.attention(arc_h, arc_c, mask_d=mask_d, mask_e=mask_e).squeeze(dim=1)

        # use predicted head_ids when validation step
        if not is_training:
            head_ids = torch.argmax(out_arc, dim=2)

        type_c = type_c[batch_index, head_ids.data.t()].transpose(0, 1).contiguous()
        out_type = self.bilinear(type_h, type_c)

        return out_arc, out_type

    @overrides
    def training_step(self, batch: List[torch.Tensor], batch_idx: int) -> dict:
        input_ids, masks, ids, max_word_length = batch
        attention_mask, bpe_head_mask, bpe_tail_mask, mask_e, mask_d = masks
        head_ids, type_ids, pos_ids = ids
        inputs = {"input_ids": input_ids, "attention_mask": attention_mask}

        batch_size = head_ids.size()[0]
        batch_index = torch.arange(0, int(batch_size)).long()
        head_index = (
            torch.arange(0, max_word_length).view(max_word_length, 1).expand(max_word_length, batch_size).long()
        )

        # forward
        out_arc, out_type = self(
            bpe_head_mask, bpe_tail_mask, pos_ids, head_ids, max_word_length, mask_e, mask_d, batch_index, **inputs
        )

        # compute loss
        minus_inf = -1e8
        minus_mask_d = (1 - mask_d) * minus_inf
        minus_mask_e = (1 - mask_e) * minus_inf
        out_arc = out_arc + minus_mask_d.unsqueeze(2) + minus_mask_e.unsqueeze(1)

        loss_arc = F.log_softmax(out_arc, dim=2)
        loss_type = F.log_softmax(out_type, dim=2)

        loss_arc = loss_arc * mask_d.unsqueeze(2) * mask_e.unsqueeze(1)
        loss_type = loss_type * mask_d.unsqueeze(2)
        num = mask_d.sum()

        loss_arc = loss_arc[batch_index, head_index, head_ids.data.t()].transpose(0, 1)
        loss_type = loss_type[batch_index, head_index, type_ids.data.t()].transpose(0, 1)
        loss_arc = -loss_arc.sum() / num
        loss_type = -loss_type.sum() / num
        loss = loss_arc + loss_type

        self.log("train/loss_arc", loss_arc)
        self.log("train/loss_type", loss_type)
        self.log("train/loss", loss)

        return {"loss": loss}

    @overrides
    def validation_step(self, batch: List[torch.Tensor], batch_idx: int, data_type: str = "valid") -> dict:
        input_ids, masks, ids, max_word_length = batch
        attention_mask, bpe_head_mask, bpe_tail_mask, mask_e, mask_d = masks
        head_ids, type_ids, pos_ids = ids
        inputs = {"input_ids": input_ids, "attention_mask": attention_mask}

        batch_index = torch.arange(0, head_ids.size()[0]).long()

        out_arc, out_type = self(
            bpe_head_mask,
            bpe_tail_mask,
            pos_ids,
            head_ids,
            max_word_length,
            mask_e,
            mask_d,
            batch_index,
            is_training=False,
            **inputs,
        )

        # predict arc and its type
        heads = torch.argmax(out_arc, dim=2)
        types = torch.argmax(out_type, dim=2)

        preds = DPResult(heads, types)
        labels = DPResult(head_ids, type_ids)

        return {"preds": preds, "labels": labels}

    @overrides
    def validation_epoch_end(
        self, outputs: List[Dict[str, DPResult]], data_type: str = "valid", write_predictions: bool = False
    ) -> None:
        all_preds = []
        all_labels = []
        for output in zip(outputs):
            all_preds.append(output[0]["preds"])
            all_labels.append(output[0]["labels"])

        if write_predictions is True:
            self.write_prediction_file(all_preds, all_labels)

        self._set_metrics_device()
        for k, metric in self.metrics.items():
            metric(all_preds, all_labels)
            self.log(f"{data_type}/{k}", metric, on_step=False, on_epoch=True, logger=True)

    def write_prediction_file(self, prs: List[DPResult], gts: List[DPResult]) -> None:
        """Write head, head type predictions and corresponding labels to json file. Each line indicates a word."""
        head_preds, type_preds, head_labels, type_labels = self._flatten_prediction_and_labels(prs, gts)
        save_path = self.output_dir.joinpath("transformers/pred")
        if not os.path.exists(save_path):
            os.makedirs(save_path, exist_ok=True)
        with open(os.path.join(save_path, f"pred-{self.step_count}.json"), "w", encoding="utf-8") as f:
            for h, t, hl, tl in zip(head_preds, type_preds, head_labels, type_labels):
                f.write(" ".join([str(h), str(t), str(hl), str(tl)]) + "\n")

    def _flatten_prediction_and_labels(
        self, preds: List[DPResult], labels: List[DPResult]
    ) -> Tuple[List, List, List, List]:
        """Convert prediction and labels to np.array and remove -1s."""
        head_pred_list = list()
        head_label_list = list()
        type_pred_list = list()
        type_label_list = list()
        for pred, label in zip(preds, labels):
            head_pred_list += pred.heads.cpu().flatten().tolist()
            head_label_list += label.heads.cpu().flatten().tolist()
            type_pred_list += pred.types.cpu().flatten().tolist()
            type_label_list += label.types.cpu().flatten().tolist()
        head_preds = np.array(head_pred_list)
        head_labels = np.array(head_label_list)
        type_preds = np.array(type_pred_list)
        type_labels = np.array(type_label_list)

        index = [i for i, label in enumerate(head_labels) if label == -1]
        head_preds = np.delete(head_preds, index)
        head_labels = np.delete(head_labels, index)
        index = [i for i, label in enumerate(type_labels) if label == -1]
        type_preds = np.delete(type_preds, index)
        type_labels = np.delete(type_labels, index)

        return (
            head_preds.tolist(),
            type_preds.tolist(),
            head_labels.tolist(),
            type_labels.tolist(),
        )

    @pl.utilities.rank_zero_only
    def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        save_path = self.output_dir.joinpath("transformers")
        if not os.path.exists(save_path):
            os.makedirs(save_path, exist_ok=True)
        self.config.save_step = self.step_count
        torch.save(self.state_dict(), save_path.joinpath("dp-model.bin"))
        self.config.save_pretrained(save_path)
        self.tokenizer.save_pretrained(save_path)

    @staticmethod
    def add_specific_args(parser: argparse.ArgumentParser, root_dir: str) -> argparse.ArgumentParser:
        BaseTransformer.add_specific_args(parser, root_dir)
        parser.add_argument("--encoder_layers", default=1, type=int, help="Number of layers of encoder")
        parser.add_argument("--decoder_layers", default=1, type=int, help="Number of layers of decoder")
        parser.add_argument("--hidden_size", default=768, type=int, help="Number of hidden units in LSTM")
        parser.add_argument("--arc_space", default=512, type=int, help="Dimension of tag space")
        parser.add_argument("--type_space", default=256, type=int, help="Dimension of tag space")
        parser.add_argument("--no_pos", action="store_true", help="Do not use pos feature in head layers")
        parser.add_argument("--pos_dim", default=256, type=int, help="Dimension of pos embedding")
        args = parser.parse_args()
        if not args.no_pos and args.pos_dim <= 0:
            parser.error("--pos_dim should be a positive integer when --no_pos is False.")
        return parser

    def resize_outputs(
        self, outputs: torch.Tensor, bpe_head_mask: torch.Tensor, bpe_tail_mask: torch.Tensor, max_word_length: int
    ) -> Tuple[torch.Tensor, List]:
        """Resize output of pre-trained transformers (bsz, max_token_length, hidden_dim) to word-level outputs (bsz, max_word_length, hidden_dim*2). """
        batch_size, input_size, hidden_size = outputs.size()
        word_outputs = torch.zeros(batch_size, max_word_length + 1, hidden_size * 2).to(outputs.device)
        sent_len = list()

        for batch_id in range(batch_size):
            head_ids = [i for i, token in enumerate(bpe_head_mask[batch_id]) if token == 1]
            tail_ids = [i for i, token in enumerate(bpe_tail_mask[batch_id]) if token == 1]
            assert len(head_ids) == len(tail_ids)

            word_outputs[batch_id][0] = torch.cat(
                (outputs[batch_id][0], outputs[batch_id][0])
            )  # replace root with [CLS]
            for i, (head, tail) in enumerate(zip(head_ids, tail_ids)):
                word_outputs[batch_id][i + 1] = torch.cat((outputs[batch_id][head], outputs[batch_id][tail]))
            sent_len.append(i + 2)

        return word_outputs, sent_len

    def _transform_decoder_init_state(self, hn: torch.Tensor) -> torch.Tensor:
        hn, cn = hn
        cn = cn[-2:]  # take the last layer
        _, batch_size, hidden_size = cn.size()
        cn = cn.transpose(0, 1).contiguous()
        cn = cn.view(batch_size, 1, 2 * hidden_size).transpose(0, 1)
        cn = self.hx_dense(cn)
        if self.decoder.num_layers > 1:
            cn = torch.cat(
                [
                    cn,
                    torch.autograd.Variable(cn.data.new(self.decoder.num_layers - 1, batch_size, hidden_size).zero_()),
                ],
                dim=0,
            )
        hn = torch.tanh(cn)
        hn = (hn, cn)
        return hn


class BiAttention(nn.Module):
    def __init__(  # type: ignore[no-untyped-def]
        self, input_size_encoder: int, input_size_decoder: int, num_labels: int, biaffine: bool = True, **kwargs
    ) -> None:
        super(BiAttention, self).__init__()
        self.input_size_encoder = input_size_encoder
        self.input_size_decoder = input_size_decoder
        self.num_labels = num_labels
        self.biaffine = biaffine

        self.W_e = Parameter(torch.Tensor(self.num_labels, self.input_size_encoder))
        self.W_d = Parameter(torch.Tensor(self.num_labels, self.input_size_decoder))
        self.b = Parameter(torch.Tensor(self.num_labels, 1, 1))
        if self.biaffine:
            self.U = Parameter(torch.Tensor(self.num_labels, self.input_size_decoder, self.input_size_encoder))
        else:
            self.register_parameter("U", None)

        self.reset_parameters()

    def reset_parameters(self) -> None:
        nn.init.xavier_uniform_(self.W_e)
        nn.init.xavier_uniform_(self.W_d)
        nn.init.constant_(self.b, 0.0)
        if self.biaffine:
            nn.init.xavier_uniform_(self.U)

    def forward(
        self,
        input_d: torch.Tensor,
        input_e: torch.Tensor,
        mask_d: Optional[torch.Tensor] = None,
        mask_e: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        assert input_d.size(0) == input_e.size(0)
        batch, length_decoder, _ = input_d.size()
        _, length_encoder, _ = input_e.size()

        out_d = torch.matmul(self.W_d, input_d.transpose(1, 2)).unsqueeze(3)
        out_e = torch.matmul(self.W_e, input_e.transpose(1, 2)).unsqueeze(2)

        if self.biaffine:
            output = torch.matmul(input_d.unsqueeze(1), self.U)
            output = torch.matmul(output, input_e.unsqueeze(1).transpose(2, 3))
            output = output + out_d + out_e + self.b
        else:
            output = out_d + out_d + self.b

        if mask_d is not None:
            output = output * mask_d.unsqueeze(1).unsqueeze(3) * mask_e.unsqueeze(1).unsqueeze(2)

        return output


class BiLinear(nn.Module):
    def __init__(self, left_features: int, right_features: int, out_features: int):
        super(BiLinear, self).__init__()
        self.left_features = left_features
        self.right_features = right_features
        self.out_features = out_features

        self.U = Parameter(torch.Tensor(self.out_features, self.left_features, self.right_features))
        self.W_l = Parameter(torch.Tensor(self.out_features, self.left_features))
        self.W_r = Parameter(torch.Tensor(self.out_features, self.left_features))
        self.bias = Parameter(torch.Tensor(out_features))

        self.reset_parameters()

    def reset_parameters(self) -> None:
        nn.init.xavier_uniform_(self.W_l)
        nn.init.xavier_uniform_(self.W_r)
        nn.init.constant_(self.bias, 0.0)
        nn.init.xavier_uniform_(self.U)

    def forward(self, input_left: torch.Tensor, input_right: torch.Tensor) -> torch.Tensor:
        left_size = input_left.size()
        right_size = input_right.size()
        assert left_size[:-1] == right_size[:-1], "batch size of left and right inputs mis-match: (%s, %s)" % (
            left_size[:-1],
            right_size[:-1],
        )
        batch = int(np.prod(left_size[:-1]))

        input_left = input_left.contiguous().view(batch, self.left_features)
        input_right = input_right.contiguous().view(batch, self.right_features)

        output = F.bilinear(input_left, input_right, self.U, self.bias)
        output = output + F.linear(input_left, self.W_l, None) + F.linear(input_right, self.W_r, None)
        return output.view(left_size[:-1] + (self.out_features,))


In [None]:
P

In [None]:
pos_labels = get_pos_labels()
dep_labels = get_dep_labels()

In [None]:
!git clone https://github.com/KLUE-benchmark/KLUE-baseline.git

Cloning into 'KLUE-baseline'...
remote: Enumerating objects: 74, done.[K
remote: Counting objects: 100% (74/74), done.[K
remote: Compressing objects: 100% (64/64), done.[K
remote: Total 74 (delta 18), reused 63 (delta 9), pack-reused 0[K
Unpacking objects: 100% (74/74), 63.62 KiB | 651.00 KiB/s, done.


In [None]:
%cd /content/KLUE-baseline

/content/KLUE-baseline


In [None]:
!pip install pytorch-lightning


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
!pip install overrides dataclasses

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting dataclasses
  Downloading dataclasses-0.6-py3-none-any.whl (14 kB)
Installing collected packages: dataclasses
Successfully installed dataclasses-0.6


In [None]:

!sh ./run_all.sh

sh: 0: Can't open ./run_all.sh
