In [None]:
import os
import torch
from torch.nn import BCEWithLogitsLoss
from torch.optim import AdamW
from transformers import Trainer, TrainingArguments, TrainerCallback, TrainerState, TrainerControl, AutoTokenizer, AutoModelForMaskedLM


# ElectraTrainer 정의
class ElectraTrainer(Trainer):
    def __init__(
        self,
        generator,
        discriminator,
        gen_optimizer,
        disc_optimizer,
        processing_class,
        warmup_steps=0,
        *args,
        **kwargs
    ):
        super().__init__(model=discriminator, optimizers=(gen_optimizer, None), *args, **kwargs)
        self.generator = generator
        self.discriminator = discriminator
        self.gen_optimizer = gen_optimizer
        self.disc_optimizer = disc_optimizer
        self.processing_class = processing_class
        self.bce_loss_fn = BCEWithLogitsLoss()
        self.warmup_steps = warmup_steps
        self.global_step_counter = 0

    def training_step(self, model, inputs, loss_or_steps=None, **kwargs):
        inputs = self._prepare_inputs(inputs)
        generator = self.generator
        discriminator = self.discriminator
        attention_mask = inputs["attention_mask"]
        real_input_ids = inputs["input_ids"]
        labels = inputs["labels"]

        self.global_step_counter += 1

        # Generator 학습 (MLM)
        self.gen_optimizer.zero_grad()
        gen_outputs = generator(input_ids=real_input_ids, attention_mask=attention_mask, labels=labels)
        gen_loss = gen_outputs.loss

        if torch.isnan(gen_loss) or torch.isinf(gen_loss):
            # print(f"Step {self.global_step_counter} | Generator Loss NaN/Inf → skip")
            return torch.tensor(0.0, device=gen_loss.device)

        self.accelerator.backward(gen_loss, retain_graph=True)

        # Fake 문장 생성
        with torch.no_grad():
            gen_predictions = gen_outputs.logits.argmax(dim=-1)
            fake_inputs = real_input_ids.clone()
            mask = real_input_ids == self.processing_class.mask_token_id
            fake_inputs[mask] = gen_predictions[mask]

        # Discriminator 학습 (RTD)
        if self.global_step_counter > self.warmup_steps:
            self.disc_optimizer.zero_grad()
            disc_labels = (real_input_ids != fake_inputs).float()
            disc_outputs = discriminator(input_ids=fake_inputs, attention_mask=attention_mask)
            disc_logits = disc_outputs.logits.squeeze(-1)
            disc_loss = self.bce_loss_fn(disc_logits, disc_labels)

            if torch.isnan(disc_loss) or torch.isinf(disc_loss):
                # print(f"Step {self.global_step_counter} | Discriminator Loss NaN/Inf → skip")
                # loss는 0.0으로 로깅을 위해 설정하지만, backward는 호출하지 않음
                valid_disc_loss = False
                log_disc_loss = torch.tensor(0.0, device=gen_loss.device)
            else:
                valid_disc_loss = True
                log_disc_loss = disc_loss # 로깅을 위해 실제 loss 사용

            if valid_disc_loss:
                self.accelerator.backward(disc_loss)

            disc_loss = log_disc_loss # 이후 로깅 및 total_loss 계산에 사용

        else:
            # 웜업 단계에서는 backward 호출 안 함
            disc_loss = torch.tensor(0.0, device=gen_loss.device)
            log_disc_loss = disc_loss
            valid_disc_loss = False

        # Gradient Clipping
        torch.nn.utils.clip_grad_norm_(
            list(generator.parameters()) + list(discriminator.parameters()),
            max_norm=self.args.max_grad_norm
        )

        # Optimizer Step
        self.gen_optimizer.step()
        if self.global_step_counter > self.warmup_steps:
            self.disc_optimizer.step()

        # Scheduler Step
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        # Total Loss 반환 및 로깅
        total_loss = gen_loss + disc_loss
        self.log({
            "gen_loss": gen_loss.detach().item(),
            "disc_loss": disc_loss.detach().item(),
            "total_loss": total_loss.detach().item()
        })
        return total_loss.detach()


