From 8a2c89f7d9b4930c1147ccd148ae9d661156a146 Mon Sep 17 00:00:00 2001 From: Joshua Lochner <26504141+xenova@users.noreply.github.com> Date: Thu, 13 Nov 2025 17:03:52 -0500 Subject: [PATCH 1/4] Add support for EdgeTAM --- src/models.js | 105 ++++++++++++++++++++--- src/models/image_processors.js | 1 + src/models/processors.js | 1 + src/models/sam/image_processing_sam.js | 12 +-- src/models/sam2/image_processing_sam2.js | 2 + src/models/sam2/processing_sam2.js | 3 + 6 files changed, 108 insertions(+), 16 deletions(-) create mode 100644 src/models/sam2/image_processing_sam2.js create mode 100644 src/models/sam2/processing_sam2.js diff --git a/src/models.js b/src/models.js index 2f7fd569c..f3ccafa9d 100644 --- a/src/models.js +++ b/src/models.js @@ -6014,16 +6014,8 @@ export class SamModel extends SamPreTrainedModel { } } - if (!model_inputs.input_labels && model_inputs.input_points) { - // Set default input labels if they are missing - const shape = model_inputs.input_points.dims.slice(0, -1); - const numElements = shape.reduce((a, b) => a * b, 1); - model_inputs.input_labels = new Tensor( - 'int64', - new BigInt64Array(numElements).fill(1n), - shape - ); - } + // Set default input labels if they are missing + model_inputs.input_labels ??= ones(model_inputs.input_points.dims.slice(0, -1)); const decoder_inputs = { image_embeddings: model_inputs.image_embeddings, @@ -6073,6 +6065,98 @@ export class SamImageSegmentationOutput extends ModelOutput { } ////////////////////////////////////////////////// +////////////////////////////////////////////////// +export class Sam2ImageSegmentationOutput extends ModelOutput { + /** + * @param {Object} output The output of the model. + * @param {Tensor} output.iou_scores The output logits of the model. + * @param {Tensor} output.pred_masks Predicted boxes. + * @param {Tensor} output.object_score_logits Logits for the object score, indicating if an object is present. + */ + constructor({ iou_scores, pred_masks, object_score_logits }) { + super(); + this.iou_scores = iou_scores; + this.pred_masks = pred_masks; + this.object_score_logits = object_score_logits; + } +} + +export class EdgeTamPreTrainedModel extends PreTrainedModel { } + +/** + * EdgeTAM for generating segmentation masks, given an input image + * and optional 2D location and bounding boxes. + */ +export class EdgeTamModel extends EdgeTamPreTrainedModel { + + /** + * Compute image embeddings and positional image embeddings, given the pixel values of an image. + * @param {Object} model_inputs Object containing the model inputs. + * @param {Tensor} model_inputs.pixel_values Pixel values obtained using a `Sam2Processor`. + * @returns {Promise>} The image embeddings. + */ + async get_image_embeddings({ pixel_values }) { + // in: + // - pixel_values: tensor.float32[batch_size,3,1024,1024] + // + // out: + // - image_embeddings.0: tensor.float32[batch_size,32,256,256] + // - image_embeddings.1: tensor.float32[batch_size,64,128,128] + // - image_embeddings.2: tensor.float32[batch_size,256,64,64] + return await encoderForward(this, { pixel_values }); + } + + async forward(model_inputs) { + // @ts-expect-error ts(2339) + const { num_feature_levels } = this.config.vision_config; + const image_embeddings_name = Array.from({ length: num_feature_levels }, (_, i) => `image_embeddings.${i}`); + + if (image_embeddings_name.some(name => !model_inputs[name]) || !model_inputs.image_positional_embeddings) { + // Compute the image embeddings if they are missing + model_inputs = { + ...model_inputs, + ...(await this.get_image_embeddings(model_inputs)) + } + } + + if (model_inputs.input_points) { + if (model_inputs.input_boxes && model_inputs.input_boxes.dims[1] !== 1) { + throw new Error('When both `input_points` and `input_boxes` are provided, the number of boxes per image must be 1.'); + } + const shape = model_inputs.input_points.dims; + model_inputs.input_labels ??= ones(shape.slice(0, -1)); + model_inputs.input_boxes ??= full([shape[0], 0, 4], 0.0); + + } else if (model_inputs.input_boxes) { // only boxes + const shape = model_inputs.input_boxes.dims; + model_inputs.input_labels = full([shape[0], shape[1], 0], -1n); + model_inputs.input_points = full([shape[0], 1, 0, 2], 0.0); + + } else { + throw new Error('At least one of `input_points` or `input_boxes` must be provided.'); + } + + const prompt_encoder_mask_decoder_session = this.sessions['prompt_encoder_mask_decoder']; + const decoder_inputs = pick(model_inputs, prompt_encoder_mask_decoder_session.inputNames); + + // Returns: + // - iou_scores: tensor.float32[batch_size,num_boxes_or_points,3] + // - pred_masks: tensor.float32[batch_size,num_boxes_or_points,3,256,256] + // - object_score_logits: tensor.float32[batch_size,num_boxes_or_points,1] + return await sessionRun(prompt_encoder_mask_decoder_session, decoder_inputs); + } + + /** + * Runs the model with the provided inputs + * @param {Object} model_inputs Model inputs + * @returns {Promise} Object containing segmentation outputs + */ + async _call(model_inputs) { + return new Sam2ImageSegmentationOutput(await super._call(model_inputs)); + } +} +////////////////////////////////////////////////// + ////////////////////////////////////////////////// // MarianMT models @@ -8154,6 +8238,7 @@ const MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = new Map([ const MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = new Map([ ['sam', ['SamModel', SamModel]], + ['edgetam', ['EdgeTamModel', EdgeTamModel]], ]); const MODEL_FOR_CTC_MAPPING_NAMES = new Map([ diff --git a/src/models/image_processors.js b/src/models/image_processors.js index 57c4b158a..3c9c494f4 100644 --- a/src/models/image_processors.js +++ b/src/models/image_processors.js @@ -31,6 +31,7 @@ export * from './pvt/image_processing_pvt.js' export * from './qwen2_vl/image_processing_qwen2_vl.js' export * from './rt_detr/image_processing_rt_detr.js' export * from './sam/image_processing_sam.js' +export * from './sam2/image_processing_sam2.js'; export * from './segformer/image_processing_segformer.js' export * from './siglip/image_processing_siglip.js' export * from './smolvlm/image_processing_smolvlm.js' diff --git a/src/models/processors.js b/src/models/processors.js index 32969a655..c45f9de36 100644 --- a/src/models/processors.js +++ b/src/models/processors.js @@ -13,6 +13,7 @@ export * from './paligemma/processing_paligemma.js'; export * from './pyannote/processing_pyannote.js'; export * from './qwen2_vl/processing_qwen2_vl.js'; export * from './sam/processing_sam.js'; +export * from './sam2/processing_sam2.js'; export * from './smolvlm/processing_smolvlm.js'; export * from './speecht5/processing_speecht5.js'; export * from './ultravox/processing_ultravox.js'; diff --git a/src/models/sam/image_processing_sam.js b/src/models/sam/image_processing_sam.js index bd71e1f43..2a564937f 100644 --- a/src/models/sam/image_processing_sam.js +++ b/src/models/sam/image_processing_sam.js @@ -47,12 +47,12 @@ export class SamImageProcessor extends ImageProcessor { // Reshape input points for (let i = 0; i < input_points.length; ++i) { // batch_size - let originalImageSize = original_sizes[i]; - let reshapedImageSize = reshaped_input_sizes[i]; + const [originalHeight, originalWidth] = original_sizes[i]; + const [reshapedHeight, reshapedWidth] = reshaped_input_sizes[i]; - let resizeFactors = [ - reshapedImageSize[0] / originalImageSize[0], - reshapedImageSize[1] / originalImageSize[1] + const resizeFactors = [ + reshapedWidth / originalWidth, + reshapedHeight / originalHeight, ] for (let j = 0; j < input_points[i].length; ++j) { // point_batch_size @@ -170,7 +170,7 @@ export class SamImageProcessor extends ImageProcessor { const output_masks = []; - pad_size = pad_size ?? this.pad_size; + pad_size = pad_size ?? this.pad_size ?? this.size; /** @type {[number, number]} */ const target_image_size = [pad_size.height, pad_size.width]; diff --git a/src/models/sam2/image_processing_sam2.js b/src/models/sam2/image_processing_sam2.js new file mode 100644 index 000000000..514ef91b7 --- /dev/null +++ b/src/models/sam2/image_processing_sam2.js @@ -0,0 +1,2 @@ + +export { SamImageProcessor as Sam2ImageProcessor } from '../sam/image_processing_sam.js'; diff --git a/src/models/sam2/processing_sam2.js b/src/models/sam2/processing_sam2.js new file mode 100644 index 000000000..82bb488c9 --- /dev/null +++ b/src/models/sam2/processing_sam2.js @@ -0,0 +1,3 @@ +import { SamProcessor } from "../sam/processing_sam.js"; + +export class Sam2VideoProcessor extends SamProcessor { } From e0117bd7565d22bebb83c953cd4a07547da1fe53 Mon Sep 17 00:00:00 2001 From: Joshua Lochner <26504141+xenova@users.noreply.github.com> Date: Thu, 13 Nov 2025 17:28:24 -0500 Subject: [PATCH 2/4] Update test precision --- tests/models/sam/test_modeling_sam.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/sam/test_modeling_sam.js b/tests/models/sam/test_modeling_sam.js index 6815eaad8..af30988f4 100644 --- a/tests/models/sam/test_modeling_sam.js +++ b/tests/models/sam/test_modeling_sam.js @@ -28,7 +28,7 @@ export default () => { const { pred_masks, iou_scores } = await model(inputs); expect(pred_masks.dims).toEqual([1, 1, 3, 256, 256]); - expect(pred_masks.mean().item()).toBeCloseTo(-5.769824981689453, 3); + expect(pred_masks.mean().item()).toBeCloseTo(-5.764908313751221, 3); expect(iou_scores.dims).toEqual([1, 1, 3]); expect(iou_scores.tolist()).toBeCloseToNested([[[0.8583833575248718, 0.9773167967796326, 0.8511142730712891]]]); From 1dd682bf86339c30c8c040d34b23785aba489da6 Mon Sep 17 00:00:00 2001 From: Joshua Lochner <26504141+xenova@users.noreply.github.com> Date: Thu, 13 Nov 2025 18:43:53 -0500 Subject: [PATCH 3/4] EdgeTAM no longer requires image_positional_embeddings --- src/models.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models.js b/src/models.js index f3ccafa9d..17005def6 100644 --- a/src/models.js +++ b/src/models.js @@ -6111,7 +6111,7 @@ export class EdgeTamModel extends EdgeTamPreTrainedModel { const { num_feature_levels } = this.config.vision_config; const image_embeddings_name = Array.from({ length: num_feature_levels }, (_, i) => `image_embeddings.${i}`); - if (image_embeddings_name.some(name => !model_inputs[name]) || !model_inputs.image_positional_embeddings) { + if (image_embeddings_name.some(name => !model_inputs[name])) { // Compute the image embeddings if they are missing model_inputs = { ...model_inputs, From f6ec1449473d1b9e822cee82941ec24dc6b50b28 Mon Sep 17 00:00:00 2001 From: Joshua Lochner <26504141+xenova@users.noreply.github.com> Date: Thu, 13 Nov 2025 19:44:40 -0500 Subject: [PATCH 4/4] Make shallow copy of model_inputs before in-place changes --- src/models.js | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/models.js b/src/models.js index 17005def6..7d31e2ed3 100644 --- a/src/models.js +++ b/src/models.js @@ -6012,6 +6012,8 @@ export class SamModel extends SamPreTrainedModel { ...model_inputs, ...(await this.get_image_embeddings(model_inputs)) } + } else { + model_inputs = { ...model_inputs }; } // Set default input labels if they are missing @@ -6117,6 +6119,8 @@ export class EdgeTamModel extends EdgeTamPreTrainedModel { ...model_inputs, ...(await this.get_image_embeddings(model_inputs)) } + } else { + model_inputs = { ...model_inputs }; } if (model_inputs.input_points) {