In [None]:
# 라이브러리 설치 및 import
!pip install -q transformers

In [None]:
!pip install -q torch

In [None]:
from transformers import BertTokenizer, BertForQuestionAnswering
from transformers import AdamW

In [None]:
# 모델 및 토크나이저 초기화
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
model = BertForQuestionAnswering.from_pretrained('bert-base-multilingual-cased')

In [None]:
# KorQuAD 데이터셋 로드
from transformers import squad_convert_examples_to_features
from transformers.data.processors.squad import SquadV2Processor

In [None]:
processor = SquadV2Processor()
train_examples = processor.get_train_examples("<path_to_train_dataset>")
train_features, train_dataset = squad_convert_examples_to_features(
    examples=train_examples,
    tokenizer=tokenizer,
    max_seq_length=512,
    doc_stride=128,
    max_query_length=64,
    is_training=True,
    return_dataset="pt",
    threads=1,
)

In [None]:
# 옵티마이저 초기화
optimizer = AdamW(model.parameters(), lr=5e-5)

# 학습 반복
epochs = 3
for epoch in range(epochs):
    model.train()
    for step, batch in enumerate(train_dataset):
        optimizer.zero_grad()
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        start_positions = batch['start_positions']
        end_positions = batch['end_positions']
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        print(f"Epoch: {epoch+1}/{epochs} | Step: {step+1}/{len(train_dataset)} | Loss: {loss.item()}")

In [None]:
# 학습된 모델로 예측 수행
model.eval()
question = "한국의 수도는 어디인가요?"
context = "한국은 동아시아에 위치한 나라로, 수도는 서울입니다."

inputs = tokenizer.encode_plus(question, context, return_tensors='pt')
start_logits, end_logits = model(**inputs).logits

all_tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
answer = ' '.join(all_tokens[torch.argmax(start_logits) : torch.argmax(end_logits)+1])

In [None]:
print("Question:", question)
print("Answer:", answer)