In [79]:
%cd /content
!rm -rf KoreanStandardPronunciation
!git clone https://github.com/dhkang01/KoreanStandardPronunciation.git
%cd KoreanStandardPronunciation


/content
Cloning into 'KoreanStandardPronunciation'...
remote: Enumerating objects: 46, done.[K
remote: Counting objects: 100% (46/46), done.[K
remote: Compressing objects: 100% (40/40), done.[K
remote: Total 46 (delta 23), reused 17 (delta 5), pack-reused 0 (from 0)[K
Receiving objects: 100% (46/46), 120.03 KiB | 1.07 MiB/s, done.
Resolving deltas: 100% (23/23), done.
/content/KoreanStandardPronunciation


In [80]:
from google.colab import drive
drive.mount('/content/drive')
save_dir = "/content/drive/MyDrive/models/kocharelectra-pron-lora-adapter"

import os
import json

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [81]:
!pip install -q "transformers>=4.38.0" "datasets>=2.18.0" "peft>=0.11.0" accelerate huggingface_hub evaluate

In [78]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [82]:
from datasets import load_dataset

dataset_id = "dhkang01/KMA_dataset"
raw_ds = load_dataset(dataset_id, split="train")

raw_ds

Dataset({
    features: ['id', 'input', 'output'],
    num_rows: 447115
})

In [83]:
ds_train_tmp = raw_ds.train_test_split(test_size=0.1, seed=42)
train_ds = ds_train_tmp["train"]
tmp_ds   = ds_train_tmp["test"]

ds_val_test = tmp_ds.train_test_split(test_size=0.5, seed=42)
val_ds = ds_val_test["train"]
test_ds = ds_val_test["test"]

# train/val/test -> 90/5/5

train_ds = train_ds.select(range(1000))
val_ds = val_ds.select(range(500))
test_ds = test_ds.select(range(10))


Tokenizer 다운로드

In [84]:
from KoCharELECTRA.tokenization_kocharelectra import KoCharElectraTokenizer

model_name = "monologg/kocharelectra-small-discriminator"

tokenizer = KoCharElectraTokenizer.from_pretrained(model_name)
print(tokenizer.tokenize("가나다"))

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'ElectraTokenizer'. 
The class this function is called from is 'KoCharElectraTokenizer'.


['가', '나', '다']


output vocab

In [85]:
from collections import OrderedDict

# tokenizer.vocab은 OrderedDict(토큰 → ID)
token_list = list(tokenizer.vocab.keys())

pron2id = OrderedDict()
for idx, tok in enumerate(token_list):
    pron2id[tok] = idx

id2pron = {v: k for k, v in pron2id.items()}

len(pron2id), list(list(pron2id.items())[:10])

(11568,
 [('[PAD]', 0),
  ('[UNK]', 1),
  ('[CLS]', 2),
  ('[SEP]', 3),
  ('[MASK]', 4),
  (' ', 5),
  ('이', 6),
  ('다', 7),
  ('는', 8),
  ('에', 9)])

전처리 함수 정의 및 적용

복수 발음 허용 X

In [86]:
import numpy as np

max_length = 128  # 필요에 따라 조절

def preprocess_example(example):
    text = example["input"]
    pron = example["output"]  # List[List[str]]

    # KoCharElectra는 char 단위 토큰 + [CLS], [SEP]
    encoding = tokenizer(
        text,
        truncation=True,
        max_length=max_length,
        padding="max_length",  # DataCollator 써도 되지만 여기서는 고정 길이로
        return_tensors=None,
    )

    input_ids = encoding["input_ids"]
    attention_mask = encoding["attention_mask"]

    # Electra/KoCharElectra: 대체로 [CLS] + chars + [SEP]
    # => 실제 문자 수 = len(text)
    # => pron 길이와 len(text)가 맞는다고 가정 (안 맞는 샘플은 나중에 필터 가능)
    seq_len = sum(attention_mask)  # 실제 non-pad 길이
    # [CLS] at 0, [SEP] at seq_len-1, chars in 1..seq_len-2

    labels = np.full_like(input_ids, fill_value=-100)  # default ignore_index

    # 문자 수와 pron 길이 안 맞으면 그냥 전부 ignore(-100)로 두고 스킵되게 할 수도 있음
    # 여기선 일단 최소한으로만 체크
    n_chars = seq_len - 2  # CLS, SEP 제외

    if len(pron) != n_chars:
        # 불일치하는 경우: 전부 padding label로 두고, 나중에 이런 샘플 비율 보고 판단
        print(f"Warning: pron len {len(pron)} != n_chars {n_chars} for text: {text}")
        print(f"in the case, pron: {"".join([l[0] for l in pron])}")
        encoding["labels"] = labels.tolist()
        return encoding

    for i in range(len(pron)):
        cand_list = pron[i]
        if not cand_list:
            continue
        label_id = pron2id[cand_list[0]]               # 첫 후보를 gold label로 사용
        if label_id < 6:                               # 특수토큰, 띄어쓰기는 사용 X
            continue

        token_pos = 1 + i                      # 0: [CLS], 1.. : chars
        if token_pos < seq_len - 1:            # 마지막 [SEP] 전까지만
            labels[token_pos] = label_id

    encoding["labels"] = labels.tolist()
    return encoding

In [87]:
train_tokenized = train_ds.map(
    preprocess_example,
    remove_columns=train_ds.column_names,
)

val_tokenized = val_ds.map(
    preprocess_example,
    remove_columns=val_ds.column_names,
)

test_tokenized = test_ds.map(
    preprocess_example,
    remove_columns=val_ds.column_names,
)

# too long seq is out.

# train_tokenized[0]


모델 로드, LoRA 적용

encoder에 LoRA적용
classifier에 LoRA적용X, 전부 trainable

char_embed 설정

char->vec(41dim)

emb wrapping

In [88]:
# 14개 기본 자음 (쌍자음은 여기에 포함 X)
BASE_CONSONANTS = ["ㄱ","ㄴ","ㄷ","ㄹ","ㅁ","ㅂ","ㅅ","ㅇ","ㅈ","ㅊ","ㅋ","ㅌ","ㅍ","ㅎ"]
CONSONANT2IDX = {c: i for i, c in enumerate(BASE_CONSONANTS)}

DOUBLE_CONSONANTS = {"ㄲ": "ㄱ", "ㄸ": "ㄷ", "ㅃ": "ㅂ", "ㅆ": "ㅅ", "ㅉ": "ㅈ"}

# 종성 복합 받침 분해 (기존과 동일 의미)
CODA_COMPOUND = {
    "ㄳ": ("ㄱ", "ㅅ"),
    "ㄵ": ("ㄴ", "ㅈ"),
    "ㄶ": ("ㄴ", "ㅎ"),
    "ㄺ": ("ㄹ", "ㄱ"),
    "ㄻ": ("ㄹ", "ㅁ"),
    "ㄼ": ("ㄹ", "ㅂ"),
    "ㄽ": ("ㄹ", "ㅅ"),
    "ㄾ": ("ㄹ", "ㅌ"),
    "ㄿ": ("ㄹ", "ㅍ"),
    "ㅀ": ("ㄹ", "ㅎ"),
    "ㅄ": ("ㅂ", "ㅅ"),
}

In [89]:
# 단모음 10개
BASE_VOWELS = ["ㅏ","ㅓ","ㅗ","ㅜ","ㅡ","ㅣ","ㅐ","ㅔ","ㅚ","ㅟ"]
VOWEL2IDX = {v: i for i, v in enumerate(BASE_VOWELS)}

# (기존 COMPOSED_VOWELS 활용: 야/여/요/유/ㅘ/ㅙ/ㅝ/ㅞ 등)
COMPOSED_VOWELS = {
    "ㅑ": ("ㅣ", "ㅏ"),
    "ㅒ": ("ㅣ", "ㅐ"),
    "ㅕ": ("ㅣ", "ㅓ"),
    "ㅖ": ("ㅣ", "ㅔ"),
    "ㅛ": ("ㅣ", "ㅗ"),
    "ㅠ": ("ㅣ", "ㅜ"),

    "ㅘ": ("ㅗ", "ㅏ"),
    "ㅙ": ("ㅗ", "ㅐ"),
    "ㅝ": ("ㅜ", "ㅓ"),
    "ㅞ": ("ㅜ", "ㅔ"),
    "ㅢ": ("ㅡ", "ㅣ")
}

In [90]:
ONSETS  = list("ㄱㄲㄴㄷㄸㄹㅁㅂㅃㅅㅆㅇㅈㅉㅊㅋㅌㅍㅎ")
NUCLEI  = list("ㅏㅐㅑㅒㅓㅔㅕㅖㅗㅘㅙㅚㅛㅜㅝㅞㅟㅠㅡㅢㅣ")
CODAS   = [""] + list("ㄱㄲㄳㄴㄵㄶㄷㄹㄺㄻㄼㄽㄾㄿㅀㅁㅂㅄㅅㅆㅇㅈㅊㅋㅌㅍㅎ")

def decompose(syllable):
    code = ord(syllable) - 0xAC00
    onset_idx = code // 588
    nucleus_idx = (code % 588) // 28
    coda_idx = code % 28

    onset = ONSETS[onset_idx]
    nucleus = NUCLEI[nucleus_idx]
    coda = CODAS[coda_idx] if coda_idx != 0 else None

    return onset, nucleus, coda

In [91]:
def consonant_base_and_double(jamo):
    """
    자음 jamo를 (base_consonant, is_double)로 변환.
    예: 'ㄲ' -> ('ㄱ', 1), 'ㄷ' -> ('ㄷ', 0)
    """
    if jamo in DOUBLE_CONSONANTS:
        return DOUBLE_CONSONANTS[jamo], 1
    return jamo, 0

def get_onset_feature_15(onset_jamo):
    """
    onset: 14개 기본자음 one-hot + double_flag 1 = 15D
    """
    base, is_double = consonant_base_and_double(onset_jamo)
    feat = np.zeros(15, dtype=np.float32)
    if base not in CONSONANT2IDX:
        raise ValueError(f"Unknown onset consonant: {onset_jamo}")
    feat[CONSONANT2IDX[base]] = 1.0
    feat[14] = float(is_double)   # 마지막 차원: 쌍자음 토글
    return feat

def get_coda_feature_16(coda_jamo):
    """
    coda: 14개 기본자음 few-hot + double_flag + no_coda_flag = 16D
    - 복합 받침(ㄳ 등)은 few-hot으로 두 자음 bit 모두 1
    - coda가 None이면: no_coda_flag = 1, 나머지 0
    """
    feat = np.zeros(16, dtype=np.float32)

    # 마지막 차원: no_coda_flag
    if coda_jamo is None:
        feat[15] = 1.0
        return feat

    # 복합 받침 처리
    def apply_single_coda(j):
        base, is_double = consonant_base_and_double(j)
        if base not in CONSONANT2IDX:
            raise ValueError(f"Unknown coda consonant: {j}")
        feat[CONSONANT2IDX[base]] = 1.0
        # double flag: 14번째 인덱스
        if is_double:
            feat[14] = 1.0

    if coda_jamo in CODA_COMPOUND:
        a, b = CODA_COMPOUND[coda_jamo]
        apply_single_coda(a)
        apply_single_coda(b)
    else:
        apply_single_coda(coda_jamo)

    # no_coda_flag는 0 (이미 default)
    return feat


In [92]:
def _vowel_fewhot_single(jamo):
    """
    단일 모음 jamo에 대한 10D one-hot.
    (BASE_VOWELS에 없는 경우 에러)
    """
    if jamo not in VOWEL2IDX:
        raise ValueError(f"Unknown base vowel: {jamo}")
    v = np.zeros(10, dtype=np.float32)
    v[VOWEL2IDX[jamo]] = 1.0
    return v

def get_vowel_feature_10(nucleus_jamo):
    """
    nucleus: 10D few-hot
    - BASE_VOWELS에 있으면 one-hot
    - COMPOSED_VOWELS에 있으면 구성요소 둘의 벡터를 OR (few-hot)
    """
    # 단모음으로 바로 있는 경우
    if nucleus_jamo in VOWEL2IDX:
        return _vowel_fewhot_single(nucleus_jamo)

    # 이중/복합 모음
    if nucleus_jamo in COMPOSED_VOWELS:
        a, b = COMPOSED_VOWELS[nucleus_jamo]
        va = get_vowel_feature_10(a)
        vb = get_vowel_feature_10(b)
        # few-hot: 두 벡터를 OR (0/1)
        v = va + vb
        v = np.clip(v, 0.0, 1.0)
        return v

    # 그 밖의 모음은 필요에 따라 매핑 추가 가능
    raise ValueError(f"Unknown nucleus vowel: {nucleus_jamo}")

In [93]:
def get_syllable_feature_41(syllable):
    onset, nucleus, coda = decompose(syllable)

    onset_feat  = get_onset_feature_15(onset)          # 15D
    nucleus_feat = get_vowel_feature_10(nucleus)       # 10D
    coda_feat   = get_coda_feature_16(coda)            # 16D

    return np.concatenate([onset_feat, nucleus_feat, coda_feat])  # 41D


char_embed = np.zeros((len(pron2id), 41), dtype=np.float32) # (11568, 41)
print(char_embed.shape)

for pron, idx in pron2id.items():
    if len(pron) != 1:
        continue
    if ord(pron) < ord('가'):
        continue
    if ord('힣') < ord(pron):
        continue
    char_embed[idx] = get_syllable_feature_41(pron)

(11568, 41)


In [94]:
import torch

char_embed = torch.from_numpy(char_embed)

NewEmb:
```
char  ->  OldEmb  ->  emb
    + 41dim -> 128dim +
```

41dim: 발음정보 고정 emb
128dim: 학습가능, Dense layer

In [106]:
import torch.nn as nn

class ElectraEmbeddingWithNew(nn.Module):
    def __init__(self, electra_embeddings, electra_embeddings_project, char_embed, adapt):
        super().__init__()
        self.old = electra_embeddings                  # 기존 ElectraEmbeddings
        self.proj = electra_embeddings_project
        self.pron_dim = char_embed.size(1)

        # new embedding (11568X41), 학습하지 않음
        self.register_buffer("char_embed", char_embed, persistent=True)

        # new embedding → 256 projection
        self.new_up = nn.Linear(self.pron_dim, 256)
        nn.init.zeros_(self.new_up.weight)
        nn.init.zeros_(self.new_up.bias)

    def forward(
        self,
        input_ids=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        past_key_values_length=0,
    ):
      # HF ElectraEmbeddings의 원래 forward와 동일한 시그니처로 맞추고,
      # 내부에서 self.old(...) 를 그대로 호출하는 방식이 더 안전함.

        # ---- Electra projection ----
        old_emb_128 = self.old(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
        )
        old_emb_256 = self.proj(old_emb_128)

        # ---- new embedding 41 차원 ----
        new_emb_41 = self.char_embed[input_ids]     # (B, L, 41)
        new_emb_256 = self.new_up(new_emb_41)

        # ---- 256 차원에서 add ----
        if adapt == False:
            return old_emb_256, new_emb_256
        final_emb = old_emb_256 + new_emb_256
        return final_emb, new_emb_256


Concat wrapper 모듈 적용





In [113]:
class ElectraWithCharEmbedding(nn.Module):
    """
    - peft_model.base_model.electra 를 backbone으로 사용
    - electra.embeddings 를 ElectraEmbeddingWithNew 로 교체
    - forward에서는 encoder 통과 후:
      - sequence_output (encoder output)
      - embedding_output (입력 임베딩)
      을 함께 반환
    """
    def __init__(self, peft_model, char_embed, adapt=True):
        super().__init__()
        self.peft_model = peft_model              # PeftModelForTokenClassification

        # ElectraModel (LoRA 포함)
        electra = peft_model.base_model.electra

        # electra emb 교체
        new_emb = ElectraEmbeddingWithNew(
            electra.embeddings,
            electra.embeddings_project,
            char_embed,
            adapt=adapt
        )
        electra.embeddings = new_emb
        electra.embeddings_project = None

        # char_embed를 바깥에서 접근할 수 있게 노출
        self.char_embed = new_emb.char_embed      # (vocab_size, 41)
        self.config = peft_model.base_model.config
        self.hidden_size = self.config.hidden_size

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        **kwargs,
    ):
        # 0) define electra
        electra = self.peft_model.base_model.electra

        # 1) input embedding
        embedding_output, new_embedding_output = electra.embeddings(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
        )  # (B, L, embed_dim=hidden_size)

        # 2) attention mask / head mask
        extended_mask = electra.get_extended_attention_mask(
            attention_mask,
            input_shape=input_ids.shape,
            device=input_ids.device,
        )

        head_mask = electra.get_head_mask(
            None, electra.config.num_hidden_layers
        )

        # 3) encoder 출력 (LoRA 적용됨)
        encoder_outputs = electra.encoder(
            embedding_output,
            attention_mask=extended_mask,
            head_mask=head_mask,
            output_attentions=electra.config.output_attentions,
            output_hidden_states=electra.config.output_hidden_states,
            return_dict=True,
        )

        sequence_output = encoder_outputs.last_hidden_state  # (B, L, hidden_size)

        # 필요하면 encoder_outputs도 같이 넘길 수 있음
        return {
            "sequence_output": sequence_output,
            "embedding_output": new_embedding_output, # new emb(only pron info)
            # "encoder_outputs": encoder_outputs,
        }


In [108]:
class PronunciationRNNCell(nn.Module):
    """
    한 음절에 대해:
      - 입력: seq_state (N, 256), ctx_vec (N, 192 = 3*64)
      - 출력:
        - new_seq_state: (N, 256)   (tanh 된 상태)
        - emb_pron:      (N, 41)    (activation 없음)
        - is_end:        (N, 1)     (sigmoid)
    여기서 N = B * L (배치 내 전체 토큰 수)
    """
    def __init__(self, hidden_size_rnn: int = 256, ctx_dim: int = 192, out_embed_dim: int = 41):
        super().__init__()
        self.hidden_size_rnn = hidden_size_rnn   # 256
        self.ctx_dim = ctx_dim                  # 192
        self.out_embed_dim = out_embed_dim      # 41

        # 1개의 fc로 256 + 41 + 1차원 한 번에 뽑기
        self.fc = nn.Linear(hidden_size_rnn + ctx_dim,
                            hidden_size_rnn + out_embed_dim + 1)

        self.sigmoid = nn.Sigmoid()

    def forward(self, seq_state, ctx_vec):
        """
        seq_state: (N, 256)
        ctx_vec:   (N, 192)
        """
        x = torch.cat([seq_state, ctx_vec], dim=-1)  # (N, 256+192=448)
        out = self.fc(x)                             # (N, 256+41+1=298)

        h_part    = out[:, :self.hidden_size_rnn]                  # (N, 256)
        emb_part  = out[:, self.hidden_size_rnn:self.hidden_size_rnn+self.out_embed_dim]  # (N, 41)
        end_part  = out[:, -1:].contiguous()                       # (N, 1)

        new_seq_state = torch.tanh(h_part)         # seq_state는 항상 tanh를 거친 형태
        emb_pron      = emb_part                  # 발음 벡터는 activation 없음
        is_end        = self.sigmoid(end_part)    # EOS 확률

        return new_seq_state, emb_pron, is_end


In [109]:
class PronunciationTaggerRNN(nn.Module):
    """
    ElectraWithCharEmbedding을 backbone으로 사용하는 발음 tagger.

    - backbone:
        input_ids, attention_mask, token_type_ids ->
        sequence_output (B, L, 256)
        embedding_output (B, L, 256)  # 여기서는 사용하지 않아도 됨

    - 이 모델:
        1) Electra hidden(256)을 23dim으로 축소 (act+fc)
        2) 23dim + pron_feat(41dim)를 concat → 64dim syllable 벡터
        3) prev/cur/next 64dim을 이어붙여 ctx_t (192dim) 생성
           * prev = t-1 (mask면 t-2)
           * next = t+1 (mask면 t+2)
        4) ctx_t를 2-layer MLP(192→512→256, 마지막 tanh)를 통해 h0로 사용
        5) RNNCell(h0, ctx_t) 한 번 호출 → emb_pron_t(41dim) 얻기
        6) emb_pron_t와 char_embed[labels] 사이의 MSE loss 계산

    - char_embed: (vocab_size, 41) few-hot 발음 feature
    """
    def __init__(
        self,
        backbone,              # ElectraWithCharEmbedding
        num_labels: int,
        ignore_index: int = -100,
        space_mask_id: int = None,  # 필요하면 공백 토큰 id 지정 가능 (옵션)
    ):
        super().__init__()
        self.backbone = backbone
        self.ignore_index = ignore_index
        self.num_labels = num_labels
        self.space_mask_id = space_mask_id

        self.config = backbone.config
        self.hidden_size_rnn = backbone.hidden_size     # 256

        H_enc = self.hidden_size_rnn        # 256
        H_pron = backbone.char_embed.size(1)  # 41
        self.H_enc = H_enc
        self.H_pron = H_pron

        # 256 → 23 축소용
        self.hidden_shrink = nn.Linear(H_enc, 23)
        self.hidden_shrink_act = nn.GELU()

        # syllable vec: 23 + 41 = 64
        self.syll_dim = 64

        # h0 초기화용 2-layer MLP: 192(3*64) → 512 → 256
        self.h0_fc1 = nn.Linear(3 * self.syll_dim, 512)
        self.h0_act1 = nn.GELU()
        self.h0_fc2 = nn.Linear(512, H_enc)    # 256
        # 마지막은 tanh로 state 안정화
        self.h0_act2 = nn.Tanh()

        # RNN cell: 256 state, 192 context, 41 output embed
        self.cell = PronunciationRNNCell(
            hidden_size_rnn=H_enc,
            ctx_dim=3 * self.syll_dim,
            out_embed_dim=H_pron,
        )

        # char_embed (vocab_size, 41)
        self.char_embed = backbone.char_embed

    def _build_mask_for_ctx(self, input_ids, attention_mask, labels=None):
        """
        ctx용 mask 정의:
        - 발음에 쓰지 않는 위치: CLS/SEP/공백 등
        - pad는 attention_mask == 0으로 따로 빠진다고 가정
        여기서는 기본적으로:
            mask_ctx = (labels == ignore_index) & (attention_mask == 1)
        로 두고, 필요 시 space_mask_id로 더 엄격히 잡을 수 있음.
        """
        mask_ctx = (attention_mask == 1)

        if labels is not None:
            mask_ctx = mask_ctx & (labels == self.ignore_index)

        # space_mask_id가 주어졌다면, 해당 토큰도 명시적으로 mask에 포함
        if self.space_mask_id is not None:
            space_positions = (input_ids == self.space_mask_id)
            mask_ctx = mask_ctx | space_positions

        return mask_ctx  # True = mask 위치 (발음에 안 쓰는 위치)

    def _build_prev_next_idx(self, mask_ctx, valid_mask):
        """
        pad를 제외하면 mask_ctx가 두 번 연속 나오지 않는다는 가정 하에,
        prev index를:
            prev1 = t-1
            prev2 = t-2
            prev = prev1 if prev1 not mask else prev2
        로 정의하고, next도 대칭으로 정의.

        valid_mask는 attention_mask == 1 같은 것 (pad 제외).
        """
        B, L = mask_ctx.shape
        device = mask_ctx.device

        # 기본 index
        idx = torch.arange(L, device=device)

        # --- prev ---
        prev1 = (idx - 1).clamp(min=0)      # [0,0,1,2,...]
        prev2 = (idx - 2).clamp(min=0)      # [0,0,0,1,...]
        prev1 = prev1.unsqueeze(0).expand(B, -1)  # (B, L)
        prev2 = prev2.unsqueeze(0).expand(B, -1)

        prev1_is_mask = mask_ctx.gather(1, prev1)  # (B, L)
        prev_idx = torch.where(prev1_is_mask, prev2, prev1)  # (B, L)

        # pad 위치는 prev 자체를 자기 자신으로 두거나 0으로 둬도 상관없음
        # 여기서는 그냥 0으로 설정
        prev_idx = prev_idx * valid_mask.long()

        # --- next ---
        next1 = (idx + 1).clamp(max=L - 1)
        next2 = (idx + 2).clamp(max=L - 1)
        next1 = next1.unsqueeze(0).expand(B, -1)
        next2 = next2.unsqueeze(0).expand(B, -1)

        next1_is_mask = mask_ctx.gather(1, next1)
        next_idx = torch.where(next1_is_mask, next2, next1)

        next_idx = next_idx * valid_mask.long()

        return prev_idx, next_idx

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        labels=None,
        **kwargs,
    ):
        # 1) backbone 출력 (ElectraWithCharEmbedding)
        backbone_outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            **kwargs,
        )
        seq_out = backbone_outputs["sequence_output"]   # (B, L, 256)
        B, L, H_enc = seq_out.shape

        # 2) 발음 feature: char_embed에서 41dim 가져오기
        #    (B, L) -> (B, L, 41)
        char_embed = self.char_embed.to(input_ids.device)
        pron_feat = char_embed[input_ids]          # (B, L, 41)

        # 3) hidden 256 → 23 → syllable 64 = [23; 41]
        hidden_23 = self.hidden_shrink_act(self.hidden_shrink(seq_out))  # (B, L, 23)
        syll_vec = torch.cat([hidden_23, pron_feat], dim=-1)             # (B, L, 64)

        # 4) ctx용 mask/valid_mask 만들기
        #    - valid_mask: pad 제외 (attention_mask == 1)
        #    - mask_ctx: 발음에 안 쓰는 위치 (CLS/SEP/공백 등)
        valid_mask = (attention_mask == 1)          # (B, L) bool
        mask_ctx = self._build_mask_for_ctx(input_ids, attention_mask, labels)

        # 5) prev/next index (pad 제외, mask는 prevprev 규칙 사용)
        prev_idx, next_idx = self._build_prev_next_idx(mask_ctx, valid_mask)

        # 6) prev/cur/next 64dim 가져오기
        D = self.syll_dim  # 64
        # gather를 위해 index를 (B, L, 1) → (B, L, D)로 확장
        gather_prev = prev_idx.unsqueeze(-1).expand(-1, -1, D)
        gather_next = next_idx.unsqueeze(-1).expand(-1, -1, D)

        syll_prev = syll_vec.gather(1, gather_prev)   # (B, L, 64)
        syll_next = syll_vec.gather(1, gather_next)   # (B, L, 64)

        # mask 위치 자체에서는 prev/next를 0으로
        syll_prev = syll_prev.masked_fill(mask_ctx.unsqueeze(-1), 0.0)
        syll_next = syll_next.masked_fill(mask_ctx.unsqueeze(-1), 0.0)

        # 7) ctx_t = [prev; cur; next]  → (B, L, 192)
        ctx = torch.cat([syll_prev, syll_vec, syll_next], dim=-1)  # (B, L, 192)

        # 8) h0 초기화: ctx_t → 512 → 256 → tanh
        ctx_flat = ctx.view(B * L, -1)          # (N, 192)
        h0 = self.h0_act1(self.h0_fc1(ctx_flat))  # (N, 512)
        h0 = self.h0_act2(self.h0_fc2(h0))        # (N, 256), tanh

        # pad 위치는 state를 0으로 만들어도 됨
        valid_flat = valid_mask.view(B * L)
        h0 = h0 * valid_flat.unsqueeze(-1)       # (N, 256)

        # 9) RNN cell 한 번 (현재 설계에서는 step=1)
        seq_state, emb_pron_flat, is_end = self.cell(h0, ctx_flat)  # (N,256), (N,41), (N,1)

        # 10) (B, L, 41)로 reshape
        pred_embed = emb_pron_flat.view(B, L, self.H_pron)  # (B, L, 41)

        # ===========================================================
        # 11) Loss 계산 (MSE)
        # ===========================================================
        loss = None
        if labels is not None:
            # labels: (B, L), ignore_index = -100
            with torch.no_grad():
                labels_clamped = labels.clone()
                labels_clamped = labels_clamped.masked_fill(
                    labels_clamped == self.ignore_index, 0
                )
                char_embed = self.char_embed.to(labels_clamped.device)
                labels_clamped = labels_clamped.to(char_embed.device)

                target_embed = char_embed[labels_clamped]  # (B, L, 41)

            mse = (pred_embed - target_embed) ** 2         # (B, L, 41)
            mask = (labels != self.ignore_index).unsqueeze(-1)  # (B, L, 1)
            mse = mse * mask

            denom = mask.sum().clamp(min=1)   # 유효 토큰 수
            loss = mse.sum() / denom

        # ===========================================================
        # 12) 최종 return
        # ===========================================================
        return {
            "loss": loss,
            "logits": pred_embed,      # (B, L, 41)
            # "is_end": is_end.view(B, L, 1)  # 필요하면 꺼내서 쓰면 됨
        }

    def predict_chars(self, logits):
        """
        logits: (B, L, 41) - 예측된 발음 feature 벡터
        return: pred_ids (B, L)  - 각 위치별로 '가','나' 같은 발음 vocab index

        char_embed (V, 41)과 dot-product 기반 nearest neighbor.
        """
        B, L, D = logits.shape
        pred_flat = logits.reshape(B * L, D)             # (N, 41)

        char_embed = self.char_embed.to(pred_flat.device)   # (V, 41)

        sims = pred_flat @ char_embed.T                 # (N, V)
        pred_ids_flat = sims.argmax(dim=-1)             # (N,)
        pred_ids = pred_ids_flat.view(B, L)             # (B, L)

        return pred_ids


학습 진행

In [99]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)