# 커스텀 로깅 콜백
class ElectraLoggingCallback(TrainerCallback):
    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, logs=None, **kwargs):
        if logs is not None and 'loss' in logs:
            last_log = state.log_history[-1]
            if 'gen_loss' in last_log and 'disc_loss' in last_log:
                current_step = last_log.get('step', state.global_step)
                if current_step % args.logging_steps == 0:
                    print(f"Step {current_step} | Generator Loss: {last_log['gen_loss']:.4f} | Discriminator Loss: {last_log['disc_loss']:.4f}")

    def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        # 1000 스텝마다 체크포인트 저장 플래그를 True로 설정
        if state.global_step % 1000 == 0 and state.global_step > 0:
            control.should_save = True


# 커스텀 체크포인트 콜백 (Generator + Discriminator 동시 저장)
class ElectraCheckpointCallback(TrainerCallback):
    def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        if isinstance(trainer, ElectraTrainer):
            # 체크포인트 폴더 생성
            output_dir = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
            os.makedirs(output_dir, exist_ok=True)

            # Generator 저장
            torch.save(trainer.generator.state_dict(), os.path.join(output_dir, "generator.pt"))

            # Discriminator 저장
            torch.save(trainer.discriminator.state_dict(), os.path.join(output_dir, "discriminator.pt"))

In [None]:
import pathlib

path = str(pathlib.Path(os.path.abspath(".")).parent.parent)
print(path)

In [None]:
from transformers import ElectraForMaskedLM, ElectraForPreTraining, AutoTokenizer
from torch.optim import AdamW

# Generator & Discriminator 초기화
generator = ElectraForMaskedLM.from_pretrained("google/electra-small-generator")
discriminator = ElectraForPreTraining.from_pretrained("google/electra-small-discriminator")

# 한국어 토크나이저 사용
tokenizer = AutoTokenizer.from_pretrained(f"{path}/tokenizer")

# Optimizer 설정 (Generator / Discriminator 분리)
discriminator_lr = 5e-5
generator_lr = discriminator_lr * 0.5  # Generator는 더 작은 lr 사용

gen_optimizer = AdamW(generator.parameters(), lr=generator_lr)
disc_optimizer = AdamW(discriminator.parameters(), lr=discriminator_lr)

In [None]:
import os
import torch

# 마지막에 저장된 checkpoint 숫자
# Re_ELECTRA_checkpoint -> checkpoint-{step}
step = 135000
last_checkpoint_path = f"{path}/model/ReELECTRA/pretrained/checkpoints/checkpoint-{step}"

print(f"체크포인트 {last_checkpoint_path} 에서 상태를 로드합니다.")

try:
    # Generator 및 Discriminator 상태 로드
    gen_checkpoint_path = os.path.join(last_checkpoint_path, "generator.pt")
    generator.load_state_dict(torch.load(gen_checkpoint_path))
    print(f"Loaded Generator from {gen_checkpoint_path}")

    disc_checkpoint_path = os.path.join(last_checkpoint_path, "discriminator.pt")
    discriminator.load_state_dict(torch.load(disc_checkpoint_path))
    print(f"Loaded Discriminator from {disc_checkpoint_path}")

except FileNotFoundError as e:
    print(f"체크포인트 파일 로드 실패: {e}")
    print("→ 초기 모델부터 DAPT 학습을 시작합니다.")
    last_checkpoint_path = None

In [None]:
# 데이터셋 객체로 변환

from datasets import load_dataset
from transformers import DataCollatorForLanguageModeling

# 텍스트 파일 경로
file_path = f"{path}/data/processed/model/dapt_preprocessed.txt"

datasets = load_dataset("text", data_files=file_path)

# print(datasets.keys())

# train 데이터에서 상위 5000개만 선택
# datasets['train'] = datasets['train'].select(range(5000))

# type : DatasetDict, Dataset 객체를 dict 형태로 묶어 관리
# print(type(datasets))
# 전체 key 값 확인
# print(datasets.keys())
# 해당 key 값의 인덱스 10에 해당하는 값 출력
# print(datasets['train'][10])

# tokenize_function() : 텍스트 토큰화 함수 정의
def tokenize_function(examples):
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained(f"{path}/tokenizer")
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=128,
    )

# datasets 토큰화
tokenized_datasets = datasets.map(
    tokenize_function,
    batched=True,
    num_proc=2, # 프로세스 개수
    remove_columns=["text"],
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15
)

In [None]:
from datasets import DatasetDict

