Skip to content

v4.38: Gemma, Depth Anything, Stable LM; Static Cache, HF Quantizer, AQLM

Compare
Choose a tag to compare
@LysandreJik LysandreJik released this 21 Feb 13:40
· 637 commits to main since this release

New model additions

💎 Gemma 💎

Gemma is a new opensource Language Model series from Google AI that comes with a 2B and 7B variant. The release comes with the pre-trained and instruction fine-tuned versions and you can use them via AutoModelForCausalLM, GemmaForCausalLM or pipeline interface!

Read more about it in the Gemma release blogpost: https://hf.co/blog/gemma

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto", torch_dtype=torch.float16)

input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids)

You can use the model with Flash Attention, SDPA, Static cache and quantization API for further optimizations !

  • Flash Attention 2
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b", device_map="auto", torch_dtype=torch.float16, attn_implementation="flash_attention_2"
)

input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids)
  • bitsandbytes-4bit
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b", device_map="auto", load_in_4bit=True
)

input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids)
  • Static Cache
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b", device_map="auto"
)

model.generation_config.cache_implementation = "static"

input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids)

Depth Anything Model

The Depth Anything model was proposed in Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data by Lihe Yang, Bingyi Kang, Zilong Huang, Xiaogang Xu, Jiashi Feng, Hengshuang Zhao. Depth Anything is based on the DPT architecture, trained on ~62 million images, obtaining state-of-the-art results for both relative and absolute depth estimation.

image

Stable LM

StableLM 3B 4E1T was proposed in StableLM 3B 4E1T: Technical Report by Stability AI and is the first model in a series of multi-epoch pre-trained language models.

StableLM 3B 4E1T is a decoder-only base language model pre-trained on 1 trillion tokens of diverse English and code datasets for four epochs. The model architecture is transformer-based with partial Rotary Position Embeddings, SwiGLU activation, LayerNorm, etc.

The team also provides StableLM Zephyr 3B, an instruction fine-tuned version of the model that can be used for chat-based applications.

⚡️ Static cache was introduced in the following PRs ⚡️

Static past key value cache allows LlamaForCausalLM' s forward pass to be compiled using torch.compile !
This means that (cuda) graphs can be used for inference, which speeds up the decoding step by 4x!
A forward pass of Llama2 7B takes around 10.5 ms to run with this on an A100! Equivalent to TGI performances! ⚡️

⚠️ Support for generate is not included yet. This feature is experimental and subject to changes in subsequent releases.

from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
import torch
import os

# compilation triggers multiprocessing
os.environ["TOKENIZERS_PARALLELISM"] = "true"

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    device_map="auto",
    torch_dtype=torch.float16
)

# set up the static cache in advance of using the model
model._setup_cache(StaticCache, max_batch_size=1, max_cache_len=128)

# trigger compilation!
compiled_model = torch.compile(model, mode="reduce-overhead", fullgraph=True)

# run the model as usual
input_text = "A few facts about the universe: "
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda").input_ids
model_outputs = compiled_model(input_ids)

Quantization

🧼 HF Quantizer 🧼

HfQuantizer makes it easy for quantization method researchers and developers to add inference and / or quantization support in 🤗 transformers. If you are interested in adding the support for new methods, please refer to this documentation page: https://huggingface.co/docs/transformers/main/en/hf_quantizer

⚡️AQLM ⚡️

AQLM is a new quantization method that enables no-performance degradation in 2-bit precision. Check out this demo about how to run Mixtral in 2-bit on a free-tier Google Colab instance: https://huggingface.co/posts/ybelkada/434200761252287

🧼 Moving canonical repositories 🧼

The canonical repositories on the hugging face hub (models that did not have an organization, like bert-base-cased), have been moved under organizations.

You can find the entire list of models moved here: https://huggingface.co/collections/julien-c/canonical-models-65ae66e29d5b422218567567

Redirection has been set up so that your code continues working even if you continue calling the previous paths. We, however, still encourage you to update your code to use the new links so that it is entirely future proof.

