<a href="https://colab.research.google.com/github/eujin99/gemma2-2b-math/blob/main/gemma2_math.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers datasets torch

In [None]:
from datasets import load_dataset

# GSM8K 데이터셋 로드 (config name 지정)
dataset = load_dataset("gsm8k", "main", split="train")

# 데이터셋 확인
print(dataset[0])


In [None]:
# huggingface_hub 토큰 입력

from huggingface_hub import notebook_login

notebook_login()


In [None]:
# 토크나이저, 모델 불러옴

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")

In [None]:
# 데이터셋 토큰화 (입력과 레이블 모두 토큰화)
def tokenize_function(examples):
    inputs = tokenizer(examples["question"], padding="max_length", truncation=True, max_length=128)
    outputs = tokenizer(examples["answer"], padding="max_length", truncation=True, max_length=128)
    inputs["labels"] = outputs["input_ids"]  # 레이블 추가
    return inputs

tokenized_dataset = dataset.map(tokenize_function, batched=True)


In [None]:
from transformers import Trainer, TrainingArguments

# 학습 설정
training_args = TrainingArguments(
    output_dir='./results',
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,  # 그래디언트 누적 스텝을 증가
    num_train_epochs=3,
    logging_dir='./logs',
    save_steps=1000,
    save_total_limit=2,
    fp16=True  # 혼합 정밀도 학습
)



# Trainer 설정
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,  # 전처리된 데이터셋
)

# Fine-tuning 시작
trainer.train()


In [None]:
# 모델 저장
model.save_pretrained("./finetuned_gemma2b_math")
tokenizer.save_pretrained("./finetuned_gemma2b_math")


In [None]:
# huggingface 모델 업로드
from huggingface_hub import notebook_login
notebook_login()

In [None]:
model.push_to_hub("chloestella/finetuned_gemma2b_math")
tokenizer.push_to_hub("chloestella/finetuned_gemma2b_math")

In [None]:
from transformers import pipeline

# 파이프라인 설정
generator = pipeline('text-generation', model='./finetuned_gemma2b_math', tokenizer=tokenizer)

# 예시 문제 생성
generated_problem = generator("Generate a high school algebra problem.", max_length=100)
print(generated_problem[0]['generated_text'])


In [None]:
# 질문 및 답변 test
user_question = input("질문을 입력하세요: ")
answer = generator(f"Provide a step-by-step solution in plain text to solve the equation: {user_question}", max_length=200)
print("답변: ", answer[0]['generated_text'])


In [None]:
from transformers import pipeline

# 파이프라인 설정 (CPU로 설정)
generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device="cpu")

# 문제 생성, 답변 입력 및 채점 함수
def generate_and_solve_math_problem():
    try:
        level = int(input("과정을 선택하세요 : 1. 초등 2. 중등 3. 고등\n"))

        if level == 1:
            grade_level = "elementary"
            topic = "simple arithmetic"
        elif level == 2:
            grade_level = "middle school"
            topic = "algebra"
        elif level == 3:
            grade_level = "high school"
            topic = "calculus"
        else:
            print("잘못된 입력입니다. 숫자 1, 2, 3 중 하나를 선택하세요.")
            return

        # 문제 생성
        prompt = f"Generate a {grade_level} math problem in {topic}. The problem should be simple and require one-step calculation."
        print("문제를 생성중입니다...\n")
        problem = generator(prompt, max_length=100, num_return_sequences=1)[0]['generated_text']
        print("문제: ", problem)

        # 정답 생성 (사용자에게 공개하지 않음)
        answer_prompt = f"Provide the answer for the problem: {problem}"
        correct_answer = generator(answer_prompt, max_length=50, num_return_sequences=1)[0]['generated_text'].strip()

        # 답변 입력
        user_answer = input("답변을 입력하세요: ")

        # 채점
        if user_answer.strip() == correct_answer:
            print("정답입니다!")
        else:
            print(f"오답입니다. 정답은 {correct_answer}입니다.")

        # 풀이 제공
        explanation_prompt = f"Explain the solution step by step for the problem: {problem}."
        explanation = generator(explanation_prompt, max_length=200)[0]['generated_text']
        print("풀이: ", explanation)

    except ValueError:
        print("숫자만 입력해주세요.")

# 문제 생성 및 풀이 과정 실행
generate_and_solve_math_problem()
