Skip to content

homzer/MCC-KD

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MCC-KD: Multi-CoT Consistent Knowledge Distillation

Hongzhan Chen1, Siyue Wu1, Xiaojun Quan1*, Rui Wang, Ming Yan2, Ji Zhang2
chenhzh59@mail2.sysu.edu.cn, wusy39@mail2.sysu.edu.cn, quanxj3@mail.sysu.edu.cn
1Sun Yat-sen University 2Alibaba Group
*Corresponding authors

Framework Overview

The framework applied by MCC-KD is an efficient and easy-to-develop LLM training + inference framework. This project is developed based on PyTorch and FairScale, employing tensor (model) parallelism strategy.

  • Efficient Training
  • Efficient Inference

The maximum supported inference batch size is 384, and the maximum supported training batch size is 4. This is based on using 8xV100 32GB GPUs, a model with 7 billion parameters, and a maximum sequence length of 1024.

Compared to the HuggingFace framework, LLaMA-RLHF achieves an increase in both inference and training speed of over 2 times.

Requirement

Library Recommend
python 3.8
torch 2.0.1
transformers 4.37.2
fire 0.5.0
fairscale 0.4.13
sentencepiece 0.1.97
safetensors 0.4.1

Current Supported Models

Supported Models
llama-1-7b
llama-1-13b
llama-1-33b
llama-2-7b
llama-2-13b
llama-2-70b
mistral-7b-instruct-v0.2
mixtral-8x7b-instruct-v0.1
qwen-7b
qwen-14b
qwen-72b

Teacher Rationales

We provide training sets, validation sets, test sets, and extracted raw teacher rationales for datasets including GSM8K, CSQA, SVAMP, and ASDiv in the data directory.

Checkpoint Downloading

The original llama version can be downloaded from https://github.com/facebookresearch/llama, which can be perfectly loaded into our framework.

Theoretically, the current model architecture can also be compatible with the model weight parameters available on Hugging Face, but further renaming of the module names is required to be able to load them. We have provided the relevant renaming functions in the src/checkpoint.py file. This will take a little bit of your time to make the modifications.

Getting Started

1. Checkpoint Splitting

To conduct model parallel training and inference, we need to split the model checkpoint file into several parts. For example, for world_size=8, which means we need to split the checkpoint into 8 parts. Considering a model parameter file /path/to/your/checkpoint.bin (suffixes such as .pth, .safetensors are supported, in fact, as long as the file is stored in the form of a dictionary), run:

torchrun checkpoint_split.py \
--ckpt_file /path/to/your/checkpoint.bin \
--save_path /path/to/save/ \
--n 8

You are expected to get following checkpoint files:

/path/to/save/consolidated.00.pth
/path/to/save/consolidated.01.pth
/path/to/save/consolidated.02.pth
/path/to/save/consolidated.03.pth
/path/to/save/consolidated.04.pth
/path/to/save/consolidated.05.pth
/path/to/save/consolidated.06.pth
/path/to/save/consolidated.07.pth

2. Model Training

Take Llama-1-7b as an example, with lora_rank=128, run the following script to train the model on 8 GPUs (The current settings are compatible with 8xV100 32GB.):

torchrun --nproc_per_node 8 train.py \
--task GSM8K \
--ckpt_dir /path/to/your/ckpt/ \
--save_dir /path/to/save/ \
--train_file data/GSM8K/train.json \
--label_file data/GSM8K/test.json \
--model_type lora-llama-1-7b \
--max_batch_size 6 \
--lora_rank 128 \
--eval_batch_size 384 \
--epochs 24 \
--use_float16 True

If you don't want to use LoRA, change model_type to llama-1-7b and set lora_rank=-1. If you want to use bfloat16 instead, replace --use_float16=True with --use_bfloat16=True. It is default to use float32, when --use_float16=False and --use_bfloat16=False.

3. MCC-KD Training

torchrun --nproc_per_node 8 train_mcc.py \
--task GSM8K \
--ckpt_dir /path/to/your/ckpt/ \
--save_dir /path/to/save/ \
--train_file data/GSM8K/train-multi-cots-preview.json \
--label_file data/GSM8K/test.json \
--model_type lora-llama-1-7b \
--max_batch_size 6 \
--lora_rank 128 \
--eval_batch_size 384 \
--epochs 24 \
--use_float16 True

MCC-KD requires ensuring the diversity of rationales and finding a common answer span. Make sure to include "indices" to record the starting and ending indices of the common answer span (after tokenized). It should look something like the following:

[
  {
    "instruction": "...",
    "output": [
        "rationale1",
        "rationale2"
    ],
    "label": "...",
    "indices": [
        [
            23,
            28
        ],
        [
            42,
            47
        ]
    ]
  }  
]

We provide a preview version JSON file at data/GSM8K/train-multi-cots-preview.json, which typically contains fewer samples than data/GSM8K/train.json due to the correctness filtering.

Citation

@misc{chen2023mcckd,
      title={MCC-KD: Multi-CoT Consistent Knowledge Distillation}, 
      author={Hongzhan Chen and Siyue Wu and Xiaojun Quan and Rui Wang and Ming Yan and Ji Zhang},
      year={2023},
      eprint={2310.14747},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages