<a href="https://colab.research.google.com/github/leeks8888/00-study-web-framework/blob/main/nb/Gemma3_(4B).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
#-------------------------------------------------------------------------------
# --- 1. Setup Environment ---
#-------------------------------------------------------------------------------

# Install necessary libraries
!pip install -q transformers datasets accelerate peft bitsandbytes trl torch

import os
import json
import torch
from datasets import Dataset, DatasetDict
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
from trl import SFTTrainer

# Set Hugging Face Token (Optional, but recommended for Gemma models)
# If the model is gated, you'll need to accept terms on its Hugging Face page
# and provide a token with 'read' access.
# from google.colab import userdata
# hf_token = userdata.get('HF_TOKEN') # Store your token as a Colab secret
# If not using Colab secrets, you can paste your token here, but it's less secure.
# hf_token = "YOUR_HUGGINGFACE_TOKEN"
# if 'hf_token' in locals():
#     from huggingface_hub import login
#     login(token=hf_token)

# Check for GPU
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
    device = "cuda"
else:
    print("GPU not available. Training will be very slow. Consider enabling GPU in Runtime > Change runtime type.")
    device = "cpu"

#-------------------------------------------------------------------------------
# --- 2. Configuration ---
#-------------------------------------------------------------------------------

# Model Configuration
# Replace with your desired Gemma model ID.
# Original Gemma: "google/gemma-7b-it" or "google/gemma-2b-it"
# Gemma 2: "google/gemma-2-9b-it" or "google/gemma-2-2b-it" (9B might be too large for free Colab tier)
# If you have a specific "Gemma 3" ID, use it here.
model_id = "google/gemma-2-2b-it" # Using 2B version for better Colab compatibility

# PEFT Configuration (LoRA)
lora_r = 16  # Rank of the update matrices
lora_alpha = 32 # Alpha parameter for LoRA scaling
lora_dropout = 0.05 # Dropout probability for LoRA layers
# Target modules for Gemma (may vary slightly by specific Gemma version, check model card)
# Common for Gemma:
lora_target_modules = [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
]

# Training Arguments
output_dir = "./gemma_finetuned_level_generator"
num_train_epochs = 3  # Adjust as needed
per_device_train_batch_size = 1 # Reduce if OOM errors, requires more gradient_accumulation_steps
gradient_accumulation_steps = 4 # Increase if batch size is small
learning_rate = 2e-4
logging_steps = 25
save_steps = 0 # Set to a positive integer if you want to save checkpoints during training
save_strategy = "epoch" # Save at the end of each epoch
max_seq_length = 1024 # Adjust based on your prompt/response lengths. Your JSONs can be long.
# Check average and max length of your formatted prompts to set this appropriately.

# Quantization Config
use_4bit = True
bnb_4bit_quant_type = "nf4"
bnb_4bit_compute_dtype = torch.bfloat16 # Use bfloat16 for Ampere GPUs and newer

#-------------------------------------------------------------------------------
# --- 3. Data Preparation ---
#-------------------------------------------------------------------------------

# Upload your sample.txt file to Colab
# You can do this by clicking the "Files" icon on the left sidebar, then "Upload".
# Make sure 'sample.txt' is in the root directory of your Colab environment.

data_file_path = "sample.txt"

if not os.path.exists(data_file_path):
    print(f"Error: {data_file_path} not found. Please upload it to your Colab environment.")
    # You might want to stop execution here if the file isn't present
    # import sys
    # sys.exit()
else:
    print(f"Found {data_file_path}. Proceeding with data loading.")

