From 7952a2804abc43e3ce170c175ac697bc79d1005a Mon Sep 17 00:00:00 2001 From: llermaly Date: Mon, 30 Sep 2024 02:54:20 -0500 Subject: [PATCH 1/6] added notebook --- .../using-openelm-models/OpenELM/LICENSE | 47 + .../using-openelm-models/OpenELM/README.md | 216 ++++ .../OpenELM/generate_openelm.py | 239 +++++ .../using-openelm-models/OpenELM/modelfile | 1 + .../using-openelm-models.ipynb | 960 ++++++++++++++++++ 5 files changed, 1463 insertions(+) create mode 100644 supporting-blog-content/using-openelm-models/OpenELM/LICENSE create mode 100644 supporting-blog-content/using-openelm-models/OpenELM/README.md create mode 100644 supporting-blog-content/using-openelm-models/OpenELM/generate_openelm.py create mode 100644 supporting-blog-content/using-openelm-models/OpenELM/modelfile create mode 100644 supporting-blog-content/using-openelm-models/using-openelm-models.ipynb diff --git a/supporting-blog-content/using-openelm-models/OpenELM/LICENSE b/supporting-blog-content/using-openelm-models/OpenELM/LICENSE new file mode 100644 index 00000000..02fa0ad0 --- /dev/null +++ b/supporting-blog-content/using-openelm-models/OpenELM/LICENSE @@ -0,0 +1,47 @@ +Copyright (C) 2024 Apple Inc. All Rights Reserved. + +Disclaimer: IMPORTANT: This Apple software is supplied to you by Apple +Inc. ("Apple") in consideration of your agreement to the following +terms, and your use, installation, modification or redistribution of +this Apple software constitutes acceptance of these terms. If you do +not agree with these terms, please do not use, install, modify or +redistribute this Apple software. + +In consideration of your agreement to abide by the following terms, and +subject to these terms, Apple grants you a personal, non-exclusive +license, under Apple's copyrights in this original Apple software (the +"Apple Software"), to use, reproduce, modify and redistribute the Apple +Software, with or without modifications, in source and/or binary forms; +provided that if you redistribute the Apple Software in its entirety and +without modifications, you must retain this notice and the following +text and disclaimers in all such redistributions of the Apple Software. +Neither the name, trademarks, service marks or logos of Apple Inc. may +be used to endorse or promote products derived from the Apple Software +without specific prior written permission from Apple. Except as +expressly stated in this notice, no other rights or licenses, express or +implied, are granted by Apple herein, including but not limited to any +patent rights that may be infringed by your derivative works or by other +works in which the Apple Software may be incorporated. + +The Apple Software is provided by Apple on an "AS IS" basis. APPLE +MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION +THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS +FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND +OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. + +IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL +OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, +MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED +AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), +STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. + + +------------------------------------------------------------------------------- +SOFTWARE DISTRIBUTED IN THIS REPOSITORY: + +This software includes a number of subcomponents with separate +copyright notices and license terms - please see the file ACKNOWLEDGEMENTS. +------------------------------------------------------------------------------- diff --git a/supporting-blog-content/using-openelm-models/OpenELM/README.md b/supporting-blog-content/using-openelm-models/OpenELM/README.md new file mode 100644 index 00000000..26a7f198 --- /dev/null +++ b/supporting-blog-content/using-openelm-models/OpenELM/README.md @@ -0,0 +1,216 @@ +--- +license: other +license_name: apple-sample-code-license +license_link: LICENSE +--- + +# OpenELM: An Efficient Language Model Family with Open Training and Inference Framework + +*Sachin Mehta, Mohammad Hossein Sekhavat, Qingqing Cao, Maxwell Horton, Yanzi Jin, Chenfan Sun, Iman Mirzadeh, Mahyar Najibi, Dmitry Belenko, Peter Zatloukal, Mohammad Rastegari* + +We introduce **OpenELM**, a family of **Open** **E**fficient **L**anguage **M**odels. OpenELM uses a layer-wise scaling strategy to efficiently allocate parameters within each layer of the transformer model, leading to enhanced accuracy. We pretrained OpenELM models using the [CoreNet](https://github.com/apple/corenet) library. We release both pretrained and instruction tuned models with 270M, 450M, 1.1B and 3B parameters. + +Our pre-training dataset contains RefinedWeb, deduplicated PILE, a subset of RedPajama, and a subset of Dolma v1.6, totaling approximately 1.8 trillion tokens. Please check license agreements and terms of these datasets before using them. + +See the list below for the details of each model: + +- [OpenELM-270M](https://huggingface.co/apple/OpenELM-270M) +- [OpenELM-450M](https://huggingface.co/apple/OpenELM-450M) +- [OpenELM-1_1B](https://huggingface.co/apple/OpenELM-1_1B) +- [OpenELM-3B](https://huggingface.co/apple/OpenELM-3B) +- [OpenELM-270M-Instruct](https://huggingface.co/apple/OpenELM-270M-Instruct) +- [OpenELM-450M-Instruct](https://huggingface.co/apple/OpenELM-450M-Instruct) +- [OpenELM-1_1B-Instruct](https://huggingface.co/apple/OpenELM-1_1B-Instruct) +- [OpenELM-3B-Instruct](https://huggingface.co/apple/OpenELM-3B-Instruct) + + +```python + +from transformers import AutoModelForCausalLM + +openelm_270m = AutoModelForCausalLM.from_pretrained("apple/OpenELM-270M", trust_remote_code=True) +openelm_450m = AutoModelForCausalLM.from_pretrained("apple/OpenELM-450M", trust_remote_code=True) +openelm_1b = AutoModelForCausalLM.from_pretrained("apple/OpenELM-1_1B", trust_remote_code=True) +openelm_3b = AutoModelForCausalLM.from_pretrained("apple/OpenELM-3B", trust_remote_code=True) + +openelm_270m_instruct = AutoModelForCausalLM.from_pretrained("apple/OpenELM-270M-Instruct", trust_remote_code=True) +openelm_450m_instruct = AutoModelForCausalLM.from_pretrained("apple/OpenELM-450M-Instruct", trust_remote_code=True) +openelm_1b_instruct = AutoModelForCausalLM.from_pretrained("apple/OpenELM-1_1B-Instruct", trust_remote_code=True) +openelm_3b_instruct = AutoModelForCausalLM.from_pretrained("apple/OpenELM-3B-Instruct", trust_remote_code=True) + +``` + +## Usage + +We have provided an example function to generate output from OpenELM models loaded via [HuggingFace Hub](https://huggingface.co/docs/hub/) in `generate_openelm.py`. + +You can try the model by running the following command: +``` +python generate_openelm.py --model [MODEL_NAME] --hf_access_token [HF_ACCESS_TOKEN] --prompt 'Once upon a time there was' --generate_kwargs repetition_penalty=1.2 +``` +Please refer to [this link](https://huggingface.co/docs/hub/security-tokens) to obtain your hugging face access token. + +Additional arguments to the hugging face generate function can be passed via `generate_kwargs`. As an example, to speedup the inference, you can try [lookup token speculative generation](https://huggingface.co/docs/transformers/generation_strategies) by passing the `prompt_lookup_num_tokens` argument as follows: +``` +python generate_openelm.py --model [MODEL_NAME] --hf_access_token [HF_ACCESS_TOKEN] --prompt 'Once upon a time there was' --generate_kwargs repetition_penalty=1.2 prompt_lookup_num_tokens=10 +``` +Alternatively, try model-wise speculative generation with an [assistive model](https://huggingface.co/blog/assisted-generation) by passing a smaller model through the `assistant_model` argument, for example: +``` +python generate_openelm.py --model [MODEL_NAME] --hf_access_token [HF_ACCESS_TOKEN] --prompt 'Once upon a time there was' --generate_kwargs repetition_penalty=1.2 --assistant_model [SMALLER_MODEL_NAME] +``` + + +## Main Results + +### Zero-Shot + +| **Model Size** | **ARC-c** | **ARC-e** | **BoolQ** | **HellaSwag** | **PIQA** | **SciQ** | **WinoGrande** | **Average** | +|-----------------------------------------------------------------------------|-----------|-----------|-----------|---------------|-----------|-----------|----------------|-------------| +| [OpenELM-270M](https://huggingface.co/apple/OpenELM-270M) | 26.45 | 45.08 | **53.98** | 46.71 | 69.75 | **84.70** | **53.91** | 54.37 | +| [OpenELM-270M-Instruct](https://huggingface.co/apple/OpenELM-270M-Instruct) | **30.55** | **46.68** | 48.56 | **52.07** | **70.78** | 84.40 | 52.72 | **55.11** | +| [OpenELM-450M](https://huggingface.co/apple/OpenELM-450M) | 27.56 | 48.06 | 55.78 | 53.97 | 72.31 | 87.20 | 58.01 | 57.56 | +| [OpenELM-450M-Instruct](https://huggingface.co/apple/OpenELM-450M-Instruct) | **30.38** | **50.00** | **60.37** | **59.34** | **72.63** | **88.00** | **58.96** | **59.95** | +| [OpenELM-1_1B](https://huggingface.co/apple/OpenELM-1_1B) | 32.34 | **55.43** | 63.58 | 64.81 | **75.57** | **90.60** | 61.72 | 63.44 | +| [OpenELM-1_1B-Instruct](https://huggingface.co/apple/OpenELM-1_1B-Instruct) | **37.97** | 52.23 | **70.00** | **71.20** | 75.03 | 89.30 | **62.75** | **65.50** | +| [OpenELM-3B](https://huggingface.co/apple/OpenELM-3B) | 35.58 | 59.89 | 67.40 | 72.44 | 78.24 | **92.70** | 65.51 | 67.39 | +| [OpenELM-3B-Instruct](https://huggingface.co/apple/OpenELM-3B-Instruct) | **39.42** | **61.74** | **68.17** | **76.36** | **79.00** | 92.50 | **66.85** | **69.15** | + +### LLM360 + +| **Model Size** | **ARC-c** | **HellaSwag** | **MMLU** | **TruthfulQA** | **WinoGrande** | **Average** | +|-----------------------------------------------------------------------------|-----------|---------------|-----------|----------------|----------------|-------------| +| [OpenELM-270M](https://huggingface.co/apple/OpenELM-270M) | 27.65 | 47.15 | 25.72 | **39.24** | **53.83** | 38.72 | +| [OpenELM-270M-Instruct](https://huggingface.co/apple/OpenELM-270M-Instruct) | **32.51** | **51.58** | **26.70** | 38.72 | 53.20 | **40.54** | +| [OpenELM-450M](https://huggingface.co/apple/OpenELM-450M) | 30.20 | 53.86 | **26.01** | 40.18 | 57.22 | 41.50 | +| [OpenELM-450M-Instruct](https://huggingface.co/apple/OpenELM-450M-Instruct) | **33.53** | **59.31** | 25.41 | **40.48** | **58.33** | **43.41** | +| [OpenELM-1_1B](https://huggingface.co/apple/OpenELM-1_1B) | 36.69 | 65.71 | **27.05** | 36.98 | 63.22 | 45.93 | +| [OpenELM-1_1B-Instruct](https://huggingface.co/apple/OpenELM-1_1B-Instruct) | **41.55** | **71.83** | 25.65 | **45.95** | **64.72** | **49.94** | +| [OpenELM-3B](https://huggingface.co/apple/OpenELM-3B) | 42.24 | 73.28 | **26.76** | 34.98 | 67.25 | 48.90 | +| [OpenELM-3B-Instruct](https://huggingface.co/apple/OpenELM-3B-Instruct) | **47.70** | **76.87** | 24.80 | **38.76** | **67.96** | **51.22** | + + +### OpenLLM Leaderboard + +| **Model Size** | **ARC-c** | **CrowS-Pairs** | **HellaSwag** | **MMLU** | **PIQA** | **RACE** | **TruthfulQA** | **WinoGrande** | **Average** | +|-----------------------------------------------------------------------------|-----------|-----------------|---------------|-----------|-----------|-----------|----------------|----------------|-------------| +| [OpenELM-270M](https://huggingface.co/apple/OpenELM-270M) | 27.65 | **66.79** | 47.15 | 25.72 | 69.75 | 30.91 | **39.24** | **53.83** | 45.13 | +| [OpenELM-270M-Instruct](https://huggingface.co/apple/OpenELM-270M-Instruct) | **32.51** | 66.01 | **51.58** | **26.70** | **70.78** | 33.78 | 38.72 | 53.20 | **46.66** | +| [OpenELM-450M](https://huggingface.co/apple/OpenELM-450M) | 30.20 | **68.63** | 53.86 | **26.01** | 72.31 | 33.11 | 40.18 | 57.22 | 47.69 | +| [OpenELM-450M-Instruct](https://huggingface.co/apple/OpenELM-450M-Instruct) | **33.53** | 67.44 | **59.31** | 25.41 | **72.63** | **36.84** | **40.48** | **58.33** | **49.25** | +| [OpenELM-1_1B](https://huggingface.co/apple/OpenELM-1_1B) | 36.69 | **71.74** | 65.71 | **27.05** | **75.57** | 36.46 | 36.98 | 63.22 | 51.68 | +| [OpenELM-1_1B-Instruct](https://huggingface.co/apple/OpenELM-1_1B-Instruct) | **41.55** | 71.02 | **71.83** | 25.65 | 75.03 | **39.43** | **45.95** | **64.72** | **54.40** | +| [OpenELM-3B](https://huggingface.co/apple/OpenELM-3B) | 42.24 | **73.29** | 73.28 | **26.76** | 78.24 | **38.76** | 34.98 | 67.25 | 54.35 | +| [OpenELM-3B-Instruct](https://huggingface.co/apple/OpenELM-3B-Instruct) | **47.70** | 72.33 | **76.87** | 24.80 | **79.00** | 38.47 | **38.76** | **67.96** | **55.73** | + +See the technical report for more results and comparison. + +## Evaluation + +### Setup + +Install the following dependencies: + +```bash + +# install public lm-eval-harness + +harness_repo="public-lm-eval-harness" +git clone https://github.com/EleutherAI/lm-evaluation-harness ${harness_repo} +cd ${harness_repo} +# use main branch on 03-15-2024, SHA is dc90fec +git checkout dc90fec +pip install -e . +cd .. + +# 66d6242 is the main branch on 2024-04-01 +pip install datasets@git+https://github.com/huggingface/datasets.git@66d6242 +pip install tokenizers>=0.15.2 transformers>=4.38.2 sentencepiece>=0.2.0 + +``` + +### Evaluate OpenELM + +```bash + +# OpenELM-270M +hf_model=apple/OpenELM-270M + +# this flag is needed because lm-eval-harness set add_bos_token to False by default, but OpenELM uses LLaMA tokenizer which requires add_bos_token to be True +tokenizer=meta-llama/Llama-2-7b-hf +add_bos_token=True +batch_size=1 + +mkdir lm_eval_output + +shot=0 +task=arc_challenge,arc_easy,boolq,hellaswag,piqa,race,winogrande,sciq,truthfulqa_mc2 +lm_eval --model hf \ + --model_args pretrained=${hf_model},trust_remote_code=True,add_bos_token=${add_bos_token},tokenizer=${tokenizer} \ + --tasks ${task} \ + --device cuda:0 \ + --num_fewshot ${shot} \ + --output_path ./lm_eval_output/${hf_model//\//_}_${task//,/_}-${shot}shot \ + --batch_size ${batch_size} 2>&1 | tee ./lm_eval_output/eval-${hf_model//\//_}_${task//,/_}-${shot}shot.log + +shot=5 +task=mmlu,winogrande +lm_eval --model hf \ + --model_args pretrained=${hf_model},trust_remote_code=True,add_bos_token=${add_bos_token},tokenizer=${tokenizer} \ + --tasks ${task} \ + --device cuda:0 \ + --num_fewshot ${shot} \ + --output_path ./lm_eval_output/${hf_model//\//_}_${task//,/_}-${shot}shot \ + --batch_size ${batch_size} 2>&1 | tee ./lm_eval_output/eval-${hf_model//\//_}_${task//,/_}-${shot}shot.log + +shot=25 +task=arc_challenge,crows_pairs_english +lm_eval --model hf \ + --model_args pretrained=${hf_model},trust_remote_code=True,add_bos_token=${add_bos_token},tokenizer=${tokenizer} \ + --tasks ${task} \ + --device cuda:0 \ + --num_fewshot ${shot} \ + --output_path ./lm_eval_output/${hf_model//\//_}_${task//,/_}-${shot}shot \ + --batch_size ${batch_size} 2>&1 | tee ./lm_eval_output/eval-${hf_model//\//_}_${task//,/_}-${shot}shot.log + +shot=10 +task=hellaswag +lm_eval --model hf \ + --model_args pretrained=${hf_model},trust_remote_code=True,add_bos_token=${add_bos_token},tokenizer=${tokenizer} \ + --tasks ${task} \ + --device cuda:0 \ + --num_fewshot ${shot} \ + --output_path ./lm_eval_output/${hf_model//\//_}_${task//,/_}-${shot}shot \ + --batch_size ${batch_size} 2>&1 | tee ./lm_eval_output/eval-${hf_model//\//_}_${task//,/_}-${shot}shot.log + +``` + + +## Bias, Risks, and Limitations + +The release of OpenELM models aims to empower and enrich the open research community by providing access to state-of-the-art language models. Trained on publicly available datasets, these models are made available without any safety guarantees. Consequently, there exists the possibility of these models producing outputs that are inaccurate, harmful, biased, or objectionable in response to user prompts. Thus, it is imperative for users and developers to undertake thorough safety testing and implement appropriate filtering mechanisms tailored to their specific requirements. + +## Citation + +If you find our work useful, please cite: + +```BibTex +@article{mehtaOpenELMEfficientLanguage2024, + title = {{OpenELM}: {An} {Efficient} {Language} {Model} {Family} with {Open} {Training} and {Inference} {Framework}}, + shorttitle = {{OpenELM}}, + url = {https://arxiv.org/abs/2404.14619v1}, + language = {en}, + urldate = {2024-04-24}, + journal = {arXiv.org}, + author = {Mehta, Sachin and Sekhavat, Mohammad Hossein and Cao, Qingqing and Horton, Maxwell and Jin, Yanzi and Sun, Chenfan and Mirzadeh, Iman and Najibi, Mahyar and Belenko, Dmitry and Zatloukal, Peter and Rastegari, Mohammad}, + month = apr, + year = {2024}, +} + +@inproceedings{mehta2022cvnets, + author = {Mehta, Sachin and Abdolhosseini, Farzad and Rastegari, Mohammad}, + title = {CVNets: High Performance Library for Computer Vision}, + year = {2022}, + booktitle = {Proceedings of the 30th ACM International Conference on Multimedia}, + series = {MM '22} +} +``` diff --git a/supporting-blog-content/using-openelm-models/OpenELM/generate_openelm.py b/supporting-blog-content/using-openelm-models/OpenELM/generate_openelm.py new file mode 100644 index 00000000..12b167e2 --- /dev/null +++ b/supporting-blog-content/using-openelm-models/OpenELM/generate_openelm.py @@ -0,0 +1,239 @@ +# +# For licensing see accompanying LICENSE file. +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +# + +"""Module to generate OpenELM output given a model and an input prompt.""" +import os +import logging +import time +import argparse +from typing import Optional, Union +import torch + +from transformers import AutoTokenizer, AutoModelForCausalLM + +def generate( + prompt: str, + model: Union[str, AutoModelForCausalLM], + hf_access_token: str = None, + tokenizer: Union[str, AutoTokenizer] = 'meta-llama/Llama-2-7b-hf', + device: Optional[str] = None, + max_length: int = 1024, + assistant_model: Optional[Union[str, AutoModelForCausalLM]] = None, + generate_kwargs: Optional[dict] = None, +) -> str: + """ Generates output given a prompt. + + Args: + prompt: The string prompt. + model: The LLM Model. If a string is passed, it should be the path to + the hf converted checkpoint. + hf_access_token: Hugging face access token. + tokenizer: Tokenizer instance. If model is set as a string path, + the tokenizer will be loaded from the checkpoint. + device: String representation of device to run the model on. If None + and cuda available it would be set to cuda:0 else cpu. + max_length: Maximum length of tokens, input prompt + generated tokens. + assistant_model: If set, this model will be used for + speculative generation. If a string is passed, it should be the + path to the hf converted checkpoint. + generate_kwargs: Extra kwargs passed to the hf generate function. + + Returns: + output_text: output generated as a string. + generation_time: generation time in seconds. + + Raises: + ValueError: If device is set to CUDA but no CUDA device is detected. + ValueError: If tokenizer is not set. + ValueError: If hf_access_token is not specified. + """ + if not device: + if torch.cuda.is_available() and torch.cuda.device_count(): + device = "cuda:0" + logging.warning( + 'inference device is not set, using cuda:0, %s', + torch.cuda.get_device_name(0) + ) + else: + device = 'cpu' + logging.warning( + ( + 'No CUDA device detected, using cpu, ' + 'expect slower speeds.' + ) + ) + + if 'cuda' in device and not torch.cuda.is_available(): + raise ValueError('CUDA device requested but no CUDA device detected.') + + if not tokenizer: + raise ValueError('Tokenizer is not set in the generate function.') + + if not hf_access_token: + raise ValueError(( + 'Hugging face access token needs to be specified. ' + 'Please refer to https://huggingface.co/docs/hub/security-tokens' + ' to obtain one.' + ) + ) + + if isinstance(model, str): + checkpoint_path = model + model = AutoModelForCausalLM.from_pretrained( + checkpoint_path, + trust_remote_code=True + ) + model.to(device).eval() + if isinstance(tokenizer, str): + tokenizer = AutoTokenizer.from_pretrained( + tokenizer, + token=hf_access_token, + ) + + # Speculative mode + draft_model = None + if assistant_model: + draft_model = assistant_model + if isinstance(assistant_model, str): + draft_model = AutoModelForCausalLM.from_pretrained( + assistant_model, + trust_remote_code=True + ) + draft_model.to(device).eval() + + # Prepare the prompt + tokenized_prompt = tokenizer(prompt) + tokenized_prompt = torch.tensor( + tokenized_prompt['input_ids'], + device=device + ) + + tokenized_prompt = tokenized_prompt.unsqueeze(0) + + # Generate + stime = time.time() + output_ids = model.generate( + tokenized_prompt, + max_length=max_length, + pad_token_id=0, + assistant_model=draft_model, + **(generate_kwargs if generate_kwargs else {}), + ) + generation_time = time.time() - stime + + output_text = tokenizer.decode( + output_ids[0].tolist(), + skip_special_tokens=True + ) + + return output_text, generation_time + + +def openelm_generate_parser(): + """Argument Parser""" + + class KwargsParser(argparse.Action): + """Parser action class to parse kwargs of form key=value""" + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, dict()) + for val in values: + if '=' not in val: + raise ValueError( + ( + 'Argument parsing error, kwargs are expected in' + ' the form of key=value.' + ) + ) + kwarg_k, kwarg_v = val.split('=') + try: + converted_v = int(kwarg_v) + except ValueError: + try: + converted_v = float(kwarg_v) + except ValueError: + converted_v = kwarg_v + getattr(namespace, self.dest)[kwarg_k] = converted_v + + parser = argparse.ArgumentParser('OpenELM Generate Module') + parser.add_argument( + '--model', + dest='model', + help='Path to the hf converted model.', + required=True, + type=str, + ) + parser.add_argument( + '--hf_access_token', + dest='hf_access_token', + help='Hugging face access token, starting with "hf_".', + type=str, + ) + parser.add_argument( + '--prompt', + dest='prompt', + help='Prompt for LLM call.', + default='', + type=str, + ) + parser.add_argument( + '--device', + dest='device', + help='Device used for inference.', + type=str, + ) + parser.add_argument( + '--max_length', + dest='max_length', + help='Maximum length of tokens.', + default=256, + type=int, + ) + parser.add_argument( + '--assistant_model', + dest='assistant_model', + help=( + ( + 'If set, this is used as a draft model ' + 'for assisted speculative generation.' + ) + ), + type=str, + ) + parser.add_argument( + '--generate_kwargs', + dest='generate_kwargs', + help='Additional kwargs passed to the HF generate function.', + type=str, + nargs='*', + action=KwargsParser, + ) + return parser.parse_args() + + +if __name__ == '__main__': + args = openelm_generate_parser() + prompt = args.prompt + + output_text, genertaion_time = generate( + prompt=prompt, + model=args.model, + device=args.device, + max_length=args.max_length, + assistant_model=args.assistant_model, + generate_kwargs=args.generate_kwargs, + hf_access_token=args.hf_access_token, + ) + + print_txt = ( + f'\r\n{"=" * os.get_terminal_size().columns}\r\n' + '\033[1m Prompt + Generated Output\033[0m\r\n' + f'{"-" * os.get_terminal_size().columns}\r\n' + f'{output_text}\r\n' + f'{"-" * os.get_terminal_size().columns}\r\n' + '\r\nGeneration took' + f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m' + 'seconds.\r\n' + ) + print(print_txt) diff --git a/supporting-blog-content/using-openelm-models/OpenELM/modelfile b/supporting-blog-content/using-openelm-models/OpenELM/modelfile new file mode 100644 index 00000000..6d556901 --- /dev/null +++ b/supporting-blog-content/using-openelm-models/OpenELM/modelfile @@ -0,0 +1 @@ +FROM tomasmcm/openelm:3b-intruct-q5_K_M diff --git a/supporting-blog-content/using-openelm-models/using-openelm-models.ipynb b/supporting-blog-content/using-openelm-models/using-openelm-models.ipynb new file mode 100644 index 00000000..aab785e8 --- /dev/null +++ b/supporting-blog-content/using-openelm-models/using-openelm-models.ipynb @@ -0,0 +1,960 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Y2RgpN7yLf9J" + }, + "source": [ + "# Using Elastic and OpenELM to Prototype Apple-Inspired AI\n", + "\n", + "This is the supporting material for [this blog post.](https://search-labs.elastic.co/search-labs/blog/using-openelm-models)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4y68v93nNXTH", + "outputId": "a2d6f13d-c0ec-40af-f623-a043fe3654e6" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting elasticsearch\n", + " Downloading elasticsearch-8.15.1-py3-none-any.whl.metadata (8.7 kB)\n", + "Collecting elastic-transport<9,>=8.13 (from elasticsearch)\n", + " Downloading elastic_transport-8.15.0-py3-none-any.whl.metadata (3.6 kB)\n", + "Requirement already satisfied: urllib3<3,>=1.26.2 in /usr/local/lib/python3.10/dist-packages (from elastic-transport<9,>=8.13->elasticsearch) (2.2.3)\n", + "Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from elastic-transport<9,>=8.13->elasticsearch) (2024.8.30)\n", + "Downloading elasticsearch-8.15.1-py3-none-any.whl (524 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m524.6/524.6 kB\u001b[0m \u001b[31m31.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading elastic_transport-8.15.0-py3-none-any.whl (64 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.4/64.4 kB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: elastic-transport, elasticsearch\n", + "Successfully installed elastic-transport-8.15.0 elasticsearch-8.15.1\n" + ] + } + ], + "source": [ + "%pip install elasticsearch" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": { + "id": "IBWlj7OJLhmN" + }, + "outputs": [], + "source": [ + "from elasticsearch import Elasticsearch, helpers, exceptions, ConnectionTimeout\n", + "from getpass import getpass" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GEoWlRH2Ma9d", + "outputId": "42a9fcf1-5f4c-4bdb-8fd0-7e68bd24dc07" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Elastic Cloud ID: ··········\n", + "Elastic Api Key: ··········\n", + "Huggingface Token: ··········\n" + ] + } + ], + "source": [ + "# https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud#finding-your-cloud-id\n", + "ELASTIC_CLOUD_ID = getpass(\"Elastic Cloud ID: \")\n", + "\n", + "# https://www.elastic.co/search-labs/tutorials/install-elasticsearch/elastic-cloud#creating-an-api-key\n", + "ELASTIC_API_KEY = getpass(\"Elastic Api Key: \")\n", + "\n", + "# https://huggingface.co/docs/hub/en/security-tokens\n", + "HUGGINGFACE_TOKEN = getpass(\"Huggingface Token: \")\n", + "\n", + "# https://huggingface.co/apple/OpenELM\n", + "MODEL = \"apple/OpenELM-3B-Instruct\"\n", + "\n", + "# Create the client instance\n", + "client = Elasticsearch(\n", + " # For local development\n", + " # hosts=[\"http://localhost:9200\"]\n", + " cloud_id=ELASTIC_CLOUD_ID,\n", + " api_key=ELASTIC_API_KEY,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vs0LMPgwLoFW" + }, + "source": [ + "## 2. Deploy the OpenELM Model\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "5bWSYVrOLo9D", + "outputId": "84460f5e-39ab-4cb2-ef52-351aa461cfa6" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cloning into 'OpenELM'...\n", + "remote: Enumerating objects: 12, done.\u001b[K\n", + "remote: Counting objects: 100% (11/11), done.\u001b[K\n", + "remote: Compressing objects: 100% (11/11), done.\u001b[K\n", + "remote: Total 12 (delta 4), reused 0 (delta 0), pack-reused 1 (from 1)\u001b[K\n", + "Unpacking objects: 100% (12/12), 8.28 KiB | 2.07 MiB/s, done.\n" + ] + } + ], + "source": [ + "!git clone https://huggingface.co/apple/OpenELM" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "id": "CDHtF8RaO62t" + }, + "outputs": [], + "source": [ + "prompt = \"Once upon a time there was\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fMPhVXXoPsXY" + }, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GqwDnJSSLth8", + "outputId": "02c73b2b-1ea6-4371-a3e6-992e5f89b8a2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING:root:inference device is not set, using cuda:0, NVIDIA A100-SXM4-40GB\n", + "Loading checkpoint shards: 100% 2/2 [00:01<00:00, 1.25it/s]\n", + "2024-09-30 04:46:39.147179: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2024-09-30 04:46:39.163451: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "2024-09-30 04:46:39.181587: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "2024-09-30 04:46:39.186881: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", + "2024-09-30 04:46:39.200600: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2024-09-30 04:46:40.313149: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n", + "\n", + "\n", + "\u001b[1m Prompt + Generated Output\u001b[0m\n", + "\n", + "Once upon a time there was a little girl named Rosie. Rosie loved to play dress-up, and her favorite costume was a princess dress. Her mother dressed Rosie in the princess dress every day, but Rosie wanted to wear it on special occasions, too.\n", + "\n", + "One day Rosie's mother told Rosie that Prince Charming would arrive at their house later that evening. Rosie couldn't wait! Prince Charming would surely ask Rosie to marry him, and she would wear her beautiful princess dress for their wedding. Rosie ran upstairs to get dressed.\n", + "\n", + "When Prince Charming arrived, Rosie wore her princess dress and tiara. Her mother helped her with her makeup, and Rosie practiced her curtsy. Prince Charming smiled and kissed Rosie's hand. Rosie hoped Prince Charming would ask her to marry him, but Prince Charming shook Rosie's hand instead. Prince Charming explained that Rosie's father had died when Rosie was very young, and Prince Charming wanted Rosie to choose her own husband. Rosie felt sad, but Prince Charming promised her\n", + "\n", + "\n", + "Generation took\u001b[1m\u001b[92m 9.71 \u001b[0mseconds.\n", + "\n" + ] + } + ], + "source": [ + "!python /content/OpenELM/generate_openelm.py --model '{MODEL}' --hf_access_token '{HUGGINGFACE_TOKEN}' --prompt '{prompt}' --generate_kwargs repetition_penalty=1.2 prompt_lookup_num_tokens=10" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iF4n3etCLwVs" + }, + "source": [ + "## 3. Index Data in Elasticsearch\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 391 + }, + "id": "S71mSikTQWWT", + "outputId": "8d468e94-2e07-4271-f263-5cfe7b717aea" + }, + "outputs": [], + "source": [ + "try:\n", + " # client.options(request_timeout=5).inference.delete(inference_id=\"my-elser-model\")\n", + " client.options(request_timeout=5).inference.put(\n", + " task_type=\"sparse_embedding\",\n", + " inference_id=\"my-elser-model\",\n", + " body={\"service\": \"elser\", \"service_settings\": {\n", + " \"num_allocations\": 1, \"num_threads\": 1}},\n", + " )\n", + "except ConnectionTimeout:\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bs8GUJb2L0vQ", + "outputId": "8394a345-4d40-4021-b8ed-af96497c0bfc" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "ObjectApiResponse({'acknowledged': True, 'shards_acknowledged': True, 'index': 'mobile-assistant'})" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Create the index\n", + "index_name = \"mobile-assistant\"\n", + "client.indices.delete(index=index_name, ignore_unavailable=True)\n", + "index_body = {\n", + " \"mappings\": {\n", + " \"properties\": {\n", + " \"title\": {\n", + " \"type\": \"text\",\n", + " \"analyzer\": \"english\"\n", + " },\n", + " \"description\": {\n", + " \"type\": \"text\",\n", + " \"analyzer\": \"english\",\n", + " \"copy_to\": \"semantic_field\"\n", + " },\n", + " \"semantic_field\": {\n", + " \"type\": \"semantic_text\",\n", + " \"inference_id\": \"my-elser-model\"\n", + " }\n", + " }\n", + " }\n", + "}\n", + "\n", + "client.indices.create(index=index_name, body=index_body)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "id": "UTsn9uYxL3oN" + }, + "outputs": [], + "source": [ + "documents = [\n", + " {\n", + " \"_index\": index_name,\n", + " \"_id\": \"email1\",\n", + " \"title\": \"Team Meeting Agenda\",\n", + " \"description\": \"Hello team, Let's discuss our project progress in tomorrow's meeting. Please prepare your updates. Best regards, Manager\"\n", + " },\n", + " {\n", + " \"_index\": index_name,\n", + " \"_id\": \"email2\",\n", + " \"title\": \"Client Proposal Draft\",\n", + " \"description\": \"Hi, I've attached the draft of our client proposal. Could you review it and provide feedback? Thanks, Colleague\"\n", + " },\n", + " {\n", + " \"_index\": index_name,\n", + " \"_id\": \"email3\",\n", + " \"title\": \"Weekly Newsletter\",\n", + " \"description\": \"This week in tech: AI advancements, new smartphone releases, and cybersecurity updates. Read more on our website!\"\n", + " },\n", + " {\n", + " \"_index\": index_name,\n", + " \"_id\": \"email4\",\n", + " \"title\": \"Urgent: Project Deadline Update\",\n", + " \"description\": \"Dear team, Due to recent developments, we need to move up our project deadline. The new submission date is next Friday. Please adjust your schedules accordingly and let me know if you foresee any issues. We'll discuss this in detail during our next team meeting. Best regards, Project Manager\"\n", + " },\n", + " {\n", + " \"_index\": index_name,\n", + " \"_id\": \"email5\",\n", + " \"title\": \"Invitation: Company Summer Picnic\",\n", + " \"description\": \"Hello everyone, We're excited to announce our annual company summer picnic! It will be held on Saturday, July 15th, at Sunny Park. There will be food, games, and activities for all ages. Please RSVP by replying to this email with the number of guests you'll be bringing. We look forward to seeing you there! Best, HR Team\"\n", + " }\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "nkQVCidRL7K1", + "outputId": "efdfd4a7-d292-4733-cc0e-0c948abb0b4e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Successfully indexed 5 documents\n" + ] + } + ], + "source": [ + "success, errors = helpers.bulk(client, documents, raise_on_error=False)\n", + "print(f\"Successfully indexed {success} documents\")\n", + "if errors:\n", + " print(\"Errors encountered during bulk indexing:\")\n", + " for error in errors:\n", + " print(error)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "a6hBZKpIL-7k" + }, + "source": [ + "## 4. Asking Questions\n" + ] + }, + { + "cell_type": "code", + "execution_count": 176, + "metadata": { + "id": "bJUv9I1DMBDr" + }, + "outputs": [], + "source": [ + "# https://github.com/riccardomusmeci/mlx-llm/blob/main/src/mlx_llm/prompt/openelm.py\n", + "def build_prompt(question, elasticsearch_documents):\n", + " docs_text = \"\\n\".join([\n", + " f\"Subject: {doc['title']}\\nBody: {doc['description']}\"\n", + " for doc in elasticsearch_documents\n", + " ])\n", + "\n", + " prompt = f\"\"\"\n", + " You are a helpful virtual assistant.\n", + " You must classify an email in one of the following categories:\n", + " ['SPAM', 'Marketing', 'Project']\n", + " Do not make up emails or email categories.\n", + " EMAIL:\n", + " {docs_text}\n", + " Category:\n", + " \"\"\"\n", + "\n", + " return prompt\n", + "\n", + "\n", + "def retrieve_documents(question):\n", + " search_body = {\n", + " \"size\": 1,\n", + " \"query\": {\n", + " \"semantic\": {\n", + " \"query\": question,\n", + " \"field\": \"semantic_field\"\n", + " }\n", + " }\n", + " }\n", + " response = client.search(index=index_name, body=search_body)\n", + " return [hit[\"_source\"] for hit in response[\"hits\"][\"hits\"]]" + ] + }, + { + "cell_type": "code", + "execution_count": 177, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 122 + }, + "id": "zVyZj-txW8_A", + "outputId": "3dce8016-3b33-49a9-e8b8-cc3e208a44a7" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + }, + "text/plain": [ + "\"\\n You are a helpful virtual assistant.\\n You must classify an email in one of the following categories:\\n ['SPAM', 'Marketing', 'Project']\\n Do not make up emails or email categories.\\n EMAIL:\\n Subject: Urgent: Project Deadline Update\\nBody: Dear team, Due to recent developments, we need to move up our project deadline. The new submission date is next Friday. Please adjust your schedules accordingly and let me know if you foresee any issues. We'll discuss this in detail during our next team meeting. Best regards, Project Manager\\n Category:\\n \"" + ] + }, + "execution_count": 177, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "question = \"how is the project going?\"\n", + "documents = retrieve_documents(question)\n", + "prompt = build_prompt(question, documents)\n", + "prompt" + ] + }, + { + "cell_type": "code", + "execution_count": 118, + "metadata": { + "id": "evIb9FeUZLeT" + }, + "outputs": [], + "source": [ + "from OpenELM.generate_openelm import generate" + ] + }, + { + "cell_type": "code", + "execution_count": 178, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "861d746d54784cc8ba0a1ba8d87b37c8", + "26292b1de03041ad9b00c4d9dc82db5e", + "a0fac1f5214b41a2b32507a5ad87c132", + "af8487814d8c4ba889068a40cc17e24a", + "2800c358468a4a12aa37baaf74de163e", + "5e058355f5b74f0b8ac00ea3c086a648", + "04c3698e16b54fa8b13110450b108ccd", + "bcac1d4b7adc4c5d924ba6a1429ee0e9", + "bb7fe61992a445949f54f7275a473271", + "5257db973c3a4c86a1349f0e7a8a6c79", + "ea4b6dedd2534a8baa0f992dd2877ef8" + ] + }, + "id": "iJRFRSTBZSoF", + "outputId": "32f50164-479c-47e7-c0a6-794060578c06" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:inference device is not set, using cuda:0, NVIDIA A100-SXM4-40GB\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "861d746d54784cc8ba0a1ba8d87b37c8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/2 [00:00\n", + " \"Project\"\n", + "

\n", + "\n", + "---\n", + "\n", + "### Task Recap\n", + "1. Implement a Python function `classify_email` that accepts an email as input and categorizes it according to the given rules.\n", + "2. Modify your solution from Lesson 1 so that it accepts a dictionary containing email categories instead of hardcoding them.\n", + "3. Test your function using Python's `doctest` module.\n", + "4. Submit your solution (including tests) as a Python file named `classify_emails.py`.\n", + "\n", + "---\n", + "\n", + "## Solution\n", + "\n", + "#### Python Function `classify_email`\n", + "\n", + "```python\n", + "import re\n", + "import pymemcpy\n", + "import doctest\n", + "import itertools\n", + "import collections\n", + "from typing import List, Dict, Tuple\n", + "\n", + "def classify_email(email: str, categories: Dict[str, int]) -> Tuple[bool, List[str]]:\n", + " \"\"\"Classify an email according to the given categories.\n", + "\n", + " Parameters:\n", + " email (str): Email string.\n", + " categories (dict[str, int]): Dictionary mapping email categories to integer weights.\n", + "\n", + " Returns:\n", + " bool: True if email was classified correctly, False otherwise.\n", + " List[str]: List of email categories assigned to the email.\n", + " \"\"\"\n", + "\n", + " email_pattern = re.compile(r'^(?:[a-zA-Z0-9._%+-]+|[^@\\s]+)@(?:[a-zA-Z0-9-]+\\.)+' + r'\\w+' + r'\\w*' + r'\\.\\w+' + r'\\w*$')\n", + "\n", + " email_match = email_pattern.fullmatch(email)\n", + " if email_match:\n", + " email_domain = email_match.groups()[-2]\n", + " email_domain_parts = email_domain.split('.')\n", + " email_domain_root = '.'.join(email_domain_parts[:-1])\n", + " email_domain_root_parts = email_domain_root.split('\\\\')\n", + " email_domain_root_path = '/'.join(email_domain_root_parts[:-1])\n", + "\n", + " email_host_pattern = r'\\b' + email_domain_root_path + r'\\b'\n", + " email_host_pattern += r'\\b' + '.' + r'\\w+' + r'\\w*$'\n", + " email_host_pattern += r'\\b' + ':' + r'\\d+' + r'\\b'\n", + "\n", + " email_host_match = email_host_pattern.fullmatch(email_domain)\n", + " if email_host_match:\n", + " email_host = email_host_match.groups()[-2]\n", + " email_host_domain_parts = email_host.split('.')\n", + " email_host_domain_root = '.'.join(email_host_domain_parts[:-1])\n", + " email_host_domain_root_parts = email_host_domain_root.split('\\\\')\n", + " email_host_domain_root_path = '/'.join(email_host_domain_root_parts[:-1])\n", + "\n", + " email_domain_root_path_regex = r'\\b' + email_domain_root_path + r'\\b'\n", + " email_domain_root_path_regex += r'\\b' + ':' + r'\\d+' + r'\\b'\n", + "\n", + " email_domain_root_\n" + ] + } + ], + "source": [ + "output_text, generation_time = generate(\n", + " prompt=prompt,\n", + " model=MODEL,\n", + " hf_access_token=HUGGINGFACE_TOKEN,\n", + " generate_kwargs={\"repetition_penalty\": 1.2, \"prompt_lookup_num_tokens\": 10}\n", + ")\n", + "print(\"-----GENERATION TIME-----\")\n", + "print(f'\\033[92m {round(generation_time, 2)} \\033[0m')\n", + "print(\"-----RESPONSE-----\")\n", + "print(output_text)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "04c3698e16b54fa8b13110450b108ccd": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "26292b1de03041ad9b00c4d9dc82db5e": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_5e058355f5b74f0b8ac00ea3c086a648", + "placeholder": "​", + "style": "IPY_MODEL_04c3698e16b54fa8b13110450b108ccd", + "value": "Loading checkpoint shards: 100%" + } + }, + "2800c358468a4a12aa37baaf74de163e": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5257db973c3a4c86a1349f0e7a8a6c79": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5e058355f5b74f0b8ac00ea3c086a648": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "861d746d54784cc8ba0a1ba8d87b37c8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_26292b1de03041ad9b00c4d9dc82db5e", + "IPY_MODEL_a0fac1f5214b41a2b32507a5ad87c132", + "IPY_MODEL_af8487814d8c4ba889068a40cc17e24a" + ], + "layout": "IPY_MODEL_2800c358468a4a12aa37baaf74de163e" + } + }, + "a0fac1f5214b41a2b32507a5ad87c132": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_bcac1d4b7adc4c5d924ba6a1429ee0e9", + "max": 2, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_bb7fe61992a445949f54f7275a473271", + "value": 2 + } + }, + "af8487814d8c4ba889068a40cc17e24a": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_5257db973c3a4c86a1349f0e7a8a6c79", + "placeholder": "​", + "style": "IPY_MODEL_ea4b6dedd2534a8baa0f992dd2877ef8", + "value": " 2/2 [00:01<00:00,  1.11it/s]" + } + }, + "bb7fe61992a445949f54f7275a473271": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "bcac1d4b7adc4c5d924ba6a1429ee0e9": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ea4b6dedd2534a8baa0f992dd2877ef8": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "1.5.0", + "model_name": "DescriptionStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} From 46b86c4d1b17aa6b6c8cf0c82efa7c9ec2d08168 Mon Sep 17 00:00:00 2001 From: llermaly Date: Mon, 30 Sep 2024 02:55:43 -0500 Subject: [PATCH 2/6] removed unused file --- supporting-blog-content/using-openelm-models/OpenELM/modelfile | 1 - 1 file changed, 1 deletion(-) delete mode 100644 supporting-blog-content/using-openelm-models/OpenELM/modelfile diff --git a/supporting-blog-content/using-openelm-models/OpenELM/modelfile b/supporting-blog-content/using-openelm-models/OpenELM/modelfile deleted file mode 100644 index 6d556901..00000000 --- a/supporting-blog-content/using-openelm-models/OpenELM/modelfile +++ /dev/null @@ -1 +0,0 @@ -FROM tomasmcm/openelm:3b-intruct-q5_K_M From ee4bb264e490bbb34981049bb6de28c82fb7c8cc Mon Sep 17 00:00:00 2001 From: llermaly Date: Wed, 9 Oct 2024 21:07:44 -0500 Subject: [PATCH 3/6] black formatter --- .../using-openelm-models.ipynb | 50 +++++++++---------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/supporting-blog-content/using-openelm-models/using-openelm-models.ipynb b/supporting-blog-content/using-openelm-models/using-openelm-models.ipynb index aab785e8..fe8dfb6d 100644 --- a/supporting-blog-content/using-openelm-models/using-openelm-models.ipynb +++ b/supporting-blog-content/using-openelm-models/using-openelm-models.ipynb @@ -226,8 +226,10 @@ " client.options(request_timeout=5).inference.put(\n", " task_type=\"sparse_embedding\",\n", " inference_id=\"my-elser-model\",\n", - " body={\"service\": \"elser\", \"service_settings\": {\n", - " \"num_allocations\": 1, \"num_threads\": 1}},\n", + " body={\n", + " \"service\": \"elser\",\n", + " \"service_settings\": {\"num_allocations\": 1, \"num_threads\": 1},\n", + " },\n", " )\n", "except ConnectionTimeout:\n", " pass" @@ -262,19 +264,16 @@ "index_body = {\n", " \"mappings\": {\n", " \"properties\": {\n", - " \"title\": {\n", - " \"type\": \"text\",\n", - " \"analyzer\": \"english\"\n", - " },\n", + " \"title\": {\"type\": \"text\", \"analyzer\": \"english\"},\n", " \"description\": {\n", " \"type\": \"text\",\n", " \"analyzer\": \"english\",\n", - " \"copy_to\": \"semantic_field\"\n", + " \"copy_to\": \"semantic_field\",\n", " },\n", " \"semantic_field\": {\n", " \"type\": \"semantic_text\",\n", - " \"inference_id\": \"my-elser-model\"\n", - " }\n", + " \"inference_id\": \"my-elser-model\",\n", + " },\n", " }\n", " }\n", "}\n", @@ -295,32 +294,32 @@ " \"_index\": index_name,\n", " \"_id\": \"email1\",\n", " \"title\": \"Team Meeting Agenda\",\n", - " \"description\": \"Hello team, Let's discuss our project progress in tomorrow's meeting. Please prepare your updates. Best regards, Manager\"\n", + " \"description\": \"Hello team, Let's discuss our project progress in tomorrow's meeting. Please prepare your updates. Best regards, Manager\",\n", " },\n", " {\n", " \"_index\": index_name,\n", " \"_id\": \"email2\",\n", " \"title\": \"Client Proposal Draft\",\n", - " \"description\": \"Hi, I've attached the draft of our client proposal. Could you review it and provide feedback? Thanks, Colleague\"\n", + " \"description\": \"Hi, I've attached the draft of our client proposal. Could you review it and provide feedback? Thanks, Colleague\",\n", " },\n", " {\n", " \"_index\": index_name,\n", " \"_id\": \"email3\",\n", " \"title\": \"Weekly Newsletter\",\n", - " \"description\": \"This week in tech: AI advancements, new smartphone releases, and cybersecurity updates. Read more on our website!\"\n", + " \"description\": \"This week in tech: AI advancements, new smartphone releases, and cybersecurity updates. Read more on our website!\",\n", " },\n", " {\n", " \"_index\": index_name,\n", " \"_id\": \"email4\",\n", " \"title\": \"Urgent: Project Deadline Update\",\n", - " \"description\": \"Dear team, Due to recent developments, we need to move up our project deadline. The new submission date is next Friday. Please adjust your schedules accordingly and let me know if you foresee any issues. We'll discuss this in detail during our next team meeting. Best regards, Project Manager\"\n", + " \"description\": \"Dear team, Due to recent developments, we need to move up our project deadline. The new submission date is next Friday. Please adjust your schedules accordingly and let me know if you foresee any issues. We'll discuss this in detail during our next team meeting. Best regards, Project Manager\",\n", " },\n", " {\n", " \"_index\": index_name,\n", " \"_id\": \"email5\",\n", " \"title\": \"Invitation: Company Summer Picnic\",\n", - " \"description\": \"Hello everyone, We're excited to announce our annual company summer picnic! It will be held on Saturday, July 15th, at Sunny Park. There will be food, games, and activities for all ages. Please RSVP by replying to this email with the number of guests you'll be bringing. We look forward to seeing you there! Best, HR Team\"\n", - " }\n", + " \"description\": \"Hello everyone, We're excited to announce our annual company summer picnic! It will be held on Saturday, July 15th, at Sunny Park. There will be food, games, and activities for all ages. Please RSVP by replying to this email with the number of guests you'll be bringing. We look forward to seeing you there! Best, HR Team\",\n", + " },\n", "]" ] }, @@ -371,10 +370,12 @@ "source": [ "# https://github.com/riccardomusmeci/mlx-llm/blob/main/src/mlx_llm/prompt/openelm.py\n", "def build_prompt(question, elasticsearch_documents):\n", - " docs_text = \"\\n\".join([\n", - " f\"Subject: {doc['title']}\\nBody: {doc['description']}\"\n", - " for doc in elasticsearch_documents\n", - " ])\n", + " docs_text = \"\\n\".join(\n", + " [\n", + " f\"Subject: {doc['title']}\\nBody: {doc['description']}\"\n", + " for doc in elasticsearch_documents\n", + " ]\n", + " )\n", "\n", " prompt = f\"\"\"\n", " You are a helpful virtual assistant.\n", @@ -392,12 +393,7 @@ "def retrieve_documents(question):\n", " search_body = {\n", " \"size\": 1,\n", - " \"query\": {\n", - " \"semantic\": {\n", - " \"query\": question,\n", - " \"field\": \"semantic_field\"\n", - " }\n", - " }\n", + " \"query\": {\"semantic\": {\"query\": question, \"field\": \"semantic_field\"}},\n", " }\n", " response = client.search(index=index_name, body=search_body)\n", " return [hit[\"_source\"] for hit in response[\"hits\"][\"hits\"]]" @@ -586,10 +582,10 @@ " prompt=prompt,\n", " model=MODEL,\n", " hf_access_token=HUGGINGFACE_TOKEN,\n", - " generate_kwargs={\"repetition_penalty\": 1.2, \"prompt_lookup_num_tokens\": 10}\n", + " generate_kwargs={\"repetition_penalty\": 1.2, \"prompt_lookup_num_tokens\": 10},\n", ")\n", "print(\"-----GENERATION TIME-----\")\n", - "print(f'\\033[92m {round(generation_time, 2)} \\033[0m')\n", + "print(f\"\\033[92m {round(generation_time, 2)} \\033[0m\")\n", "print(\"-----RESPONSE-----\")\n", "print(output_text)" ] From 9f3982753d9cf3c16073fa5f93efd7f56fec6eb1 Mon Sep 17 00:00:00 2001 From: llermaly Date: Wed, 9 Oct 2024 21:12:49 -0500 Subject: [PATCH 4/6] second attempt --- .../using-openelm-models/using-openelm-models.ipynb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/supporting-blog-content/using-openelm-models/using-openelm-models.ipynb b/supporting-blog-content/using-openelm-models/using-openelm-models.ipynb index fe8dfb6d..4bab37c7 100644 --- a/supporting-blog-content/using-openelm-models/using-openelm-models.ipynb +++ b/supporting-blog-content/using-openelm-models/using-openelm-models.ipynb @@ -582,7 +582,8 @@ " prompt=prompt,\n", " model=MODEL,\n", " hf_access_token=HUGGINGFACE_TOKEN,\n", - " generate_kwargs={\"repetition_penalty\": 1.2, \"prompt_lookup_num_tokens\": 10},\n", + " generate_kwargs={\"repetition_penalty\": 1.2,\n", + " \"prompt_lookup_num_tokens\": 10},\n", ")\n", "print(\"-----GENERATION TIME-----\")\n", "print(f\"\\033[92m {round(generation_time, 2)} \\033[0m\")\n", From de83267a6052cb061b000b475c66cdbd95eb0a0d Mon Sep 17 00:00:00 2001 From: llermaly Date: Wed, 9 Oct 2024 21:15:20 -0500 Subject: [PATCH 5/6] 3rd attempt --- .../using-openelm-models/using-openelm-models.ipynb | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/supporting-blog-content/using-openelm-models/using-openelm-models.ipynb b/supporting-blog-content/using-openelm-models/using-openelm-models.ipynb index 4bab37c7..fe8dfb6d 100644 --- a/supporting-blog-content/using-openelm-models/using-openelm-models.ipynb +++ b/supporting-blog-content/using-openelm-models/using-openelm-models.ipynb @@ -582,8 +582,7 @@ " prompt=prompt,\n", " model=MODEL,\n", " hf_access_token=HUGGINGFACE_TOKEN,\n", - " generate_kwargs={\"repetition_penalty\": 1.2,\n", - " \"prompt_lookup_num_tokens\": 10},\n", + " generate_kwargs={\"repetition_penalty\": 1.2, \"prompt_lookup_num_tokens\": 10},\n", ")\n", "print(\"-----GENERATION TIME-----\")\n", "print(f\"\\033[92m {round(generation_time, 2)} \\033[0m\")\n", From 5d98a5691a872760cd837bac2c6c0159865e7354 Mon Sep 17 00:00:00 2001 From: llermaly Date: Mon, 28 Oct 2024 19:39:51 -0500 Subject: [PATCH 6/6] ran make pre-commit for formatting fixes --- .../OpenELM/generate_openelm.py | 124 ++++++++---------- 1 file changed, 58 insertions(+), 66 deletions(-) diff --git a/supporting-blog-content/using-openelm-models/OpenELM/generate_openelm.py b/supporting-blog-content/using-openelm-models/OpenELM/generate_openelm.py index 12b167e2..8dbeaa12 100644 --- a/supporting-blog-content/using-openelm-models/OpenELM/generate_openelm.py +++ b/supporting-blog-content/using-openelm-models/OpenELM/generate_openelm.py @@ -13,17 +13,18 @@ from transformers import AutoTokenizer, AutoModelForCausalLM + def generate( prompt: str, model: Union[str, AutoModelForCausalLM], hf_access_token: str = None, - tokenizer: Union[str, AutoTokenizer] = 'meta-llama/Llama-2-7b-hf', + tokenizer: Union[str, AutoTokenizer] = "meta-llama/Llama-2-7b-hf", device: Optional[str] = None, max_length: int = 1024, assistant_model: Optional[Union[str, AutoModelForCausalLM]] = None, generate_kwargs: Optional[dict] = None, ) -> str: - """ Generates output given a prompt. + """Generates output given a prompt. Args: prompt: The string prompt. @@ -53,43 +54,40 @@ def generate( if torch.cuda.is_available() and torch.cuda.device_count(): device = "cuda:0" logging.warning( - 'inference device is not set, using cuda:0, %s', - torch.cuda.get_device_name(0) + "inference device is not set, using cuda:0, %s", + torch.cuda.get_device_name(0), ) else: - device = 'cpu' + device = "cpu" logging.warning( - ( - 'No CUDA device detected, using cpu, ' - 'expect slower speeds.' - ) + ("No CUDA device detected, using cpu, " "expect slower speeds.") ) - if 'cuda' in device and not torch.cuda.is_available(): - raise ValueError('CUDA device requested but no CUDA device detected.') + if "cuda" in device and not torch.cuda.is_available(): + raise ValueError("CUDA device requested but no CUDA device detected.") if not tokenizer: - raise ValueError('Tokenizer is not set in the generate function.') + raise ValueError("Tokenizer is not set in the generate function.") if not hf_access_token: - raise ValueError(( - 'Hugging face access token needs to be specified. ' - 'Please refer to https://huggingface.co/docs/hub/security-tokens' - ' to obtain one.' + raise ValueError( + ( + "Hugging face access token needs to be specified. " + "Please refer to https://huggingface.co/docs/hub/security-tokens" + " to obtain one." ) ) if isinstance(model, str): checkpoint_path = model model = AutoModelForCausalLM.from_pretrained( - checkpoint_path, - trust_remote_code=True + checkpoint_path, trust_remote_code=True ) model.to(device).eval() if isinstance(tokenizer, str): tokenizer = AutoTokenizer.from_pretrained( - tokenizer, - token=hf_access_token, + tokenizer, + token=hf_access_token, ) # Speculative mode @@ -98,17 +96,13 @@ def generate( draft_model = assistant_model if isinstance(assistant_model, str): draft_model = AutoModelForCausalLM.from_pretrained( - assistant_model, - trust_remote_code=True + assistant_model, trust_remote_code=True ) draft_model.to(device).eval() # Prepare the prompt tokenized_prompt = tokenizer(prompt) - tokenized_prompt = torch.tensor( - tokenized_prompt['input_ids'], - device=device - ) + tokenized_prompt = torch.tensor(tokenized_prompt["input_ids"], device=device) tokenized_prompt = tokenized_prompt.unsqueeze(0) @@ -123,10 +117,7 @@ def generate( ) generation_time = time.time() - stime - output_text = tokenizer.decode( - output_ids[0].tolist(), - skip_special_tokens=True - ) + output_text = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True) return output_text, generation_time @@ -136,83 +127,84 @@ def openelm_generate_parser(): class KwargsParser(argparse.Action): """Parser action class to parse kwargs of form key=value""" + def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, dict()) for val in values: - if '=' not in val: + if "=" not in val: raise ValueError( ( - 'Argument parsing error, kwargs are expected in' - ' the form of key=value.' + "Argument parsing error, kwargs are expected in" + " the form of key=value." ) ) - kwarg_k, kwarg_v = val.split('=') + kwarg_k, kwarg_v = val.split("=") try: converted_v = int(kwarg_v) except ValueError: try: converted_v = float(kwarg_v) except ValueError: - converted_v = kwarg_v + converted_v = kwarg_v getattr(namespace, self.dest)[kwarg_k] = converted_v - parser = argparse.ArgumentParser('OpenELM Generate Module') + parser = argparse.ArgumentParser("OpenELM Generate Module") parser.add_argument( - '--model', - dest='model', - help='Path to the hf converted model.', + "--model", + dest="model", + help="Path to the hf converted model.", required=True, type=str, ) parser.add_argument( - '--hf_access_token', - dest='hf_access_token', + "--hf_access_token", + dest="hf_access_token", help='Hugging face access token, starting with "hf_".', type=str, ) parser.add_argument( - '--prompt', - dest='prompt', - help='Prompt for LLM call.', - default='', - type=str, + "--prompt", + dest="prompt", + help="Prompt for LLM call.", + default="", + type=str, ) parser.add_argument( - '--device', - dest='device', - help='Device used for inference.', + "--device", + dest="device", + help="Device used for inference.", type=str, ) parser.add_argument( - '--max_length', - dest='max_length', - help='Maximum length of tokens.', + "--max_length", + dest="max_length", + help="Maximum length of tokens.", default=256, type=int, ) parser.add_argument( - '--assistant_model', - dest='assistant_model', + "--assistant_model", + dest="assistant_model", help=( ( - 'If set, this is used as a draft model ' - 'for assisted speculative generation.' + "If set, this is used as a draft model " + "for assisted speculative generation." ) ), type=str, ) parser.add_argument( - '--generate_kwargs', - dest='generate_kwargs', - help='Additional kwargs passed to the HF generate function.', + "--generate_kwargs", + dest="generate_kwargs", + help="Additional kwargs passed to the HF generate function.", type=str, - nargs='*', + nargs="*", action=KwargsParser, ) return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": args = openelm_generate_parser() prompt = args.prompt @@ -228,12 +220,12 @@ def __call__(self, parser, namespace, values, option_string=None): print_txt = ( f'\r\n{"=" * os.get_terminal_size().columns}\r\n' - '\033[1m Prompt + Generated Output\033[0m\r\n' + "\033[1m Prompt + Generated Output\033[0m\r\n" f'{"-" * os.get_terminal_size().columns}\r\n' - f'{output_text}\r\n' + f"{output_text}\r\n" f'{"-" * os.get_terminal_size().columns}\r\n' - '\r\nGeneration took' - f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m' - 'seconds.\r\n' + "\r\nGeneration took" + f"\033[1m\033[92m {round(genertaion_time, 2)} \033[0m" + "seconds.\r\n" ) print(print_txt)