Skip to content

Commit

Permalink
refine code-generation example (#192)
Browse files Browse the repository at this point in the history
* refine code-generation example

Signed-off-by: changwangss <chang1.wang@intel.com>

* remove code

Signed-off-by: changwangss <chang1.wang@intel.com>

* remove invalid code

* improve readme and line length

Signed-off-by: changwangss <chang1.wang@intel.com>

---------

Signed-off-by: changwangss <chang1.wang@intel.com>
Co-authored-by: Haihao Shen <haihao.shen@intel.com>
  • Loading branch information
changwangss and hshen14 committed Aug 30, 2023
1 parent c2242e6 commit c569fd5
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Required libraries.
pip install -r requirements.txt
```

We use the local gpt_bigcode defination script `modeling_gpt_bigcode.py` in `run_generation.py`. Here is a little change to success trace.
We use the gpt_bigcode defination script [modeling_gpt_bigcode.py](https://github.com/intel/intel-extension-for-transformers/blob/main/intel_extension_for_transformers/transformers/modeling/gpt_bigcode/modeling_gpt_bigcode.py) in `run_generation.py`. Here is a little change to success trace.
```diff
# Line 227 in modeling_gpt_bigcode.py on transformers 4.28.1
- query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,13 @@
from datasets import load_dataset
from torch.nn.functional import pad
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig
from transformers import AutoTokenizer, PretrainedConfig
import transformers
from optimum.utils import NormalizedConfigManager

import numpy as np
from itertools import chain

from modeling_gpt_bigcode import GPTBigCodeForCausalLM
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTJForCausalLM = GPTBigCodeForCausalLM

parser = argparse.ArgumentParser()

# Main config
Expand Down Expand Up @@ -87,9 +84,23 @@
parser.add_argument("--top_p", default=0.95, type=float)
parser.add_argument("--top_k", default=0, type=int)
parser.add_argument("--do_sample", action="store_true")
parser.add_argument("--check_references", action="store_true")
parser.add_argument("--max_memory_per_gpu", type=str, default=None)
parser.add_argument(
"--modeltype",
default="causal",
help="AutoModel to use, it can be causal or seq2seq",
)
parser.add_argument(
"--limit_start",
type=int,
default=0,
help="Optional offset to start from when limiting the number of samples",
)
args = parser.parse_args()


from intel_extension_for_transformers.transformers import AutoModelForCausalLM
user_model = AutoModelForCausalLM.from_pretrained(
args.model,
torchscript=True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ cd intel-extension-for-pytorch
git submodule sync && git submodule update --init --recursive
python setup.py install
```
We use the local GPTJ defination script `modeling_gptj.py` in `run_generation.py`. Here is a little change to success trace.
We use the GPTJ defination script [modeling_gptj.py](https://github.com/intel/intel-extension-for-transformers/blob/main/intel_extension_for_transformers/transformers/modeling/gptj/modeling_gptj.py) in `run_generation.py`. Here is a little change to success trace.
```diff
# Line 602 in modeling_gptj.py on transformers 4.28.1

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2021 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
# !/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# coding=utf-8
# Copyright 2023 The Bigcode team and HuggingFace Inc. team.
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
Expand Down Expand Up @@ -383,8 +401,9 @@ def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, (GPTBigCodeMLP, GPTBigCodeAttention)):
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
# > A modified initialization which accounts for the accumulation on the residual path with model depth.
# > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the
# > of residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
Expand All @@ -406,7 +425,8 @@ def _init_weights(self, module):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

# Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with GPT2->GPTBigCode
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._set_gradient_checkpointing with
# GPT2->GPTBigCode
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, GPTBigCodeModel):
module.gradient_checkpointing = value
Expand Down Expand Up @@ -477,15 +497,15 @@ def _set_gradient_checkpointing(self, module, value=False):
- 0 indicates the head is **masked**.
inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
`past_key_values`).
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
Expand Down Expand Up @@ -717,8 +737,8 @@ def custom_forward(*inputs):

@add_start_docstrings(
"""
The GPT_BIGCODE Model transformer with a language modeling head on top (linear layer with weights tied to the input
embeddings).
The GPT_BIGCODE Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings).
""",
GPT_BIGCODE_START_DOCSTRING,
)
Expand Down Expand Up @@ -868,10 +888,10 @@ def _reorder_cache(
models (e.g. GPT-1) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row.
If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess
the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value
in each row of the batch).
""",
GPT_BIGCODE_START_DOCSTRING,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
from .bloom.modeling_bloom import BloomForCausalLM
from .gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
from .opt.modeling_opt import OPTForCausalLM
from .gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM
# to use modeling modification base transformers 4.30.2:
transformers.models.gpt_bigcode.modeling_gpt_bigcode.GPTJForCausalLM = GPTBigCodeForCausalLM
# to use modeling modification base transformers 4.28.1:
transformers.models.gptj.modeling_gptj.GPTJForCausalLM = GPTJForCausalLM
transformers.models.llama.modeling_llama.LlamaForCausalLM = LlamaForCausalLM
Expand Down Expand Up @@ -118,4 +121,4 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
)

class AutoModelForCausalLM(_BaseAutoModelClass):
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING

0 comments on commit c569fd5

Please sign in to comment.