# train 80%, validation 20% 분할
split_dataset = tokenized_datasets["train"].train_test_split(test_size=0.2)
tokenized_datasets = DatasetDict({
    "train": split_dataset["train"],
    "validation": split_dataset["test"]
})

In [None]:
# TrainingArguments

import torch
from transformers import TrainingArguments

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device=="cpu":
    num_train_epochs=1
    per_device_train_batch_size=2
    gradient_accumulation_steps=8
    dataloader_num_workers=0

elif device=="cuda":
    num_train_epochs=2
    per_device_train_batch_size=16
    gradient_accumulation_steps=1
    dataloader_num_workers=2

training_args = TrainingArguments(
    output_dir=f"{path}/model/ReELECTRA/DAPT/checkpoints",      # 체크포인트, 로그가 저장될 경로
    overwrite_output_dir=True,                                  # 기존 경로 덮어쓰기(허용)

    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,

    max_grad_norm=1.0,

    logging_steps=50,

    save_strategy="steps",
    save_steps=1000,

    dataloader_num_workers=dataloader_num_workers,

    report_to="none",
    push_to_hub=False,
    hub_model_id=None,
    hub_token=None,
)

In [None]:
generator.to(device)
discriminator.to(device)

# Trainer 설정

trainer = ElectraTrainer(
    generator=generator,
    discriminator=discriminator,
    gen_optimizer=gen_optimizer,
    disc_optimizer=disc_optimizer,
    processing_class=tokenizer,
    warmup_steps=1000,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    args=training_args,
    data_collator=data_collator,
    callbacks=[ElectraLoggingCallback(), ElectraCheckpointCallback()]
)

In [None]:
# 모델 학습

trainer.train()

In [None]:
# 모델 저장

trainer.save_model(f"{path}/model/ReELECTRA/DAPT")

In [None]:
# 모델 평가

from torch.utils.data import DataLoader
from tqdm import tqdm
import torch

# 평가용 DataLoader
eval_dataloader = DataLoader(
    tokenized_datasets["validation"],
    batch_size=16,
    shuffle=False,
    collate_fn=data_collator
)

generator.eval()
discriminator.eval()

device = "cuda" if torch.cuda.is_available() else "cpu"

gen_correct, gen_total = 0, 0
disc_correct, disc_total = 0, 0

with torch.no_grad():
    for batch in tqdm(eval_dataloader, desc="평가 진행 중"):
        # GPU로 이동
        inputs = {k: v.to(device) for k, v in batch.items() if k in ["input_ids", "attention_mask", "labels"]}

        # Generator 평가 (MLM 정확도)
        gen_outputs = generator(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            labels=inputs["labels"]
        )
        gen_logits = gen_outputs.logits
        mask = inputs["labels"] != -100  # -100은 무시 토큰
        gen_preds = gen_logits.argmax(dim=-1)

        gen_correct += (gen_preds[mask] == inputs["labels"][mask]).sum().item()
        gen_total += mask.sum().item()

        # Discriminator 평가 (RTD 정확도)
        # Generator 예측으로 fake 문장 생성
        gen_predictions = gen_logits.argmax(dim=-1)
        fake_inputs = inputs["input_ids"].clone()
        fake_inputs[inputs["labels"] != -100] = gen_predictions[inputs["labels"] != -100]

        # Fake token label 생성
        disc_labels = (inputs["input_ids"] != fake_inputs).float()

        disc_outputs = discriminator(input_ids=fake_inputs, attention_mask=inputs["attention_mask"])
        disc_logits = torch.sigmoid(disc_outputs.logits.squeeze(-1))
        disc_preds = (disc_logits > 0.5).float()

        disc_correct += (disc_preds == disc_labels).sum().item()
        disc_total += disc_labels.numel()

# 정확도 계산
gen_accuracy = gen_correct / gen_total
disc_accuracy = disc_correct / disc_total

print(f"\n✅ Generator MLM 정확도: {gen_accuracy*100:.2f}%")
print(f"✅ Discriminator RTD 정확도: {disc_accuracy*100:.2f}%")


In [None]:
masked_texts = [
    "[MASK]이 느려요."
]