# Load and preprocess data
def load_and_format_data(filepath):
    prompts = []
    responses = []
    formatted_texts = []
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line)
                    # The 'prompt' field in your data is the actual input JSON string
                    # The 'response' field is the actual output JSON string
                    instruction = data['prompt']
                    output = data['response']

                    # Gemma instruction-following format (using the chat template)
                    # You might need to adjust this based on the specific Gemma model's recommended prompt format
                    # For -it (instruction-tuned) models, a direct instruction format is usually good.
                    # A common format is <start_of_turn>user\n{instruction}<end_of_turn>\n<start_of_turn>model\n{output}<end_of_turn>
                    # For SFTTrainer, we typically format it as a single string.
                    # Let's use a simplified instruction format:
                    # text = f"Instruction:\nGenerate a game level configuration based on the following parameters.\n\nParameters:\n{instruction}\n\nConfiguration:\n{output}"

                    # Or using the specific chat template tokens if known for the model version
                    # For gemma-2b-it and similar, the format is usually:
                    # User: <prompt text>
                    # Assistant: <response text>
                    # SFTTrainer can often handle this with a dataset_text_field approach
                    # A simple and effective format for SFTTrainer with Gemma often looks like:
                    # "<s>[INST] Instruction Text [/INST] Response Text </s>"
                    # Given your prompt is already structured, let's use it directly.
                    text = f"<s>[INST] {instruction} [/INST] {output} </s>"
                    formatted_texts.append({"text": text})
                except json.JSONDecodeError as e:
                    print(f"Skipping line due to JSON decode error: {e} - Line: {line.strip()}")
                except KeyError as e:
                    print(f"Skipping line due to missing key: {e} - Line: {line.strip()}")
        return formatted_texts
    except FileNotFoundError:
        print(f"Error: File '{filepath}' not found.")
        return []

if os.path.exists(data_file_path):
    formatted_data = load_and_format_data(data_file_path)
    if formatted_data:
        dataset = Dataset.from_list(formatted_data)
        print(f"\nLoaded and formatted {len(dataset)} samples.")
        print("Sample formatted entry:")
        print(dataset[0]['text'][:500] + "...") # Print first 500 chars of the first sample

        # Optional: Split into train/validation
        # dataset = dataset.train_test_split(test_size=0.1)
        # train_dataset = dataset["train"]
        # eval_dataset = dataset["test"]
        # For simplicity with small datasets, we can use the full dataset for training.
        train_dataset = dataset
        eval_dataset = None # Or set to a subset if you have a split
    else:
        print("No data loaded. Exiting.")
        # import sys
        # sys.exit()
else:
    print(f"{data_file_path} not found. Cannot proceed with training.")
    train_dataset = None # Ensure it's defined

#-------------------------------------------------------------------------------
# --- 4. Load Model and Tokenizer ---
#-------------------------------------------------------------------------------

if train_dataset: # Only proceed if data is loaded
    print(f"\nLoading model: {model_id}")

    compute_dtype = getattr(torch, "bfloat16")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=use_4bit,
        bnb_4bit_quant_type=bnb_4bit_quant_type,
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=False, # Optional
    )

    # Load base model
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_config if use_4bit else None,
        device_map="auto", # Automatically distributes model across available GPUs
        # token=hf_token # if using a token
    )
    model.config.use_cache = False # Recommended for fine-tuning
    model.config.pretraining_tp = 1 # Optional: Set to 1 if not using tensor parallelism

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token # Gemma typically uses eos_token for padding
    tokenizer.padding_side = "right" # Fixes weird overflow issue with fp16 training

    print("Model and tokenizer loaded.")
else:
    print("Skipping model loading as no training data is available.")

