In [1]:
import os

# os.environ["WANDB_API_KEY"] = '+++++++++++'  # 将引号内的+替换成自己在wandb上的一串值
# os.environ["WANDB_MODE"] = "offline"  # 离线  （此行代码不用修改）

import json

import pandas as pd
import torch
from datasets import Dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
# from swanlab.integration.huggingface import SwanLabCallback
from transformers import DataCollatorForSeq2Seq, Trainer
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
glm4_model_path = '/home/model/para_glm4'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# 加载训练集和测试集
tokenizer = AutoTokenizer.from_pretrained(glm4_model_path, use_fast=False,
                                          trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(glm4_model_path,
                                             #quantization_config=bnb_config,
                                             device_map="auto", torch_dtype=torch.bfloat16,
#                                              attn_implementation="flash_attention_2",
                                             trust_remote_code=True)
model.config.use_cache = False

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████| 10/10 [01:27<00:00,  8.79s/it]


In [4]:
model.enable_input_require_grads()  #  开启梯度检查点
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
#结束标识符
print(tokenizer.eos_token)
tokenizer.encode('<|endoftext|>')

<|endoftext|>


[151331, 151333, 151329]

In [5]:
train_dataset_path = "wiki/wikisql_train.csv"  
df = pd.read_csv(train_dataset_path)
print(df)

                                               question  \
0     What is the number for the international with ...   
1     What position(s) does the player drafted #34 p...   
2     What is the average Gold award that has a Silv...   
3     What country came in third when there were 13 ...   
4     What is the number of bronze when the total is...   
...                                                 ...   
7330  What is the rank for the player with 5 wins an...   
7331  How many wins were there when draws were more ...   
7332  What is the point total for the season with 2 ...   
7333                 Which venue hosted a race in 1967?   
7334  What is the average pick with 85 overall in a ...   

                                           create_table  \
0     CREATE TABLE 2-17673820-2 (\n  "Year" TEXT,\n ...   
1     CREATE TABLE 1-14650373-1 (\n  "Pick__" TEXT,\...   
2     CREATE TABLE 2-12573588-9 (\n  "Rank" TEXT,\n ...   
3     CREATE TABLE 2-15526447-1 (\n  "Season" TEXT,\...

In [6]:
system_prompt = """
You are an intelligent SQL generation assistant.  
Your task is to generate a valid SQL query using the schema provided in <database>...</database>,  
based on the natural language request given in <question>...</question>.  

Return your answer strictly in the following format:  
<answer>YOUR_GENERATED_SQL</answer>  
"""
print(system_prompt)


You are an intelligent SQL generation assistant.  
Your task is to generate a valid SQL query using the schema provided in <database>...</database>,  
based on the natural language request given in <question>...</question>.  

Return your answer strictly in the following format:  
<answer>YOUR_GENERATED_SQL</answer>  



In [7]:
def process_func(example):
    """
    将数据集进行预处理  {example['input']}
    """
    MAX_LENGTH = 700 
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer(
        f"<|system|>\n{system_prompt}<|endoftext|>\n<|user|><question>{example['question']}</question>\n<database>{example['create_table']}</database>\n<|endoftext|>\n<|assistant|>\n",
        add_special_tokens=False,
    )
    response = tokenizer(f"<answer>\n{example['sql']}</answer>", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = (
        instruction["attention_mask"] + response["attention_mask"] + [1]
    )
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

In [8]:
train_ds = Dataset.from_pandas(df)
train_dataset = train_ds.map(process_func, remove_columns=train_ds.column_names)

Map: 100%|█████████████████████████████████████████████████████████████████| 7335/7335 [00:06<00:00, 1077.86 examples/s]


In [9]:
print(train_dataset)

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 7335
})


In [10]:
first_item = train_dataset[0]

# input_ids 解码
decoded_input = tokenizer.decode(first_item["input_ids"], skip_special_tokens=True)

# labels 解码（有的任务 labels 会把 pad 部分设成 -100，要先替换回 pad_token_id）
labels_ids = [
    token_id if token_id != -100 else tokenizer.pad_token_id
    for token_id in first_item["labels"]
]
decoded_labels = tokenizer.decode(labels_ids, skip_special_tokens=True)

print("=== Input ===")
print(decoded_input)

print("\n=== Label ===")
print(decoded_labels)

=== Input ===


You are an intelligent SQL generation assistant.  
Your task is to generate a valid SQL query using the schema provided in <database>...</database>,  
based on the natural language request given in <question>...</question>.  

Return your answer strictly in the following format:  
<answer>YOUR_GENERATED_SQL</answer>  

<question>What is the number for the international with 669 domestic earlier than 2005?</question>
<database>CREATE TABLE 2-17673820-2 (
  "Year" TEXT,
  "Domestic" TEXT,
  "International" TEXT,
  "Total" TEXT,
  "Change" TEXT
);</database>


<answer>
SELECT AVG("International") FROM 2-17673820-2 WHERE "Domestic" = 669 AND "Year" < 2005</answer>

=== Label ===
<answer>
SELECT AVG("International") FROM 2-17673820-2 WHERE "Domestic" = 669 AND "Year" < 2005</answer>


In [11]:
lora_r = 64
lora_alpha = 32
lora_dropout = 0.1
output_dir = "./A-GLM4"
num_train_epochs = 4
bf16 = True
overwrite_output_dir = True
per_device_train_batch_size = 2
# per_device_eval_batch_size = 2
gradient_accumulation_steps = 16
gradient_checkpointing = True
evaluation_strategy = "steps"
learning_rate = 5e-5
weight_decay = 0.01
lr_scheduler_type = "cosine"
warmup_ratio = 0.01
max_grad_norm = 0.3
group_by_length = True
auto_find_batch_size = False
save_steps = 40
logging_steps = 50
load_best_model_at_end= False
packing = False
save_total_limit=2
neftune_noise_alpha=5
# report_to="wandb"
max_seq_length = 700

In [12]:
peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    target_modules=[
       "query_key_value", "dense", "dense_h_to_4h", "activation_func", "dense_4h_to_h"
    ],
    task_type=TaskType.CAUSAL_LM,
)

In [13]:
training_arguments = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=overwrite_output_dir,
    num_train_epochs=num_train_epochs,
    load_best_model_at_end=load_best_model_at_end,
    per_device_train_batch_size=per_device_train_batch_size,
#     evaluation_strategy=evaluation_strategy,
    max_grad_norm = max_grad_norm,
    auto_find_batch_size = auto_find_batch_size,
    save_total_limit = save_total_limit,
    gradient_accumulation_steps=gradient_accumulation_steps,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    bf16=bf16,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    report_to="none",
    neftune_noise_alpha= neftune_noise_alpha
)

In [14]:
model = get_peft_model(model, peft_config)
trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=train_dataset,
#     eval_dataset= val_dataset,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
#     callbacks=[swanlab_callback],
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [15]:
trainer.train()

Step,Training Loss
50,0.2525
100,0.044
150,0.0338
200,0.0331
250,0.0322
300,0.0257
350,0.0246
400,0.0223
450,0.0218
500,0.0169




TrainOutput(global_step=916, training_loss=0.034736334887133936, metrics={'train_runtime': 3295.1572, 'train_samples_per_second': 8.904, 'train_steps_per_second': 0.278, 'total_flos': 3.104837138270945e+17, 'train_loss': 0.034736334887133936, 'epoch': 3.995637949836423})

In [16]:
output_dir = os.path.join("./", "A-GLM4/final_adatper_GLM4")
trainer.model.save_pretrained(output_dir)



In [17]:
# def make_user_prompt(question, database_schema):
#     return f"<question>{question}</question>\n<database>{database_schema}</database>"

# def make_reponse(target_schema):
#     return f"<answer>\n{target_schema}</answer>"

In [15]:
# def make_conversation(question, database_schema,target_schema):
#     conversation = []
#     conversation0 = {}
#     conversation1 = {}
#     conversation2 = {}
#     user_prompt = make_user_prompt(question, database_schema)
#     assistance_prompt = make_reponse(target_schema)
#     conversation0["role"] = "system"
#     conversation0["content"] = system_prompt
#     conversation1["role"] = "user"
#     conversation1["content"] = user_prompt + "\n"
#     conversation2["role"] = "assistant"
#     conversation2["content"] = assistance_prompt
#     conversation.append(conversation0)
#     conversation.append(conversation1)
#     conversation.append(conversation2)
#     all_info = system_prompt + user_prompt + assistance_prompt
# #     print(conversation2["content"])
#     return conversation, all_info

In [16]:
# trans_data = {}
# message = []
# max_token = 0
# for index, row in df.iterrows():
#     conversation, all_info = make_conversation(row["question"],row["create_table"],row["sql"])
#     message.append(conversation)
#     token_len = len(tokenizer(all_info, add_special_tokens=False)["input_ids"])
# #     print(all_info)
#     if  token_len >  max_token:
#         max_token = token_len
# trans_data["messages"] = message

In [17]:
# trans_data = pd.DataFrame(trans_data)
# trans_data = Dataset.from_pandas(trans_data,split="train")
# for sample in trans_data:
#     print(sample)
#     break  # 仅打印第一个样本

{'messages': [{'content': '\nYou are an intelligent SQL generation assistant.  \nYour task is to generate a valid SQL query using the schema provided in <database>...</database>,  \nbased on the natural language request given in <question>...</question>.  \n\nReturn your answer strictly in the following format:  \n<answer>YOUR_GENERATED_SQL</answer>  \n', 'role': 'system'}, {'content': '<question>What is the number for the international with 669 domestic earlier than 2005?</question>\n<database>CREATE TABLE 2-17673820-2 (\n  "Year" TEXT,\n  "Domestic" TEXT,\n  "International" TEXT,\n  "Total" TEXT,\n  "Change" TEXT\n);</database>\n', 'role': 'user'}, {'content': '<answer>\nSELECT AVG("International") FROM 2-17673820-2 WHERE "Domestic" = 669 AND "Year" < 2005</answer>', 'role': 'assistant'}]}


In [18]:
# print(max_token)

582


In [None]:
# trans_data = {}
# message = []
# max_token = 0
# for index, row in df.iterrows():
#     conversation, all_info = make_conversation(row["question"],row["database_schema"],row["target_schema"])
#     message.append(conversation)
#     token_len = len(tokenizer(all_info, add_special_tokens=False)["input_ids"])
# #     print(all_info)
#     if  token_len >  max_token:
#         max_token = token_len
# trans_data["messages"] = message

In [5]:
#获取最大toekn数
# train_path = "/home/code/GLM4_Lora_NER-master/Gen_SQL_dataset/T2Q_GLM4_SFT_train_SQL.jsonl"
# max_len = 0
# with open(train_path, "r") as file:
#         for line in file:
#             # 解析每一行的json数据
#             example = json.loads(line)
#             instruction = tokenizer(
#         f"<|system|>\n Given the following SQL tables, your job is to generate the Sqlite SQL query given the user's question.<|endoftext|>\n<|user|>\n{example['input']}<|endoftext|>\n<|assistant|>\n",
#         add_special_tokens=False, )
#             response = tokenizer(f"{example['output']}", add_special_tokens=False)
#             input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
#             attention_mask = (
#             instruction["attention_mask"] + response["attention_mask"] + [1])
#             labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
#             if len(input_ids) > max_len:
#                 max_len = len(input_ids)
# print("max_len:",max_len)   
# #1810

max_len: 1912


In [6]:
#  和下面的max_len 作一个合并
#  max_seq_length = 2100
# def process_func(example):
#     """
#     将数据集进行预处理
#     """
#     MAX_LENGTH = 2100 
#     input_ids, attention_mask, labels = [], [], []
#     instruction = tokenizer(
#         f"<|system|>\nGiven the following SQL tables, your job is to determine the tables that the question is referring to.<|endoftext|>\n<|user|>\n{example['input']}<|endoftext|>\n<|assistant|>\n",
#         add_special_tokens=False,
#     )
#     response = tokenizer(f"{example['output']}", add_special_tokens=False)
#     input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
#     attention_mask = (
#         instruction["attention_mask"] + response["attention_mask"] + [1]
#     )
#     labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
#     if len(input_ids) > MAX_LENGTH:  # 做一个截断
#         input_ids = input_ids[:MAX_LENGTH]
#         attention_mask = attention_mask[:MAX_LENGTH]
#         labels = labels[:MAX_LENGTH]
#     return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

In [7]:
#  set dataset
# train_path = "/home/code/GLM4_Lora_NER-master/Gen_SQL_dataset/T2Q_GLM4_SFT_train_SQL.jsonl"
# val_path = "/home/code/GLM4_Lora_NER-master/Gen_SQL_dataset/T2Q_GLM4_SFT_val_SQL.jsonl"

# train_df = pd.read_json(train_path, lines=True)
# train_ds = Dataset.from_pandas(train_df)
# train_dataset = train_ds.map(process_func, remove_columns=train_ds.column_names)


# val_df = pd.read_json(val_path, lines=True)
# val_ds = Dataset.from_pandas(val_df)
# val_dataset = val_ds.map(process_func, remove_columns=val_ds.column_names)

                                                                                                                                     

In [8]:
#  train_config
#  示例里 r 为 8
lora_r = 8
lora_alpha = 32
lora_dropout = 0.1
output_dir = "./SFT_SQL"
num_train_epochs = 2
bf16 = True
overwrite_output_dir = True
per_device_train_batch_size = 2
per_device_eval_batch_size = 2
gradient_accumulation_steps = 16
gradient_checkpointing = True
evaluation_strategy = "steps"
learning_rate = 5e-5
weight_decay = 0.01
lr_scheduler_type = "cosine"
warmup_ratio = 0.01
max_grad_norm = 0.3
group_by_length = True
auto_find_batch_size = False
save_steps = 50
logging_steps = 10
load_best_model_at_end= False
packing = False
save_total_limit=4
neftune_noise_alpha=5
# report_to="wandb"
max_seq_length = 2100

In [9]:
peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    target_modules=[
       "query_key_value", "dense", "dense_h_to_4h", "activation_func", "dense_4h_to_h"
    ],
    task_type=TaskType.CAUSAL_LM,
)

In [10]:
swanlab_callback = SwanLabCallback(
    project="GLM4-SFT_T2QSQL",
    experiment_name="GLM4-9B-Chat",
    description="使用智谱GLM4-9B-Chat模型在spider数据集上微调 生成SQL。",
    config={
        "model": "/home/LLM_para/para_glm4",
        "dataset": "T2Q_GLM4_SFT_train_SQL.jsonl",
    },
)

In [11]:
training_arguments = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=overwrite_output_dir,
    num_train_epochs=num_train_epochs,
    load_best_model_at_end=load_best_model_at_end,
    per_device_train_batch_size=per_device_train_batch_size,
    evaluation_strategy=evaluation_strategy,
    max_grad_norm = max_grad_norm,
    auto_find_batch_size = auto_find_batch_size,
    save_total_limit = save_total_limit,
    gradient_accumulation_steps=gradient_accumulation_steps,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    bf16=bf16,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    report_to="none",
    neftune_noise_alpha= neftune_noise_alpha
)



