Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/source/pipelines.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,33 @@ let result = await transcriber('https://huggingface.co/datasets/Narsil/asr_dummy

## Pipeline options

### Loading

We offer a variety of options to control how models are loaded from the Hugging Face Hub (or locally).
By default, the *quantized* version of the model is used, which is smaller and faster, but usually less accurate.
To override this behaviour (i.e., use the unquantized model), you can use a custom `PretrainedOptions` object
as the third parameter to the `pipeline` function:

```javascript
// Allocation a pipeline for feature extraction, using the unquantized model
const pipe = await pipeline('feature-extraction', 'Xenova/all-MiniLM-L6-v2', {
quantized: false,
});
```

You can also specify which revision of the model to use, by passing a `revision` parameter.
Since the Hugging Face Hub uses a git-based versioning system, you can use any valid git revision specifier (e.g., branch name or commit hash)

```javascript
let transcriber = await pipeline('automatic-speech-recognition', 'Xenova/whisper-tiny.en', {
revision: 'output_attentions',
});
```

For the full list of options, check out the [PretrainedOptions](/api/utils/hub#module_utils/hub..PretrainedOptions) documentation.


### Running
Many pipelines have additional options that you can specify. For example, when using a model that does multilingual translation, you can specify the source and target languages like this:

<!-- TODO: Replace 'Xenova/nllb-200-distilled-600M' with 'facebook/nllb-200-distilled-600M' -->
Expand Down
54 changes: 50 additions & 4 deletions src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
* ```javascript
* import { pipeline } from '@xenova/transformers';
*
* let pipeline = await pipeline('sentiment-analysis');
* let result = await pipeline('I love transformers!');
* let classifier = await pipeline('sentiment-analysis');
* let result = await classifier('I love transformers!');
* // [{'label': 'POSITIVE', 'score': 0.999817686}]
* ```
*
Expand Down Expand Up @@ -1312,6 +1312,26 @@ export class ZeroShotImageClassificationPipeline extends Pipeline {
/**
* Object detection pipeline using any `AutoModelForObjectDetection`.
* This pipeline predicts bounding boxes of objects and their classes.
*
* **Example:** Run object-detection with `facebook/detr-resnet-50`.
* ```javascript
* let img = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg';
*
* let detector = await pipeline('object-detection', 'Xenova/detr-resnet-50');
* let output = await detector(img, { threshold: 0.9 });
* // [{
* // "score": 0.9976370930671692,
* // "label": "remote",
* // "box": { "xmin": 31, "ymin": 68, "xmax": 190, "ymax": 118 }
* // },
* // ...
* // {
* // "score": 0.9984092116355896,
* // "label": "cat",
* // "box": { "xmin": 331, "ymin": 19, "xmax": 649, "ymax": 371 }
* // }]
* ```
*
* @extends Pipeline
*/
export class ObjectDetectionPipeline extends Pipeline {
Expand Down Expand Up @@ -1354,9 +1374,35 @@ export class ObjectDetectionPipeline extends Pipeline {

// Add labels
let id2label = this.model.config.id2label;
processed.forEach(x => x.labels = x.classes.map(y => id2label[y]));

return isBatched ? processed : processed[0];
// Format output
const result = processed.map(batch => {
return batch.boxes.map((box, i) => {
return {
score: batch.scores[i],
label: id2label[batch.classes[i]],
box: this._get_bounding_box(box, !percentage),
}
})
})

return isBatched ? result : result[0];
}

/**
* Helper function to convert list [xmin, xmax, ymin, ymax] into object { "xmin": xmin, ... }
* @param {number[]} box The bounding box as a list.
* @param {boolean} asInteger Whether to cast to integers.
* @returns {Object} The bounding box as an object.
* @private
*/
_get_bounding_box(box, asInteger) {
if (asInteger) {
box = box.map(x => x | 0);
}
const [xmin, ymin, xmax, ymax] = box;

return { xmin, ymin, xmax, ymax };
}
}

Expand Down
124 changes: 79 additions & 45 deletions tests/pipelines.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -1116,67 +1116,101 @@ describe('Pipelines', () => {
let detector = await pipeline('object-detection', m(models[0]));

// TODO add batched test cases when supported
let url = 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/savanna.jpg';
let urls = ['https://huggingface.co/datasets/mishig/sample_images/resolve/main/airport.jpg']
let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg';
let urls = ['https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/savanna.jpg']

// single + threshold
{
let output = await detector(url, {
threshold: 0.9,
});

// let expected = {
// "boxes": [
// [352.8210112452507, 247.36732184886932, 390.5271676182747, 318.09066116809845],
// [111.15852802991867, 235.34255504608154, 224.96717244386673, 325.21119117736816],
// [13.524770736694336, 146.81672930717468, 207.97560095787048, 278.6452639102936],
// [187.396682202816, 227.97491312026978, 313.05202156305313, 300.26460886001587],
// [201.60082161426544, 230.86223602294922, 312.1393972635269, 306.5505266189575],
// [365.85242718458176, 95.3144109249115, 526.5485098958015, 313.17670941352844]
// ],
// "classes": [24, 24, 25, 24, 24, 25],
// "scores": [0.9989480376243591, 0.9990893006324768, 0.9690554738044739, 0.9274907112121582, 0.9714975953102112, 0.9989491105079651],
// "labels": ["zebra", "zebra", "giraffe", "zebra", "zebra", "giraffe"]
// };

let num_classes = output.boxes.length;
expect(num_classes).toBeGreaterThan(1);
expect(output.classes.length).toEqual(num_classes);
expect(output.scores.length).toEqual(num_classes);
expect(output.labels.length).toEqual(num_classes);
// let expected = [
// {
// "score": 0.9977124929428101,
// "label": "remote",
// "box": { "xmin": 41, "ymin": 70, "xmax": 176, "ymax": 118 }
// },
// {
// "score": 0.9984639883041382,
// "label": "remote",
// "box": { "xmin": 332, "ymin": 73, "xmax": 369, "ymax": 188 }
// },
// {
// "score": 0.9964856505393982,
// "label": "couch",
// "box": { "xmin": 0, "ymin": 1, "xmax": 639, "ymax": 474 }
// },
// {
// "score": 0.9988334774971008,
// "label": "cat",
// "box": { "xmin": 11, "ymin": 51, "xmax": 314, "ymax": 472 }
// },
// {
// "score": 0.9982513785362244,
// "label": "cat",
// "box": { "xmin": 345, "ymin": 22, "xmax": 640, "ymax": 371 }
// }
// ]

expect(output.length).toBeGreaterThan(0);
for (let cls of output) {
expect(typeof cls.score).toBe('number');
expect(typeof cls.label).toBe('string');
for (let key of ['xmin', 'ymin', 'xmax', 'ymax']) {
expect(typeof cls.box[key]).toBe('number');
}
}
}

// single + threshold + percentage
// batched + threshold + percentage
{
let output = await detector(urls, {
threshold: 0.9,
percentage: true
});
// let expected = [[
// {
// score: 0.9991137385368347,
// label: 'zebra',
// box: { xmin: 0.65165576338768, ymin: 0.685152679681778, xmax: 0.723189502954483, ymax: 0.8801506459712982 }
// },
// {
// score: 0.998811662197113,
// label: 'zebra',
// box: { xmin: 0.20797613263130188, ymin: 0.6543092578649521, xmax: 0.4147692620754242, ymax: 0.9040975719690323 }
// },
// {
// score: 0.9707837104797363,
// label: 'giraffe',
// box: { xmin: 0.02498096227645874, ymin: 0.40549489855766296, xmax: 0.38669759035110474, ymax: 0.7895723879337311 }
// },
// {
// score: 0.9984336495399475,
// label: 'zebra',
// box: { xmin: 0.3540637195110321, ymin: 0.6370827257633209, xmax: 0.5765090882778168, ymax: 0.8480959832668304 }
// },
// {
// score: 0.9986463785171509,
// label: 'giraffe',
// box: { xmin: 0.6763969212770462, ymin: 0.25748637318611145, xmax: 0.974339172244072, ymax: 0.8684568107128143 }
// }
// ]]

expect(output).toHaveLength(urls.length); // Same number of inputs as outputs

for (let i = 0; i < output.length; ++i) {
expect(output[i].length).toBeGreaterThan(0);
for (let cls of output[i]) {
expect(typeof cls.score).toBe('number');
expect(typeof cls.label).toBe('string');
for (let key of ['xmin', 'ymin', 'xmax', 'ymax']) {
expect(typeof cls.box[key]).toBe('number');
}
}
}


// let expected = [{
// "boxes": [
// [0.7231650948524475, 0.32641804218292236, 0.981127917766571, 0.9918863773345947],
// [0.7529061436653137, 0.52558633685112, 0.8229959607124329, 0.6482008993625641],
// [0.5080368518829346, 0.5156279355287552, 0.5494132041931152, 0.5434067696332932],
// [0.33636586368083954, 0.5217841267585754, 0.3535611182451248, 0.6151944994926453],
// [0.42090220749378204, 0.4482414871454239, 0.5515891760587692, 0.5207531303167343],
// [0.1988394856452942, 0.41224047541618347, 0.45213085412979126, 0.5206181704998016],
// [0.5063001662492752, 0.5170856416225433, 0.5478668659925461, 0.54373899102211],
// [0.5734506398439407, 0.4508090913295746, 0.7049560993909836, 0.6252130568027496],
// ],
// "classes": [6, 1, 8, 1, 5, 5, 3, 6],
// "scores": [0.9970788359642029, 0.996989905834198, 0.9505048990249634, 0.9984546899795532, 0.9942372441291809, 0.9989550709724426, 0.938920259475708, 0.9992448091506958],
// "labels": ["bus", "person", "truck", "person", "airplane", "airplane", "car", "bus"]
// }];

expect(output).toHaveLength(urls.length);

let num_classes = output[0].boxes.length;
expect(num_classes).toBeGreaterThan(1);
expect(output[0].classes.length).toEqual(num_classes);
expect(output[0].scores.length).toEqual(num_classes);
expect(output[0].labels.length).toEqual(num_classes);
}

await detector.dispose();
Expand Down