In [100]:
import evaluate
import numpy as np

accuracy_metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    logits_torch = torch.tensor(logits)
    predictions_torch = pron_model_rnn.predict_chars(logits_torch)
    predictions = predictions_torch.detach().cpu().numpy()

    # ignore_index = -100 제거 후 accuracy 계산
    mask = labels != -100
    y_true = labels[mask]
    y_pred = predictions[mask]

    if len(y_true) == 0:
        return {"accuracy": 0.0}

    result = accuracy_metric.compute(predictions=y_pred, references=y_true)
    return {"accuracy": result["accuracy"]}


In [110]:
from transformers import AutoModelForTokenClassification

num_labels = len(pron2id)

base_model = AutoModelForTokenClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
)


Some weights of ElectraForTokenClassification were not initialized from the model checkpoint at monologg/kocharelectra-small-discriminator and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [111]:
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    task_type=TaskType.TOKEN_CLS,
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    target_modules=["query", "key", "value", "dense"]  # Electra의 attention/FFN 모듈 이름 기준
)

peft_model = get_peft_model(base_model, lora_config)
peft_model.print_trainable_parameters()

# print(model)

trainable params: 3,415,344 || all params: 17,445,216 || trainable%: 19.5775


In [114]:
# 1) backbone 생성
backbone = ElectraWithCharEmbedding(peft_model=model, char_embed=char_embed, adapt=False)

