In [1]:
%cd /content
!rm -rf KoreanStandardPronunciation
!git clone https://github.com/dhkang01/KoreanStandardPronunciation.git
%cd KoreanStandardPronunciation


/content
Cloning into 'KoreanStandardPronunciation'...
remote: Enumerating objects: 40, done.[K
remote: Counting objects: 100% (40/40), done.[K
remote: Compressing objects: 100% (34/34), done.[K
remote: Total 40 (delta 19), reused 17 (delta 5), pack-reused 0 (from 0)[K
Receiving objects: 100% (40/40), 108.31 KiB | 4.71 MiB/s, done.
Resolving deltas: 100% (19/19), done.
/content/KoreanStandardPronunciation


In [2]:
from google.colab import drive
drive.mount('/content/drive')
save_dir = "/content/drive/MyDrive/models/kocharelectra-pron-lora-adapter"

import os
import json

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
!pip install -q "transformers>=4.38.0" "datasets>=2.18.0" "peft>=0.11.0" accelerate huggingface_hub evaluate

In [4]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [5]:
from datasets import load_dataset

dataset_id = "dhkang01/KMA_dataset"
raw_ds = load_dataset(dataset_id, split="train")

raw_ds

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Dataset({
    features: ['id', 'input', 'output'],
    num_rows: 447115
})

In [6]:
ds_train_tmp = raw_ds.train_test_split(test_size=0.1, seed=42)
train_ds = ds_train_tmp["train"]
tmp_ds   = ds_train_tmp["test"]

ds_val_test = tmp_ds.train_test_split(test_size=0.5, seed=42)
val_ds = ds_val_test["train"]
test_ds = ds_val_test["test"]

# train/val/test -> 90/5/5

train_ds = train_ds.select(range(10000))
val_ds = val_ds.select(range(100))
test_ds = test_ds.select(range(10))


Tokenizer 다운로드

In [7]:
from KoCharELECTRA.tokenization_kocharelectra import KoCharElectraTokenizer

model_name = "monologg/kocharelectra-small-discriminator"

tokenizer = KoCharElectraTokenizer.from_pretrained(model_name)
print(tokenizer.tokenize("가나다"))

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'ElectraTokenizer'. 
The class this function is called from is 'KoCharElectraTokenizer'.


['가', '나', '다']


output vocab

In [8]:
from collections import OrderedDict

# tokenizer.vocab은 OrderedDict(토큰 → ID)
token_list = list(tokenizer.vocab.keys())

pron2id = OrderedDict()
for idx, tok in enumerate(token_list):
    pron2id[tok] = idx

id2pron = {v: k for k, v in pron2id.items()}

len(pron2id), list(list(pron2id.items())[:10])

(11568,
 [('[PAD]', 0),
  ('[UNK]', 1),
  ('[CLS]', 2),
  ('[SEP]', 3),
  ('[MASK]', 4),
  (' ', 5),
  ('이', 6),
  ('다', 7),
  ('는', 8),
  ('에', 9)])

전처리 함수 정의 및 적용

복수 발음 허용 X

In [9]:
import numpy as np

max_length = 128  # 필요에 따라 조절

def preprocess_example(example):
    text = example["input"]
    pron = example["output"]  # List[List[str]]

    # KoCharElectra는 char 단위 토큰 + [CLS], [SEP]
    encoding = tokenizer(
        text,
        truncation=True,
        max_length=max_length,
        padding="max_length",  # DataCollator 써도 되지만 여기서는 고정 길이로
        return_tensors=None,
    )

    input_ids = encoding["input_ids"]
    attention_mask = encoding["attention_mask"]

    # Electra/KoCharElectra: 대체로 [CLS] + chars + [SEP]
    # => 실제 문자 수 = len(text)
    # => pron 길이와 len(text)가 맞는다고 가정 (안 맞는 샘플은 나중에 필터 가능)
    seq_len = sum(attention_mask)  # 실제 non-pad 길이
    # [CLS] at 0, [SEP] at seq_len-1, chars in 1..seq_len-2

    labels = np.full_like(input_ids, fill_value=-100)  # default ignore_index

    # 문자 수와 pron 길이 안 맞으면 그냥 전부 ignore(-100)로 두고 스킵되게 할 수도 있음
    # 여기선 일단 최소한으로만 체크
    n_chars = seq_len - 2  # CLS, SEP 제외

    if len(pron) != n_chars:
        # 불일치하는 경우: 전부 padding label로 두고, 나중에 이런 샘플 비율 보고 판단
        print(f"Warning: pron len {len(pron)} != n_chars {n_chars} for text: {text}")
        print(f"in the case, pron: {"".join([l[0] for l in pron])}")
        encoding["labels"] = labels.tolist()
        return encoding

    for i in range(len(pron)):
        cand_list = pron[i]
        if not cand_list:
            continue
        label_id = pron2id[cand_list[0]]               # 첫 후보를 gold label로 사용
        if label_id < 6:                               # 특수토큰, 띄어쓰기는 사용 X
            continue

        token_pos = 1 + i                      # 0: [CLS], 1.. : chars
        if token_pos < seq_len - 1:            # 마지막 [SEP] 전까지만
            labels[token_pos] = label_id

    encoding["labels"] = labels.tolist()
    return encoding

In [10]:
train_tokenized = train_ds.map(
    preprocess_example,
    remove_columns=train_ds.column_names,
)

val_tokenized = val_ds.map(
    preprocess_example,
    remove_columns=val_ds.column_names,
)

test_tokenized = test_ds.map(
    preprocess_example,
    remove_columns=val_ds.column_names,
)

# too long seq is out.

# train_tokenized[0]


모델 로드, LoRA 적용

encoder에 LoRA적용
classifier에 LoRA적용X, 전부 trainable

In [11]:
from transformers import AutoModelForTokenClassification

num_labels = len(pron2id)

base_model = AutoModelForTokenClassification.from_pretrained(
    model_name,
    num_labels=num_labels,
)


Some weights of ElectraForTokenClassification were not initialized from the model checkpoint at monologg/kocharelectra-small-discriminator and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    task_type=TaskType.TOKEN_CLS,
    r=16,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    target_modules=["query", "key", "value", "dense"]  # Electra의 attention/FFN 모듈 이름 기준
)

model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()

print(model)

trainable params: 3,857,712 || all params: 17,887,584 || trainable%: 21.5664
PeftModelForTokenClassification(
  (base_model): LoraModel(
    (model): ElectraForTokenClassification(
      (electra): ElectraModel(
        (embeddings): ElectraEmbeddings(
          (word_embeddings): Embedding(11568, 128, padding_idx=0)
          (position_embeddings): Embedding(512, 128)
          (token_type_embeddings): Embedding(2, 128)
          (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (embeddings_project): Linear(in_features=128, out_features=256, bias=True)
        (encoder): ElectraEncoder(
          (layer): ModuleList(
            (0-11): 12 x ElectraLayer(
              (attention): ElectraAttention(
                (self): ElectraSelfAttention(
                  (query): lora.Linear(
                    (base_layer): Linear(in_features=256, out_features=256, bias=True)
                    (lora_dropo

char_embed 설정

char->vec(31dim)

emb wrapping

In [13]:
# 각 feature index:
# 0: stop, 1: fricative, 2: affricate, 3: nasal, 4: liquid,
# 5: labial, 6: coronal, 7: dorsal, 8: glottal,
# 9: tense, 10: aspirated

CONSONANT_FEATURES = {
    "ㄱ": [1,0,0,0,0,  0,0,1,0, 0,0],  # velar stop
    "ㄲ": [1,0,0,0,0,  0,0,1,0, 1,0],  # tense velar stop
    "ㄴ": [0,0,0,1,0,  0,1,0,0, 0,0],  # coronal nasal
    "ㄷ": [1,0,0,0,0,  0,1,0,0, 0,0],  # coronal stop
    "ㄸ": [1,0,0,0,0,  0,1,0,0, 1,0],  # tense coronal stop
    "ㄹ": [0,0,0,0,1,  0,1,0,0, 0,0],  # coronal liquid
    "ㅁ": [0,0,0,1,0,  1,0,0,0, 0,0],  # labial nasal
    "ㅂ": [1,0,0,0,0,  1,0,0,0, 0,0],  # labial stop
    "ㅃ": [1,0,0,0,0,  1,0,0,0, 1,0],  # tense labial stop
    "ㅅ": [0,1,0,0,0,  0,1,0,0, 0,0],  # coronal fricative
    "ㅆ": [0,1,0,0,0,  0,1,0,0, 1,0],  # tense coronal fricative
    "ㅇ": [0,0,0,1,0,  0,0,1,0, 0,0],  # velar nasal (종성 기준)
    "ㅈ": [0,0,1,0,0,  0,1,0,0, 0,0],  # coronal affricate
    "ㅉ": [0,0,1,0,0,  0,1,0,0, 1,0],  # tense coronal affricate
    "ㅊ": [0,0,1,0,0,  0,1,0,0, 0,1],  # aspirated coronal affricate
    "ㅋ": [1,0,0,0,0,  0,0,1,0, 0,1],  # aspirated velar stop
    "ㅌ": [1,0,0,0,0,  0,1,0,0, 0,1],  # aspirated coronal stop
    "ㅍ": [1,0,0,0,0,  1,0,0,0, 0,1],  # aspirated labial stop
    "ㅎ": [0,1,0,0,0,  0,0,0,1, 0,1],  # glottal fricative, aspirated
}

CODA_COMPOUND = {
    "ㄳ": ("ㄱ", "ㅅ"),
    "ㄵ": ("ㄴ", "ㅈ"),
    "ㄶ": ("ㄴ", "ㅎ"),
    "ㄺ": ("ㄹ", "ㄱ"),
    "ㄻ": ("ㄹ", "ㅁ"),
    "ㄼ": ("ㄹ", "ㅂ"),
    "ㄽ": ("ㄹ", "ㅅ"),
    "ㄾ": ("ㄹ", "ㅌ"),
    "ㄿ": ("ㄹ", "ㅍ"),
    "ㅀ": ("ㄹ", "ㅎ"),
    "ㅄ": ("ㅂ", "ㅅ"),
}


In [14]:
# vowel feature index:
# 0: high, 1: mid, 2: low,
# 3: front, 4: central, 5: back,
# 6: round, 7: tense, 8: diphthong

BASE_VOWEL_FEATURES = {
    "ㅏ": [0,0,1,  0,0,1,  0,1,0],  # low back unround, tense-ish
    "ㅓ": [0,1,0,  0,0,1,  0,0,0],  # mid back unround, lax-ish
    "ㅗ": [0,1,0,  0,0,1,  1,1,0],  # mid back round, tense-ish
    "ㅜ": [0,1,0,  0,0,1,  1,0,0],  # mid back round, lax-ish
    "ㅡ": [0,1,0,  0,1,0,  0,0,0],  # mid central unround
    "ㅣ": [1,0,0,  1,0,0,  0,1,0],  # high front unround, tense

    "ㅐ": [0,1,0,  1,0,0,  0,1,0],  # mid front unround, tense
    "ㅔ": [0,1,0,  1,0,0,  0,0,0],  # mid front unround, lax-ish

    # 필요하면 ㅚ, ㅟ, ㅢ도 base로 직접 정의 가능
    "ㅚ": [0,1,0,  0,0,1,  1,1,0],  # /we/ or /ø/ 계열
    "ㅟ": [1,0,0,  1,0,0,  1,1,0],  # /y/ 계열
    "ㅢ": [0,1,0,  0,1,0,  0,1,0],  # central-ish compound
}

COMPOSED_VOWELS = {
    "ㅑ": ("ㅣ", "ㅏ"),
    "ㅒ": ("ㅣ", "ㅐ"),
    "ㅕ": ("ㅣ", "ㅓ"),
    "ㅖ": ("ㅣ", "ㅔ"),
    "ㅛ": ("ㅣ", "ㅗ"),
    "ㅠ": ("ㅣ", "ㅜ"),

    "ㅘ": ("ㅗ", "ㅏ"),
    "ㅙ": ("ㅗ", "ㅐ"),
    "ㅝ": ("ㅜ", "ㅓ"),
    "ㅞ": ("ㅜ", "ㅔ"),
}

In [15]:
def get_consonant_feature(jamo):
    if jamo in CONSONANT_FEATURES:
        return np.array(CONSONANT_FEATURES[jamo], dtype=np.float32)

    if jamo in CODA_COMPOUND:
        a, b = CODA_COMPOUND[jamo]
        return (get_consonant_feature(a) + get_consonant_feature(b)) / 2

    raise ValueError(f"Unknown consonant: {jamo}")


def get_vowel_feature(jamo):
    if jamo in BASE_VOWEL_FEATURES:
        return np.array(BASE_VOWEL_FEATURES[jamo], dtype=np.float32)

    if jamo in COMPOSED_VOWELS:
        a, b = COMPOSED_VOWELS[jamo]
        return (get_vowel_feature(a) + get_vowel_feature(b)) / 2

    raise ValueError(f"Unknown vowel: {jamo}")

In [16]:
ONSETS  = list("ㄱㄲㄴㄷㄸㄹㅁㅂㅃㅅㅆㅇㅈㅉㅊㅋㅌㅍㅎ")
NUCLEI  = list("ㅏㅐㅑㅒㅓㅔㅕㅖㅗㅘㅙㅚㅛㅜㅝㅞㅟㅠㅡㅢㅣ")
CODAS   = [""] + list("ㄱㄲㄳㄴㄵㄶㄷㄹㄺㄻㄼㄽㄾㄿㅀㅁㅂㅄㅅㅆㅇㅈㅊㅋㅌㅍㅎ")

def decompose(syllable):
    code = ord(syllable) - 0xAC00
    onset_idx = code // 588
    nucleus_idx = (code % 588) // 28
    coda_idx = code % 28

    onset = ONSETS[onset_idx]
    nucleus = NUCLEI[nucleus_idx]
    coda = CODAS[coda_idx] if coda_idx != 0 else None

    return onset, nucleus, coda

In [17]:
def get_syllable_feature(syllable):
    onset, nucleus, coda = decompose(syllable)

    onset_feat  = get_consonant_feature(onset)     # 11D
    nucleus_feat = get_vowel_feature(nucleus)      # 9D

    if coda is None:
        coda_feat = np.zeros(11)                 # 종성 없음 → 11D zero
    else:
        coda_feat = get_consonant_feature(coda)  # 11D

    return np.concatenate([onset_feat, nucleus_feat, coda_feat])  # 31D


char_embed = np.zeros((len(pron2id), 31), dtype=np.float32) # (11568, 31)
print(char_embed.shape)

for pron, idx in pron2id.items():
    if len(pron) != 1:
        continue
    if ord(pron) < ord('가'):
        continue
    if ord('힣') < ord(pron):
        continue
    char_embed[idx] = get_syllable_feature(pron)

(11568, 31)


In [18]:
import torch

char_embed = torch.from_numpy(char_embed)

NewEmb:
```
char  ->  OldEmb  ->  emb
    + 31dim -> 128dim +
```

31dim: 발음정보 고정 emb
128dim: 학습가능, Dense layer

In [19]:
import torch.nn as nn

class ElectraEmbeddingWithNew(nn.Module):
    def __init__(self, electra_embeddings, electra_embeddings_project, char_embed):
        super().__init__()
        self.old = electra_embeddings                  # 기존 ElectraEmbeddings
        self.proj = electra_embeddings_project

        # new embedding (11568X31), 학습하지 않음
        self.register_buffer("char_embed", char_embed, persistent=True)

        # new embedding → 256 projection
        self.new_up = nn.Linear(31, 256)
        nn.init.zeros_(self.new_up.weight)
        nn.init.zeros_(self.new_up.bias)

    def forward(
        self,
        input_ids=None,
        token_type_ids=None,
        position_ids=None,
        inputs_embeds=None,
        past_key_values_length=0,
    ):
      # HF ElectraEmbeddings의 원래 forward와 동일한 시그니처로 맞추고,
      # 내부에서 self.old(...) 를 그대로 호출하는 방식이 더 안전함.

        # ---- Electra projection ----
        old_emb_128 = self.old(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            past_key_values_length=past_key_values_length,
        )
        old_emb_256 = self.proj(old_emb_128)

        # ---- new embedding 31 차원 ----
        new_emb_31 = self.char_embed[input_ids]     # (B, L, 31)
        new_emb_256 = self.new_up(new_emb_31)

        # ---- 256 차원에서 add ----
        final_emb = old_emb_256 + new_emb_256
        return final_emb, new_emb_256


Concat wrapper 모듈 적용





In [20]:
class ElectraWithCharEmbedding(nn.Module):
    """
    - peft_model.base_model.electra 를 backbone으로 사용
    - electra.embeddings 를 ElectraEmbeddingWithNew 로 교체
    - forward에서는 encoder 통과 후:
      - sequence_output (encoder output)
      - embedding_output (입력 임베딩)
      을 함께 반환
    """
    def __init__(self, peft_model, char_embed):
        super().__init__()
        self.peft_model = peft_model              # PeftModelForTokenClassification

        # ElectraModel (LoRA 포함)
        electra = peft_model.base_model.electra

        # electra emb 교체
        new_emb = ElectraEmbeddingWithNew(
            electra.embeddings,
            electra.embeddings_project,
            char_embed
        )
        electra.embeddings = new_emb
        electra.embeddings_project = None

        # char_embed를 바깥에서 접근할 수 있게 노출
        self.char_embed = new_emb.char_embed      # (vocab_size, 31)
        self.config = peft_model.base_model.config
        self.hidden_size = self.config.hidden_size

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        **kwargs,
    ):
        # 0) define electra
        electra = self.peft_model.base_model.electra

        # 1) input embedding
        embedding_output, new_embedding_output = electra.embeddings(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
        )  # (B, L, embed_dim=hidden_size)

        # 2) attention mask / head mask
        extended_mask = electra.get_extended_attention_mask(
            attention_mask,
            input_shape=input_ids.shape,
            device=input_ids.device,
        )

        head_mask = electra.get_head_mask(
            None, electra.config.num_hidden_layers
        )

        # 3) encoder 출력 (LoRA 적용됨)
        encoder_outputs = electra.encoder(
            embedding_output,
            attention_mask=extended_mask,
            head_mask=head_mask,
            output_attentions=electra.config.output_attentions,
            output_hidden_states=electra.config.output_hidden_states,
            return_dict=True,
        )

        sequence_output = encoder_outputs.last_hidden_state  # (B, L, hidden_size)

        # 필요하면 encoder_outputs도 같이 넘길 수 있음
        return {
            "sequence_output": sequence_output,
            "embedding_output": new_embedding_output, # new emb(only pron info)
            # "encoder_outputs": encoder_outputs,
        }


In [21]:
class PronunciationRNNCell(nn.Module):
    """
    한 스텝에서:
      seq_state, emb_256 -> concat -> new_seq, emb_31, is_end_logit
    - seq_state: (B, H_enc) (Electra hidden)
    - emb_256:  (B, H_emb)  (Embedding output)
    - new_seq:  (B, H_enc)  (다음 step hidden)
    - emb_31:   (B, 31)     (예측 phoneme embedding)
    - is_end_logit: (B, 1)  (EOS 여부 logit, sigmoid로 확률)
    """
    def __init__(self, hidden_size_rnn: int, enc_dim: int, out_embed_dim: int = 31):
        super().__init__()
        self.hidden_size_rnn = hidden_size_rnn
        self.enc_dim = enc_dim
        self.out_embed_dim = out_embed_dim

        self.in_act = nn.GELU()

        self.fc = nn.Linear(hidden_size_rnn + enc_dim, hidden_size_rnn+out_embed_dim+1)

        self.sigmoid = nn.Sigmoid()

    def forward(self, seq_state, emb_256):
        # seq_state: (B, H_enc)
        # emb_256:  (B, H_enc)
        x = torch.cat([seq_state, emb_256], dim=-1)          # (B, H_enc+H_enc)
        x = self.in_act(x)

        out = self.fc(x)
        new_seq = torch.tanh(out[:, :self.hidden_size_rnn])                     # (B, H_enc)
        out_dim = self.hidden_size_rnn+self.out_embed_dim
        emb_31 = self.sigmoid(out[:, self.hidden_size_rnn:self.hidden_size_rnn+self.out_embed_dim])                      # (B, 31)
        is_end_logit = self.sigmoid(out[:, out_dim:out_dim+1])              # (B, 1)

        return new_seq, emb_31, is_end_logit


In [22]:
class PronunciationTaggerRNN(nn.Module):
    """
    ElectraWithCharEmbedding을 backbone으로 사용하는 RNN 기반 발음 tagger.

    - backbone:
        input_ids, attention_mask, token_type_ids ->
        sequence_output (B, L, H), embedding_output (B, L, H)

    - 이 모델:
        fusion_t = sequence_output_t + embedding_output_t (각 토큰 위치)
        seq_state_0 = 0
        for t in 1..L:
            seq_state_t, emb_31_t, is_end_logit_t = RNNCell(seq_state_{t-1}, fusion_t)

        -> emb_31: (B, L, 31)
        -> is_end_logit: (B, L, 1)

    학습 시:
        - labels: (B, L), ignore_index=-100
        - target_embed = char_embed[labels]
        - emb_31와의 MSE를 token-wise로 계산
    """
    def __init__(
        self,
        backbone: ElectraWithCharEmbedding,
        num_labels: int,
        ignore_index: int = -100,
    ):
        super().__init__()
        self.backbone = backbone
        self.ignore_index = ignore_index
        self.num_labels = num_labels

        self.config = backbone.config
        self.hidden_size_rnn = backbone.hidden_size     # H_enc (e.g., 256)

        # RNN cell
        self.cell = PronunciationRNNCell(
            hidden_size_rnn=self.hidden_size_rnn,
            enc_dim=self.hidden_size_rnn,
            out_embed_dim=backbone.char_embed.size(1),  # 31
        )

        # char_embed (vocab_size, 31)
        self.char_embed = backbone.char_embed

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        labels=None,
        **kwargs,
    ):
        # 1) backbone 출력
        backbone_outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            **kwargs,
        )
        seq_out = backbone_outputs["sequence_output"]   # (B, L, H_enc)
        emb_out = backbone_outputs["embedding_output"]  # (B, L, H_enc)

        B, L, H_enc = seq_out.shape
        if seq_out.shape != emb_out.shape:
            raise ValueError(f"seq_out.shape({seq_out.shape}) != emb_out.shape({emb_out.shape})")

        seq_state = seq_out.reshape(B * L, H_enc)
        emb_state = emb_out.reshape(B * L, H_enc)

        pred_emb_list = []
        is_end_logit = False

        while True:
            # 원래는 2-3회 정도 실시
            seq_state, emb_31_t, is_end_logit = self.cell(seq_state, emb_state)
            pred_emb_list.append(emb_31_t.unsqueeze(1))        # (B*L, 1, 31)

            break # don't use end_logit, use only one output

        pred_embed = torch.cat(pred_emb_list, dim=1).reshape(B, L, 31)        # (B*L, 1, 31) -> (B, L, 31)

        # ===========================================================
        # 5) Loss 계산
        # ===========================================================
        loss = None
        if labels is not None:
            # labels: (B, L)
            with torch.no_grad():
                labels_clamped = labels.clone()
                labels_clamped = labels_clamped.masked_fill(
                    labels_clamped == self.ignore_index, 0
                )
                char_embed = self.char_embed.to(labels_clamped.device)
                labels_clamped = labels_clamped.to(char_embed.device)

                target_embed = char_embed[labels_clamped]  # (B, L, 31)

            mse = (pred_embed - target_embed) ** 2              # (B, L, 1)
            mask = (labels != self.ignore_index).unsqueeze(-1)  # (B, L, 1)
            mse = mse * mask

            denom = mask.sum().clamp(min=1) # count
            loss = mse.sum() / denom

        # ===========================================================
        # 6) 최종 return
        # ===========================================================
        return {
            "loss": loss,
            "logits": pred_embed,      # 예측된 31차원 phoneme embedding
            # "is_end_logits": is_end_logits,  # EOS 여부 (미사용 가능)
        }

    def predict_chars(self, logits):
        """
        logits: (B, L, D)
        return: pred_ids (B, L)  - 각 위치별로 '가','나' 같은 발음 vocab index
        """

        # 2) (B, L, 31) -> (B*L, 31)
        B, L, D = logits.shape
        pred_flat = logits.reshape(B * L, D)

        # 3) char_embed와 같은 device로 맞추기
        char_embed = self.char_embed.to(pred_flat.device)   # (V, 31)

        # 4) L2 distance 기준으로 가장 가까운 발음 index 찾기
        #    dist^2 = ||x - e||^2 = ||x||^2 + ||e||^2 - 2 x·e
        #    argmin dist^2 == argmax (x·e) (norm 비슷하다고 보면 dot-product로 충분)
        sims = pred_flat @ char_embed.T           # (B*L, V)
        pred_ids_flat = sims.argmax(dim=-1)       # (B*L,)
        pred_ids = pred_ids_flat.view(B, L)       # (B, L)

        return pred_ids

In [23]:
# 1) backbone 생성
backbone = ElectraWithCharEmbedding(peft_model=model, char_embed=char_embed)

# 2) RNN 기반 tagger
pron_model_rnn = PronunciationTaggerRNN(
    backbone=backbone,
    num_labels=num_labels,\
    ignore_index=-100,
)

학습 진행

In [24]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)


In [25]:
import evaluate
import numpy as np

accuracy_metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    logits_torch = torch.tensor(logits)
    predictions_torch = pron_model_rnn.predict_chars(logits_torch)
    predictions = predictions_torch.detach().cpu().numpy()

    # ignore_index = -100 제거 후 accuracy 계산
    mask = labels != -100
    y_true = labels[mask]
    y_pred = predictions[mask]

    if len(y_true) == 0:
        return {"accuracy": 0.0}

    result = accuracy_metric.compute(predictions=y_pred, references=y_true)
    return {"accuracy": result["accuracy"]}


In [26]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="kocharelectra-pron-lora",
    learning_rate=5e-3,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    logging_steps=100,
    fp16=True,          # GPU가 지원하면 속도↑
    report_to="none",   # wandb 등 안 쓸 거면 none
    eval_accumulation_steps=16,
    # prediction_loss_only=True,
)

trainer = Trainer(
    model=pron_model_rnn,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=val_tokenized,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)


  trainer = Trainer(


In [None]:
trainer.train()




Epoch,Training Loss,Validation Loss,Accuracy
1,5.1316,0.983095,0.044807
2,0.4184,0.217869,0.049898
3,0.2272,0.078465,0.050916
4,0.0921,0.046363,0.051935
5,0.0697,0.037252,0.052953
6,0.0477,0.031129,0.052953
7,0.0414,0.029002,0.053971




In [None]:
trainer.save_model(save_dir)   # ✅ 모델 가중치 저장 (pytorch_model.bin)
tokenizer.save_pretrained(save_dir)

# 발음 vocab도 같이 저장
with open(os.path.join(save_dir, "pron_vocab.json"), "w", encoding="utf-8") as f:
    json.dump(pron2id, f, ensure_ascii=False, indent=2)


Inference

In [None]:
from safetensors.torch import load_file

# 1) pron_vocab 로드
with open(os.path.join(save_dir, "pron_vocab.json"), encoding="utf-8") as f:
    pron2id = json.load(f)
id2pron = {v: k for k, v in pron2id.items()}

# 3) safetensors 로드
state = load_file(os.path.join(save_dir, "model.safetensors"))
pron_model_rnn.load_state_dict(state, strict=False)
pron_model_rnn.to('cuda')

# 4) tokenizer 로드
tokenizer = KoCharElectraTokenizer.from_pretrained(save_dir)

In [None]:
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer, return_tensors="pt")

test_loader = DataLoader(
    test_tokenized,
    batch_size=32,
    shuffle=False,
    collate_fn=data_collator,
)

device = next(pron_model_rnn.parameters()).device
pron_model_rnn.eval()

all_pred_ids = []

with torch.no_grad():
    for batch in test_loader:
        batch = {k: v.to(device) for k, v in batch.items()}

        output = pron_model_rnn.forward(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            token_type_ids=batch.get("token_type_ids", None),
        )  # (B, L)
        pred_ids = pron_model_rnn.predict_chars(output["logits"])

        # GPU → CPU → numpy
        pred_ids = pred_ids.cpu().numpy()
        masks = batch["attention_mask"].cpu().numpy()
        input_ids = batch["input_ids"].cpu()
        labels = batch["labels"].cpu()

        texts = tokenizer.batch_decode(
            input_ids,
            skip_special_tokens=True
        )

        B = input_ids.size(0)

        for i in range(B):
            seq_ids = pred_ids[i]
            mask = masks[i]
            text = texts[i]
            label = labels[i]
            valid = (label != -100)

            pred = [
                id2pron[int(idx)]
                for idx, m in zip(seq_ids.tolist(), valid.tolist())
                if m
            ]

            gold = []
            for idx, m in zip(label.tolist(), valid.tolist()):
                if not m:
                    continue
                if idx == -100:
                    continue
                gold.append(id2pron[int(idx)])

            print("TEXT:", text)
            print("PRED:", "".join(pred))
            print("GOLD:", "".join(gold))
            print("-" * 60)

In [None]:
from google.colab import runtime
runtime.unassign()
