In [1]:
import json

import pandas as pd
import torch
from datasets import Dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from swanlab.integration.huggingface import SwanLabCallback
from transformers import DataCollatorForSeq2Seq, Trainer
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
glm4_model_path = '/home/LLM_para/para_glm4'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载model和tokenizer
tokenizer = AutoTokenizer.from_pretrained(glm4_model_path, use_fast=False,
                                          trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(glm4_model_path,
                                             #quantization_config=bnb_config,
                                             device_map="auto", torch_dtype=torch.bfloat16,
#                                              attn_implementation="flash_attention_2",
                                             trust_remote_code=True)
model.config.use_cache = False


model.enable_input_require_grads()  #  开启梯度检查点
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
#结束标识符
print(tokenizer.eos_token)
tokenizer.encode('<|endoftext|>')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████| 10/10 [01:37<00:00,  9.79s/it]

<|endoftext|>





[151331, 151333, 151329]

In [3]:
def filter_descibe(sql_statement):
    modified_sql = re.sub(r'(\w+)\s+(INTEGER|int|NUMERIC|INT|REAL|TEXT|DATETIME|datetime|bigint\(?\d*\)?|CHAR\(?\d*\)?|char\(?\d*\)?|varchar2\(?\d*\)|VARCHAR\(?\d*\)|varchar\(?\d*\)|DECIMAL\(?\d*,\s*\d*\)?)', r'\1', sql_statement)
    return  modified_sql

In [4]:
from tqdm import tqdm
import copy
import re

# 获取REF表数据
tab_str = ""
count = 0
df = pd.read_csv("./mydataset_new/table_schema_Reference_cropped.csv", encoding="utf-8")
for index, row in tqdm(df.iterrows(), total=len(df)):
    table_group = row['Reference_group']
    table_group = filter_descibe(table_group)
    table_group = table_group.replace("\r","")
    table_group = table_group.replace("\n","")
    table_group = table_group.replace("  "," ")
    tab_str += table_group + "\n"
    
df = pd.read_csv("./mydataset_new/table_schema_noReference_cropped.csv", encoding="utf-8")
for index, row in tqdm(df.iterrows(), total=len(df)):
    table_group = row['noReference_group']
    table_group = filter_descibe(table_group)
    table_group = table_group.replace("\r", "")
    table_group = table_group.replace("\n", "")
    table_group = table_group.replace("  ", " ")
    tab_str += table_group + "\n"

schema_str = tab_str.strip()
check_token = tokenizer(f"{schema_str}", add_special_tokens=False)
print(len(check_token["input_ids"]))
max_token = 5200

100%|███████████████████████████████████████████████████████████████████████████████| 142/142 [00:00<00:00, 6399.18it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 30/30 [00:00<00:00, 13092.20it/s]

4900





In [5]:
def fill_dataset_template(schema_pad, question_pad, answer_pad):
    message = {
        "instruction": f" I want you to act as a relation extraction robot for a sample SQL table. You only need to return the tables related to the user's input question. Below are instructions describing the relationship between tables. Please write a response that appropriately completes the request. \n##instruction:{schema_pad}",
        "input": f"{question_pad}",
        "output": f"```Reference Table\n-- Tables: {answer_pad};\n```",
    }
    return message

In [6]:
df = pd.read_csv("./mydataset_new/full_finetuning_dataset_cropped.csv",
                 encoding="utf-8")
outer_index = 0
message_group = []
for i in range(len(df['db_id'])):
    question = df['question'][i]
    query = df['query'][i]
    correct_table = df['correct_tables'][i]
    dict_target_table_slice = {}
    message = fill_dataset_template(schema_str, question, correct_table)
    message_group.append(message)

with open("T2Q_GLM4_SFT_train_Table_myidea.jsonl", "w", encoding="utf-8") as file:
    for message in message_group:
        file.write(json.dumps(message, ensure_ascii=False) + "\n")

print(len(message_group))

1736


In [7]:
train_path = "T2Q_GLM4_SFT_train_Table_myidea.jsonl"
max_len = 0
count = 0
with open(train_path, "r") as file:
        for line in file:
            count+=1
            # 解析每一行的json数据
            example = json.loads(line)
            if count == 1:
                print(f"<|system|>\n {example['instruction']}.<|endoftext|>\n<|user|>\n{example['input']}<|endoftext|>\n<|assistant|>\n")
            print(f"index:---------{count}")
            instruction = tokenizer(
        f"<|system|>\n {example['instruction']}.<|endoftext|>\n<|user|>\n{example['input']}<|endoftext|>\n<|assistant|>\n",
        add_special_tokens=False, )
            response = tokenizer(f"{example['output']}<|endoftext|>\n", add_special_tokens=False)
            input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
            attention_mask = (
            instruction["attention_mask"] + response["attention_mask"] + [1])
            labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
            if len(input_ids) > max_len:
                max_len = len(input_ids)
print("max_len:",max_len)

<|system|>
  I want you to act as a relation extraction robot for a sample SQL table. You only need to return the tables related to the user's input question. Below are instructions describing the relationship between tables. Please write a response that appropriately completes the request. 
##instruction:CREATE TABLE `Manufacturers` ( Code, Name, Headquarter, Founder, Revenue);
CREATE TABLE `Products` ( Code, Name, PriceManufacturer REFERENCES Manufacturers(Code));
CREATE TABLE `Student` ( StuID, LName, Fname, Age, Sex, Major, Advisor, city_code);
CREATE TABLE `Plays_Games` ( StuID REFERENCES Student(StuID), GameID REFERENCES Video_Games(GameID), Hours_Played);
CREATE TABLE `SportsInfo` ( StuID REFERENCES Student(StuID), SportName, HoursPerWeek, GamesPlayed, OnScholarship);
CREATE TABLE `actor` ( Actor_ID PRIMARY KEY, Name, Musical_ID REFERENCES actor(Actor_ID), Character, Duration, age);
CREATE TABLE `entrepreneur` ( Entrepreneur_ID PRIMARY KEY, People_ID REFERENCES people(People_ID)

index:---------16
index:---------17
index:---------18
index:---------19
index:---------20
index:---------21
index:---------22
index:---------23
index:---------24
index:---------25
index:---------26
index:---------27
index:---------28
index:---------29
index:---------30
index:---------31
index:---------32
index:---------33
index:---------34
index:---------35
index:---------36
index:---------37
index:---------38
index:---------39
index:---------40
index:---------41
index:---------42
index:---------43
index:---------44
index:---------45
index:---------46
index:---------47
index:---------48
index:---------49
index:---------50
index:---------51
index:---------52
index:---------53
index:---------54
index:---------55
index:---------56
index:---------57
index:---------58
index:---------59
index:---------60
index:---------61
index:---------62
index:---------63
index:---------64
index:---------65
index:---------66
index:---------67
index:---------68
index:---------69
index:---------70
index:----

index:---------461
index:---------462
index:---------463
index:---------464
index:---------465
index:---------466
index:---------467
index:---------468
index:---------469
index:---------470
index:---------471
index:---------472
index:---------473
index:---------474
index:---------475
index:---------476
index:---------477
index:---------478
index:---------479
index:---------480
index:---------481
index:---------482
index:---------483
index:---------484
index:---------485
index:---------486
index:---------487
index:---------488
index:---------489
index:---------490
index:---------491
index:---------492
index:---------493
index:---------494
index:---------495
index:---------496
index:---------497
index:---------498
index:---------499
index:---------500
index:---------501
index:---------502
index:---------503
index:---------504
index:---------505
index:---------506
index:---------507
index:---------508
index:---------509
index:---------510
index:---------511
index:---------512
index:------

index:---------895
index:---------896
index:---------897
index:---------898
index:---------899
index:---------900
index:---------901
index:---------902
index:---------903
index:---------904
index:---------905
index:---------906
index:---------907
index:---------908
index:---------909
index:---------910
index:---------911
index:---------912
index:---------913
index:---------914
index:---------915
index:---------916
index:---------917
index:---------918
index:---------919
index:---------920
index:---------921
index:---------922
index:---------923
index:---------924
index:---------925
index:---------926
index:---------927
index:---------928
index:---------929
index:---------930
index:---------931
index:---------932
index:---------933
index:---------934
index:---------935
index:---------936
index:---------937
index:---------938
index:---------939
index:---------940
index:---------941
index:---------942
index:---------943
index:---------944
index:---------945
index:---------946
index:------

index:---------1320
index:---------1321
index:---------1322
index:---------1323
index:---------1324
index:---------1325
index:---------1326
index:---------1327
index:---------1328
index:---------1329
index:---------1330
index:---------1331
index:---------1332
index:---------1333
index:---------1334
index:---------1335
index:---------1336
index:---------1337
index:---------1338
index:---------1339
index:---------1340
index:---------1341
index:---------1342
index:---------1343
index:---------1344
index:---------1345
index:---------1346
index:---------1347
index:---------1348
index:---------1349
index:---------1350
index:---------1351
index:---------1352
index:---------1353
index:---------1354
index:---------1355
index:---------1356
index:---------1357
index:---------1358
index:---------1359
index:---------1360
index:---------1361
index:---------1362
index:---------1363
index:---------1364
index:---------1365
index:---------1366
index:---------1367
index:---------1368
index:---------1369


In [8]:
def process_func(example):
    """
    将数据集进行预处理
    """
    MAX_LENGTH = max_token
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer(
        f"<|system|>\n {example['instruction']}.<|endoftext|>\n<|user|>\n{example['input']}<|endoftext|>\n<|assistant|>\n",
        add_special_tokens=False, )

    response = tokenizer(f"{example['output']}", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = (
            instruction["attention_mask"] + response["attention_mask"] + [1]
    )
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

In [9]:
train_df = pd.read_json(train_path, lines=True)
train_ds = Dataset.from_pandas(train_df)
train_dataset = train_ds.map(process_func, remove_columns=train_ds.column_names)

Map: 100%|███████████████████████████████████████████████████████████████████| 1736/1736 [00:24<00:00, 71.64 examples/s]


In [10]:
lora_r = 64
lora_alpha = 32
lora_dropout = 0.1
output_dir = "./GLM4_SFT_Table_cropped_testGPU"
num_train_epochs = 6
bf16 = True
overwrite_output_dir = True
per_device_train_batch_size = 1
per_device_eval_batch_size = 16
gradient_accumulation_steps = 16
gradient_checkpointing = True
evaluation_strategy = "steps"
learning_rate = 5e-5
weight_decay = 0.01
lr_scheduler_type = "cosine"
warmup_ratio = 0.01
max_grad_norm = 0.3
group_by_length = True
auto_find_batch_size = False
save_steps = 50
logging_steps = 50
load_best_model_at_end= False
packing = False
save_total_limit=4
neftune_noise_alpha=5
# report_to="wandb"
max_seq_length = max_token

In [11]:
peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    target_modules=[
       "query_key_value", "dense", "dense_h_to_4h", "activation_func", "dense_4h_to_h"
    ],
    task_type=TaskType.CAUSAL_LM,
)

swanlab_callback = SwanLabCallback(
    project="ENG_SFT_T2QSQL",
    experiment_name="GLM4-9B-2epo",
    description="使用智谱GLM4-9B-Chat模型微调MY_spider改数据集。",
    config={
        "model": "para_glm4",
        "dataset": "My_idea_T2Q_GLM4_SFT_train_table.jsonl",
    },
)

In [12]:
training_arguments = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=overwrite_output_dir,
    num_train_epochs=num_train_epochs,
    load_best_model_at_end=load_best_model_at_end,
    per_device_train_batch_size=per_device_train_batch_size,
#     evaluation_strategy=evaluation_strategy,
    max_grad_norm = max_grad_norm,
    auto_find_batch_size = auto_find_batch_size,
    save_total_limit = save_total_limit,
    gradient_accumulation_steps=gradient_accumulation_steps,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    bf16=bf16,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    report_to="none",
    neftune_noise_alpha= neftune_noise_alpha
)

In [13]:
training_arguments = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=overwrite_output_dir,
    num_train_epochs=num_train_epochs,
    load_best_model_at_end=load_best_model_at_end,
    per_device_train_batch_size=per_device_train_batch_size,
#     evaluation_strategy=evaluation_strategy,
    max_grad_norm = max_grad_norm,
    auto_find_batch_size = auto_find_batch_size,
    save_total_limit = save_total_limit,
    gradient_accumulation_steps=gradient_accumulation_steps,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    bf16=bf16,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    report_to="none",
    neftune_noise_alpha= neftune_noise_alpha
)

In [14]:
model = get_peft_model(model, peft_config)
trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=train_dataset,
#     eval_dataset= val_dataset,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
    callbacks=[swanlab_callback],
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [15]:
trainer.train()

[1m[34mswanlab[0m[0m: Tracking run with swanlab version 0.3.13                                  
[1m[34mswanlab[0m[0m: Run data will be saved locally in [35m[1m/home/code/chat_SQL/main_verify/generate_cropped/swanlog/run-20240827_130927-a3b1799d[0m[0m
[1m[34mswanlab[0m[0m: 👋 Hi [1m[39mwinhong[0m[0m, welcome to swanlab!
[1m[34mswanlab[0m[0m: Syncing run [33mGLM4-9B-2epo_Aug27_13-09-27[0m to the cloud
[1m[34mswanlab[0m[0m: 🌟 Run `[1mswanlab watch -l /home/code/chat_SQL/main_verify/generate_cropped/swanlog[0m` to view SwanLab Experiment Dashboard locally
[1m[34mswanlab[0m[0m: 🏠 View project at [34m[4mhttps://swanlab.cn/@winhong/ENG_SFT_T2QSQL[0m[0m
[1m[34mswanlab[0m[0m: 🚀 View run at [34m[4mhttps://swanlab.cn/@winhong/ENG_SFT_T2QSQL/runs/l4i56m5m1g19qo17t49q5[0m[0m


Step,Training Loss


KeyboardInterrupt: 

In [None]:
output_dir = os.path.join("./", "final_checkpoint_GLM4_Table_SFT__testGPU")
trainer.model.save_pretrained(output_dir)