In [None]:
!pip install protobuf==3.20.3 trl peft accelerate bitsandbytes unsloth

In [None]:
from unsloth import FastLanguageModel
import json
import torch
from datasets import Dataset
from transformers import TrainingArguments, AutoTokenizer
from trl import SFTTrainer

def interactive_test(model, tokenizer):
    """
    Creates an interactive loop to test the model with user prompts.
    """
    print("\n--- Interactive Model Test ---")
    print("Enter a prompt to test the fine-tuned model.")
    print("Type 'save' to finish testing and save the GGUF model.")
    print("Type 'cancel' to exit without saving.")
    print("------------------------------------")

    while True:
        # Get input from the user
        user_input = input("\nPrompt: ")

        # Check for control commands
        if user_input.lower() == "save":
            print("\nProceeding to save the model...")
            return True  # Signal to continue and save
        elif user_input.lower() == "cancel":
            print("\nCanceling save. Exiting script.")
            return False  # Signal to exit without saving

        # Prepare the input for the model
        messages = [
            {"role": "user", "content": user_input},
        ]
        inputs = tokenizer.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
        ).to("cuda")

        # Generate a response
        outputs = model.generate(input_ids=inputs, max_new_tokens=256, use_cache=True)
        response_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

        # Clean up the output to only show the assistant's part
        # Note: The split logic may vary slightly based on the exact model output format.
        # This is a common way to parse it.
        assistant_response = response_text.split("<|assistant|>")
        if len(assistant_response) > 1:
            clean_response = assistant_response[1].strip()
        else:
            # Fallback for models that don't add the user prompt to the output
            # or have a different format.
            if "user" in response_text and "assistant" in response_text:
                 clean_response = response_text.split("assistant")[-1].strip()
            else:
                 clean_response = response_text

        print(f"Model: {clean_response}")

# Import Dataset
try:
    with open("data.json", "r") as f:
        file = json.load(f)
    print("Successfully loaded data.json.")
    print("Sample record:", file[1])
except FileNotFoundError:
    print("Error: data.json not found. Please upload it to the Colab session.")
except Exception as e:
    print(f"An error occurred: {e}")


# Load the base model
model_name = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
max_seq_length = 2048
dtype = None  # Auto detection

# Load model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=True,
)

# Create the dataset directly from the loaded file
dataset = Dataset.from_list(file)

# Formatting prompts so they can be sent like this {"context": "今天天气不错，不冷不热。晚饭后，我们决定去公园散步。那里的风景很优美，空气也很新鲜。", "target_sentence": "真是个锻练身体的好地方。"}
def format_prompts(batch):
    """
    Takes a batch of examples and returns a list of formatted strings
    in a chat format where the user input is a JSON object.
    """
    contexts = batch["context"]
    target_sentences = batch["target_sentence"]
    outputs = batch["output"]

    texts = []
    for context, target, output in zip(contexts, target_sentences, outputs):
        input_json = {
            "context": context,
            "target_sentence": target
        }

        # The user provides the simple JSON, and the assistant provides the correction JSON followed by the EOS token.
        prompt = f"<|user|>\n{json.dumps(input_json, ensure_ascii=False)}\n<|assistant|>\n{json.dumps(output, ensure_ascii=False)}<|endoftext|>"
        texts.append(prompt)

    return texts

# Add LoRA adapters
model = FastLanguageModel.get_peft_model(
    model,
    r=64,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha=128,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    use_rslora=False,
    loftq_config=None,
)

# Training arguments
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    formatting_func=format_prompts,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    args=TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        warmup_steps=10,
        num_train_epochs=3,
        learning_rate=2e-4,
        fp16=not torch.cuda.is_bf16_supported(),
        bf16=torch.cuda.is_bf16_supported(),
        logging_steps=25,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir="outputs",
        save_strategy="epoch",
        save_total_limit=2,
        dataloader_pin_memory=False, # Important for Colab
        report_to="none",
    ),
)

# Train the model
print("\n--- Starting Model Training ---")
trainer_stats = trainer.train()
print("--- Model Training Finished ---")


# # Merge and save the 16-bit model
# merged_model_path = "merged_16bit_model"
# model.save_pretrained_merged(merged_model_path, tokenizer, save_method="merged_16bit")

# # Reload the merged model for testing
# model, tokenizer = FastLanguageModel.from_pretrained(
#     model_name=merged_model_path,
#     dtype=dtype,
#     load_in_4bit=False,
# )

# Run interactive test
should_save = interactive_test(model, tokenizer)

# Save the final GGUF model if requested
if should_save:
    model.save_pretrained_gguf(
        "gguf_model", tokenizer, quantization_method="q4_k_m"
    )
    print("\nGGUF model saved successfully in the 'gguf_model' file.")
    print("You can download it from the file browser on the left.")
else:
    print("\nExiting. The final GGUF model was not saved.")