for text in masked_texts:

    # 토큰화
    inputs = tokenizer(text, return_tensors="pt").to(device)

    # Generator 예측
    generator.eval()
    with torch.no_grad():
        gen_outputs = generator(**inputs)
        gen_logits = gen_outputs.logits
        mask_token_index = (inputs["input_ids"] == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
        pred_token_id = gen_logits[0, mask_token_index].argmax(dim=-1)
        pred_token = tokenizer.decode(pred_token_id)

    # Discriminator 판단
    discriminator.eval()
    fake_inputs = inputs["input_ids"].clone()
    fake_inputs[0, mask_token_index] = pred_token_id

    disc_outputs = discriminator(input_ids=fake_inputs, attention_mask=inputs["attention_mask"])
    disc_logits = torch.sigmoid(disc_outputs.logits.squeeze(-1))
    is_fake = disc_logits[0, mask_token_index] > 0.5

    # 결과 출력
    print(f"\n원문: {text}")
    print(f"Generator 예측: {pred_token}")
    print(f"Discriminator 판단: {'FAKE' if is_fake else 'REAL'}")

checkpoint를 사용하여 중간부터 학습하기

In [None]:
from transformers import ElectraForMaskedLM, ElectraForPreTraining, AutoTokenizer
from torch.optim import AdamW

# Generator & Discriminator 초기화
generator = ElectraForMaskedLM.from_pretrained("google/electra-small-generator")
discriminator = ElectraForPreTraining.from_pretrained("google/electra-small-discriminator")

# 한국어 토크나이저 사용
tokenizer = AutoTokenizer.from_pretrained(f"{path}/tokenizer")

# Optimizer 설정 (Generator / Discriminator 분리)
discriminator_lr = 5e-5
generator_lr = discriminator_lr * 0.5  # Generator는 더 작은 lr 사용

gen_optimizer = AdamW(generator.parameters(), lr=generator_lr)
disc_optimizer = AdamW(discriminator.parameters(), lr=discriminator_lr)

In [None]:
import os
import torch

# 마지막에 저장된 checkpoint 숫자
# Re_ELECTRA_checkpoint -> checkpoint-{step}
step = 
last_checkpoint_path = f"{path}/model/ReELECTRA/DAPT/checkpoints/checkpoint-{step}"

print(f"체크포인트 {last_checkpoint_path} 에서 상태를 로드합니다.")

try:
    # Generator 및 Discriminator 상태 로드
    gen_checkpoint_path = os.path.join(last_checkpoint_path, "generator.pt")
    generator.load_state_dict(torch.load(gen_checkpoint_path))
    print(f"Loaded Generator from {gen_checkpoint_path}")

    disc_checkpoint_path = os.path.join(last_checkpoint_path, "discriminator.pt")
    discriminator.load_state_dict(torch.load(disc_checkpoint_path))
    print(f"Loaded Discriminator from {disc_checkpoint_path}")

    optim_checkpoint_path = os.path.join(last_checkpoint_path, "optimizer.pt")

    if os.path.exists(optim_checkpoint_path):
        optim_checkpoint = torch.load(optim_checkpoint_path)
        print(f"Found optimizer state at {optim_checkpoint_path}. Trainer will handle full restore.")

    else:
        # 옵티마이저 파일이 없으면 Trainer가 초기화된 옵티마이저로 시작할 수 있습니다.
        print("optimizer.pt 파일이 없습니다. 옵티마이저 상태가 초기화될 수 있습니다.")

    # 모델 로드가 성공했으므로, 재개를 위해 last_checkpoint_path를 유지
    last_checkpoint_path = last_checkpoint_path

except FileNotFoundError as e:
    print(f"체크포인트 파일 로드 실패: {e}")
    print("→ 초기 모델부터 DAPT 학습을 시작합니다.")
    last_checkpoint_path = None

In [None]:
if device == "cuda":
    generator.to(device)
    discriminator.to(device)

# Trainer 설정

trainer = ElectraTrainer(
    generator=generator,
    discriminator=discriminator,
    gen_optimizer=gen_optimizer,
    disc_optimizer=disc_optimizer,
    processing_class=tokenizer,
    warmup_steps=1000,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    args=training_args,
    data_collator=data_collator,
    callbacks=[ElectraLoggingCallback(), ElectraCheckpointCallback()]
)

In [None]:
# 학습 재개

trainer.train(resume_from_checkpoint=last_checkpoint_path)

In [None]:
# 모델 저장

trainer.save_model(f"{path}/model/ReELECTRA/DAPT")