In [12]:
# response_template = "### Response:"
# collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
# collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True)
# trainer = SFTTrainer(
#     model=model,
#     train_dataset=dataset['train'],
#     eval_dataset=dataset['validation'],
#     peft_config=peft_config,
#     data_collator=collator,
#     args=training_arguments,
#     max_seq_length=max_seq_length,
#     packing=packing
# )
model = get_peft_model(model, peft_config)
trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=train_dataset,
    eval_dataset= val_dataset,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
    callbacks=[swanlab_callback],
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [13]:
trainer.train()

[1m[34mswanlab[0m[0m: Tracking run with swanlab version 0.3.10                                  
[1m[34mswanlab[0m[0m: Run data will be saved locally in [35m[1m/home/code/GLM4_Lora_NER-master/swanlog/run-20240615_073134-a3b1799d[0m[0m
[1m[34mswanlab[0m[0m: 👋 Hi [1m[39mwinhong[0m[0m, welcome to swanlab!
[1m[34mswanlab[0m[0m: Syncing run [33mGLM4-9B-Chat_Jun15_07-31-34[0m to the cloud
[1m[34mswanlab[0m[0m: 🌟 Run `[1mswanlab watch -l /home/code/GLM4_Lora_NER-master/swanlog[0m` to view SwanLab Experiment Dashboard locally
[1m[34mswanlab[0m[0m: 🏠 View project at [34m[4mhttps://swanlab.cn/@winhong/GLM4-SFT_T2QSQL[0m[0m
[1m[34mswanlab[0m[0m: 🚀 View run at [34m[4mhttps://swanlab.cn/@winhong/GLM4-SFT_T2QSQL/runs/mmk3sxabfclfrqrwmxh6c[0m[0m


Step,Training Loss,Validation Loss
10,2.3018,0.779119
20,0.5217,0.487665
30,0.5235,0.454727
40,0.5233,0.463469
50,0.5896,0.460688
60,0.4699,0.424108
70,0.4213,0.423039
80,0.3843,0.416103
90,0.4011,0.421988
100,0.4345,0.418301


[33mswanlab[0m: Step 10 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 20 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 30 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 40 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 50 on key train/epoch already exists, ignored.




[33mswanlab[0m: Step 60 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 70 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 80 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 90 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 100 on key train/epoch already exists, ignored.




[33mswanlab[0m: Step 110 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 120 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 130 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 140 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 150 on key train/epoch already exists, ignored.




[33mswanlab[0m: Step 160 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 170 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 180 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 190 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 200 on key train/epoch already exists, ignored.




[33mswanlab[0m: Step 210 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 220 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 230 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 240 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 250 on key train/epoch already exists, ignored.




[33mswanlab[0m: Step 260 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 270 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 280 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 290 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 300 on key train/epoch already exists, ignored.




[33mswanlab[0m: Step 310 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 320 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 330 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 340 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 350 on key train/epoch already exists, ignored.




[33mswanlab[0m: Step 360 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 370 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 380 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 390 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 400 on key train/epoch already exists, ignored.




[33mswanlab[0m: Step 410 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 420 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 430 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 440 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 450 on key train/epoch already exists, ignored.




[33mswanlab[0m: Step 460 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 470 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 480 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 490 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 500 on key train/epoch already exists, ignored.




[33mswanlab[0m: Step 510 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 520 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 530 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 540 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 550 on key train/epoch already exists, ignored.




[33mswanlab[0m: Step 560 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 570 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 580 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 590 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 600 on key train/epoch already exists, ignored.




[33mswanlab[0m: Step 610 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 620 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 630 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 640 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 650 on key train/epoch already exists, ignored.




[33mswanlab[0m: Step 660 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 670 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 680 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 690 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 700 on key train/epoch already exists, ignored.




[33mswanlab[0m: Step 710 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 720 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 730 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 740 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 750 on key train/epoch already exists, ignored.




[33mswanlab[0m: Step 760 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 770 on key train/epoch already exists, ignored.
[33mswanlab[0m: Step 780 on key train/epoch already exists, ignored.


TrainOutput(global_step=786, training_loss=0.308908224712498, metrics={'train_runtime': 12813.9588, 'train_samples_per_second': 1.969, 'train_steps_per_second': 0.061, 'total_flos': 8.549260910788608e+17, 'train_loss': 0.308908224712498, 'epoch': 2.9900142653352355})

                                         instruction  \
0  Given the following SQL tables, your job is to...   
1  Given the following SQL tables, your job is to...   
2  Given the following SQL tables, your job is to...   
3  Given the following SQL tables, your job is to...   
4  Given the following SQL tables, your job is to...   
5  Given the following SQL tables, your job is to...   
6  Given the following SQL tables, your job is to...   
7  Given the following SQL tables, your job is to...   
8  Given the following SQL tables, your job is to...   
9  Given the following SQL tables, your job is to...   

                                               input  \
0  all tabels and samples:\nCREATE TABLE `stadium...   
1  all tabels and samples:\nCREATE TABLE `stadium...   
2  all tabels and samples:\nCREATE TABLE `stadium...   
3  all tabels and samples:\nCREATE TABLE `stadium...   
4  all tabels and samples:\nCREATE TABLE `stadium...   
5  all tabels and samples:\nCREATE TABLE `stadi

In [14]:
output_dir = os.path.join("./", "final_checkpoint_newS_SFT_SQL")
trainer.model.save_pretrained(output_dir)



In [15]:
def predict(messages, model, tokenizer):
#     device = "cuda"
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(device)

    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens=512
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]
    
    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    
    print(response)
     
    return response

In [16]:
# 用测试集的前10条，测试模型
test_df = val_df[:10]
print(test_df)
print(1)

In [17]:
# import swanlab
test_text_list = []
for index, row in test_df.iterrows():
    instruction = row['instruction']
    input_value = row['input']
    
    messages = [
        {"role": "system", "content": f"{instruction}"},
        {"role": "user", "content": f"{input_value}"}
    ]

    response = predict(messages, model, tokenizer)
#     print(response)
#     messages.append({"role": "assistant", "content": f"{response}"})
#     result_text = f"{messages[0]}\n\n{messages[1]}\n\n{messages[2]}"
#     test_text_list.append(swanlab.Text(result_text, caption=response))
    
# swanlab.log({"Prediction": test_text_list})
# swanlab.finish()