From 2b1c975903018e490baa72d990024d5263370e3b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Sat, 25 May 2024 21:59:51 +0000 Subject: [PATCH] Deployed 43a2b87 with MkDocs version: 1.6.0 --- .nojekyll | 0 404.html | 5915 +++++++ AvailableModels/index.html | 6212 +++++++ Bits/index.html | 6018 +++++++ CONTRIBUTING/index.html | 6142 +++++++ DataProcessing/index.html | 6052 +++++++ EasyAttentionExample/index.html | 6133 +++++++ EasyStateExample/index.html | 6169 +++++++ Falcon/index.html | 6131 +++++++ FineTuningExample/index.html | 6125 +++++++ Install/index.html | 6219 +++++++ JAXServer/index.html | 6518 ++++++++ Llama/index.html | 6238 +++++++ Llama2/index.html | 6208 +++++++ LoRA-TransferLearningExample/index.html | 6142 +++++++ Mistral/index.html | 6133 +++++++ MosaicMPT/index.html | 6094 +++++++ Parameter-Quantization/index.html | 6249 +++++++ PyTorchServer/index.html | 6018 +++++++ assets/_mkdocstrings.css | 119 + assets/images/favicon.png | Bin 0 -> 1870 bytes assets/javascripts/bundle.081f42fc.min.js | 29 + assets/javascripts/bundle.081f42fc.min.js.map | 7 + assets/javascripts/lunr/min/lunr.ar.min.js | 1 + assets/javascripts/lunr/min/lunr.da.min.js | 18 + assets/javascripts/lunr/min/lunr.de.min.js | 18 + assets/javascripts/lunr/min/lunr.du.min.js | 18 + assets/javascripts/lunr/min/lunr.el.min.js | 1 + assets/javascripts/lunr/min/lunr.es.min.js | 18 + assets/javascripts/lunr/min/lunr.fi.min.js | 18 + assets/javascripts/lunr/min/lunr.fr.min.js | 18 + assets/javascripts/lunr/min/lunr.he.min.js | 1 + assets/javascripts/lunr/min/lunr.hi.min.js | 1 + assets/javascripts/lunr/min/lunr.hu.min.js | 18 + assets/javascripts/lunr/min/lunr.hy.min.js | 1 + assets/javascripts/lunr/min/lunr.it.min.js | 18 + assets/javascripts/lunr/min/lunr.ja.min.js | 1 + assets/javascripts/lunr/min/lunr.jp.min.js | 1 + assets/javascripts/lunr/min/lunr.kn.min.js | 1 + assets/javascripts/lunr/min/lunr.ko.min.js | 1 + assets/javascripts/lunr/min/lunr.multi.min.js | 1 + assets/javascripts/lunr/min/lunr.nl.min.js | 18 + assets/javascripts/lunr/min/lunr.no.min.js | 18 + assets/javascripts/lunr/min/lunr.pt.min.js | 18 + assets/javascripts/lunr/min/lunr.ro.min.js | 18 + assets/javascripts/lunr/min/lunr.ru.min.js | 18 + assets/javascripts/lunr/min/lunr.sa.min.js | 1 + .../lunr/min/lunr.stemmer.support.min.js | 1 + assets/javascripts/lunr/min/lunr.sv.min.js | 18 + assets/javascripts/lunr/min/lunr.ta.min.js | 1 + assets/javascripts/lunr/min/lunr.te.min.js | 1 + assets/javascripts/lunr/min/lunr.th.min.js | 1 + assets/javascripts/lunr/min/lunr.tr.min.js | 18 + assets/javascripts/lunr/min/lunr.vi.min.js | 1 + assets/javascripts/lunr/min/lunr.zh.min.js | 1 + assets/javascripts/lunr/tinyseg.js | 206 + assets/javascripts/lunr/wordcut.js | 6708 ++++++++ .../workers/search.b8dbb3d2.min.js | 42 + .../workers/search.b8dbb3d2.min.js.map | 7 + assets/stylesheets/main.6543a935.min.css | 1 + assets/stylesheets/main.6543a935.min.css.map | 1 + assets/stylesheets/palette.06af60db.min.css | 1 + .../stylesheets/palette.06af60db.min.css.map | 1 + generated-cli-cli/index.html | 6024 +++++++ .../index.html | 6042 +++++++ generated-etils-auto_tx/index.html | 6602 ++++++++ generated-etils-configs/index.html | 6193 +++++++ generated-etils-easystate/index.html | 10665 ++++++++++++ generated-etils-errors/index.html | 6042 +++++++ generated-etils-etils/index.html | 6757 ++++++++ generated-eval-lm_eval/index.html | 6139 +++++++ .../index.html | 6048 +++++++ .../index.html | 6044 +++++++ generated-modules-_attentions-ring/index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 6696 ++++++++ .../index.html | 10291 ++++++++++++ generated-modules-attention_module/index.html | 9008 +++++++++++ .../index.html | 8890 ++++++++++ .../index.html | 6754 ++++++++ .../index.html | 12560 +++++++++++++++ .../index.html | 6046 +++++++ .../index.html | 8968 +++++++++++ .../index.html | 7035 ++++++++ .../index.html | 9072 +++++++++++ .../index.html | 11175 +++++++++++++ .../index.html | 6044 +++++++ .../index.html | 7025 ++++++++ .../index.html | 8049 +++++++++ .../index.html | 6865 ++++++++ .../index.html | 7599 +++++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 6046 +++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 6762 ++++++++ .../index.html | 10882 +++++++++++++ .../index.html | 6602 ++++++++ .../index.html | 6108 +++++++ .../index.html | 7984 +++++++++ .../index.html | 13435 +++++++++++++++ .../index.html | 6513 ++++++++ .../index.html | 6441 ++++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 7372 +++++++++ .../index.html | 7780 +++++++++ .../index.html | 10813 +++++++++++++ .../index.html | 6610 ++++++++ .../index.html | 6441 ++++++++ .../index.html | 7906 +++++++++ .../index.html | 10250 ++++++++++++ .../index.html | 6796 ++++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 6575 ++++++++ .../index.html | 10137 ++++++++++++ .../index.html | 8042 +++++++++ .../index.html | 6186 +++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 7579 +++++++++ .../index.html | 6332 ++++++++ .../index.html | 7431 +++++++++ .../index.html | 6400 ++++++++ .../index.html | 13438 ++++++++++++++++ .../index.html | 6872 ++++++++ .../index.html | 13385 +++++++++++++++ .../index.html | 7136 ++++++++ .../index.html | 6767 ++++++++ .../index.html | 12998 +++++++++++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 6236 +++++++ .../index.html | 8279 ++++++++++ .../index.html | 6324 ++++++++ .../index.html | 6756 ++++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ generated-partitioning-partitioner/index.html | 6280 ++++++++ .../index.html | 6228 +++++++ .../index.html | 6454 ++++++++ .../index.html | 6044 +++++++ .../index.html | 7419 +++++++++ .../index.html | 6044 +++++++ .../index.html | 6259 +++++++ .../index.html | 6044 +++++++ .../index.html | 6268 +++++++ .../index.html | 7238 +++++++++ generated-serve-jax_serve/index.html | 12200 ++++++++++++++ .../index.html | 6580 ++++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 6044 +++++++ .../index.html | 6832 ++++++++ generated-serve-serve_engine-serve/index.html | 9114 +++++++++++ generated-serve-torch_serve/index.html | 9394 +++++++++++ generated-serve-utils/index.html | 6910 ++++++++ generated-smi-smi/index.html | 6547 ++++++++ generated-trainer-base_trainer/index.html | 8532 ++++++++++ .../index.html | 8544 ++++++++++ .../index.html | 6633 ++++++++ .../index.html | 6044 +++++++ generated-trainer-dpo-dpo_trainer/index.html | 10603 ++++++++++++ .../index.html | 8389 ++++++++++ .../index.html | 6044 +++++++ generated-trainer-dpo-utils/index.html | 6334 ++++++++ .../index.html | 7329 +++++++++ .../index.html | 6044 +++++++ .../index.html | 10305 ++++++++++++ generated-trainer-orpo-utils/index.html | 6028 +++++++ generated-trainer-sft-stf_trainer/index.html | 6860 ++++++++ generated-trainer-sft-utils/index.html | 6028 +++++++ .../index.html | 9843 +++++++++++ generated-trainer-utils/index.html | 7160 ++++++++ .../index.html | 6526 ++++++++ .../index.html | 6044 +++++++ .../index.html | 8396 ++++++++++ .../index.html | 7057 ++++++++ generated-transform-falcon/index.html | 6108 +++++++ generated-transform-llama/index.html | 6164 +++++++ generated-transform-mistral/index.html | 6164 +++++++ generated-transform-mpt/index.html | 6110 +++++++ generated-transform-utils/index.html | 6042 +++++++ generated-utils-checker/index.html | 6042 +++++++ generated-utils-prompters/index.html | 6715 ++++++++ generated-utils-tensor_utils/index.html | 6174 +++++++ generated-utils-utils/index.html | 7588 +++++++++ index.html | 6227 +++++++ objects.inv | Bin 0 -> 9707 bytes search/search_index.json | 1 + sitemap.xml | 3 + sitemap.xml.gz | Bin 0 -> 127 bytes 204 files changed, 1107584 insertions(+) create mode 100644 .nojekyll create mode 100644 404.html create mode 100644 AvailableModels/index.html create mode 100644 Bits/index.html create mode 100644 CONTRIBUTING/index.html create mode 100644 DataProcessing/index.html create mode 100644 EasyAttentionExample/index.html create mode 100644 EasyStateExample/index.html create mode 100644 Falcon/index.html create mode 100644 FineTuningExample/index.html create mode 100644 Install/index.html create mode 100644 JAXServer/index.html create mode 100644 Llama/index.html create mode 100644 Llama2/index.html create mode 100644 LoRA-TransferLearningExample/index.html create mode 100644 Mistral/index.html create mode 100644 MosaicMPT/index.html create mode 100644 Parameter-Quantization/index.html create mode 100644 PyTorchServer/index.html create mode 100644 assets/_mkdocstrings.css create mode 100644 assets/images/favicon.png create mode 100644 assets/javascripts/bundle.081f42fc.min.js create mode 100644 assets/javascripts/bundle.081f42fc.min.js.map create mode 100644 assets/javascripts/lunr/min/lunr.ar.min.js create mode 100644 assets/javascripts/lunr/min/lunr.da.min.js create mode 100644 assets/javascripts/lunr/min/lunr.de.min.js create mode 100644 assets/javascripts/lunr/min/lunr.du.min.js create mode 100644 assets/javascripts/lunr/min/lunr.el.min.js create mode 100644 assets/javascripts/lunr/min/lunr.es.min.js create mode 100644 assets/javascripts/lunr/min/lunr.fi.min.js create mode 100644 assets/javascripts/lunr/min/lunr.fr.min.js create mode 100644 assets/javascripts/lunr/min/lunr.he.min.js create mode 100644 assets/javascripts/lunr/min/lunr.hi.min.js create mode 100644 assets/javascripts/lunr/min/lunr.hu.min.js create mode 100644 assets/javascripts/lunr/min/lunr.hy.min.js create mode 100644 assets/javascripts/lunr/min/lunr.it.min.js create mode 100644 assets/javascripts/lunr/min/lunr.ja.min.js create mode 100644 assets/javascripts/lunr/min/lunr.jp.min.js create mode 100644 assets/javascripts/lunr/min/lunr.kn.min.js create mode 100644 assets/javascripts/lunr/min/lunr.ko.min.js create mode 100644 assets/javascripts/lunr/min/lunr.multi.min.js create mode 100644 assets/javascripts/lunr/min/lunr.nl.min.js create mode 100644 assets/javascripts/lunr/min/lunr.no.min.js create mode 100644 assets/javascripts/lunr/min/lunr.pt.min.js create mode 100644 assets/javascripts/lunr/min/lunr.ro.min.js create mode 100644 assets/javascripts/lunr/min/lunr.ru.min.js create mode 100644 assets/javascripts/lunr/min/lunr.sa.min.js create mode 100644 assets/javascripts/lunr/min/lunr.stemmer.support.min.js create mode 100644 assets/javascripts/lunr/min/lunr.sv.min.js create mode 100644 assets/javascripts/lunr/min/lunr.ta.min.js create mode 100644 assets/javascripts/lunr/min/lunr.te.min.js create mode 100644 assets/javascripts/lunr/min/lunr.th.min.js create mode 100644 assets/javascripts/lunr/min/lunr.tr.min.js create mode 100644 assets/javascripts/lunr/min/lunr.vi.min.js create mode 100644 assets/javascripts/lunr/min/lunr.zh.min.js create mode 100644 assets/javascripts/lunr/tinyseg.js create mode 100644 assets/javascripts/lunr/wordcut.js create mode 100644 assets/javascripts/workers/search.b8dbb3d2.min.js create mode 100644 assets/javascripts/workers/search.b8dbb3d2.min.js.map create mode 100644 assets/stylesheets/main.6543a935.min.css create mode 100644 assets/stylesheets/main.6543a935.min.css.map create mode 100644 assets/stylesheets/palette.06af60db.min.css create mode 100644 assets/stylesheets/palette.06af60db.min.css.map create mode 100644 generated-cli-cli/index.html create mode 100644 generated-data_preprocessing-_processor/index.html create mode 100644 generated-etils-auto_tx/index.html create mode 100644 generated-etils-configs/index.html create mode 100644 generated-etils-easystate/index.html create mode 100644 generated-etils-errors/index.html create mode 100644 generated-etils-etils/index.html create mode 100644 generated-eval-lm_eval/index.html create mode 100644 generated-modules-_attentions-blockwise_attn/index.html create mode 100644 generated-modules-_attentions-flash/index.html create mode 100644 generated-modules-_attentions-ring/index.html create mode 100644 generated-modules-_attentions-vanilla/index.html create mode 100644 generated-modules-arctic-arctic_configuration/index.html create mode 100644 generated-modules-arctic-modelling_arctic_flax/index.html create mode 100644 generated-modules-attention_module/index.html create mode 100644 generated-modules-auto_easydel_model/index.html create mode 100644 generated-modules-cohere-cohere_configuration/index.html create mode 100644 generated-modules-cohere-modelling_cohere_flax/index.html create mode 100644 generated-modules-dbrx-dbrx_configuration/index.html create mode 100644 generated-modules-dbrx-modelling_dbrx_flax/index.html create mode 100644 generated-modules-deepseek_v2-deepseek_configuration/index.html create mode 100644 generated-modules-deepseek_v2-modeling_deepseek_flax/index.html create mode 100644 generated-modules-easydel_modelling_utils/index.html create mode 100644 generated-modules-falcon-falcon_configuration/index.html create mode 100644 generated-modules-falcon-modelling_falcon_flax/index.html create mode 100644 generated-modules-flax_modelling_utils/index.html create mode 100644 generated-modules-gemma-gemma_configuration/index.html create mode 100644 generated-modules-gemma-modelling_gemma_flax/index.html create mode 100644 generated-modules-gpt2-gpt2_configuration/index.html create mode 100644 generated-modules-gpt2-modelling_gpt2_flax/index.html create mode 100644 generated-modules-gpt_j-gpt_j_configuration/index.html create mode 100644 generated-modules-gpt_j-modelling_gpt_j_flax/index.html create mode 100644 generated-modules-gpt_neo_x-gpt_neo_x_configuration/index.html create mode 100644 generated-modules-gpt_neo_x-modelling_gpt_neo_x_flax/index.html create mode 100644 generated-modules-grok_1-grok_1_configuration/index.html create mode 100644 generated-modules-grok_1-modelling_grok_1_flax/index.html create mode 100644 generated-modules-jetmoe-jetmoe_configuration/index.html create mode 100644 generated-modules-jetmoe-modelling_jetmoe_flax/index.html create mode 100644 generated-modules-llama-llama_configuration/index.html create mode 100644 generated-modules-llama-modelling_llama_flax/index.html create mode 100644 generated-modules-llama-modelling_vision_llama_flax/index.html create mode 100644 generated-modules-llama-vision_llama_configuration/index.html create mode 100644 generated-modules-lucid_transformer-lt_configuration/index.html create mode 100644 generated-modules-lucid_transformer-modelling_lt_flax/index.html create mode 100644 generated-modules-mamba-mamba_configuration/index.html create mode 100644 generated-modules-mamba-modelling_mamba_flax/index.html create mode 100644 generated-modules-mistral-mistral_configuration/index.html create mode 100644 generated-modules-mistral-modelling_mistral_flax/index.html create mode 100644 generated-modules-mistral-modelling_vision_mistral_flax/index.html create mode 100644 generated-modules-mistral-vision_mistral_configuration/index.html create mode 100644 generated-modules-mixtral-mixtral_configuration/index.html create mode 100644 generated-modules-mixtral-modelling_mixtral_flax/index.html create mode 100644 generated-modules-mosaic_mpt-modelling_mpt_flax/index.html create mode 100644 generated-modules-mosaic_mpt-mosaic_configuration/index.html create mode 100644 generated-modules-olmo-modelling_olmo_flax/index.html create mode 100644 generated-modules-olmo-olmo_configuration/index.html create mode 100644 generated-modules-openelm-modelling_openelm_flax/index.html create mode 100644 generated-modules-openelm-openelm_configuration/index.html create mode 100644 generated-modules-opt-modelling_opt_flax/index.html create mode 100644 generated-modules-opt-opt_configuration/index.html create mode 100644 generated-modules-palm-modelling_palm_flax/index.html create mode 100644 generated-modules-palm-palm_configuration/index.html create mode 100644 generated-modules-phi-modelling_phi_flax/index.html create mode 100644 generated-modules-phi-phi_configuration/index.html create mode 100644 generated-modules-phi3-modelling_phi3_flax/index.html create mode 100644 generated-modules-phi3-phi3_configuration/index.html create mode 100644 generated-modules-qwen1-modelling_qwen1_flax/index.html create mode 100644 generated-modules-qwen1-qwen1_configuration/index.html create mode 100644 generated-modules-qwen2-modelling_qwen_flax/index.html create mode 100644 generated-modules-qwen2-qwen_configuration/index.html create mode 100644 generated-modules-qwen2_moe-configuration_qwen2_moe/index.html create mode 100644 generated-modules-qwen2_moe-modeling_qwen2_moe_flax/index.html create mode 100644 generated-modules-roberta-modelling_roberta_flax/index.html create mode 100644 generated-modules-roberta-roberta_configuration/index.html create mode 100644 generated-modules-rwkv-modelling_rwkv_flax/index.html create mode 100644 generated-modules-rwkv-rwkv_configuration/index.html create mode 100644 generated-modules-stablelm-modelling_stablelm_flax/index.html create mode 100644 generated-modules-stablelm-stablelm_configuration/index.html create mode 100644 generated-modules-t5-modelling_t5_flax/index.html create mode 100644 generated-modules-t5-t5_configuration/index.html create mode 100644 generated-modules-whisper-modelling_whisper_flax/index.html create mode 100644 generated-modules-whisper-whisper_configuration/index.html create mode 100644 generated-partitioning-partitioner/index.html create mode 100644 generated-reinforcement_learning-core/index.html create mode 100644 generated-reinforcement_learning-models-modelling_casual_language_rl/index.html create mode 100644 generated-reinforcement_learning-trainer-partitioner_config/index.html create mode 100644 generated-reinforcement_learning-trainer-ppo_config/index.html create mode 100644 generated-reinforcement_learning-trainer-ppo_trainer/index.html create mode 100644 generated-reinforcement_learning-trainer-training_configs/index.html create mode 100644 generated-reinforcement_learning-trainer-utils/index.html create mode 100644 generated-reinforcement_learning-utils-collectors/index.html create mode 100644 generated-serve-gradio_user_interface_base/index.html create mode 100644 generated-serve-jax_serve/index.html create mode 100644 generated-serve-prompters-base_prompter/index.html create mode 100644 generated-serve-prompters-cargo_prompter/index.html create mode 100644 generated-serve-prompters-chatml_prompter/index.html create mode 100644 generated-serve-prompters-gemma_prompter/index.html create mode 100644 generated-serve-prompters-guanaco_prompter/index.html create mode 100644 generated-serve-prompters-llama2_prompter/index.html create mode 100644 generated-serve-prompters-openchat_prompter/index.html create mode 100644 generated-serve-prompters-zephyr_prompter/index.html create mode 100644 generated-serve-serve_engine-client/index.html create mode 100644 generated-serve-serve_engine-configuration/index.html create mode 100644 generated-serve-serve_engine-serve/index.html create mode 100644 generated-serve-torch_serve/index.html create mode 100644 generated-serve-utils/index.html create mode 100644 generated-smi-smi/index.html create mode 100644 generated-trainer-base_trainer/index.html create mode 100644 generated-trainer-causal_language_model_trainer-causal_language_model_trainer/index.html create mode 100644 generated-trainer-causal_language_model_trainer-fwd_bwd_functions/index.html create mode 100644 generated-trainer-causal_language_model_trainer-modeling_output/index.html create mode 100644 generated-trainer-dpo-dpo_trainer/index.html create mode 100644 generated-trainer-dpo-fwd_bwd_functions/index.html create mode 100644 generated-trainer-dpo-modelling_output/index.html create mode 100644 generated-trainer-dpo-utils/index.html create mode 100644 generated-trainer-orpo-fwd_bwd_functions/index.html create mode 100644 generated-trainer-orpo-modelling_output/index.html create mode 100644 generated-trainer-orpo-orpo_trainer/index.html create mode 100644 generated-trainer-orpo-utils/index.html create mode 100644 generated-trainer-sft-stf_trainer/index.html create mode 100644 generated-trainer-sft-utils/index.html create mode 100644 generated-trainer-training_configurations/index.html create mode 100644 generated-trainer-utils/index.html create mode 100644 generated-trainer-vision_causal_language_model_trainer-fwd_bwd_functions/index.html create mode 100644 generated-trainer-vision_causal_language_model_trainer-modelling_output/index.html create mode 100644 generated-trainer-vision_causal_language_model_trainer-vision_causal_language_model_trainer/index.html create mode 100644 generated-transform-easydel_transform/index.html create mode 100644 generated-transform-falcon/index.html create mode 100644 generated-transform-llama/index.html create mode 100644 generated-transform-mistral/index.html create mode 100644 generated-transform-mpt/index.html create mode 100644 generated-transform-utils/index.html create mode 100644 generated-utils-checker/index.html create mode 100644 generated-utils-prompters/index.html create mode 100644 generated-utils-tensor_utils/index.html create mode 100644 generated-utils-utils/index.html create mode 100644 index.html create mode 100644 objects.inv create mode 100644 search/search_index.json create mode 100644 sitemap.xml create mode 100644 sitemap.xml.gz diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 000000000..e69de29bb diff --git a/404.html b/404.html new file mode 100644 index 000000000..5e1c03503 --- /dev/null +++ b/404.html @@ -0,0 +1,5915 @@ + + + +
+ + + + + + + + + + + + + + + + +Model | +Video Model | +CausalLM | +AttentionModule | +Parameters Quantization | +Operation Bit Quantization | +
---|---|---|---|---|---|
Gptj | +❌ | +✅ | +✅ | +✅ | +✅ | +
LucidTransformer | +❌ | +✅ | +✅ | +✅ | +✅ | +
Mixtral | +✅ | +✅ | +✅ | +✅ | +✅ | +
Opt | +❌ | +✅ | +✅ | +✅ | +✅ | +
Qwen2Moe | +❌ | +✅ | +✅ | +✅ | +✅ | +
Stablelm | +❌ | +✅ | +✅ | +✅ | +✅ | +
Cohere | +❌ | +✅ | +✅ | +✅ | +✅ | +
Arctic | +❌ | +✅ | +✅ | +✅ | +✅ | +
OpenELM | +❌ | +✅ | +✅ | +✅ | +✅ | +
Gemma | +❌ | +✅ | +✅ | +✅ | +✅ | +
GptNeoX | +❌ | +✅ | +✅ | +✅ | +✅ | +
Jetmoe | +❌ | +✅ | +✅ | +✅ | +✅ | +
Mamba | +❌ | +✅ | +❌ | +✅ | +✅ | +
MosaicMpt | +❌ | +✅ | +✅ | +✅ | +✅ | +
Palm | +❌ | +✅ | +✅ | +✅ | +✅ | +
Qwen1 | +❌ | +✅ | +✅ | +✅ | +✅ | +
Roberta | +❌ | +✅ | +✅ | +✅ | +✅ | +
T5 | +❌ | +✅ | +✅ | +✅ | +✅ | +
Dbrx | +❌ | +✅ | +✅ | +✅ | +✅ | +
Falcon | +❌ | +✅ | +✅ | +✅ | +✅ | +
Gpt2 | +❌ | +✅ | +✅ | +✅ | +✅ | +
Grok1 | +❌ | +✅ | +✅ | +✅ | +✅ | +
Llama | +✅ | +✅ | +✅ | +✅ | +✅ | +
Mistral | +✅ | +✅ | +✅ | +✅ | +✅ | +
Olmo | +❌ | +✅ | +✅ | +✅ | +✅ | +
Phi | +❌ | +✅ | +✅ | +✅ | +✅ | +
Phi 3 | +❌ | +✅ | +✅ | +✅ | +✅ | +
Qwen2 | +❌ | +✅ | +✅ | +✅ | +✅ | +
Rwkv | +❌ | +✅ | +❌ | +✅ | +✅ | +
Whisper | +❌ | +✅ | +✅ | +✅ | +✅ | +
you can also tell me the model you want in Flax/Jax version and ill try my best to build it ;)
+++ + + + + + + + + + + + + +More Models might have been added to
+~HEAD
but not mentioned here
In easydel bits are totally different from huggingface and in EasyDeL training model with 8 bit is supported too without +needs to change the code just change the bit and that's all you have todo but by the way you still have to pass +the dtype and param_dtype cause unlike the transformers and bitsandbytes which store parameters in int8 and do +operations +in float16, bfloat16, float32 we don't do that like this in Jax we still store parameters as float16,bfloat16 or float32 +and +do operations in bits like 8 6 4, and you can still train your model in this way and make it much more accurate than +bitsandbytes or peft fine-tuning
+++ + + + + + + + + + + + + +Right now im looking to make EasyBITs in EasyDeL work on TPU-v3 cause on low amp GPUs and old TPUs it +might now work as good as it does on TPU-v4/5
+
Thank you for considering contributing to EasyDeL! We welcome your input. To ensure a smooth collaboration, please review and adhere to the following guidelines.
+To contribute to EasyDeL, follow these steps: +1. Fork the repository. +2. Create a new branch for your feature or bug fix. +3. Make your changes and commit them with clear and descriptive messages. +4. Push your changes to your branch in your forked repository. +5. Submit a pull request to the main EasyDeL repository, detailing the changes you've made and the problem it solves.
+Please adhere to the Apache Code of Conduct in all interactions related to EasyDeL.
+If you encounter a bug, please open an issue on the EasyDeL repository, providing a clear and detailed description of the issue, including steps to reproduce it.
+If you have ideas for enhancements, feel free to open an issue on the EasyDeL repository. Provide a clear and detailed description of your proposed enhancement.
+To set up EasyDeL for development, follow the instructions in the README.md file.
+When submitting a pull request, please ensure the following: +- Your code follows the project's coding standards. +- Your commits are accompanied by clear and descriptive messages. +- Your pull request addresses a single issue or feature.
+By contributing to EasyDeL, you agree that your contributions will be licensed under the Apache License, Version 2.0.
+Thank you for your interest in contributing to EasyDeL! We appreciate your support.
+ + + + + + + + + + + + + +here in this case you will see an example data required by EasyDeL to pre-train or fine-tune models
+from datasets import load_dataset
+from easydel.data_preprocessing import DataProcessor, DataProcessorArguments
+from transformers import LlamaTokenizerFast
+
+
+def main():
+ tokenizer = LlamaTokenizerFast.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
+ dataset = load_dataset("erfanzar/orca-lite")
+ print(dataset)
+
+ # DatasetDict({
+ # train: Dataset({
+ # features: ['user', 'gpt', 'system', 'llama_2_prompt_style', 'prompt_length'],
+ # num_rows: 101397
+ # })
+ # })
+
+ processor_arguments = DataProcessorArguments(
+ max_position_embeddings=2048,
+ num_proc=6,
+ prompt_field='llama_2_prompt_style',
+
+ )
+
+ easydel_dataset = DataProcessor.process_data(
+ data=dataset['train'],
+ tokenizer=tokenizer,
+ arguments=processor_arguments,
+ field='train'
+ )
+ print(easydel_dataset)
+ # DatasetDict({
+ # train: Dataset({
+ # features: ['input_ids', 'attention_mask'],
+ # num_rows: 101397
+ # })
+ # })
+
+
+if __name__ == "__main__":
+ main()
+
+now you can pass this data to Trainer and train your model 😇.
+ + + + + + + + + + + + + +AttentionModule
AttentionModule is a EasyDeL module that can perform attention operation with different strategies to help user achieve +the best possible performance and numerical stability, here are some strategies supported right now.
+import jax
+import flax.linen.attention as flt
+from fjformer import GenerateRNG
+from easydel.modules.attention_module import AttentionModule
+from easydel.modules.easydel_modelling_utils import EasyDeLPretrainedConfig
+from jax import numpy as jnp, random, lax
+import math
+
+rng_gen = GenerateRNG(seed=42)
+config = EasyDeLPretrainedConfig(
+ axis_dims=(1, -1, 1, 1),
+ axis_names=("dp", "fsdp", "tp", "sp"),
+ block_q=512,
+ block_k=512
+)
+
+BATCH_SIZE = len(jax.devices())
+NUM_ATTN_HEADS = 32
+CONTEXT_LENGTH = 8192
+HEAD_DIM = 256
+
+
+def make_fake_input_data(
+ batch_size: int,
+ num_attention_head: int,
+ context_length: int,
+ head_dim: int,
+):
+ q = random.normal(next(rng_gen), (batch_size, context_length, num_attention_head, head_dim), dtype=jnp.float32)
+ k = random.normal(next(rng_gen), (batch_size, context_length, num_attention_head, head_dim), dtype=jnp.float32)
+ v = random.normal(next(rng_gen), (batch_size, context_length, num_attention_head, head_dim), dtype=jnp.float32)
+
+ attention_mask = jnp.ones((batch_size, context_length))
+ causal_mask = flt.make_causal_mask(attention_mask)
+
+ cm_ = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
+ at_ = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), cm_.shape)
+ at_ = flt.combine_masks(at_, cm_)
+
+ attention_bias = lax.select(
+ at_ > 0,
+ jnp.full(at_.shape, 0.0).astype(jnp.float32),
+ jnp.full(at_.shape, jnp.finfo(jnp.float32).min).astype(jnp.float32),
+ )
+
+ return (
+ q, k, v, attention_mask, causal_mask, attention_bias
+ )
+
+
+q, k, v, attention_mask, causal_mask, attention_bias = make_fake_input_data(
+ BATCH_SIZE,
+ NUM_ATTN_HEADS,
+ CONTEXT_LENGTH,
+ HEAD_DIM
+)
+
+flash_attention = AttentionModule(
+
+ block_k_major=config.block_k_major,
+ block_b=config.block_b,
+ block_q=config.block_q,
+ block_k=config.block_k,
+ block_q_major_dkv=config.block_q_major_dkv,
+ block_k_major_dkv=config.block_k_major_dkv,
+ block_k_major_dq=config.block_k_major_dq,
+ block_k_dkv=config.block_k_dkv,
+ block_q_dkv=config.block_q_dkv,
+ block_q_dq=config.block_q_dq,
+ block_k_dq=config.block_k_dq,
+ num_attention_heads=NUM_ATTN_HEADS,
+ attention_dropout=0.0,
+ head_dims=HEAD_DIM,
+ attention_partition_spec=config.attention_partition_spec,
+ shard_attention_computation=config.shard_attention_computation,
+ precision=lax.Precision("fastest"),
+ force_float32_tpu=True,
+ attn_mechanism="flash",
+ dtype=jnp.float32,
+ bias_partition_spec=config.bias_partition_spec,
+ key_partition_spec=config.key_partition_spec,
+ query_partition_spec=config.query_partition_spec,
+ generation_query_partition_spec=config.generation_query_partition_spec,
+ generation_bias_partition_spec=config.generation_bias_partition_spec,
+ value_partition_spec=config.value_partition_spec,
+ scan_ring_attention=config.scan_ring_attention,
+ mesh=config.jax_mesh(),
+ sm_scale=1 / math.sqrt(q.shape[-1]),
+)
+
+normal_attention = AttentionModule(
+
+ block_k_major=config.block_k_major,
+ block_b=config.block_b,
+ block_q=config.block_q,
+ block_k=config.block_k,
+ block_q_major_dkv=config.block_q_major_dkv,
+ block_k_major_dkv=config.block_k_major_dkv,
+ block_k_major_dq=config.block_k_major_dq,
+ block_k_dkv=config.block_k_dkv,
+ block_q_dkv=config.block_q_dkv,
+ block_q_dq=config.block_q_dq,
+ block_k_dq=config.block_k_dq,
+ num_attention_heads=NUM_ATTN_HEADS,
+ attention_dropout=0.0,
+ head_dims=HEAD_DIM,
+ attention_partition_spec=config.attention_partition_spec,
+ shard_attention_computation=config.shard_attention_computation,
+ precision=lax.Precision("fastest"),
+ force_float32_tpu=True,
+ attn_mechanism="normal",
+ dtype=jnp.float32,
+ bias_partition_spec=config.bias_partition_spec,
+ key_partition_spec=config.key_partition_spec,
+ query_partition_spec=config.query_partition_spec,
+ generation_query_partition_spec=config.generation_query_partition_spec,
+ generation_bias_partition_spec=config.generation_bias_partition_spec,
+ value_partition_spec=config.value_partition_spec,
+ scan_ring_attention=config.scan_ring_attention,
+ mesh=config.jax_mesh(),
+ sm_scale=1 / math.sqrt(q.shape[-1]),
+)
+
+with config.jax_mesh():
+ flash_attn_out = flash_attention(
+ query_states=q,
+ key_states=k,
+ value_states=v,
+ bias=attention_bias,
+ key_value_sequence_length=CONTEXT_LENGTH,
+ query_sequence_length=CONTEXT_LENGTH
+ )
+ normal_attn_out = normal_attention(
+ query_states=q,
+ key_states=k,
+ value_states=v,
+ bias=attention_bias,
+ key_value_sequence_length=CONTEXT_LENGTH,
+ query_sequence_length=CONTEXT_LENGTH
+ )
+
+print(
+ flash_attn_out.attention_outputs[0, CONTEXT_LENGTH - 5, NUM_ATTN_HEADS - 1, HEAD_DIM - 10:]
+)
+# Array([-0.05915311, 0.0078501 , 0.03785717, 0.0134844 , 0.08464689,
+# 0.06667967, -0.02629154, -0.0180066 , -0.02972782, 0.02833381], dtype=float32)
+print(
+ normal_attn_out.attention_outputs[0, CONTEXT_LENGTH - 5, NUM_ATTN_HEADS - 1, HEAD_DIM - 10:]
+)
+
+# Array([-0.0590958 , 0.00796138, 0.03789062, 0.01350671, 0.08461153,
+# 0.06662725, -0.0262386 , -0.01806086, -0.0296791 , 0.02824247], dtype=float32)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ EasyDeLState is a cool feature in easydel and have a lot of options like
+storing Model Parameters
, Optimizer State, Model Config, Model Type, Optimizer and Scheduler Configs
Let see and examples of using EasyDeLState
+Fine-tuning from a previous State or a new state
+from easydel import (
+ AutoEasyDeLConfig,
+ EasyDeLState
+)
+from transformers import AutoTokenizer
+from jax import numpy as jnp, lax
+import jax
+
+huggingface_model_repo_id = "REPO_ID"
+checkpoint_name = "CKPT_NAME"
+
+state = EasyDeLState.from_pretrained(
+ pretrained_model_name_or_path=huggingface_model_repo_id,
+ filename=checkpoint_name,
+ optimizer="adamw",
+ scheduler="none",
+ tx_init=None,
+ device=jax.devices('cpu')[0], # Offload Device
+ dtype=jnp.bfloat16,
+ param_dtype=jnp.bfloat16,
+ precision=lax.Precision("fastest"),
+ sharding_axis_dims=(1, -1, 1, 1),
+ sharding_axis_names=("dp", "fsdp", "tp", "sp"),
+ query_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
+ key_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
+ value_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
+ bias_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), None, None, None),
+ attention_partition_spec=jax.sharding.PartitionSpec(("dp", "fsdp"), "sp", "tp", None),
+ shard_attention_computation=False,
+ input_shape=(1, 1),
+ backend=None,
+ init_optimizer_state=False,
+ free_optimizer_state=True,
+ verbose=True,
+ state_shard_fns=None,
+)
+
+config = AutoEasyDeLConfig.from_pretrained(
+ huggingface_model_repo_id
+)
+
+tokenizer = AutoTokenizer.from_pretrained(
+ huggingface_model_repo_id,
+ trust_remote_code=True
+)
+
+max_length = config.max_position_embeddings
+
+configs_to_initialize_model_class = {
+ 'config': config,
+ 'dtype': jnp.bfloat16,
+ 'param_dtype': jnp.bfloat16,
+ 'input_shape': (8, 8)
+}
+
+EasyDeLState
also has .load_state()
and .save_state()
with some other usable options like .free_opt_state()
+which
+free optimizer state or .shard_params()
which shard parameters you can read docs in order to find out more about these
+options.
Let see how you can convert a EasyDeLMistral Model to Huggingface Pytorch Mistral Model from a trained State
+
+from transformers import MistralForCausalLM
+from easydel import (
+ AutoEasyDeLConfig,
+ EasyDeLState,
+ easystate_to_huggingface_model
+)
+import jax
+
+huggingface_model_repo_id = "REPO_ID"
+
+config = AutoEasyDeLConfig.from_pretrained(
+ huggingface_model_repo_id
+)
+with jax.default_device(jax.devices("cpu")[0]):
+ model = easystate_to_huggingface_model(
+ state=EasyDeLState.load_state(
+ "PATH_TO_CKPT",
+ input_shape=(8, 2048)
+ ), # You can Pass EasyDeLState here
+ base_huggingface_module=MistralForCausalLM,
+ config=config,
+ )
+
+model = model.half() # it's a huggingface model now
+
+EasyDeLState
have a general use you can use it everywhere in easydel for example for a stand-alone model
+, serve, fine-tuning and many other features, it's up to you to test how creative you are 😇.
Sure, here is a document about Falcon Models:
+Falcon Models
+Falcon Models is a family of large language models (LLMs) developed by the Technology Innovation Institute (TII) in Abu +Dhabi. The models are trained on a massive dataset of text and code, and can be used for a variety of tasks, including
+The Falcon models are available under the Apache 2.0 license, which means that they can be freely used, modified, and +redistributed.
+Falcon-40B
+The Falcon-40B is the largest model in the Falcon family. It has 40 billion parameters, and is trained on a dataset of +500 billion words. The model is capable of state-of-the-art performance on a variety of NLP tasks.
+Falcon-7B
+The Falcon-7B is a smaller version of the Falcon-40B. It has 7 billion parameters, and is trained on a dataset of 100 +billion words. The model is still capable of achieving strong performance on NLP tasks, but it is more efficient to +train and deploy.
+Falcon-180B
+The Falcon-180B is the newest model in the Falcon family. It has 180 billion parameters, and is trained on a dataset of +2 trillion words. The model is the largest openly available LLM, and it is capable of achieving state-of-the-art +performance on a variety of NLP tasks.
+Use Cases
+The Falcon models can be used for a variety of tasks, including:
+Availability
+The Falcon models are available through the Hugging Face Hub. The models are also available in the TensorFlow Hub and +the PyTorch Hub ( and EasyDeL).
+Conclusion
+The Falcon models are a powerful family of LLMs that can be used for a variety of tasks. The models are open source and +available for free, making them a valuable resource for researchers and developers.
+import jax
+from easydel import AutoEasyDeLModelForCausalLM
+
+model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
+ 'tiiuae/falcon-7b',
+ # other kwargs
+)
+
+also keep that in mind that returned config
includes .get_partition_rules(fsdp=True)
from easydel.serve import JAXServer, JAXServerConfig
+from easydel import AutoEasyDeLModelForCausalLM
+from transformers import AutoTokenizer
+
+model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
+ 'tiiuae/falcon-7b',
+ # other kwargs
+)
+
+
+class FalconJaxServer(JAXServer):
+ ...
+ # You have to Custom this one yourself as you
+ # need read JaxServer Documents inorder to learn how
+
+
+server = FalconJaxServer.from_parameters(
+ params=params,
+ model=model,
+ config_model=model.config,
+ add_params_field=True,
+ tokenizer=AutoTokenizer.from_pretrained('tiiuae/falcon-7b'),
+ verbose=False,
+ do_memory_log=True,
+ server_config=JAXServerConfig()
+)
+
+server.fire() # Launch FastAPI functions
+
+shared_urls = server.launch(
+ share_chat=True,
+ share_inst=True
+)
+
+Done 😇 this method can be used for all the Falcon models
+ + + + + + + + + + + + + +with using EasyDeL FineTuning LLM (CausalLanguageModels) are easy as much as possible with using Jax and Flax
+and having the benefit of TPUs
for the best speed here's a simple code to use in order to finetune your
+own Model
Days Has Been Passed and now using easydel in Jax is way more similar to HF/PyTorch Style +now it's time to finetune our model.
+import jax.numpy
+from easydel import (
+ TrainArguments,
+ CausalLanguageModelTrainer,
+ AutoEasyDeLModelForCausalLM,
+ EasyDeLOptimizers,
+ EasyDeLSchedulers,
+ EasyDeLGradientCheckPointers
+)
+from datasets import load_dataset
+import flax
+from jax import numpy as jnp
+from transformers import AutoTokenizer
+
+huggingface_repo_id_or_path = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
+
+model, params = AutoEasyDeLModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )
+
+max_length = 2048
+tokenizer = AutoTokenizer.from_pretrained(
+ huggingface_repo_id_or_path,
+ trust_remote_code=True
+)
+tokenizer.pad_token = tokenizer.eos_token
+
+model.config.add_basic_configurations(
+ attn_mechanism="flash", # Change to 'normal' if the model you are using
+ # don't support flash attention, or you don't want to apply flash attention for the model
+ block_b=1,
+ block_q=1024,
+ block_k=1024,
+ block_k_major=1024,
+)
+
+configs_to_initialize_model_class = {
+ "config": model.config,
+ "dtype": jnp.bfloat16,
+ "param_dtype": jnp.bfloat16,
+ "input_shape": (1, 1)
+}
+
+train_arguments = TrainArguments(
+ model_class=type(model),
+ model_name="my_first_model_to_train_using_easydel",
+ num_train_epochs=3,
+ configs_to_initialize_model_class=configs_to_initialize_model_class,
+ learning_rate=5e-5,
+ learning_rate_end=1e-6,
+ optimizer=EasyDeLOptimizers.ADAMW, # "adamw", "lion", "adafactor" are supported
+ scheduler=EasyDeLSchedulers.LINEAR,
+ # "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear" are supported
+ weight_decay=0.01,
+ total_batch_size=64,
+ max_training_steps=None, # None to let trainer Decide
+ do_train=True,
+ do_eval=False, # it's optional but supported
+ backend="tpu", # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
+ max_length=max_length, # Note that you have to change this in the model config too
+ gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
+ sharding_array=(1, -1, 1, 1), # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
+ # everything training will be in fully FSDP automatic and share data between devices
+ remove_ckpt_after_load=True,
+ gradient_accumulation_steps=8,
+ loss_re_mat="",
+ dtype=jnp.bfloat16
+)
+
+
+def ultra_chat_prompting_process(
+ data_chunk
+):
+ user_part = [
+ chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "user"
+ ]
+ assistant_part = [
+ chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "assistant"
+ ]
+
+ prompt = ""
+
+ for uc, ac in zip(user_part, assistant_part):
+ prompt += f"<|user|>\n{uc}</s>\n<|assistant|>\n{ac}</s>\n"
+
+ return {"prompt": prompt}
+
+
+tokenization_process = lambda data_chunk: tokenizer(
+ data_chunk["prompt"],
+ add_special_tokens=False,
+ max_length=max_length,
+ padding="max_length"
+)
+
+dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
+dataset_train = dataset["train_gen"].map(ultra_chat_prompting_process, num_proc=12)
+dataset_train = dataset_train.map(
+ tokenization_process,
+ num_proc=12,
+ remove_columns=dataset_train.column_names
+)
+
+# you can do the same for evaluation process dataset
+
+trainer = CausalLanguageModelTrainer(
+ train_arguments,
+ dataset_train,
+ checkpoint_path=None
+)
+
+output = trainer.train(flax.core.FrozenDict({"params": params}))
+print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ EasyDeL uses FJFormer and JAX as main dependencies in order to run the scripts but there are some things that needs to be installed such as GO-lang to JAX specific platform installations, but you can simply install EasyDeL via pip:
+pip install easydel
+
+JAX uses XLA to compile and run your NumPy programs on GPUs and TPUs. Compilation happens under the hood by default, with library calls getting just-in-time compiled and executed. But JAX also lets you just-in-time compile your own Python functions into XLA-optimized kernels using a one-function API, jit.
+you can install other version too but easydel required at least version of 0.4.16
+!pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -q
+
+pip install --upgrade pip
+# CUDA 12 installation
+# Note: wheels only available on linux.
+pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+
+pip install --upgrade pip
+# CUDA 11 installation
+# Note: wheels only available on linux.
+pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+
+sudo apt-get update && apt-get upgrade -y
+sudo apt-get install golang -y
+
+sudo pacman -Syyuu go
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ JAXServer
is one of offered utilities by EasyDeL, and it's help hosting using and doing process with LLMs
+and its also hackable, so you can override your own method in it and use it support both mid-level and high-level apis
+and also give you a Gradio Chat and Instruct Pre-build and ready to use page
transformers.FlaxPretrainedModel
as their Parent :)The config input is a dictionary that contains the following keys:
+port
: The port number that the server will listen on.2059
batch_size
: The batch size for training.1
max_sequence_length
: The maximum length of a sequence.2048
max_new_tokens
: The maximum number of new tokens generated by the model in a single step.2048
max_compile_tokens
: The maximum number of tokens that can be streamed to the model in a single batch.32
temperature
: The temperature parameter for sampling from the model's output distribution.0.1
top_p
: The top-p parameter for sampling from the model's output distribution.0.95
top_k
: The top-k parameter for sampling from the model's output distribution.50
mesh_axes_shape
: The shape of the mesh axes for distributed training.(1, -1, 1, 1)
host
: The host address for the server.'0.0.0.0'
dtype
: The data type for the model's parameters.'fp16'
mesh_axes_names
: The names of the mesh axes for distributed training.("dp", "fsdp", "tp", "sp")
logging
: Whether the model should log its training progress.:True
stream_tokens_for_gradio
: Whether the model should stream tokens to Gradio.True
use_prefix_tokenizer
: Whether the model should use a prefix tokenizer.True
pre_compile
: Whether the model should be pre-compiled.True
JAXServer
has format_chat
and format_instruct
funcs that you have to implement them to prompt your model
+def format_instruct(self, system: str, instruction: str) -> str:
+ """
+ Here you will get the system and instruction from user, and you can apply your prompting style
+ """
+ raise NotImplementedError()
+
+
+def format_chat(self, history: typing.List[str], prompt: str, system: typing.Union[str, None]) -> str:
+ """
+ Here you will get the system, prompt and history from user, and you can apply your prompting style
+ """
+ raise NotImplementedError()
+
+JAXServer
Contains a method named .sample
and with using sample
method you can generate text from text
what does this do and how this works ? here's the inputs that sample
function takes in
def sample(self,
+ string,
+ *,
+ greedy: bool = False,
+ max_new_tokens: int = None,
+ **kwargs
+ ) -> [str, int]:
+ ...
+
+(String)
(Bool)
(Int)
(String)
(Int)
you can use this function outside the class like this
+for string, num_used_tokens in server.sample(
+ 'im a string',
+ greedy=False,
+ max_new_tokens=256 # or None to use Maximum numbers passed in Config
+):
+ print(f'\r{num_used_tokens}: {string}', end="")
+
+if you want to change gradio response functions you can override them like this
+this is the default gradio functions and this is how it looks :
+def sample_gradio_chat(self, prompt, history, max_new_tokens, system, greedy):
+ string = self.chat_format(history=history, prompt=prompt, system=system)
+
+ if not self.config.stream_tokens_for_gradio:
+ response = ""
+ for response, _ in self.sample(
+ string=string,
+ greedy=greedy,
+ max_new_tokens=max_new_tokens,
+ ):
+ ...
+ history.append([prompt, response])
+ else:
+ history.append([prompt, ""])
+ for response, _ in self.sample(
+ string=string,
+ greedy=greedy,
+ max_new_tokens=max_new_tokens,
+ ):
+ history[-1][-1] = response
+ yield "", history
+ return "", history
+
+and here's a example of changing that in order to use Llama Models
+def sample_gradio_chat(self, prompt, history, max_new_tokens, system, greedy):
+ def prompt_llama2_model(message: str, chat_history,
+ system_prompt: str) -> str:
+
+ do_strip = False
+ texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
+ for user_input, response in chat_history:
+ user_input = user_input.strip() if do_strip else user_input
+ do_strip = True
+ texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
+ message = message.strip() if do_strip else message
+ texts.append(f'{message} [/INST]')
+ return "".join(texts)
+
+ string = prompt_llama2_model(
+ message=prompt,
+ chat_history=history or [],
+ system_prompt=system
+ )
+ if not self.config.stream_tokens_for_gradio:
+ response = ""
+ for response, _ in self.sample(
+ string=string,
+ greedy=greedy,
+ max_new_tokens=max_new_tokens,
+ ):
+ ...
+ history.append([prompt, response])
+ else:
+ history.append([prompt, ""])
+ for response, _ in self.sample(
+ string=string,
+ greedy=greedy,
+ max_new_tokens=max_new_tokens
+ ):
+ history[-1][-1] = response
+ yield "", history
+
+ return "", history
+
+
+as you see you can easily override the functions just like how you want and use them with some simple changes,
+and you can Also Use Their Gradio Client
or use JAXServer
FastAPI
builtin methods
to Override this api you have to code forward_instruct
just like what you want the default implementation of this
+function is
def forward_instruct(self, data: InstructRequest):
+ if not self._funcs_generated:
+ return {
+ 'status': "down"
+ }
+
+ string = self.config.instruct_format.format(instruct=data.prompt, system=data.system)
+ response, used_tokens = [None] * 2
+ for response, used_tokens in self.sample(
+ string=string,
+ greedy=data.greedy,
+ max_new_tokens=None
+ ):
+ ...
+ self.number_of_served_request_until_last_up_time += 1
+ return {
+ 'input': f'{string}',
+ 'response': response,
+ 'tokens_used': used_tokens,
+ }
+
+class InstructRequest(BaseModel):
+ prompt: str
+ system: Optional[str] = None
+ temperature: Optional[float] = None
+ greedy: Optional[bool] = False
+
+requests
library in
+ python :import requests
+
+content = {
+ 'prompt': 'can you code a simple neural network in c++ for me',
+ 'system': 'You are an AI assistant generate short and useful response',
+ 'temperature': 0.1,
+ 'greedy': False
+}
+
+response = requests.post(
+ url='http://ip:port/instruct',
+ json=content
+).json()
+
+print(response['response'])
+# Response of model
+print(response['input'])
+# The input passed to the model
+
+
+to Override this api you have to code forward_chat
just like what you want the default implementation of this function
+is
def forward_chat(self, data: ChatRequest):
+ if not self._funcs_generated:
+ return {
+ 'status': "down"
+ }
+
+ history = self.process_chat_history(data.history or [])
+ history += self.config.prompt_prefix_chat + data.prompt + self.config.prompt_postfix_chat
+
+ response, used_tokens = [None] * 2
+ for response, used_tokens in self.process(
+ string=history,
+ greedy=data.greedy,
+ max_new_tokens=None
+ ):
+ ...
+ self.number_of_served_request_until_last_up_time += 1
+ return {
+ 'input': f'{history}',
+ 'response': response,
+ 'tokens_used': used_tokens,
+ }
+
+class ChatRequest(BaseModel):
+ prompt: str
+ history: Union[List[List], None] = None
+ temperature: Optional[float] = None
+ greedy: Optional[bool] = False
+
+requests
library in
+ python :import requests
+
+content = {
+ 'prompt': 'can you code a simple neural network in c++ for me',
+ 'history': [
+ ['hello how are you', 'Hello\nthanks, im here to assist you you have any question that i could help you with']
+ ],
+ 'temperature': 0.1,
+ 'greedy': False
+}
+
+response = requests.post(
+ url='http://ip:port/chat',
+ json=content
+).json()
+
+print(response['response'])
+# Response of model
+print(response['input'])
+# The input passed to the model
+
+
+Simply by sending a get API to https://ip:port/status
you will receive base information about the server and
+how it being run, num cores in use, number of generated prompt , number of request and ...
Llama models are a family of large language models (LLMs) developed by Meta AI. They are trained on a massive dataset of +text and code, and they can be used for a variety of tasks, such as text generation, translation, summarization, +question answering, code generation, and natural language inference.
+Llama models are based on the Transformer architecture, which is a neural network architecture that has been shown to be +very effective for natural language processing tasks. The Transformer architecture uses self-attention to learn +long-range dependencies between words in a sentence.
+Llama models are trained on a massive dataset of text and code. The text dataset includes text from a variety of +sources, such as books, articles, and websites. The code dataset includes code from a variety of programming languages, +such as Python, Java, and C++.
+After being pre-trained on a massive dataset, Llama models can be fine-tuned for specific tasks. Fine-tuning involves +training the model on a smaller dataset of data that is relevant to the specific task.
+Llama models can be used for a variety of tasks, such as:
+* Text generation: Llama models can be used to generate text, such as poems, code, scripts, and musical pieces.
+* Translation: Llama models can be used to translate text from one language to another.
+* Summarization: Llama models can be used to summarize text.
+* Question answering: Llama models can be used to answer questions about text.
+* Code generation: Llama models can be used to generate code.
+* Natural language inference: Llama models can be used to determine the relationship between two sentences.
+
+Llama models are available for free for research and commercial use. They can be downloaded from the Hugging Face Hub.
+Llama models are still under development, and they have some limitations. For example, they can sometimes generate +incorrect or misleading text. They can also be biased, reflecting the biases that are present in the training data.
+Llama models are a promising new technology with the potential to be used for a variety of applications. Future work on +Llama models will focus on improving their accuracy, reducing their bias, and making them more robust to errors.
+Here is a table comparing the different sizes of Llama models:
+Model | +Parameters | +
---|---|
Llama 7B | +7 billion | +
Llama 13B | +13 billion | +
Llama 33B | +33 billion | +
Llama 65B | +65 billion | +
Llama 70B | +70 billion | +
from easydel import AutoEasyDeLModelForCausalLM
+
+model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
+ 'meta-llama/Llama-2-7b',
+ # other kwargs
+)
+
+also keep that in mind that returned config
includes .get_partition_rules(fsdp=True)
from easydel.serve import JAXServer, JAXServerConfig
+import jax
+from transformers import AutoTokenizer
+
+from easydel import AutoEasyDeLModelForCausalLM
+
+model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
+ 'meta-llama/Llama-2-7b',
+ # other kwargs
+)
+
+DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant and act as wanted"
+
+
+class Llama2JaxServer(JAXServer):
+ def sample_gradio_chat(self, prompt, history, max_new_tokens, system, greedy):
+
+ system = None if system == "" else system
+ string = self.prompt_llama2_model(
+ message=prompt,
+ chat_history=history or [],
+ system_prompt=system or DEFAULT_SYSTEM_PROMPT
+ )
+ if not self.server_config.stream_tokens_for_gradio:
+ response = ""
+ for response, _ in self.sample(
+ string=string,
+ greedy=greedy,
+ max_new_tokens=max_new_tokens,
+ ):
+ ...
+ history.append([prompt, response])
+ else:
+ history.append([prompt, ""])
+ for response, _ in self.sample(
+ string=string,
+ greedy=greedy,
+ max_new_tokens=max_new_tokens
+ ):
+ history[-1][-1] = response
+ yield "", history
+
+ return "", history
+
+ def sample_gradio_instruct(self, prompt, system, max_new_tokens, greedy):
+ string = self.prompt_llama2_model(system_prompt=DEFAULT_SYSTEM_PROMPT, message=prompt, chat_history=[])
+ if not self.server_config.stream_tokens_for_gradio:
+ response = ""
+ for response, _ in self.sample(
+ string=string,
+ greedy=greedy,
+ max_new_tokens=max_new_tokens,
+ ):
+ pass
+ else:
+ response = ""
+ for response, _ in self.sample(
+ string=string,
+ greedy=greedy,
+ max_new_tokens=max_new_tokens,
+ stream=True
+ ):
+ yield "", response
+ return "", response
+
+ @staticmethod
+ def prompt_llama2_model(message: str, chat_history,
+ system_prompt: str) -> str:
+
+ do_strip = False
+ texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
+ for user_input, response in chat_history:
+ user_input = user_input.strip() if do_strip else user_input
+ do_strip = True
+ texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
+ message = message.strip() if do_strip else message
+ texts.append(f'{message} [/INST]')
+ return "".join(texts)
+
+
+server = Llama2JaxServer.from_parameters(
+ params=params,
+ model=model,
+ config_model=model.config,
+ add_params_field=True,
+ tokenizer=AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b'),
+ verbose=False,
+ do_memory_log=True,
+ server_config=JAXServerConfig()
+)
+
+server.fire() # Launch FastAPI functions
+
+shared_urls = server.launch(
+ share_chat=True,
+ share_inst=True
+)
+
+Done 😇 this method can be used for all the llama models
+ + + + + + + + + + + + + +Llama2 Models
+Llama2 Models is a family of pretrained and fine-tuned large language models (LLMs) developed by Meta AI. The models are +trained on a massive dataset of text and code, and can be used for a variety of tasks, including
+The Llama2 models are available under the Apache 2.0 license, which means that they can be freely used, modified, and +redistributed.
+Model Architecture
+The Llama2 models are based on the Transformer architecture, which is a neural network architecture that has been shown +to be very effective for NLP tasks. The models are trained using a technique called masked language modeling, which +involves predicting the missing words in a sequence of text.
+Model Sizes
+The Llama2 models come in a variety of sizes, ranging from 7 billion to 70 billion parameters. The larger models have +more capacity to learn complex patterns in language, but they are also more computationally expensive to train and +deploy.
+Fine-tuning
+The Llama2 models are pretrained on a massive dataset of text and code, but they can be further fine-tuned on a specific +task to improve their performance. Fine-tuning involves training the model on a dataset of labeled data for the specific +task.
+Use Cases
+The Llama2 models can be used for a variety of tasks, including:
+Availability
+The Llama2 models are available through the Hugging Face Hub. The models are also available in the TensorFlow Hub , the +PyTorch Hub and EasyDeL.
+Conclusion
+The Llama2 models are a powerful family of LLMs that can be used for a variety of tasks. The models are open source and +available for free, making them a valuable resource for researchers and developers.
+from easydel import AutoEasyDeLModelForCausalLM
+
+model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
+ 'meta-llama/Llama-2-7b',
+ # other kwargs
+)
+
+also keep that in mind that returned config
includes .get_partition_rules(fsdp=True)
from easydel.serve import JAXServer, JAXServerConfig
+import jax
+from transformers import AutoTokenizer
+
+from easydel import AutoEasyDeLModelForCausalLM
+
+model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
+ 'meta-llama/Llama-2-7b',
+ # other kwargs
+)
+
+DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant and act as wanted"
+
+
+class Llama2JaxServer(JAXServer):
+ def sample_gradio_chat(self, prompt, history, max_new_tokens, system, greedy):
+
+ system = None if system == "" else system
+ string = self.prompt_llama2_model(
+ message=prompt,
+ chat_history=history or [],
+ system_prompt=system or DEFAULT_SYSTEM_PROMPT
+ )
+ if not self.server_config.stream_tokens_for_gradio:
+ response = ""
+ for response, _ in self.sample(
+ string=string,
+ greedy=greedy,
+ max_new_tokens=max_new_tokens,
+ ):
+ ...
+ history.append([prompt, response])
+ else:
+ history.append([prompt, ""])
+ for response, _ in self.sample(
+ string=string,
+ greedy=greedy,
+ max_new_tokens=max_new_tokens
+ ):
+ history[-1][-1] = response
+ yield "", history
+
+ return "", history
+
+ def sample_gradio_instruct(self, prompt, system, max_new_tokens, greedy):
+ string = self.prompt_llama2_model(system_prompt=DEFAULT_SYSTEM_PROMPT, message=prompt, chat_history=[])
+ if not self.server_config.stream_tokens_for_gradio:
+ response = ""
+ for response, _ in self.sample(
+ string=string,
+ greedy=greedy,
+ max_new_tokens=max_new_tokens,
+ ):
+ pass
+ else:
+ response = ""
+ for response, _ in self.sample(
+ string=string,
+ greedy=greedy,
+ max_new_tokens=max_new_tokens,
+ stream=True
+ ):
+ yield "", response
+ return "", response
+
+ @staticmethod
+ def prompt_llama2_model(message: str, chat_history,
+ system_prompt: str) -> str:
+
+ do_strip = False
+ texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
+ for user_input, response in chat_history:
+ user_input = user_input.strip() if do_strip else user_input
+ do_strip = True
+ texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
+ message = message.strip() if do_strip else message
+ texts.append(f'{message} [/INST]')
+ return "".join(texts)
+
+
+server = Llama2JaxServer.from_parameters(
+ params=params,
+ model=model,
+ config_model=model.config,
+ add_params_field=True,
+ tokenizer=AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b'),
+ verbose=False,
+ do_memory_log=True,
+ server_config=JAXServerConfig()
+)
+
+server.fire() # Launch FastAPI functions
+
+shared_urls = server.launch(
+ share_chat=True,
+ share_inst=True
+)
+
+Done 😇 this method can be used for all the llama2 models
+ + + + + + + + + + + + + +in case of using LoRA and applying that on the EasyDeL models there are some other things +that you might need to config on your own but a lot of things being handled by EasyDeL so let just jump into an example +for LoRA fine-tuning section and use EasyDeLXRapTure in for mistral models with flash attention example
+from flax.core import FrozenDict
+from easydel import (
+ TrainArguments,
+ CausalLanguageModelTrainer,
+ AutoEasyDeLModelForCausalLM,
+ EasyDeLOptimizers,
+ EasyDeLSchedulers,
+ EasyDeLGradientCheckPointers,
+ EasyDeLXRapTureConfig
+)
+from datasets import load_dataset
+import flax
+from jax import numpy as jnp
+from transformers import AutoTokenizer
+
+huggingface_repo_id_or_path = "mistralai/Mistral-7B-Instruct-v0.1"
+
+model, params = AutoEasyDeLModelForCausalLM.from_pretrained(huggingface_repo_id_or_path, )
+
+max_length = 8196
+model_parameters = FrozenDict({"params": params})
+
+dtype = jnp.bfloat16
+param_dtype = jnp.bfloat16 # you can change that if you want
+
+tokenizer = AutoTokenizer.from_pretrained(
+ huggingface_repo_id_or_path,
+ trust_remote_code=True
+)
+
+model.config.add_basic_configurations(
+ attn_mechanism="flash", # Using FlashAttention
+ block_b=1,
+ block_q=1024,
+ block_k=1024,
+ block_k_major=1024,
+)
+
+tokenizer.pad_token = tokenizer.eos_token
+configs_to_initialize_model_class = {
+ "config": model.config,
+ "dtype": dtype,
+ "param_dtype": param_dtype,
+ "input_shape": (1, 1)
+}
+
+rapture = EasyDeLXRapTureConfig(
+ parameters=model_parameters,
+ lora_dim=64,
+ fully_fine_tune_parameters=["embed_tokens"], # Model layer to be fully fine tuned
+ lora_fine_tune_parameters=["q_proj", "v_proj", "k_proj", "o_proj"], # LoRA Layer Targets you can pass this to none
+ # For only Layer Tuning or transfer learning
+ verbose=True
+)
+
+train_arguments = TrainArguments(
+ model_class=type(model),
+ model_name="EasyDeL-Lora-Example",
+ num_train_epochs=3,
+ configs_to_initialize_model_class=configs_to_initialize_model_class,
+ learning_rate=1e-4, # Using higher learning rate is recommended
+ learning_rate_end=8e-5,
+ optimizer=EasyDeLOptimizers.ADAMW, # "adamw", "lion", "adafactor" are supported
+ scheduler=EasyDeLSchedulers.LINEAR,
+ # "linear","cosine", "none" ,"warm_up_cosine" and "warm_up_linear" are supported
+ weight_decay=0.01,
+ total_batch_size=512,
+ max_training_steps=None, # None to let trainer Decide
+ do_train=True,
+ do_eval=False, # it's optional but supported
+ backend="tpu", # default backed is set to cpu, so you must define you want to use tpu cpu or gpu
+ max_length=max_length, # Note that you have to change this in the model config too
+ gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
+ sharding_array=(1, -1, 1, 1), # the way to shard model across gpu,cpu or TPUs using sharding array (1, -1, 1, 1)
+ # everything training will be in fully FSDP automatic and share data between devices
+ remove_ckpt_after_load=True,
+ gradient_accumulation_steps=1,
+ loss_re_mat="",
+ dtype=dtype,
+ param_dtype=param_dtype,
+ rapture_config=rapture,
+ merge_lora_rapture_parameters=True # turning this off is still not supported and not recommended to do so
+ # What this does ? this will merge the lora parameters with the original model parameters and the end of training
+)
+
+
+def ultra_chat_prompting_sample(
+ data_chunk
+):
+ user_part = [
+ chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "user"
+ ]
+ assistant_part = [
+ chunk["content"] for chunk in data_chunk["messages"] if chunk["role"] == "assistant"
+ ]
+
+ prompt = ""
+
+ for uc, ac in zip(user_part, assistant_part):
+ prompt += f"<|user|>\n{uc}</s>\n<|assistant|>\n{ac}</s>\n"
+
+ return {"prompt": prompt}
+
+
+tokenization_process = lambda data_chunk: tokenizer(
+ data_chunk["prompt"],
+ add_special_tokens=False,
+ max_length=max_length,
+ padding="max_length"
+)
+
+dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
+dataset_train = dataset["train_gen"].map(ultra_chat_prompting_process, num_proc=12)
+dataset_train = dataset_train.map(
+ tokenization_process,
+ num_proc=12,
+ remove_columns=dataset_train.column_names
+)
+
+# you can do the same for evaluation process dataset
+
+trainer = CausalLanguageModelTrainer(
+ train_arguments,
+ dataset_train,
+ checkpoint_path=None
+)
+
+output = trainer.train() # you should not pass the parameters in Trainer.train anymore when
+# you are using LoRA or transfer Learning
+print(f"Hey ! , here's where your model saved {output.checkpoint_path}")
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Mistral LLM models. Mistral AI is a French startup that develops large language models (LLMs). Mistral's first LLM, +Mistral-7B-v0.1, was released in October 2023. It is a 7 billion parameter decoder-based LM with a number of +architectural innovations, including sliding window attention, grouped query attention, and byte-fallback BPE tokenizer. +Mistral-7B-v0.1 has been shown to achieve state-of-the-art performance on a number of NLP benchmarks, including GLUE, +SuperGLUE, and the Stanford Question Answering Dataset.
+Mistral AI has not yet released a commercial version of Mistral-7B-v0.1, but it is available for free download and +evaluation. The company is also working on developing larger and more powerful LLMs, including a 100 billion parameter +model.
+Mistral's LLMs have been praised for their ability to generate creative and informative text, as well as their ability +to perform a wide range of NLP tasks, such as translation, question answering, and summarization. However, some concerns +have been raised about the potential for Mistral's LLMs to be used to generate harmful content, such as instructions on +how to make bombs or how to self-harm.
+Overall, Mistral AI is a promising startup in the field of LLM development. Its LLMs have the potential to be used in a +wide range of applications, such as customer service, education, and creative writing. However, it is important to be +aware of the potential risks associated with using LLMs, such as the risk of generating harmful content.
+README.md
+Mistral LLM models
+Mistral LLM models are a set of large language models (LLMs) developed by Mistral AI, a French startup. Mistral's LLMs +are trained on massive datasets of text and code, and can be used to perform a variety of NLP tasks, including:
+Mistral-7B-v0.1 is the first LLM released by Mistral AI. It is a 7 billion parameter decoder-based LM with a number +of architectural innovations, including sliding window attention, grouped query attention, and byte-fallback BPE +tokenizer. Mistral-7B-v0.1 has been shown to achieve state-of-the-art performance on a number of NLP benchmarks, +including GLUE, SuperGLUE, and the Stanford Question Answering Dataset.
+To use a Mistral LLM model:
+generate()
method to generate text, translate languages, answer questions, or perform other NLP
+ tasks.Mistral LLM models are still under development, but they have the potential to be used in a wide range of +applications. If you are interested in using Mistral's LLMs, please visit the Mistral AI website: https://mistral.ai/ +for more information.
+using Mistral Models are the same as all the other models in EasyDeL Collection but let take a look at how can we train +or finetune a Mistral model
+from easydel.trainer import TrainArguments, CausalLanguageModelTrainer
+from datasets import load_dataset
+from transformers import AutoTokenizer
+from jax import numpy as jnp
+import flax
+import easydel
+from easydel import (
+ AutoEasyDeLModelForCausalLM,
+ EasyDeLOptimizers,
+ EasyDeLSchedulers,
+ EasyDeLGradientCheckPointers
+)
+
+model_huggingface_repo_id = 'mistralai/Mistral-7B-v0.1'
+dataset_train = load_dataset('<TOKENIZED_MISTRAL_DATASET_AT_HUGGINGFACE>')
+tokenizer = AutoTokenizer.from_pretrained(model_huggingface_repo_id, trust_remote_code=True)
+model, params = AutoEasyDeLModelForCausalLM.from_pretrained(model_huggingface_repo_id)
+config = model.config
+config.freq_max_position_embeddings = config.max_position_embeddings # 32768
+config.max_position_embeddings = 4096 # Let use context length of 4096 for training
+config.c_max_position_embeddings = config.max_position_embeddings
+
+max_sequence_length = config.max_position_embeddings
+
+train_args = TrainArguments(
+ model_class=easydel.FlaxMistralForCausalLM,
+ configs_to_initialize_model_class={
+ 'config': config,
+ 'dtype': jnp.bfloat16,
+ 'param_dtype': jnp.bfloat16,
+ 'input_shape': (1, 1)
+ },
+ custom_rule=config.get_partition_rules(True),
+ model_name='Test',
+ num_train_epochs=2,
+ learning_rate=4e-5,
+ learning_rate_end=5e-6,
+ optimizer=EasyDeLOptimizers.ADAMW,
+ scheduler=EasyDeLSchedulers.WARM_UP_COSINE,
+ weight_decay=0.01,
+ total_batch_size=32,
+ max_training_steps=None,
+ do_train=True,
+ do_eval=False,
+ backend='tpu',
+ max_sequence_length=max_sequence_length,
+ gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
+ sharding_array=(1, -1, 1, 1),
+ gradient_accumulation_steps=8,
+ remove_ckpt_after_load=True,
+ ids_to_pop_from_dataset=['token_type_ids'],
+ loss_re_mat="",
+ dtype=jnp.bfloat16
+)
+
+trainer = CausalLanguageModelTrainer(
+ train_args,
+ dataset_train['train'],
+ checkpoint_path=None
+)
+
+output = trainer.train(flax.core.FrozenDict({'params': params}))
+# And Here were easydel goes brrrrrr and start training
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ MosaicMPT Models
+MosaicMPT Models is a family of large language models (LLMs) developed by MosaicML. The models are trained on a massive +dataset of text and code, and can be used for a variety of tasks, including
+The MosaicMPT models are available under the Apache 2.0 license, which means that they can be freely used, modified, and +redistributed.
+Model Architecture
+The MosaicMPT models are based on the Transformer architecture, which is a neural network architecture that has been +shown to be very effective for NLP tasks. The models are trained using a technique called masked language modeling, +which involves predicting the missing words in a sequence of text.
+Model Sizes
+The MosaicMPT models come in a variety of sizes, ranging from 7 billion to 70 billion parameters. The larger models have +more capacity to learn complex patterns in language, but they are also more computationally expensive to train and +deploy.
+MosaicPretrainedTransformer (MPT) Architecture
+The MosaicPretrainedTransformer (MPT) architecture is a modified transformer architecture that is optimized for +efficient training and inference. The MPT architecture includes the following changes:
+Thanks to these modifications, MPT models can be trained with high throughput efficiency and stable convergence. MPT +models can also be served efficiently with both standard HuggingFace pipelines and NVIDIA's FasterTransformer.
+Use Cases
+The MosaicMPT models can be used for a variety of tasks, including:
+Availability
+The MosaicMPT models are available through the Hugging Face Hub. The models are also available in the TensorFlow Hub, +the PyTorch Hub and EasyDeL.
+Conclusion
+The MosaicMPT models are a powerful family of LLMs that can be used for a variety of tasks. The models are open source +and available for free, making them a valuable resource for researchers and developers.
+Quantization in the context of deep learning is the process of constraining the number of bits that represent the +weights and biases of the model.
+Weights and Biases numbers that we need in backpropagation.
+In 8-bit quantization, each weight or bias is represented using only 8 bits as opposed to the typical 32 bits used in +single-precision floating-point format (float32).
+The primary advantage of using 8-bit quantization is the reduction in model size and memory usage. Here's a simple +explanation:
+A float32 number takes up 32 bits of memory. +A 8-bit quantized number takes up only 8 bits of memory. +So, theoretically, you can fit 4 times more 8-bit quantized numbers into the same memory space as float32 numbers. This +allows you to load larger models into the GPU memory or use smaller GPUs that might not have been able to handle the +model otherwise.
+The amount of memory used by an integer in a computer system is directly related to the number of bits used to represent +that integer.
+Memory Usage for 8-bit Integer +A 8-bit integer uses 8 bits of memory.
+Memory Usage for 32-bit Integer +A 32-bit integer uses 32 bits of memory.
+Conversion to Bytes +To convert these to bytes (since memory is often measured in bytes):
+in case of serving models or using them with JAX
The Easiest and the best way you can find
+is EasyDeL (you can explore more if you want) you have 4 ways to use models
let assume we want to run a 7B model on only 12 GB of vram let just jump into codding
+let assume we want to run Qwen/Qwen1.5-7B-Chat
from jax import numpy as jnp
+from easydel import AutoEasyDeLModelForCausalLM, create_generate_function
+
+from transformers import AutoTokenizer, GenerationConfig
+
+import pickle
+import torch
+
+repo_id = "Qwen/Qwen1.5-7B-Chat"
+model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
+ repo_id,
+ sharding_axis_dims=(1, 1, 1, -1),
+ config_kwargs=dict(
+ gradient_checkpointing="",
+ use_scan_mlp=False, # Turn this one if you want to go beyond 32K sequence length.
+ shard_attention_computation=True,
+ use_sharded_kv_caching=True
+ ),
+ dtype=jnp.float16,
+ param_dtype=jnp.float16,
+ auto_shard_params=True,
+ load_in_8bit=True,
+ torch_dtype=torch.float16,
+ device_map="cpu" # this one will be passed to transformers.AutoModelForCausalLM
+)
+
+# params is now an 8 Bit pytree.
+
+tokenizer = AutoTokenizer.from_pretrained(repo_id)
+mesh = model.config.jax_mesh()
+
+gen_fn = create_generate_function(
+ model,
+ GenerationConfig(
+ do_sample=True,
+ max_new_tokens=512,
+ pad_token_id=tokenizer.pad_token_id,
+ bos_token_id=tokenizer.bos_token_id,
+ temperature=0.2,
+ top_p=0.95,
+ top_k=10,
+ num_beams=1
+ ),
+ {"params": params},
+ return_prediction_only=True
+)
+
+tokenizer.padding_side = "left"
+encoded = tokenizer.apply_chat_template(
+ [{"role": "user", "content": "generate an story about stars"}],
+ return_tensors="np",
+ return_dict=True,
+ max_length=512,
+ padding="max_length",
+ add_generation_prompt=True
+)
+
+rep = 1 # in case that you are using fsdp instead of sequence sharing change this to your fsdp mesh shape
+input_ids, attention_mask = encoded.input_ids.repeat(rep, 0), encoded.attention_mask.repeat(rep, 0)
+with mesh:
+ response = gen_fn(
+ {"params": params},
+ input_ids,
+ attention_mask
+ )
+
+ response_string = tokenizer.decode(response[0], skip_special_tokens=True)
+print(
+ f"Model Response:\n{response_string}"
+)
+
+# you want to save these quantized parameters for later?
+
+pickle.dump((model, params, tokenizer), open("EasyDeL-Qwen7B-Chat", "wb"))
+
+# And load that like this ;)
+
+(model, params, tokenizer) = pickle.load(open("EasyDeL-Qwen7B-Chat", "wb"))
+
+
+from jax import numpy as jnp
+from jax.sharding import PartitionSpec
+from easydel import JAXServer, JAXServerConfig
+
+import torch
+
+server_config = JAXServerConfig(
+ mesh_axes_shape=(1, 1, 1, -1),
+ generation_ps=PartitionSpec(("dp", "fsdp"), "sp"),
+ max_sequence_length=1024,
+ max_new_tokens=4096,
+ max_compile_tokens=128
+)
+
+server = JAXServer.from_torch_pretrained(
+ pretrained_model_name_or_path="Qwen/Qwen1.5-7B-Chat",
+ server_config=server_config,
+ sharding_axis_dims=(1, 1, 1, -1),
+ model_config_kwargs=dict(
+ gradient_checkpointing="",
+ use_scan_mlp=False,
+ shard_attention_computation=True,
+ use_sharded_kv_caching=True
+ ),
+ dtype=jnp.float16,
+ param_dtype=jnp.float16,
+ auto_shard_params=True,
+ load_in_8bit=True,
+ torch_dtype=torch.float16,
+ device_map="cpu" # this one will be passed to transformers.AutoModelForCausalLM
+)
+
+conversation = []
+while True:
+ conversation.append({"role": "user", "content": input("\n## User: ")})
+ printed_response_length = 0
+ print("\n## Assistant : ", end="")
+ response = ""
+ for response, used_tokens in server.sample(
+ server.tokenizer.apply_chat_template(
+ conversation,
+ tokenize=False
+ )
+ ):
+ print(response[printed_response_length:], end="")
+ printed_response_length = len(response)
+ conversation.append({"role": "assistant", "content": response})
+
+
+or you can launch it for serve
+server.gradio_inference.launch()
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ PyTorchServer
is one of offered utilities by EasyDeL, and it's help hosting using and doing sample with LLMs
+and its also hackable, so you can override your own method in it and use it support both mid-level and high-level apis
+and also give you a Gradio Chat and Instruct Pre-build and ready to use page
transformers.PretrainedModel
as their Parent :)Documents are On The Way Amigos...
+ + + + + + + + + + + + + +