# 2) RNN 기반 tagger
pron_model_rnn = PronunciationTaggerRNN(
    backbone=backbone,
    num_labels=num_labels,\
    ignore_index=-100,
)

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="kocharelectra-pron-lora",
    learning_rate=1e-3,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    logging_strategy='steps',
    logging_steps=100,
    fp16=False,          # GPU가 지원하면 속도↑
    bf16=False,
    report_to="none",   # wandb 등 안 쓸 거면 none
    eval_accumulation_steps=16,
    # prediction_loss_only=True,
)

trainer = Trainer(
    model=pron_model_rnn,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=val_tokenized,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)


In [None]:
trainer.train()


In [None]:
trainer.save_model(save_dir)   # ✅ 모델 가중치 저장 (pytorch_model.bin)
tokenizer.save_pretrained(save_dir)

# 발음 vocab도 같이 저장
with open(os.path.join(save_dir, "pron_vocab.json"), "w", encoding="utf-8") as f:
    json.dump(pron2id, f, ensure_ascii=False, indent=2)


Inference

In [None]:
from safetensors.torch import load_file

# 1) pron_vocab 로드
with open(os.path.join(save_dir, "pron_vocab.json"), encoding="utf-8") as f:
    pron2id = json.load(f)
id2pron = {v: k for k, v in pron2id.items()}