#-------------------------------------------------------------------------------
# --- 5. PEFT Setup (LoRA) ---
#-------------------------------------------------------------------------------
if train_dataset and 'model' in locals(): # Check if model was loaded
    print("\nSetting up PEFT (LoRA)...")
    # Prepare model for k-bit training
    if use_4bit: # Only if quantization is used
      model = prepare_model_for_kbit_training(model)

    peft_config = LoraConfig(
        r=lora_r,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        target_modules=lora_target_modules,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()
    print("PEFT model setup complete.")
else:
    print("Skipping PEFT setup as model or data is not available.")

#-------------------------------------------------------------------------------
# --- 6. Training ---
#-------------------------------------------------------------------------------
if train_dataset and 'model' in locals(): # Check if model and data are ready
    print("\nConfiguring training arguments...")
    training_arguments = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=per_device_train_batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        optim="paged_adamw_8bit" if use_4bit else "adamw_torch", # paged_adamw_8bit for 4-bit QLoRA
        learning_rate=learning_rate,
        logging_steps=logging_steps,
        save_strategy=save_strategy,
        save_steps=save_steps if save_strategy == "steps" else 0,
        fp16=True if device == "cuda" and not bnb_4bit_compute_dtype == "torch.bfloat16" else False, # Use fp16 if not using bfloat16
        bf16=True if device == "cuda" and bnb_4bit_compute_dtype == "torch.bfloat16" else False,    # Use bf16 if Ampere and compute_dtype is bfloat16
        max_grad_norm=0.3, # Gradient clipping
        warmup_ratio=0.03, # Warmup ratio
        lr_scheduler_type="constant", # Or "cosine", "linear"
        report_to="tensorboard", # You can use "wandb" if you have it set up
        # evaluation_strategy="steps" if eval_dataset else "no",
        # eval_steps=save_steps if eval_dataset else None, # Evaluate at the same frequency as saving
    )

    print("Initializing SFTTrainer...")
    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset, # Pass eval_dataset here if you have one
        peft_config=peft_config if 'peft_config' in locals() else None, # Pass LoRA config if PEFT is used
        # dataset_text_field="text", # Key in dataset that contains the formatted text
        max_seq_length=max_seq_length,
        tokenizer=tokenizer,
        args=training_arguments,
        packing=False, # Set to True if your examples are short and you want to pack them
    )

    print("\nStarting training...")
    trainer.train()
    print("Training complete.")

    #-------------------------------------------------------------------------------
    # --- 7. Save Adapter ---
    #-------------------------------------------------------------------------------
    print("\nSaving fine-tuned LoRA adapter...")
    adapter_model_path = os.path.join(output_dir, "adapter_final")
    trainer.model.save_pretrained(adapter_model_path) # Saves only the LoRA adapter weights
    tokenizer.save_pretrained(adapter_model_path) # Save tokenizer with adapter for easy loading
    print(f"Adapter saved to {adapter_model_path}")

else:
    print("Skipping training as model or data is not available.")


