Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 99 additions & 10 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -6012,18 +6012,12 @@ export class SamModel extends SamPreTrainedModel {
...model_inputs,
...(await this.get_image_embeddings(model_inputs))
}
} else {
model_inputs = { ...model_inputs };
}

if (!model_inputs.input_labels && model_inputs.input_points) {
// Set default input labels if they are missing
const shape = model_inputs.input_points.dims.slice(0, -1);
const numElements = shape.reduce((a, b) => a * b, 1);
model_inputs.input_labels = new Tensor(
'int64',
new BigInt64Array(numElements).fill(1n),
shape
);
}
// Set default input labels if they are missing
model_inputs.input_labels ??= ones(model_inputs.input_points.dims.slice(0, -1));

const decoder_inputs = {
image_embeddings: model_inputs.image_embeddings,
Expand Down Expand Up @@ -6073,6 +6067,100 @@ export class SamImageSegmentationOutput extends ModelOutput {
}
//////////////////////////////////////////////////

//////////////////////////////////////////////////
export class Sam2ImageSegmentationOutput extends ModelOutput {
/**
* @param {Object} output The output of the model.
* @param {Tensor} output.iou_scores The output logits of the model.
* @param {Tensor} output.pred_masks Predicted boxes.
* @param {Tensor} output.object_score_logits Logits for the object score, indicating if an object is present.
*/
constructor({ iou_scores, pred_masks, object_score_logits }) {
super();
this.iou_scores = iou_scores;
this.pred_masks = pred_masks;
this.object_score_logits = object_score_logits;
}
}

export class EdgeTamPreTrainedModel extends PreTrainedModel { }

/**
* EdgeTAM for generating segmentation masks, given an input image
* and optional 2D location and bounding boxes.
*/
export class EdgeTamModel extends EdgeTamPreTrainedModel {

/**
* Compute image embeddings and positional image embeddings, given the pixel values of an image.
* @param {Object} model_inputs Object containing the model inputs.
* @param {Tensor} model_inputs.pixel_values Pixel values obtained using a `Sam2Processor`.
* @returns {Promise<Record<String, Tensor>>} The image embeddings.
*/
async get_image_embeddings({ pixel_values }) {
// in:
// - pixel_values: tensor.float32[batch_size,3,1024,1024]
//
// out:
// - image_embeddings.0: tensor.float32[batch_size,32,256,256]
// - image_embeddings.1: tensor.float32[batch_size,64,128,128]
// - image_embeddings.2: tensor.float32[batch_size,256,64,64]
return await encoderForward(this, { pixel_values });
}

async forward(model_inputs) {
// @ts-expect-error ts(2339)
const { num_feature_levels } = this.config.vision_config;
const image_embeddings_name = Array.from({ length: num_feature_levels }, (_, i) => `image_embeddings.${i}`);

if (image_embeddings_name.some(name => !model_inputs[name])) {
// Compute the image embeddings if they are missing
model_inputs = {
...model_inputs,
...(await this.get_image_embeddings(model_inputs))
}
} else {
model_inputs = { ...model_inputs };
}

if (model_inputs.input_points) {
if (model_inputs.input_boxes && model_inputs.input_boxes.dims[1] !== 1) {
throw new Error('When both `input_points` and `input_boxes` are provided, the number of boxes per image must be 1.');
}
const shape = model_inputs.input_points.dims;
model_inputs.input_labels ??= ones(shape.slice(0, -1));
model_inputs.input_boxes ??= full([shape[0], 0, 4], 0.0);

} else if (model_inputs.input_boxes) { // only boxes
const shape = model_inputs.input_boxes.dims;
model_inputs.input_labels = full([shape[0], shape[1], 0], -1n);
model_inputs.input_points = full([shape[0], 1, 0, 2], 0.0);

} else {
throw new Error('At least one of `input_points` or `input_boxes` must be provided.');
}

const prompt_encoder_mask_decoder_session = this.sessions['prompt_encoder_mask_decoder'];
const decoder_inputs = pick(model_inputs, prompt_encoder_mask_decoder_session.inputNames);

// Returns:
// - iou_scores: tensor.float32[batch_size,num_boxes_or_points,3]
// - pred_masks: tensor.float32[batch_size,num_boxes_or_points,3,256,256]
// - object_score_logits: tensor.float32[batch_size,num_boxes_or_points,1]
return await sessionRun(prompt_encoder_mask_decoder_session, decoder_inputs);
}

/**
* Runs the model with the provided inputs
* @param {Object} model_inputs Model inputs
* @returns {Promise<Sam2ImageSegmentationOutput>} Object containing segmentation outputs
*/
async _call(model_inputs) {
return new Sam2ImageSegmentationOutput(await super._call(model_inputs));
}
}
//////////////////////////////////////////////////


//////////////////////////////////////////////////
// MarianMT models
Expand Down Expand Up @@ -8154,6 +8242,7 @@ const MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING_NAMES = new Map([

const MODEL_FOR_MASK_GENERATION_MAPPING_NAMES = new Map([
['sam', ['SamModel', SamModel]],
['edgetam', ['EdgeTamModel', EdgeTamModel]],
]);

const MODEL_FOR_CTC_MAPPING_NAMES = new Map([
Expand Down
1 change: 1 addition & 0 deletions src/models/image_processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export * from './pvt/image_processing_pvt.js'
export * from './qwen2_vl/image_processing_qwen2_vl.js'
export * from './rt_detr/image_processing_rt_detr.js'
export * from './sam/image_processing_sam.js'
export * from './sam2/image_processing_sam2.js';
export * from './segformer/image_processing_segformer.js'
export * from './siglip/image_processing_siglip.js'
export * from './smolvlm/image_processing_smolvlm.js'
Expand Down
1 change: 1 addition & 0 deletions src/models/processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ export * from './paligemma/processing_paligemma.js';
export * from './pyannote/processing_pyannote.js';
export * from './qwen2_vl/processing_qwen2_vl.js';
export * from './sam/processing_sam.js';
export * from './sam2/processing_sam2.js';
export * from './smolvlm/processing_smolvlm.js';
export * from './speecht5/processing_speecht5.js';
export * from './ultravox/processing_ultravox.js';
Expand Down
12 changes: 6 additions & 6 deletions src/models/sam/image_processing_sam.js
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ export class SamImageProcessor extends ImageProcessor {

// Reshape input points
for (let i = 0; i < input_points.length; ++i) { // batch_size
let originalImageSize = original_sizes[i];
let reshapedImageSize = reshaped_input_sizes[i];
const [originalHeight, originalWidth] = original_sizes[i];
const [reshapedHeight, reshapedWidth] = reshaped_input_sizes[i];

let resizeFactors = [
reshapedImageSize[0] / originalImageSize[0],
reshapedImageSize[1] / originalImageSize[1]
const resizeFactors = [
reshapedWidth / originalWidth,
reshapedHeight / originalHeight,
]

for (let j = 0; j < input_points[i].length; ++j) { // point_batch_size
Expand Down Expand Up @@ -170,7 +170,7 @@ export class SamImageProcessor extends ImageProcessor {

const output_masks = [];

pad_size = pad_size ?? this.pad_size;
pad_size = pad_size ?? this.pad_size ?? this.size;

/** @type {[number, number]} */
const target_image_size = [pad_size.height, pad_size.width];
Expand Down
2 changes: 2 additions & 0 deletions src/models/sam2/image_processing_sam2.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@

export { SamImageProcessor as Sam2ImageProcessor } from '../sam/image_processing_sam.js';
3 changes: 3 additions & 0 deletions src/models/sam2/processing_sam2.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import { SamProcessor } from "../sam/processing_sam.js";

export class Sam2VideoProcessor extends SamProcessor { }
2 changes: 1 addition & 1 deletion tests/models/sam/test_modeling_sam.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ export default () => {
const { pred_masks, iou_scores } = await model(inputs);

expect(pred_masks.dims).toEqual([1, 1, 3, 256, 256]);
expect(pred_masks.mean().item()).toBeCloseTo(-5.769824981689453, 3);
expect(pred_masks.mean().item()).toBeCloseTo(-5.764908313751221, 3);
expect(iou_scores.dims).toEqual([1, 1, 3]);
expect(iou_scores.tolist()).toBeCloseToNested([[[0.8583833575248718, 0.9773167967796326, 0.8511142730712891]]]);

Expand Down