# 3) safetensors 로드
state = load_file(os.path.join(save_dir, "model.safetensors"))
pron_model_rnn.load_state_dict(state, strict=False)
pron_model_rnn.to('cuda')

# 4) tokenizer 로드
tokenizer = KoCharElectraTokenizer.from_pretrained(save_dir)

In [None]:
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer, return_tensors="pt")

test_loader = DataLoader(
    test_tokenized,
    batch_size=32,
    shuffle=False,
    collate_fn=data_collator,
)

device = next(pron_model_rnn.parameters()).device
pron_model_rnn.eval()

all_pred_ids = []

with torch.no_grad():
    for batch in test_loader:
        batch = {k: v.to(device) for k, v in batch.items()}

        output = pron_model_rnn.forward(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            token_type_ids=batch.get("token_type_ids", None),
        )  # (B, L)
        pred_ids = pron_model_rnn.predict_chars(output["logits"])

        # GPU → CPU → numpy
        pred_ids = pred_ids.cpu().numpy()
        masks = batch["attention_mask"].cpu().numpy()
        input_ids = batch["input_ids"].cpu()
        labels = batch["labels"].cpu()

        texts = tokenizer.batch_decode(
            input_ids,
            skip_special_tokens=True
        )

        B = input_ids.size(0)

        for i in range(B):
            seq_ids = pred_ids[i]
            mask = masks[i]
            text = texts[i]
            label = labels[i]
            valid = (label != -100)

            pred = [
                id2pron[int(idx)]
                for idx, m in zip(seq_ids.tolist(), valid.tolist())
                if m
            ]

            gold = []
            for idx, m in zip(label.tolist(), valid.tolist()):
                if not m:
                    continue
                if idx == -100:
                    continue
                gold.append(id2pron[int(idx)])

            print("TEXT:", text)
            print("PRED:", "".join(pred))
            print("GOLD:", "".join(gold))
            print("-" * 60)