#-------------------------------------------------------------------------------
# --- 8. Inference (Example) ---
#-------------------------------------------------------------------------------
if os.path.exists(os.path.join(output_dir, "adapter_final")) and 'model_id' in locals(): # Check if adapter and model_id exist
    print("\n--- Testing Fine-tuned Model ---")

    # Load the base model again (quantized or full-precision)
    # Important: If you used quantization during training, load with same quantization for inference.
    base_model_for_inference = AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_config if use_4bit else None,
        device_map="auto",
        torch_dtype=getattr(torch, "bfloat16") if use_4bit else torch.float16, # Match training
        # token=hf_token # if using a token
    )

    # Load the LoRA adapter
    # Make sure the adapter_model_path is correct (where you saved the adapter)
    final_adapter_path = os.path.join(output_dir, "adapter_final")
    ft_model = PeftModel.from_pretrained(base_model_for_inference, final_adapter_path)
    ft_model = ft_model.eval() # Set to evaluation mode

    # Load tokenizer associated with the adapter
    ft_tokenizer = AutoTokenizer.from_pretrained(final_adapter_path)

    # Example prompt from your data structure (or a new one)
    # Use one of the prompts from your sample.txt for a realistic test
    # This should be JUST the JSON string that was in the "prompt" field of your sample.txt
    example_idx_to_test = 0 # Test with the first example's prompt
    if os.path.exists(data_file_path):
        with open(data_file_path, 'r', encoding='utf-8') as f:
            test_lines = f.readlines()
        if test_lines and len(test_lines) > example_idx_to_test:
            try:
                test_prompt_data = json.loads(test_lines[example_idx_to_test])['prompt']
                actual_response_data = json.loads(test_lines[example_idx_to_test])['response']
            except (json.JSONDecodeError, KeyError):
                print("Could not parse test prompt from file, using a placeholder.")
                test_prompt_data = '{"level_idx": 1, "fail_cnt": 0.036, "is_hard": false, " star1": "700", " star2": "6630", " star3": "9900", " move": "30", "color": 3, "R": 1, "G": 1, "B": 1, "O": 0, "PI": 0, "S": 0, "Y": 0, "PU": 0, "bubble": 70, "row": 7, "t_기본버블": 70, "t_스파크": 0, "t_강철": 0, "t_얼음": 0, "t_블랙홀": 0, "t_변신": 0, "t_유령": 0, "t_구름": 0, "t_나무": 0, "t_스위치블랙홀": 0, "t_마이너스플러스": 0, "t_듀오": 0, "t_물감": 0, "t_fixed": 0}'
                actual_response_data = "Unknown"

    else: # Fallback if sample.txt is not found during inference part
        test_prompt_data = '{"level_idx": 1, "fail_cnt": 0.036, "is_hard": false, " star1": "700", " star2": "6630", " star3": "9900", " move": "30", "color": 3, "R": 1, "G": 1, "B": 1, "O": 0, "PI": 0, "S": 0, "Y": 0, "PU": 0, "bubble": 70, "row": 7, "t_기본버블": 70, "t_스파크": 0, "t_강철": 0, "t_얼음": 0, "t_블랙홀": 0, "t_변신": 0, "t_유령": 0, "t_구름": 0, "t_나무": 0, "t_스위치블랙홀": 0, "t_마이너스플러스": 0, "t_듀오": 0, "t_물감": 0, "t_fixed": 0}'
        actual_response_data = "Unknown"

    # Format for inference - it must match the training format UP TO the response part.
    # Training format: "<s>[INST] {instruction} [/INST] {output} </s>"
    # Inference format: "<s>[INST] {instruction} [/INST]"
    inference_prompt = f"<s>[INST] {test_prompt_data} [/INST]"

    print(f"\nTest Prompt:\n{test_prompt_data}")
    if actual_response_data != "Unknown":
        print(f"\nActual Response (from data):\n{actual_response_data[:300]}...") # Show beginning of actual response

    input_ids = ft_tokenizer(inference_prompt, return_tensors="pt", truncation=True, max_length=max_seq_length).input_ids.to(device)

    # Generate response
    # Adjust max_new_tokens based on the expected length of your JSON responses
    # Your responses are quite long, so this might need to be 512 or more.
    print("\nGenerating response...")
    with torch.no_grad():
        outputs = ft_model.generate(
            input_ids=input_ids,
            max_new_tokens=768, # Increased for potentially long JSON output
            do_sample=True,     # Sample for more diverse outputs
            temperature=0.6,    # Lower for more deterministic, higher for more creative
            top_p=0.9,          # Nucleus sampling
            eos_token_id=ft_tokenizer.eos_token_id,
            pad_token_id=ft_tokenizer.pad_token_id # Ensure pad_token_id is set
        )

    generated_text = ft_tokenizer.decode(outputs[0], skip_special_tokens=True)

    # The output will contain the prompt. We need to extract just the model's response part.
    # Response starts after "[/INST]"
    response_part = generated_text.split("[/INST]")[-1].strip()

    print(f"\nGenerated Response (JSON String):\n{response_part}")

    # Optional: Try to parse the generated JSON to see if it's valid
    try:
        parsed_generated_response = json.loads(response_part)
        print("\nSuccessfully parsed generated JSON response:")
        # print(json.dumps(parsed_generated_response, indent=2)) # Pretty print
    except json.JSONDecodeError as e:
        print(f"\nCould not parse generated response as JSON: {e}")

    # --- If you want to use a pipeline for simpler inference ---
    # print("\n--- Testing with pipeline ---")
    # pipe = pipeline(task="text-generation", model=ft_model, tokenizer=ft_tokenizer, device_map="auto")
    # result = pipe(inference_prompt, max_new_tokens=768, do_sample=True, temperature=0.6, top_p=0.9)
    # generated_text_pipe = result[0]['generated_text']
    # response_part_pipe = generated_text_pipe.split("[/INST]")[-1].strip()
    # print(f"\nGenerated Response (Pipeline):\n{response_part_pipe}")

