To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
<div class="align-center">
<a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
<a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
<a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
</div>

To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).

You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)


### News

**Read our [Gemma 3 blog](https://unsloth.ai/blog/gemma3) for what's new in Unsloth and our [Reasoning blog](https://unsloth.ai/blog/r1-reasoning) on how to train reasoning models.**

Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).


### Installation

In [3]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
    !pip install --no-deps unsloth
    !pip install datasets

### Unsloth

In [36]:
import json
from datasets import Dataset

class LocalJsonDataset:
    def __init__(self, json_file, data_template, tokenizer, max_seq_length=2048):
        self.json_file = json_file
        self.data_template = data_template
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.dataset = self.load_dataset()

    def load_dataset(self):
        with open(self.json_file, 'r', encoding='utf-8') as f:
            data = json.load(f)

        texts = []
        for item in data:
            text = self.data_template.format(item['input'], item['output']) + self.tokenizer.eos_token
            texts.append(text)

        dataset_dict = {
            'text': texts  # 添加'text'字段以适配SFTTrainer
        }

        dataset = Dataset.from_dict(dataset_dict)
        return dataset

    def get_dataset(self):
        return self.dataset


# 加载和预处理 AFSIM 数据集
custom_prompt = """你是AFSIM的专家，请根据以下任务描述生成对应的AFSIM仿真脚本。
### 任务描述:
{}
### AFSIM脚本:
{}"""

custom_dataset = LocalJsonDataset(
    json_file='formatted_afsim_dataset.json',
    data_template=custom_prompt,
    tokenizer=tokenizer,
    max_seq_length=max_seq_length
)

dataset = custom_dataset.get_dataset()

In [41]:
# 加载模型和分词器
from unsloth import FastLanguageModel
from safetensors.torch import load_model, save_model

max_seq_length = 2048
dtype = None
load_in_4bit = True

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/gemma-2-2b-it-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    use_rslora=False,
    loftq_config=None,
)

# 加载和预处理 AFSIM 数据集
custom_prompt = """你是AFSIM的专家，请根据以下任务描述生成对应的AFSIM仿真脚本。
### 任务描述:
{}
### AFSIM脚本:
{}"""

custom_dataset = LocalJsonDataset(
    json_file='formatted_afsim_dataset.json',
    data_template=custom_prompt,
    tokenizer=tokenizer,
    max_seq_length=max_seq_length
)

dataset = custom_dataset.get_dataset()

# 设置训练配置
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    args=TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        warmup_steps=5,
        max_steps=60,
        learning_rate=2e-4,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="linear",
        seed=3407,
        output_dir="outputs_afsim"
    ),
)

print("训练开始")
trainer.train()
print("训练完成")

# 保存模型
model.save_pretrained("lora_afsim_model")
tokenizer.save_pretrained("lora_afsim_model")
print("模型保存完成")

==((====))==  Unsloth 2025.3.19: Fast Gemma2 patching. Transformers: 4.50.3.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth: Tokenizing ["text"] (num_proc=2):   0%|          | 0/800 [00:00<?, ? examples/s]

训练开始


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 800 | Num Epochs = 1 | Total steps = 60
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 20,766,720/2,000,000,000 (1.04% trained)


Step,Training Loss
1,2.3876
2,2.4333
3,2.2558
4,2.2413
5,2.0128
6,1.6514
7,1.386
8,1.1549
9,0.9624
10,0.7544


训练完成
模型保存完成


测试训练后的模型代码

In [43]:
# test_afsim_model.py

from transformers import AutoTokenizer
from unsloth import FastLanguageModel
import torch

# 设置最大长度
max_seq_length = 2048
load_in_4bit = True

# 1. 加载训练好的模型和 tokenizer
print("加载模型中...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "lora_afsim_model",  # 微调后保存的路径
    max_seq_length = max_seq_length,
    load_in_4bit = load_in_4bit,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ],
    lora_alpha=16,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=3407,
    use_rslora=False,
    loftq_config=None,
)

model.eval()
print("模型加载完成。")

# 2. 定义测试任务描述
test_task = """两架战斗机从不同机场起飞，在空中进行编队飞行训练，保持相对间距1000米，飞行高度8000米。"""

prompt = f"""你是AFSIM的专家，请根据以下任务描述生成对应的AFSIM仿真脚本。
### 任务描述:
{test_task}
### AFSIM脚本:"""

# 3. 编码输入
device = "cuda" if torch.cuda.is_available() else "cpu"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
model = model.to(device)

# 4. 推理生成
with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=512,
        temperature=0.7,
        top_p=0.9,
        do_sample=True
    )

# 5. 解码输出
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("\n=== 生成结果 ===")
print(result)


加载模型中...
==((====))==  Unsloth 2025.3.19: Fast Gemma2 patching. Transformers: 4.50.3.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth: Already have LoRA adapters! We shall skip this step.


模型加载完成。

=== 生成结果 ===
你是AFSIM的专家，请根据以下任务描述生成对应的AFSIM仿真脚本。
### 任务描述:
两架战斗机从不同机场起飞，在空中进行编队飞行训练，保持相对间距1000米，飞行高度8000米。
### AFSIM脚本:
platform_type FIGHTER_TYPE WSF_PLATFORM
  icon f16
  spatial_domain air
  add mover WSF_AIR_MOVER
  add sensor main_sensor
  add processor main_proc
end_platform_type

platform_sensor MAIN_SENSOR WSF_IRST_SENSOR
  on
  frame_time 5 sec
  maximum_range 120 km
  reports_location
end_platform_sensor

platform_processor MAIN_PROCESSOR WSF_TASK_PROCESSOR
  master_track_processor
end_platform_processor

platform_type ADV_PLATFORM WSF_PLATFORM
  icon e3
  spatial_domain air
  add mover WSF_AIR_MOVER
  add sensor main_sensor
  add processor main_proc
end_platform_type

platform_sensor MAIN_SENSOR ADV_SENSOR
  on
  frame_time 5 sec
  maximum_range 120 km
  reports_location
end_platform_sensor

platform_processor MAIN_PROCESSOR ADV_PROCESSOR
  master_track_processor
end_platform_processor

platform_type FIGHTER_SITE FIGHTER_TYPE
  side red
  position 33.97N 114.62E
  a