Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.
1. **[SigLIP](https://huggingface.co/docs/transformers/main/model_doc/siglip)** (from Google AI) released with the paper [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343) by Xiaohua Zhai, Basil Mustafa, Alexander Kolesnikov, Lucas Beyer.
1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.
1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
Expand Down
1 change: 1 addition & 0 deletions docs/snippets/6_supported-models.snippet
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer)** (from ZhuiyiTechnology), released together with the paper [RoFormer: Enhanced Transformer with Rotary Position Embedding](https://arxiv.org/abs/2104.09864) by Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu.
1. **[SegFormer](https://huggingface.co/docs/transformers/model_doc/segformer)** (from NVIDIA) released with the paper [SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers](https://arxiv.org/abs/2105.15203) by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo.
1. **[Segment Anything](https://huggingface.co/docs/transformers/model_doc/sam)** (from Meta AI) released with the paper [Segment Anything](https://arxiv.org/pdf/2304.02643v1.pdf) by Alexander Kirillov, Eric Mintun, Nikhila Ravi, Hanzi Mao, Chloe Rolland, Laura Gustafson, Tete Xiao, Spencer Whitehead, Alex Berg, Wan-Yen Lo, Piotr Dollar, Ross Girshick.
1. **[SigLIP](https://huggingface.co/docs/transformers/main/model_doc/siglip)** (from Google AI) released with the paper [Sigmoid Loss for Language Image Pre-Training](https://arxiv.org/abs/2303.15343) by Xiaohua Zhai, Basil Mustafa, Alexander Kolesnikov, Lucas Beyer.
1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.
1. **[SqueezeBERT](https://huggingface.co/docs/transformers/model_doc/squeezebert)** (from Berkeley) released with the paper [SqueezeBERT: What can computer vision teach NLP about efficient neural networks?](https://arxiv.org/abs/2006.11316) by Forrest N. Iandola, Albert E. Shaw, Ravi Krishna, and Kurt W. Keutzer.
Expand Down
19 changes: 14 additions & 5 deletions scripts/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,11 +745,20 @@
'distilroberta-base',
],
},
# 'sam': [
# 'facebook/sam-vit-base',
# 'facebook/sam-vit-large',
# 'facebook/sam-vit-huge',
# ],
'sam': {
# Mask generation
'mask-generation': [
# SAM
'facebook/sam-vit-base',
'facebook/sam-vit-large',
'facebook/sam-vit-huge',
'wanglab/medsam-vit-base',

# SlimSAM
'nielsr/slimsam-50-uniform',
'nielsr/slimsam-77-uniform',
],
},
'segformer': {
# Image segmentation
'image-segmentation': [
Expand Down
143 changes: 135 additions & 8 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ const MODEL_TYPES = {
Seq2Seq: 2,
Vision2Seq: 3,
DecoderOnly: 4,
MaskGeneration: 5,
}
//////////////////////////////////////////////////

Expand Down Expand Up @@ -771,6 +772,13 @@ export class PreTrainedModel extends Callable {
getModelJSON(pretrained_model_name_or_path, 'generation_config.json', false, options),
]);

} else if (modelType === MODEL_TYPES.MaskGeneration) {
info = await Promise.all([
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
constructSession(pretrained_model_name_or_path, 'vision_encoder', options),
constructSession(pretrained_model_name_or_path, 'prompt_encoder_mask_decoder', options),
]);

} else if (modelType === MODEL_TYPES.EncoderDecoder) {
info = await Promise.all([
AutoConfig.from_pretrained(pretrained_model_name_or_path, options),
Expand Down Expand Up @@ -4242,12 +4250,130 @@ export class YolosObjectDetectionOutput extends ModelOutput {

//////////////////////////////////////////////////
export class SamPreTrainedModel extends PreTrainedModel { }

/**
* Segment Anything Model (SAM) for generating segmentation masks, given an input image
* and optional 2D location and bounding boxes.
*
* **Example:** Perform mask generation w/ `Xenova/sam-vit-base`.
* ```javascript
* import { SamModel, AutoProcessor, RawImage } from '@xenova/transformers';
*
* const model = await SamModel.from_pretrained('Xenova/sam-vit-base');
* const processor = await AutoProcessor.from_pretrained('Xenova/sam-vit-base');
*
* const img_url = 'https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png';
* const raw_image = await RawImage.read(img_url);
* const input_points = [[[450, 600]]] // 2D localization of a window
*
* const inputs = await processor(raw_image, input_points);
* const outputs = await model(inputs);
*
* const masks = await processor.post_process_masks(outputs.pred_masks, inputs.original_sizes, inputs.reshaped_input_sizes);
* // [
* // Tensor {
* // dims: [ 1, 3, 1764, 2646 ],
* // type: 'bool',
* // data: Uint8Array(14002632) [ ... ],
* // size: 14002632
* // }
* // ]
* const scores = outputs.iou_scores;
* // Tensor {
* // dims: [ 1, 1, 3 ],
* // type: 'float32',
* // data: Float32Array(3) [
* // 0.8892380595207214,
* // 0.9311248064041138,
* // 0.983696699142456
* // ],
* // size: 3
* // }
* ```
*/
export class SamModel extends SamPreTrainedModel {
/**
* @param {Object} model_inputs
* @param {Tensor} model_inputs.pixel_values Pixel values as a Tensor with shape `(batch_size, num_channels, height, width)`.
* @param {Tensor} model_inputs.input_points Input 2D spatial points with shape `(batch_size, num_points, 2)`. This is used by the prompt encoder to encode the prompt.
* @todo Add support for `input_labels`, `input_boxes`, `input_masks`, and `image_embeddings`.
* Creates a new instance of the `SamModel` class.
* @param {Object} config The configuration object specifying the hyperparameters and other model settings.
* @param {Object} vision_encoder The ONNX session containing the vision encoder model.
* @param {any} prompt_encoder_mask_decoder The ONNX session containing the prompt encoder and mask decoder model.
*/
constructor(config, vision_encoder, prompt_encoder_mask_decoder) {
super(config, vision_encoder);
this.prompt_encoder_mask_decoder = prompt_encoder_mask_decoder;
}

/**
* Compute image embeddings and positional image embeddings, given the pixel values of an image.
* @param {Object} model_inputs Object containing the model inputs.
* @param {Tensor} model_inputs.pixel_values Pixel values obtained using a `SamProcessor`.
* @returns {Promise<{ image_embeddings: Tensor, image_positional_embeddings: Tensor }>} The image embeddings and positional image embeddings.
*/
async get_image_embeddings({ pixel_values }) {
// in:
// - pixel_values: tensor.float32[batch_size,3,1024,1024]
//
// out:
// - image_embeddings: tensor.float32[batch_size,256,64,64]
// - image_positional_embeddings: tensor.float32[batch_size,256,64,64]
return await encoderForward(this, { pixel_values })
}

/**
* @typedef {Object} SamModelInputs Object containing the model inputs.
* @property {Tensor} pixel_values Pixel values as a Tensor with shape `(batch_size, num_channels, height, width)`.
* These can be obtained using a `SamProcessor`.
* @property {Tensor} input_points Input 2D spatial points with shape `(batch_size, num_points, 2)`.
* This is used by the prompt encoder to encode the prompt.
* @property {Tensor} [input_labels] Input labels for the points, as a Tensor of shape `(batch_size, point_batch_size, num_points)`.
* This is used by the prompt encoder to encode the prompt. There are 4 types of labels:
* - `1`: the point is a point that contains the object of interest
* - `0`: the point is a point that does not contain the object of interest
* - `-1`: the point corresponds to the background
* - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
* @property {Tensor} [image_embeddings] Image embeddings used by the mask decoder.
* @property {Tensor} [image_positional_embeddings] Image positional embeddings used by the mask decoder.
*/

/**
* @param {SamModelInputs} model_inputs Object containing the model inputs.
* @returns {Promise<Object>} The output of the model.
*/
async forward(model_inputs) {
if (!model_inputs.image_embeddings || !model_inputs.image_positional_embeddings) {
// Compute the image embeddings if they are missing
model_inputs = {
...model_inputs,
...(await this.get_image_embeddings(model_inputs))
}
}

if (!model_inputs.input_labels) {
// Set default input labels if they are missing
const shape = model_inputs.input_points.dims.slice(0, -1);
const numElements = shape.reduce((a, b) => a * b, 1);
model_inputs.input_labels = new Tensor(
'int64',
new BigInt64Array(numElements).fill(1n),
shape
);
}

// Returns:
// - iou_scores: tensor.float32[batch_size,point_batch_size,3]
// - pred_masks: tensor.float32[batch_size,point_batch_size,3,256,256]
return await sessionRun(this.prompt_encoder_mask_decoder, {
input_points: model_inputs.input_points,
input_labels: model_inputs.input_labels,
image_embeddings: model_inputs.image_embeddings,
image_positional_embeddings: model_inputs.image_positional_embeddings,
});
}

/**
* Runs the model with the provided inputs
* @param {Object} model_inputs Model inputs
* @returns {Promise<SamImageSegmentationOutput>} Object containing segmentation outputs
*/
async _call(model_inputs) {
return new SamImageSegmentationOutput(await super._call(model_inputs));
Expand Down Expand Up @@ -5049,7 +5175,6 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([

['hifigan', ['SpeechT5HifiGan', SpeechT5HifiGan]],

['sam', ['SamModel', SamModel]], // TODO change to encoder-decoder when model is split correctly
]);

const MODEL_MAPPING_NAMES_ENCODER_DECODER = new Map([
Expand Down Expand Up @@ -5290,7 +5415,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
[MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.MaskGeneration],
[MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
Expand Down Expand Up @@ -5329,7 +5454,9 @@ for (const [name, model, type] of CUSTOM_MAPPING) {
* let model = await AutoModel.from_pretrained('bert-base-uncased');
*/
export class AutoModel extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_MAPPING_NAMES_ENCODER_ONLY, MODEL_MAPPING_NAMES_ENCODER_DECODER, MODEL_MAPPING_NAMES_DECODER_ONLY];
/** @type {Map<string, Object>[]} */
// @ts-ignore
static MODEL_CLASS_MAPPINGS = MODEL_CLASS_TYPE_MAPPING.map(x => x[0]);
static BASE_IF_FAIL = true;
}

Expand Down Expand Up @@ -5493,7 +5620,7 @@ export class AutoModelForZeroShotObjectDetection extends PretrainedMixin {


/**
* Helper class which is used to instantiate pretrained object detection models with the `from_pretrained` function.
* Helper class which is used to instantiate pretrained mask generation models with the `from_pretrained` function.
* The chosen model class is determined by the type specified in the model config.
*
* @example
Expand Down
Loading