else:
    print("Skipping inference as fine-tuned model adapter or model_id is not available.")

#-------------------------------------------------------------------------------
# --- 9. (Optional) Merge and Save Full Model ---
#-------------------------------------------------------------------------------
# If you want to merge the LoRA adapter with the base model to create a standalone
# fine-tuned model, you can do the following. This will require more disk space
# and memory to load.

# print("\n--- (Optional) Merging LoRA adapter with base model ---")
# if os.path.exists(os.path.join(output_dir, "adapter_final")) and 'model_id' in locals():
#     # Load base model without quantization for merging, or ensure it's compatible
#     # Merging works best if the base model is loaded in full or compatible precision (e.g., float16)
#     base_model_for_merging = AutoModelForCausalLM.from_pretrained(
#         model_id,
#         torch_dtype=torch.float16, # or torch.bfloat16
#         device_map="auto", # Use CPU or a single GPU to avoid device map issues during merge
#         # token=hf_token # if using a token
#     )
#
#     # Load PEFT model (adapter)
#     merged_model = PeftModel.from_pretrained(base_model_for_merging, os.path.join(output_dir, "adapter_final"))
#
#     # Merge the adapter into the base model
#     print("Merging model...")
#     merged_model = merged_model.merge_and_unload()
#     print("Merge complete.")
#
#     # Save the merged model
#     merged_model_path = os.path.join(output_dir, "merged_model_final")
#     merged_model.save_pretrained(merged_model_path)
#     tokenizer.save_pretrained(merged_model_path) # Also save tokenizer with it
#     print(f"Merged model saved to {merged_model_path}")
#
#     # You can then load this merged_model directly:
#     # from transformers import AutoModelForCausalLM, AutoTokenizer
#     # loaded_merged_model = AutoModelForCausalLM.from_pretrained(merged_model_path)
#     # loaded_tokenizer = AutoTokenizer.from_pretrained(merged_model_path)
# else:
#     print("Skipping model merging as adapter or base model_id is not available.")

GPU: Tesla T4 is available.
Found sample.txt. Proceeding with data loading.

Loaded and formatted 10 samples.
Sample formatted entry:
<s>[INST] {"level_idx": 1, "fail_cnt": 0.036, "is_hard": false, " star1": "700", " star2": "6630", " star3": "9900", " move": "30", "color": 3, "R": 1, "G": 1, "B": 1, "O": 0, "PI": 0, "S": 0, "Y": 0, "PU": 0, "bubble": 70, "row": 7, "t_기본버블": 70, "t_스파크": 0, "t_강철": 0, "t_얼음": 0, "t_블랙홀": 0, "t_변신": 0, "t_유령": 0, "t_구름": 0, "t_나무": 0, "t_스위치블랙홀": 0, "t_마이너스플러스": 0, "t_듀오": 0, "t_물감": 0, "t_fixed": 0}
 [/INST] {"author":"whlim","index":1,"bubbles":[{"x":0,"y":0,"c":1,"t":0},{"x":1,"y":0,"c":1,"t...

Loading model: google/gemma-2-2b-it


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model and tokenizer loaded.

Setting up PEFT (LoRA)...
trainable params: 20,766,720 || all params: 2,635,108,608 || trainable%: 0.7881
PEFT model setup complete.

Configuring training arguments...
Initializing SFTTrainer...


TypeError: SFTTrainer.__init__() got an unexpected keyword argument 'max_seq_length'