Flax Improvements 🚀

The Mistral model was added to the library in Flax.

TensorFlow Improvements 🚀

With Keras 3 becoming the standard version of Keras in TensorFlow 2.16, we've made some internal changes to maintain compatibility. We now have full compatibility with TF 2.16 as long as the tf-keras compatibility package is installed. We've also taken the opportunity to do some cleanup - in particular, the objects like BatchEncoding that are returned by our tokenizers and processors can now be directly passed to Keras methods like model.fit(), which should simplify a lot of code and eliminate a long-standing source of annoyances.

Pre-Trained backbone weights 🚀

Enable loading in pretrained backbones in a new model, where all other weights are randomly initialized. Note: validation checks are still in place when creating a config. Passing in use_pretrained_backbone will raise an error. You can override by setting
config.use_pretrained_backbone = True after creating a config. However, it is not yet guaranteed to be fully backwards compatible.

from transformers import MaskFormerConfig, MaskFormerModel

config = MaskFormerConfig(
	use_pretrained_backbone=False, 
	backbone="microsoft/resnet-18"
)
config.use_pretrained_backbone = True
# Both models have resnet-18 backbone weights and all other weights randomly
# initialized 
model_1 = MaskFormerModel(config)
model_2 = MaskFormerModel(config)

Introduce a helper function load_backbone to load a backbone from a backbone's model config e.g. ResNetConfig, or from a model config which contains backbone information. This enables cleaner modeling files and crossloading between timm and transformers backbones.

from transformers import ResNetConfig, MaskFormerConfig
from transformers.utils.backbone_utils import load_backbone

# Resnet defines the backbone model to load
config = ResNetConfig()
backbone = load_backbone(config)

# Maskformer config defines a model which uses a resnet backbone
config = MaskFormerConfig(use_timm_backbone=True, backbone="resnet18")
backbone = load_backbone(config)

config = MaskFormerConfig(backbone_config=ResNetConfig())
backbone = load_backbone(config)

Add in API references, list supported backbones, updated examples, clarification and moving information to better reflect usage and docs

Image Processor work 🚀

Bugfixes and improvements 🚀

Significant community contributions

The following contributors have made significant changes to the library over the last release:

  • @nakranivaibhav
    • Improved type hinting for all attention parameters (#28479)
    • Adds LlamaForQuestionAnswering class in modeling_llama.py along with AutoModel Support (#28777)
  • @khipp
    • Fix input data file extension in examples (#28741)
    • [Docs] Fix spelling and grammar mistakes (#28825)
    • [Docs] Update project names and links in awesome-transformers (#28878)
    • [Docs] Fix backticks in inline code and documentation links (#28875)
    • [Docs] Add missing language options and fix broken links (#28852)
    • [Docs] Fix placement of tilde character (#28913)
    • [Docs] Revert translation of '@slow' decorator (#28912)
    • [Docs] Fix broken links and syntax issues (#28918)
    • [i18n-de] Translate README.md to German (#28933)
    • [Docs] Add language identifiers to fenced code blocks (#28955)
    • [i18n-de] Translate CONTRIBUTING.md to German (#28954)
  • @ThibaultLengagne
    • Add French translation: french README.md (#28696)
  • @poedator
    • HfQuantizer class for quantization-related stuff in modeling_utils.py (#26610)
  • @kiansierra
  • @hackyon
    • Adding [T5/MT5/UMT5]ForTokenClassification (#28443)
    • Always initialize tied output_embeddings if it has a bias term (#28947)
    • Add tie_weights() to LM heads and set bias in set_output_embeddings() (#28948)
    • [Phi] Add support for sdpa (#29108)
  • @SangbumChoi
    • enable graident checkpointing in DetaObjectDetection and add tests in Swin/Donut_Swin (#28615)
    • Add cuda_custom_kernel in DETA (#28989)
  • @rajveer43
    • Add models from deit (#28302)
  • @jon-tow