dummy model 평가

In [None]:
class DummyTextAsPronModel(nn.Module):
    """
    input_ids에서 각 토큰의 char_embed를 그대로 logits으로 내보내는 더미 모델.
    => predict_chars(logits)를 돌리면 '입력 텍스트 그대로'가 예측 발음이 됨.
    """
    def __init__(self, char_embed: torch.Tensor):
        super().__init__()
        # char_embed: (vocab_size, H_pron)
        self.register_buffer("char_embed", char_embed)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        labels=None,
        **kwargs,
    ):
        """
        input_ids: (B, L)
        labels: (B, L), ignore_index=-100
        """
        # (B, L, H_pron) : 각 위치 토큰 id -> 그 토큰의 발음 feature
        logits = self.char_embed[input_ids]

        # loss는 dummy (학습 안 할 거라 상관 없음)
        loss = None
        if labels is not None:
            loss = torch.tensor(0.0, device=logits.device)

        return {
            "loss": loss,
            "logits": logits,
        }


In [None]:
dummy_model = DummyTextAsPronModel(char_embed=char_embed).to("cuda")

from transformers import Trainer

dummy_trainer = Trainer(
    model=dummy_model,
    args=training_args,      # 기존 training_args 재사용해도 됨 (train 안 할 거면 epoch 이런 건 무시)
    train_dataset=train_tokenized,   # 안 써도 되지만 형식 맞춰둠
    eval_dataset=val_tokenized,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# baseline: "철자 그대로 발음"의 accuracy
dummy_results = dummy_trainer.evaluate()
print(dummy_results)


In [None]:
from google.colab import runtime
runtime.unassign()
