From e65d2b217c1525a6e1a46ed10d5032862b643032 Mon Sep 17 00:00:00 2001 From: "zheng_zian@u.nus.edu" Date: Sun, 14 Jan 2024 03:25:41 +0800 Subject: [PATCH 1/8] only import LoadBalancer when needed to avoid errors --- colossalai/moe/layers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/colossalai/moe/layers.py b/colossalai/moe/layers.py index b768fb94a585..0ab556885d62 100644 --- a/colossalai/moe/layers.py +++ b/colossalai/moe/layers.py @@ -6,10 +6,8 @@ import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F - from colossalai.moe._operation import AllGather, AllToAll, HierarchicalAllToAll, MoeCombine, MoeDispatch, ReduceScatter from colossalai.moe.experts import MLPExperts -from colossalai.moe.load_balance import LoadBalancer from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.routers import MoeRouter, get_router_cls from colossalai.moe.utils import create_ep_hierarchical_group, get_noise_generator @@ -119,6 +117,7 @@ def __init__( # load balance self.enable_load_balance = enable_load_balance if self.enable_load_balance == True: + from colossalai.moe.load_balance import LoadBalancer self.load_balancer = LoadBalancer( experts=self.experts, gate=self.gate_weight, From a44a83097966a9b0633bfa5120cd1e501f185afe Mon Sep 17 00:00:00 2001 From: "zheng_zian@u.nus.edu" Date: Sun, 14 Jan 2024 03:27:13 +0800 Subject: [PATCH 2/8] Update README with inference demo and update requirements --- examples/language/openmoe/README.md | 31 +++++++++++++++++++++- examples/language/openmoe/requirements.txt | 9 ++++--- 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/examples/language/openmoe/README.md b/examples/language/openmoe/README.md index 45657f192024..46735c236a1a 100644 --- a/examples/language/openmoe/README.md +++ b/examples/language/openmoe/README.md @@ -45,8 +45,37 @@ cd apex git checkout 741bdf50825a97664db08574981962d66436d16a pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ --global-option="--cuda_ext" ``` +### 3. Inference +You can inference by the following code to try OpenMoE-8B-Chat model: +``` +from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM + +model_path = "OrionZheng/openmoe-8b-chat" +config = AutoConfig.from_pretrained(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + device_map='auto' + ) +query = 'Question: How do I kill a process? Answer:' +prompt = f'''<> +You are a helpful, respectful and honest assistant. +<> + +[INST] {query} [/INST]''' + +inputs = tokenizer(prompt, return_tensors="pt").to('cuda') +sample = model.generate(**inputs, max_new_tokens=32) +print(tokenizer.decode(sample[0])) +``` + +We also provide a Colab [tutorial](https://colab.research.google.com/drive/1eIT1rtG7pORRQAYtQoMOAekUg7aZLDdn) demonstrating the jax checkpoint conversion and execution of PyTorch model inference. You can experiment with OpenMoE-8B-Chat on Colab directly by [this](https://colab.research.google.com/drive/1xIfIVafnlCP2XVICmRwkUFK3cwTJYjCY)(Note: both require Colab Pro). +- Running OpenMoE-8B requires ~49GB of memory in float32 or ~23GB in bfloat16. It can be executed on a Colab `CPU High-RAM` runtime or an `A100-40GB` runtime, both of which require Colab Pro.The float16 precision is not recommended because sometimes it will lead to performance degradation. +- Runing the OpenMoE-34B requries ~89GB of memory in bfloat16 or ~180GB in float32. To perform inference on multiple devices/offloading model weights to RAM, please refer to the script [here](inference_on_multi_devices.py). -### 3. Train +### 4. Train Yon can use colossalai run to launch single-node training: ```bash colossalai run --standalone --nproc_per_node YOUR_GPU_PER_NODE train.py --OTHER_CONFIGURATIONS diff --git a/examples/language/openmoe/requirements.txt b/examples/language/openmoe/requirements.txt index 6b9f807116df..a0f5c664105f 100644 --- a/examples/language/openmoe/requirements.txt +++ b/examples/language/openmoe/requirements.txt @@ -1,5 +1,8 @@ +accelerate==0.25.0 colossalai >= 0.3.3 torch >= 1.8.1 -transformers >= 4.20.0, <= 4.34.0 -sentencepiece -datasets +transformers==4.34.0 +sentencepiece==0.1.99 +datasets==2.14.7 +numpy==1.23.5 +flash_attn From fb8c5ec40a7a116b939e9f1844336f519f960f63 Mon Sep 17 00:00:00 2001 From: "zheng_zian@u.nus.edu" Date: Sun, 14 Jan 2024 03:29:00 +0800 Subject: [PATCH 3/8] update config files for different models --- .../openmoe/model/openmoe_34b_config.json | 54 ++++++++++++++++ .../openmoe/model/openmoe_8b_config.json | 62 ++++++++++++++----- .../openmoe/model/openmoe_base_config.json | 62 ++++++++++++++----- 3 files changed, 146 insertions(+), 32 deletions(-) create mode 100644 examples/language/openmoe/model/openmoe_34b_config.json diff --git a/examples/language/openmoe/model/openmoe_34b_config.json b/examples/language/openmoe/model/openmoe_34b_config.json new file mode 100644 index 000000000000..0a3822c685dd --- /dev/null +++ b/examples/language/openmoe/model/openmoe_34b_config.json @@ -0,0 +1,54 @@ +{ + "architectures": [ + "OpenMoeForCausalLM" + ], + "auto_map": { + "AutoModelForCausalLM": "modeling_openmoe.OpenMoeForCausalLM" + }, + "attention_bias": false, + "bos_token_id": 0, + "dropout_rate": 0.0, + "enable_comm_overlap": false, + "enable_hierarchical_alltoall": false, + "enable_kernel": false, + "enable_load_balance": false, + "eos_token_id": 1, + "expert_parallel": null, + "head_dim": 128, + "hidden_act": "swiglu", + "hidden_size": 3072, + "initializer_range": 0.02, + "intermediate_size": 12288, + "label_smoothing": 0.001, + "layer_norm_epsilon": 1e-06, + "load_balance_beam_width": 8, + "load_balance_group_swap_factor": 0.4, + "load_balance_tolerance": 0.1, + "max_position_embeddings": 2048, + "mlp_gated": true, + "model_type": "llama", + "moe_layer_interval": 4, + "num_attention_heads": 24, + "num_experts": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 24, + "pad_token_id": 0, + "pretraining_tp": 1, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 10000.0, + "router_aux_loss_factor": 0.01, + "router_capacity_factor_eval": 2.0, + "router_capacity_factor_train": 1.25, + "router_drop_tks": true, + "router_min_capacity": 4, + "router_noisy_policy": null, + "router_topk": 2, + "router_z_loss_factor": 0.0001, + "tie_word_embeddings": false, + "torch_dtype": "float32", + "transformers_version": "4.34.0", + "use_cache": true, + "vocab_size": 256384, + "z_loss_factor": 0.01 +} diff --git a/examples/language/openmoe/model/openmoe_8b_config.json b/examples/language/openmoe/model/openmoe_8b_config.json index 248697c37d3c..15f05cc49080 100644 --- a/examples/language/openmoe/model/openmoe_8b_config.json +++ b/examples/language/openmoe/model/openmoe_8b_config.json @@ -2,23 +2,53 @@ "architectures": [ "OpenMoeForCausalLM" ], - "intermediate_size": 8192, - "hidden_size": 2048, - "num_hidden_layers": 24, - "head_dim": 128, - "num_attention_heads": 24, + "auto_map": { + "AutoModelForCausalLM": "modeling_openmoe.OpenMoeForCausalLM" + }, + "attention_bias": false, + "bos_token_id": 0, "dropout_rate": 0.0, - "layer_norm_epsilon": 1e-06, - "vocab_size": 256384, + "enable_comm_overlap": false, + "enable_hierarchical_alltoall": false, + "enable_kernel": false, + "enable_load_balance": false, + "eos_token_id": 1, + "expert_parallel": null, + "head_dim": 128, "hidden_act": "swiglu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 8192, + "label_smoothing": 0.001, + "layer_norm_epsilon": 1e-06, + "load_balance_beam_width": 8, + "load_balance_group_swap_factor": 0.4, + "load_balance_tolerance": 0.1, + "max_position_embeddings": 2048, + "mlp_gated": true, + "model_type": "llama", + "moe_layer_interval": 6, + "num_attention_heads": 24, "num_experts": 32, - "topk": 2, - "capacity_factor_train": 1.25, - "capacity_factor_eval": 2.0, - "min_capacity": 4, - "noisy_policy": null, - "drop_tks": true, - "expert_parallel": null, - "gated": true, - "moe_layer_interval": 6 + "num_hidden_layers": 24, + "num_key_value_heads": 24, + "pad_token_id": 0, + "pretraining_tp": 1, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 10000.0, + "router_aux_loss_factor": 0.01, + "router_capacity_factor_eval": 2.0, + "router_capacity_factor_train": 1.25, + "router_drop_tks": true, + "router_min_capacity": 4, + "router_noisy_policy": null, + "router_topk": 2, + "router_z_loss_factor": 0.0001, + "tie_word_embeddings": false, + "torch_dtype": "float32", + "transformers_version": "4.34.0", + "use_cache": true, + "vocab_size": 256384, + "z_loss_factor": 0.01 } diff --git a/examples/language/openmoe/model/openmoe_base_config.json b/examples/language/openmoe/model/openmoe_base_config.json index 5a7c97bd1916..d33a76eb75d4 100644 --- a/examples/language/openmoe/model/openmoe_base_config.json +++ b/examples/language/openmoe/model/openmoe_base_config.json @@ -2,23 +2,53 @@ "architectures": [ "OpenMoeForCausalLM" ], - "intermediate_size": 2048, - "hidden_size": 768, - "num_hidden_layers": 12, - "head_dim": 64, - "num_attention_heads": 12, + "auto_map": { + "AutoModelForCausalLM": "modeling_openmoe.OpenMoeForCausalLM" + }, + "attention_bias": false, + "bos_token_id": 0, "dropout_rate": 0.0, - "layer_norm_epsilon": 1e-06, - "vocab_size": 256384, + "enable_comm_overlap": false, + "enable_hierarchical_alltoall": false, + "enable_kernel": false, + "enable_load_balance": false, + "eos_token_id": 1, + "expert_parallel": null, + "head_dim": 64, "hidden_act": "swiglu", + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 2048, + "label_smoothing": 0.001, + "layer_norm_epsilon": 1e-06, + "load_balance_beam_width": 8, + "load_balance_group_swap_factor": 0.4, + "load_balance_tolerance": 0.1, + "max_position_embeddings": 2048, + "mlp_gated": true, + "model_type": "llama", + "moe_layer_interval": 4, + "num_attention_heads": 12, "num_experts": 16, - "topk": 2, - "capacity_factor_train": 1.25, - "capacity_factor_eval": 2.0, - "min_capacity": 4, - "noisy_policy": null, - "drop_tks": true, - "expert_parallel": null, - "gated": true, - "moe_layer_interval": 4 + "num_hidden_layers": 12, + "num_key_value_heads": 12, + "pad_token_id": 0, + "pretraining_tp": 1, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 10000.0, + "router_aux_loss_factor": 0.01, + "router_capacity_factor_eval": 2.0, + "router_capacity_factor_train": 1.25, + "router_drop_tks": true, + "router_min_capacity": 4, + "router_noisy_policy": null, + "router_topk": 2, + "router_z_loss_factor": 0.0001, + "tie_word_embeddings": false, + "torch_dtype": "float32", + "transformers_version": "4.34.0", + "use_cache": true, + "vocab_size": 256384, + "z_loss_factor": 0.01 } From 3d000cc062b9b10c483a58b296d7ca39360c85b0 Mon Sep 17 00:00:00 2001 From: "zheng_zian@u.nus.edu" Date: Sun, 14 Jan 2024 03:31:17 +0800 Subject: [PATCH 4/8] update inference example --- examples/language/openmoe/infer.py | 57 ------ examples/language/openmoe/infer.sh | 1 - .../openmoe/inference_on_multi_devices.py | 186 ++++++++++++++++++ 3 files changed, 186 insertions(+), 58 deletions(-) delete mode 100644 examples/language/openmoe/infer.py delete mode 100644 examples/language/openmoe/infer.sh create mode 100644 examples/language/openmoe/inference_on_multi_devices.py diff --git a/examples/language/openmoe/infer.py b/examples/language/openmoe/infer.py deleted file mode 100644 index db90c6e34507..000000000000 --- a/examples/language/openmoe/infer.py +++ /dev/null @@ -1,57 +0,0 @@ -from argparse import ArgumentParser - -import torch -from model.modeling_openmoe import OpenMoeForCausalLM, set_openmoe_args -from transformers import T5Tokenizer -from transformers.models.llama import LlamaConfig - - -def parse_args(): - parser = ArgumentParser() - parser.add_argument("--model", default="base", type=str, help="model path", choices=["base", "8b", "test"]) - return parser.parse_args() - - -def inference(args): - tokenizer = T5Tokenizer.from_pretrained("google/umt5-small") - if args.model == "test": - config = LlamaConfig.from_pretrained("hpcaitech/openmoe-base") - set_openmoe_args(config, - num_experts=config.num_experts, - moe_layer_interval=config.moe_layer_interval, - enable_kernel=True) - model = OpenMoeForCausalLM(config) - else: - config = LlamaConfig.from_pretrained(f"hpcaitech/openmoe-{args.model}") - set_openmoe_args(config, - num_experts=config.num_experts, - moe_layer_interval=config.moe_layer_interval, - enable_kernel=False) - model = OpenMoeForCausalLM.from_pretrained(f"hpcaitech/openmoe-{args.model}", config=config) - model = model.eval().bfloat16() - model = model.to(torch.cuda.current_device()) - - input_str = """``` -y = list(map(int, ['1', 'hello', '2'])) -``` -What error does this program produce? -ValueError: invalid literal for int() with base 10: 'hello' - -``` -sum = 0 -for i in range(100): - sum += i -``` -What is the value of sum immediately after the 10th time line 3 is executed?""" - - # print("model config: ", model.config) - input_ids = tokenizer("" + input_str, return_tensors="pt", add_special_tokens=False) - input_ids = input_ids.input_ids.to(torch.cuda.current_device()) - generation_output = model.generate(input_ids, use_cache=True, do_sample=True, max_new_tokens=64) - out = tokenizer.decode(generation_output[0], skip_special_tokens=False) - print(f"output: \n{out}\n") - - -if __name__ == "__main__": - args = parse_args() - inference(args) diff --git a/examples/language/openmoe/infer.sh b/examples/language/openmoe/infer.sh deleted file mode 100644 index a578203eba84..000000000000 --- a/examples/language/openmoe/infer.sh +++ /dev/null @@ -1 +0,0 @@ -python infer.py --model "base" diff --git a/examples/language/openmoe/inference_on_multi_devices.py b/examples/language/openmoe/inference_on_multi_devices.py new file mode 100644 index 000000000000..6bf7847e60d9 --- /dev/null +++ b/examples/language/openmoe/inference_on_multi_devices.py @@ -0,0 +1,186 @@ +import torch +from accelerate import infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch +from transformers import AutoTokenizer, T5Tokenizer, AutoConfig, AutoModelForCausalLM, LogitsProcessorList, LogitsProcessor +from typing import List, Optional +from huggingface_hub import snapshot_download + + +class StopAfterEosTextGenerated(LogitsProcessor): + """Logits processor (to use with HuggingFace `generate()` method : + https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/ + text_generation#transformers.generation_utils.GenerationMixin). + + Sometimes our model output '▁' seperately as stopping signal(not '▁' as a whole), + which is unable to be captured by a single eos token and can cause a very long generation. + This logitsprocessor will force generation stop after ' '. + + Args: + base_len (int): Size of the given context. Used to know if this is + the first character to generate. + eos_token_id (int): ID of the EOS token. + """ + def __init__(self, base_len: int, eos_token_id: int): + super().__init__() + self.base_len = base_len + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if input_ids.size(1) > self.base_len: + forced_eos = torch.full((scores.size(1),), -float("inf")).to(scores.device) + forced_eos[self.eos_token_id] = 0 + # If the last three tokens of input_ids are the stop_token_ids, a eos will be forced to generate afterwards + stop_token_ids = torch.Tensor([15501, 281, 926]).to(scores.device) # ids for tokens '▁' + stop_sample_ids = torch.eq(input_ids[:, -len(stop_token_ids): ], stop_token_ids).all(dim=1) + scores[stop_sample_ids] = forced_eos + return scores + +def inference(model, tokenizer, input_strs, gen_kwargs, + add_special_tokens=True, split_special_tokens=False, output_only=True, verbose=False): + + model = model.eval() + + # Tokenization + inputs = tokenizer.batch_encode_plus(input_strs, + padding='longest', + add_special_tokens=add_special_tokens, + split_special_tokens=split_special_tokens, + return_tensors="pt") + input_ids = inputs.input_ids.to(model.device) + attention_mask = inputs.attention_mask.to(model.device) + base_len = inputs.input_ids.size(-1) + if verbose: + print("Input Tokens:\n", input_ids) + print("Num of Input Tokens: ", base_len) + print("Attention Mask:\n", attention_mask) + logits_processor = LogitsProcessorList([StopAfterEosTextGenerated(base_len, tokenizer.eos_token_id)]) + + output_ids = model.generate(input_ids=input_ids, + attention_mask=attention_mask, + bos_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + logits_processor=logits_processor, + **gen_kwargs) + if output_only: # Only preserve output tokens + output_ids = output_ids[:, input_ids.size(1):] + if verbose: + print("Generated Tokens:\n", output_ids) + output_txts = tokenizer.batch_decode(output_ids, + clean_up_tokenization_spaces=True, + skip_special_tokens=False) + return output_ids, output_txts + +def apply_llama_chat_template(tokenizer, input_strs, sys_prompt): + # Use LLaMA's Chat Template(A bit diffrent from original one at the beginning part, we may correct it to the standard llama prompt template later) + # input_strs = [('user_input', 'user'), ('AI_response', 'assistant'), ...] + tokenizer.chat_template = "{% for message in messages %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<>\\n' + message['content'] + '\\n<>\\n\\n' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + ' ' + eos_token }}{% endif %}{% endfor %}" + system_prompt = {'content': sys_prompt, 'role': 'system'} + chat = [system_prompt] + [{'content': input_str, 'role': role} for input_str, role in input_strs] + input_str = tokenizer.apply_chat_template(chat, + tokenize=False, + add_generation_prompt=True) + return input_str + +if __name__ == "__main__": + # @markdown 1. Path to the checkpoint repo + pytorch_checkpoint_path = "OrionZheng/openmoe-8b-chat"#@param {type:"string"} + #@markdown 2. (If any)Specify GPUs you want to use. + #@markdown + #@markdown - If single GPU memory is not enough, you can enter ids of multiple GPUs(seperated by comma). During inference, GPUs will be filed up sequentially. + available_gpu_ids_str = "0" # @param ["", "0", "0,1", "0,1,2"] {allow-input: true} + #@markdown - Specify available memory of each GPU + #@markdown - Leave some margin for data and activation. + #@markdown For example, we used 38GB GPU memory for an A100(40GB) + memory_per_gpu = "38GiB" # @param ["", "38GiB"] {allow-input: true} + #@markdown 3. Specify available CPU RAM + #@markdown + #@markdown - The Colab CPU High-RAM Runtime has 51GiB RAM + cpu_memory = '50GiB' #@param ["50GiB"] {allow-input: true} + # @markdown 3. Specify the model parameter's precision + + # @markdown - The CPU runtime only supports inference in float32 precision + + # @markdown - The `bfloat16` is only available on A100 Colab runtime + + # @markdown - Please use float32/bfloat16 for inference. We observed issues with the model output when running in float16, which may be due to underflow caused by our large vocabulary size. + model_dtype = 'bfloat16' #@param ["float32", "bfloat16"] + #@markdown (Not recommended, very slow) Offload model weights to CPU memory if GPU's is insufficient, then offload to disk if CPU memory is insufficient. + offload = False #@param {type:"boolean"} + + input_str = "What is the title of the last Harry Potter novel, published in 2007?" # @param [] {allow-input: true} + input_strs = [input_str] + gen_strategy = "greedy" #@param ["greedy", "top_p"] + #@markdown Please select the prompt template if chat model is being used. For raw language model, please leave this field blank. + prompt_template = "openmoe" #@param ["openmoe", ""] + max_new_tokens = 32 #@param {type:"slider", min:1, max:512, step:1} + debug_verbose = True #@param {type:"boolean"} + cache_dir = "./" + gen_kwargs = { + "greedy": {"do_sample": False, "num_beams": 1, "max_new_tokens": max_new_tokens}, # Greedy Search + "top_p": {"do_sample": True, "temperature": 0.5, "top_p": 0.8, "max_new_tokens": max_new_tokens}, # Top-p Sampling + } + + if torch.cuda.is_available(): + cuda_list = available_gpu_ids_str.split(',') + else: + available_gpu_ids_str, memory_per_gpu = "", "" + model_dtype = "float32" + cuda_list = [] + + no_split_module_classes = "OpenMoeDecoderLayer" + + # 1. Allocate Devices for Inference + available_memory = {int(cuda): memory_per_gpu for cuda in cuda_list} + available_memory['cpu'] = cpu_memory + print('Available Devices and Memory: ', available_memory) + + # 2. Load the Model (init with empty weight to save memory) + config = AutoConfig.from_pretrained(pytorch_checkpoint_path) + weights_location = snapshot_download(repo_id=pytorch_checkpoint_path, + cache_dir=cache_dir) + with init_empty_weights(): + model = AutoModelForCausalLM.from_config(config, + torch_dtype=eval(f'torch.{model_dtype}'), + trust_remote_code=True) + print('Model dtype: ', model.dtype) + device_map = infer_auto_device_map(model, + max_memory=available_memory, + no_split_module_classes=no_split_module_classes) + print('Inferred Device Map: \n', device_map) + if offload: + model = load_checkpoint_and_dispatch(model, weights_location, + device_map=device_map, + offload_folder="offload", + offload_state_dict=True, + dtype=eval(f'torch.{model_dtype}'), + no_split_module_classes=[no_split_module_classes]) + else: + model = load_checkpoint_and_dispatch(model, weights_location, + device_map=device_map, + dtype=eval(f'torch.{model_dtype}'), + no_split_module_classes=[no_split_module_classes]) + print('Fine-grained Device Map: \n', model.hf_device_map) + + + + # 3. Load the Tokenizer + tokenizer = AutoTokenizer.from_pretrained(pytorch_checkpoint_path, trust_remote_code=True) + + # 4. Inference + final_input_strs = [] + for input_str in input_strs: + if prompt_template == "openmoe": + SYS_LLAMA = "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature." + input_str = apply_llama_chat_template(tokenizer, + [(input_str, 'user')], + sys_prompt=SYS_LLAMA) + final_input_strs.append(input_str) + print("=========== The Actual Input =============") + [print(i) for i in final_input_strs] + + output_ids, output_txts = inference(model, tokenizer, final_input_strs, gen_kwargs[gen_strategy], + verbose=debug_verbose) + + print("============== Output Text ===============") + for output_txt in output_txts: + print(output_txt.split('')[0]) \ No newline at end of file From 64bc50b4b51647716c42850f482cc3b4bf5806ea Mon Sep 17 00:00:00 2001 From: "zheng_zian@u.nus.edu" Date: Sun, 14 Jan 2024 03:33:19 +0800 Subject: [PATCH 5/8] correct some bugs in modeling_openmoe.py --- .../openmoe/model/modeling_openmoe.py | 117 +++++++++++------- 1 file changed, 72 insertions(+), 45 deletions(-) diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index ec7644317903..a52198f59815 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -27,7 +27,8 @@ from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel -from transformers.models.llama.modeling_llama import LlamaConfig, LlamaRMSNorm +from transformers.models.llama.configuration_llama import LlamaConfig + from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, @@ -41,6 +42,8 @@ from colossalai.moe.manager import MOE_MANAGER from colossalai.moe.utils import get_activation, set_moe_args + + if HAS_TRITON: from colossalai.kernel.triton.llama_act_combine_kernel import LlamaActCombine @@ -157,37 +160,17 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) -def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timescale=10000.0): - """Generate Sin/Cos for Rotary Embeddings. - - Args: - features: an integer - length: an integer - min_timescale: an optional float - max_timescale: an optional float - - Returns: - output_sin: a float32 Tensor with shape [length, features] - output_cos: a float32 Tensor with shape [length, features] - """ - fraction = torch.arange(0, features, 2, dtype=torch.float32).cuda() / features - timescale = min_timescale * (max_timescale / min_timescale) ** fraction - rotational_frequency = 1.0 / timescale - - sinusoid_inp = torch.einsum("i,j->ij", torch.arange(length, dtype=torch.float32).cuda(), rotational_frequency) - - sinusoid_inp = torch.cat([sinusoid_inp, sinusoid_inp], dim=-1) - - return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) - - -def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None): +def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None): + # q: (bs, q_len, num_heads, head_dim) + # k: (bs, q_len [+past_kv_len], num_heads, head_dim) + # cos: (max_seq_len, head_dim) + # sin: (max_seq_len, head_dim) + # rotary_index: (bs, 1) # only used during decoding, when one query token is input at a time """Helper function to apply Rotary Embeddings.""" cos = cos.to(q.dtype) sin = sin.to(q.dtype) - if len(k.shape) == 3: - # for multi query attention + if len(k.shape) == 3: # for multi query attention k = k.unsqueeze(2) multiquery = True else: @@ -198,19 +181,18 @@ def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None): assert batch == kbatch, f"{batch} != {kbatch}" assert d == kd, f"{d} != {kd}" if decode and qlen == 1 and rotary_index is not None: - qcos = cos[rotary_index + 1, :] - qsin = sin[rotary_index + 1, :] - qcos = qcos.unsqueeze(2) - qsin = qsin.unsqueeze(2) - kcos, ksin = cos[:klen, :], sin[:klen, :] - kcos = kcos.unsqueeze(0).unsqueeze(2) - ksin = ksin.unsqueeze(0).unsqueeze(2) + qcos = cos[rotary_index, :] # (bs, 1, head_dim) + qsin = sin[rotary_index, :] # (bs, 1, head_dim) + qcos = qcos.unsqueeze(2) # (bs, q_len=1, 1, head_dim) # broadcast to all heads + qsin = qsin.unsqueeze(2) # (bs, q_len=1, 1, head_dim) else: - qcos, qsin = cos[:qlen, :], sin[:qlen, :] - qcos = qcos.unsqueeze(0).unsqueeze(2) + qcos, qsin = cos[:qlen, :], sin[:qlen, :] # (q_len, head_dim) + qcos = qcos.unsqueeze(0).unsqueeze(2) # (1, q_len, 1, head_dim) qsin = qsin.unsqueeze(0).unsqueeze(2) - kcos, ksin = qcos, qsin - + + kcos, ksin = cos[:klen, :], sin[:klen, :] # (k_len, head_dim) + kcos = kcos.unsqueeze(0).unsqueeze(2) # (1, k_len, 1, head_dim) # broadcast to the whole batch, broadcast to all heads + ksin = ksin.unsqueeze(0).unsqueeze(2) # (1, k_len, 1, head_dim) out_q = (q * qcos) + (rotate_half(q) * qsin) out_k = (k * kcos) + (rotate_half(k) * ksin) @@ -226,6 +208,21 @@ def rotate_half(x): x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) def SwiGLU(x): """Gated linear unit activation function. @@ -304,11 +301,37 @@ def __init__(self, config: LlamaConfig): self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.sin, self.cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1.0, 1e4) + self.generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1.0, 1e4) + self.use_kernel = config.enable_kernel + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + def generate_fixed_pos_embedding(self, features, length, min_timescale=1.0, max_timescale=10000.0): + """Generate Sin/Cos for Rotary Embeddings. + + Args: + features: an integer + length: an integer + min_timescale: an optional float + max_timescale: an optional float + + Returns: + output_sin: a float32 Tensor with shape [length, features] + output_cos: a float32 Tensor with shape [length, features] + """ + fraction = torch.arange(0, features, 2, dtype=torch.float32) / features + timescale = min_timescale * (max_timescale / min_timescale) ** fraction + rotational_frequency = 1.0 / timescale + + sinusoid_inp = torch.einsum("i,j->ij", torch.arange(length, dtype=torch.float32), rotational_frequency) + + sinusoid_inp = torch.cat([sinusoid_inp, sinusoid_inp], dim=-1) + + self.register_buffer('sin', torch.sin(sinusoid_inp), persistent=False) # persistent=False --> buffer won't appear in the state_dict + self.register_buffer('cos', torch.cos(sinusoid_inp), persistent=False) + def forward( self, hidden_states: torch.Tensor, @@ -317,7 +340,6 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - use_kernel: bool = True, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -373,8 +395,12 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - if HAS_FLASH_ATTN and use_kernel: - from flash_attn import flash_attn_func + if HAS_FLASH_ATTN and self.use_kernel: + # from flash_attn import flash_attn_func + # If we use `from flash_attn import flash_attn_func` directly, + # AutoModelForCausalLM.from_pretrained will treat flash_attn as a compulsory dependency and raise error if it cannot be found. + # Here is a workaround to avoid the error. + exec("from flash_attn import flash_attn_func") query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -431,7 +457,8 @@ def __init__(self, config: LlamaConfig, moe: bool): super().__init__() self.hidden_size = config.hidden_size self.moe = moe - self.self_attn = OpenMoeAttention(config=config) + self.self_attn = OpenMoeAttention(config=config) +# self.self_attn = LlamaAttention(config=config) # TODO: introduce LLaMA Positional Encoding self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) if self.moe: @@ -545,7 +572,7 @@ class OpenMoePreTrainedModel(PreTrainedModel): config_class = LlamaConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] + _no_split_modules = ["OpenMoeDecoderLayer"] _skip_keys_device_placement = "past_key_values" def _init_weights(self, module): From 2c416de959e097ab3a3c88ce462edca38c074a36 Mon Sep 17 00:00:00 2001 From: "zheng_zian@u.nus.edu" Date: Sun, 14 Jan 2024 03:36:09 +0800 Subject: [PATCH 6/8] Save half of the memory used for converting checkpoints by using init_empty_weights and lazy_parameters --- .../openmoe/model/convert_openmoe_ckpt.py | 123 +++++++++--------- 1 file changed, 63 insertions(+), 60 deletions(-) diff --git a/examples/language/openmoe/model/convert_openmoe_ckpt.py b/examples/language/openmoe/model/convert_openmoe_ckpt.py index 20b1e780d8b3..f050d3db04eb 100644 --- a/examples/language/openmoe/model/convert_openmoe_ckpt.py +++ b/examples/language/openmoe/model/convert_openmoe_ckpt.py @@ -27,17 +27,20 @@ --pytorch_dump_path=$HOME/t5_1_1_small_pt ``` """ - import argparse +from accelerate import init_empty_weights import collections - +from flax.core import lazy_init +import numpy as np import torch from flax import traverse_util from modeling_openmoe import OpenMoeForCausalLM from t5x import checkpoints +from t5x.checkpoint_importer import LazyAwaitableArray from transformers import LlamaConfig from transformers.utils import logging + logging.set_verbosity_info() @@ -122,10 +125,10 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, moe_interval: in layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm") k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention") new[f"model.layers.{i}.input_layernorm.weight"] = layer_norm - new[f"model.layers.{i}.self_attn.k_proj.weight"] = k.T - new[f"model.layers.{i}.self_attn.o_proj.weight"] = o.T - new[f"model.layers.{i}.self_attn.q_proj.weight"] = q.T - new[f"model.layers.{i}.self_attn.v_proj.weight"] = v.T + new[f"model.layers.{i}.self_attn.k_proj.weight"] = np.transpose(k) + new[f"model.layers.{i}.self_attn.o_proj.weight"] = np.transpose(o) + new[f"model.layers.{i}.self_attn.q_proj.weight"] = np.transpose(q) + new[f"model.layers.{i}.self_attn.v_proj.weight"] = np.transpose(v) # Block i, layer 2 (MLP). layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm") @@ -134,7 +137,7 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, moe_interval: in if (i + 1) % moe_interval == 0: # moe gate = t5x_gate_lookup(old, i, "decoder", split_mlp_wi) - new[f"model.layers.{i}.mlp.gate_weight"] = gate.T + new[f"model.layers.{i}.mlp.gate_weight"] = np.transpose(gate) wi, wo = t5x_experts_lookup(old, i, "decoder", split_mlp_wi) new[f"model.layers.{i}.mlp.experts.wi_gate"] = wi[0] new[f"model.layers.{i}.mlp.experts.wi_up"] = wi[1] @@ -143,82 +146,82 @@ def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, moe_interval: in layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_extra_mlp_layer_norm") new[f"model.layers.{i}.pre_extra_mlp_layernorm.weight"] = layer_norm wi, wo = t5x_extra_mlp_lookup(old, i, "decoder", split_mlp_wi) - new[f"model.layers.{i}.extra_mlp.gate_proj.weight"] = wi[0].T - new[f"model.layers.{i}.extra_mlp.up_proj.weight"] = wi[1].T - new[f"model.layers.{i}.extra_mlp.down_proj.weight"] = wo.T + new[f"model.layers.{i}.extra_mlp.gate_proj.weight"] = np.transpose(wi[0]) + new[f"model.layers.{i}.extra_mlp.up_proj.weight"] = np.transpose(wi[1]) + new[f"model.layers.{i}.extra_mlp.down_proj.weight"] = np.transpose(wo) else: wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi) - new[f"model.layers.{i}.mlp.gate_proj.weight"] = wi[0].T - new[f"model.layers.{i}.mlp.up_proj.weight"] = wi[1].T - new[f"model.layers.{i}.mlp.down_proj.weight"] = wo.T + new[f"model.layers.{i}.mlp.gate_proj.weight"] = np.transpose(wi[0]) + new[f"model.layers.{i}.mlp.up_proj.weight"] = np.transpose(wi[1]) + new[f"model.layers.{i}.mlp.down_proj.weight"] = np.transpose(wo) new["model.norm.weight"] = old["decoder/decoder_norm/scale"] # LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead) if "decoder/logits_dense/kernel" in old: - new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T + new["lm_head.weight"] = np.transpose(old["decoder/logits_dense/kernel"]) return new -def make_state_dict(converted_params): - """Prepares a state dict for the PyTorch model.""" - # Make a state dict with torch tensors. - state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()]) - +def load_t5x_weights_in_t5(config, t5x_checkpoint_path, dtype, lazy): + """get T5x converted params.""" + variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path, + restore_dtype=dtype, + lazy_parameters=lazy) + converted_params = convert_t5x_to_pytorch(variables, + num_layers=config.num_hidden_layers, + moe_interval=config.moe_layer_interval) + if lazy: + state_dict = collections.OrderedDict() + for k, v in converted_params.items(): + if isinstance(v, np.ndarray): + assert len(v.shape)==0 and isinstance(v.item(), LazyAwaitableArray) + state_dict[k] = torch.from_numpy(v.item().get().T.copy()) + else: + state_dict[k] = torch.from_numpy(v.get().copy()) + else: + state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()]) + # del converted_params return state_dict - -def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path): - """Replaces the params in model witht the T5X converted params.""" - variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path) - converted = convert_t5x_to_pytorch(variables, - num_layers=config.num_hidden_layers, - moe_interval=config.moe_layer_interval) - state_dict = make_state_dict(converted) - model.load_state_dict(state_dict, strict=True) - - -def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_dump_path): +def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_dump_path, + target_dtype='float32', lazy=False): """Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint.""" - # Initialise PyTorch model config = LlamaConfig.from_json_file(config_file) - print(f"Building PyTorch model from configuration: {config}") # Non-v1.1 checkpoints could also use T5Model, but this works for all. # The v1.0 checkpoints will simply have an LM head that is the word embeddings. - model = OpenMoeForCausalLM(config) + print("Get state_dict from jax checkpoint") + state_dict = load_t5x_weights_in_t5(config, t5x_checkpoint_path, + dtype=target_dtype, lazy=lazy) - # Load weights from tf checkpoint - load_t5x_weights_in_t5(model, config, t5x_checkpoint_path) + print(f"Building PyTorch model from config and checkpoint: {config}") + with init_empty_weights(): + model = OpenMoeForCausalLM(config) + print('Empty Model Initialized.') + # assign=True: https://pytorch.org/tutorials/recipes/recipes/module_load_state_dict_tips.html + model.load_state_dict(state_dict, assign=True, strict=True) # Save pytorch-model - print(f"Save PyTorch model to {pytorch_dump_path}") model.save_pretrained(pytorch_dump_path) - - # Verify that we can load the checkpoint. - model.from_pretrained(pytorch_dump_path) + print(f"Save PyTorch Model Checkpoint to {pytorch_dump_path}") print("Done") + if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.") - # Required parameters - parser.add_argument("--t5x_checkpoint_path", - default=None, - type=str, - required=True, - help="Path to the T5X checkpoint.") - parser.add_argument( - "--config_file", - default=None, - type=str, - required=True, - help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.", - ) - parser.add_argument("--pytorch_dump_path", - default=None, - type=str, - required=True, - help="Path to the output PyTorch model.") + parser = argparse.ArgumentParser(description="Convert t5x checkpoint to PyTorch format.") + + parser.add_argument("--t5x_checkpoint_path", type=str, required=True, help="Path to the original t5x checkpoint") + parser.add_argument("--config_file", type=str, required=True, help="Path to the configuration file") + parser.add_argument("--pytorch_dump_path", type=str, required=True, help="Path for the output PyTorch Checkpoint") + parser.add_argument("--target_dtype", type=str, choices=['float32', 'float16'], required=True, help="Target data type for the PyTorch Checkpoint") + parser.add_argument("--lazy", action="store_true", help="Use lazy loading for the PyTorch model") + args = parser.parse_args() - convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path) + + convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, + args.config_file, + args.pytorch_dump_path, + args.target_dtype, + args.lazy) From 2eee01fce27b8407bb57a598d1c6141f6f75bf4f Mon Sep 17 00:00:00 2001 From: "zheng_zian@u.nus.edu" Date: Sun, 14 Jan 2024 03:51:23 +0800 Subject: [PATCH 7/8] update convert_openmoe_ckpt.sh and fix typos in modeling_openmoe.py --- .../openmoe/model/convert_openmoe_ckpt.sh | 7 +++- .../openmoe/model/modeling_openmoe.py | 33 +------------------ 2 files changed, 7 insertions(+), 33 deletions(-) diff --git a/examples/language/openmoe/model/convert_openmoe_ckpt.sh b/examples/language/openmoe/model/convert_openmoe_ckpt.sh index c0d53f562e40..b8390fcc910b 100644 --- a/examples/language/openmoe/model/convert_openmoe_ckpt.sh +++ b/examples/language/openmoe/model/convert_openmoe_ckpt.sh @@ -1 +1,6 @@ -python convert_openmoe_ckpt.py --t5x_checkpoint_path /path/to/t5x --config_file /path/to/config --pytorch_dump_path /path/to/save +python ColossalAI/examples/language/openmoe/model/convert_openmoe_ckpt.py \ +--t5x_checkpoint_path checkpoint_553000 \ +--config_file ColossalAI/examples/language/openmoe/model/openmoe_8b_config.json \ +--pytorch_dump_path openmoe_8b_chat_ckpt \ +--target_dtype float32 \ +--lazy \ No newline at end of file diff --git a/examples/language/openmoe/model/modeling_openmoe.py b/examples/language/openmoe/model/modeling_openmoe.py index a52198f59815..51e14836de94 100644 --- a/examples/language/openmoe/model/modeling_openmoe.py +++ b/examples/language/openmoe/model/modeling_openmoe.py @@ -78,7 +78,6 @@ def set_openmoe_args( """ MoE related arguments. It inserts the MoE arguments into the Llama config. - Args: config (LlamaConfig): Transformers Llama config. num_experts (int, optional): Number of experts. @@ -396,11 +395,7 @@ def forward( value_states = repeat_kv(value_states, self.num_key_value_groups) if HAS_FLASH_ATTN and self.use_kernel: - # from flash_attn import flash_attn_func - # If we use `from flash_attn import flash_attn_func` directly, - # AutoModelForCausalLM.from_pretrained will treat flash_attn as a compulsory dependency and raise error if it cannot be found. - # Here is a workaround to avoid the error. - exec("from flash_attn import flash_attn_func") + from flash_attn import flash_attn_func query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -551,11 +546,9 @@ def forward( This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. - Parameters: config ([`LlamaConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not @@ -596,44 +589,33 @@ def _set_gradient_checkpointing(self, module, value=False): input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. - [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. - [What are attention masks?](../glossary#attention-mask) - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. - [What are position IDs?](../glossary#position-ids) past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. @@ -662,7 +644,6 @@ def _set_gradient_checkpointing(self, module, value=False): class OpenMoeModel(OpenMoePreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - Args: config: LlamaConfig """ @@ -897,20 +878,14 @@ def forward( Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - Returns: - Example: - ```python >>> from transformers import AutoTokenizer, LlamaForCausalLM - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") - >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] @@ -1049,11 +1024,9 @@ def _calculate_router_loss(self, aux_loss: list = None, z_loss: list = None): def _calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """Compute cross entropy and entropy for log probs and targets. - Args: logits: [batch, length, num_classes] float array. targets: categorical targets [batch, length] int array. - Returns: Tuple of scalar loss. """ @@ -1086,23 +1059,19 @@ def _calculate_loss(self, logits: torch.Tensor, targets: torch.Tensor) -> torch. class ZLossCrossEntropy(torch.autograd.Function): """Computes cross entropy loss with stable custom gradient. - Computes a stabilized-gradient version of: -jnp.sum(targets * nn.log_softmax(logits), axis=-1) - If z_loss > 0, then an auxiliary loss equal to z_loss*log(z)^2 will be added to the cross entropy loss (z = softmax normalization constant). The two uses of z_loss are: 1. To keep the logits from drifting too far from zero, which can cause unacceptable roundoff errors in bfloat16. 2. To encourage the logits to be normalized log-probabilities. - Args: logits: [batch, length, num_classes] float array. targets: categorical one-hot targets [batch, length, num_classes] float array. z_loss: coefficient for auxilliary z-loss loss term. - Returns: tuple with the total loss and the z_loss, both float arrays with shape [batch, length]. From 80eeb0602ba8c96a76be5e53ae04ed519bbb90cc Mon Sep 17 00:00:00 2001 From: "zheng_zian@u.nus.edu" Date: Sun, 14 Jan 2024 04:31:01 +0800 Subject: [PATCH 8/8] add cpu support to device.py --- colossalai/utils/device.py | 42 ++++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/colossalai/utils/device.py b/colossalai/utils/device.py index c70dbdaa5ee1..9792069fe3cb 100644 --- a/colossalai/utils/device.py +++ b/colossalai/utils/device.py @@ -45,12 +45,32 @@ def get_current_device() -> torch.device: def _dispatch_device_func(fn_name: str, *args, **kwargs): - if torch.cuda.is_available(): + if "device" in kwargs: # if device is specified, try to use the provided one + device = kwargs["device"] + del kwargs["device"] + if 'cuda' in device and torch.cuda.is_available(): + device = "cuda" + elif 'npu' in device and IS_NPU_AVAILABLE: + device = "npu" + else: + device = "cpu" + else: # if device is not specified, device will be automatically detected + if torch.cuda.is_available(): + device = "cuda" + elif IS_NPU_AVAILABLE: + device = "npu" + else: + device = "cpu" + + if device == "cuda": return getattr(torch.cuda, fn_name)(*args, **kwargs) - elif IS_NPU_AVAILABLE: + elif device == "npu": return getattr(torch.npu, fn_name)(*args, **kwargs) - else: - raise RuntimeError("No device available") + else: + try: + return getattr(torch, fn_name)(*args, **kwargs) + except AttributeError: + raise RuntimeError(f"Current device does not support the function: {fn_name}") # device semantics @@ -114,7 +134,12 @@ def utilization(device=None) -> int: def get_rng_state(device="cuda") -> torch.Tensor: - return _dispatch_device_func("get_rng_state", device) + if torch.cuda.is_available() and device=="cuda": + return _dispatch_device_func("get_rng_state", device="cuda") + elif IS_NPU_AVAILABLE and device=="npu": + return _dispatch_device_func("get_rng_state", device="npu") + else: + return _dispatch_device_func("get_rng_state", device="cpu") def get_rng_state_all() -> List[torch.Tensor]: @@ -122,7 +147,12 @@ def get_rng_state_all() -> List[torch.Tensor]: def set_rng_state(new_state: torch.ByteTensor, device="cuda") -> None: - return _dispatch_device_func("set_rng_state", new_state, device) + if torch.cuda.is_available() and device=="cuda": + return _dispatch_device_func("set_rng_state", new_state, device="cuda") + elif IS_NPU_AVAILABLE and device=="npu": + return _dispatch_device_func("set_rng_state", new_state, device="npu") + else: + return _dispatch_device_func("set_rng_state", new_state, device="cpu") def set_rng_state_all(new_states: List[torch.ByteTensor]) -> None: