diff --git a/src/models.js b/src/models.js index 1a9b021f1..9c65ce5d8 100644 --- a/src/models.js +++ b/src/models.js @@ -3779,11 +3779,7 @@ export class VitMattePreTrainedModel extends PreTrainedModel { } * import { Tensor, cat } from '@xenova/transformers'; * * // Visualize predicted alpha matte - * const imageTensor = new Tensor( - * 'uint8', - * new Uint8Array(image.data), - * [image.height, image.width, image.channels] - * ).transpose(2, 0, 1); + * const imageTensor = image.toTensor(); * * // Convert float (0-1) alpha matte to uint8 (0-255) * const alphaChannel = alphas diff --git a/src/processors.js b/src/processors.js index 28b96f971..4713a6ae2 100644 --- a/src/processors.js +++ b/src/processors.js @@ -33,10 +33,11 @@ import { min, max, softmax, + bankers_round, } from './utils/maths.js'; -import { Tensor, transpose, cat, interpolate, stack } from './utils/tensor.js'; +import { Tensor, permute, cat, interpolate, stack } from './utils/tensor.js'; import { RawImage } from './utils/image.js'; import { @@ -174,14 +175,15 @@ function validate_audio_inputs(audio, feature_extractor) { * @private */ function constraint_to_multiple_of(val, multiple, minVal = 0, maxVal = null) { - let x = Math.round(val / multiple) * multiple; + const a = val / multiple; + let x = bankers_round(a) * multiple; if (maxVal !== null && x > maxVal) { - x = Math.floor(val / multiple) * multiple; + x = Math.floor(a) * multiple; } if (x < minVal) { - x = Math.ceil(val / multiple) * multiple; + x = Math.ceil(a) * multiple; } return x; @@ -195,8 +197,8 @@ function constraint_to_multiple_of(val, multiple, minVal = 0, maxVal = null) { */ function enforce_size_divisibility([width, height], divisor) { return [ - Math.floor(width / divisor) * divisor, - Math.floor(height / divisor) * divisor + Math.max(Math.floor(width / divisor), 1) * divisor, + Math.max(Math.floor(height / divisor), 1) * divisor ]; } @@ -348,7 +350,7 @@ export class ImageFeatureExtractor extends FeatureExtractor { /** * Pad the image by a certain amount. * @param {Float32Array} pixelData The pixel data to pad. - * @param {number[]} imgDims The dimensions of the image. + * @param {number[]} imgDims The dimensions of the image (height, width, channels). * @param {{width:number; height:number}|number} padSize The dimensions of the padded image. * @param {Object} options The options for padding. * @param {'constant'|'symmetric'} [options.mode='constant'] The type of padding to add. @@ -361,7 +363,7 @@ export class ImageFeatureExtractor extends FeatureExtractor { center = false, constant_values = 0, } = {}) { - const [imageWidth, imageHeight, imageChannels] = imgDims; + const [imageHeight, imageWidth, imageChannels] = imgDims; let paddedImageWidth, paddedImageHeight; if (typeof padSize === 'number') { @@ -513,8 +515,8 @@ export class ImageFeatureExtractor extends FeatureExtractor { if (this.config.keep_aspect_ratio && this.config.ensure_multiple_of) { // determine new height and width - let scale_height = size.height / srcHeight; - let scale_width = size.width / srcWidth; + let scale_height = newHeight / srcHeight; + let scale_width = newWidth / srcWidth; // scale as little as possible if (Math.abs(1 - scale_width) < Math.abs(1 - scale_height)) { @@ -616,6 +618,9 @@ export class ImageFeatureExtractor extends FeatureExtractor { /** @type {HeightWidth} */ const reshaped_input_size = [image.height, image.width]; + // NOTE: All pixel-level manipulation (i.e., modifying `pixelData`) + // occurs with data in the hwc format (height, width, channels), + // to emulate the behavior of the original Python code (w/ numpy). let pixelData = Float32Array.from(image.data); let imgDims = [image.height, image.width, image.channels]; @@ -646,21 +651,23 @@ export class ImageFeatureExtractor extends FeatureExtractor { } // do padding after rescaling/normalizing - if (do_pad ?? (this.do_pad && this.pad_size)) { - const padded = this.pad_image(pixelData, [image.width, image.height, image.channels], this.pad_size); - [pixelData, imgDims] = padded; // Update pixel data and image dimensions + if (do_pad ?? this.do_pad) { + if (this.pad_size) { + const padded = this.pad_image(pixelData, [image.height, image.width, image.channels], this.pad_size); + [pixelData, imgDims] = padded; // Update pixel data and image dimensions + } else if (this.size_divisibility) { + const [paddedWidth, paddedHeight] = enforce_size_divisibility([imgDims[1], imgDims[0]], this.size_divisibility); + [pixelData, imgDims] = this.pad_image(pixelData, imgDims, { width: paddedWidth, height: paddedHeight }); + } } - // Create HWC tensor - const img = new Tensor('float32', pixelData, imgDims); - - // convert to channel dimension format: - const transposed = transpose(img, [2, 0, 1]); // hwc -> chw + const pixel_values = new Tensor('float32', pixelData, imgDims) + .permute(2, 0, 1); // convert to channel dimension format (hwc -> chw) return { original_size: [srcHeight, srcWidth], reshaped_input_size: reshaped_input_size, - pixel_values: transposed, + pixel_values: pixel_values, } } @@ -760,9 +767,9 @@ export class SegformerFeatureExtractor extends ImageFeatureExtractor { return toReturn; } } -export class DPTImageProcessor extends ImageFeatureExtractor { } -export class BitImageProcessor extends ImageFeatureExtractor { } export class DPTFeatureExtractor extends ImageFeatureExtractor { } +export class DPTImageProcessor extends DPTFeatureExtractor { } // NOTE: extends DPTFeatureExtractor +export class BitImageProcessor extends ImageFeatureExtractor { } export class GLPNFeatureExtractor extends ImageFeatureExtractor { } export class CLIPFeatureExtractor extends ImageFeatureExtractor { } export class ChineseCLIPFeatureExtractor extends ImageFeatureExtractor { } @@ -835,7 +842,7 @@ export class DeiTFeatureExtractor extends ImageFeatureExtractor { } export class BeitFeatureExtractor extends ImageFeatureExtractor { } export class DonutFeatureExtractor extends ImageFeatureExtractor { pad_image(pixelData, imgDims, padSize, options = {}) { - const [imageWidth, imageHeight, imageChannels] = imgDims; + const [imageHeight, imageWidth, imageChannels] = imgDims; let image_mean = this.image_mean; if (!Array.isArray(this.image_mean)) { @@ -1382,7 +1389,7 @@ export class Swin2SRImageProcessor extends ImageFeatureExtractor { pad_image(pixelData, imgDims, padSize, options = {}) { // NOTE: In this case, `padSize` represents the size of the sliding window for the local attention. // In other words, the image is padded so that its width and height are multiples of `padSize`. - const [imageWidth, imageHeight, imageChannels] = imgDims; + const [imageHeight, imageWidth, imageChannels] = imgDims; return super.pad_image(pixelData, imgDims, { // NOTE: For Swin2SR models, the original python implementation adds padding even when the image's width/height is already diff --git a/src/utils/image.js b/src/utils/image.js index 2d12cb876..1ee77d900 100644 --- a/src/utils/image.js +++ b/src/utils/image.js @@ -10,6 +10,7 @@ import { getFile } from './hub.js'; import { env } from '../env.js'; +import { Tensor } from './tensor.js'; // Will be empty (or not used) if running in browser or web-worker import sharp from 'sharp'; @@ -166,7 +167,7 @@ export class RawImage { /** * Helper method to create a new Image from a tensor - * @param {import('./tensor.js').Tensor} tensor + * @param {Tensor} tensor */ static fromTensor(tensor, channel_format = 'CHW') { if (tensor.dims.length !== 3) { @@ -586,6 +587,23 @@ export class RawImage { return await canvas.convertToBlob({ type, quality }); } + toTensor(channel_format = 'CHW') { + let tensor = new Tensor( + 'uint8', + new Uint8Array(this.data), + [this.height, this.width, this.channels] + ); + + if (channel_format === 'HWC') { + // Do nothing + } else if (channel_format === 'CHW') { // hwc -> chw + tensor = tensor.permute(2, 0, 1); + } else { + throw new Error(`Unsupported channel format: ${channel_format}`); + } + return tensor; + } + toCanvas() { if (!BROWSER_ENV) { throw new Error('toCanvas() is only supported in browser environments.') diff --git a/src/utils/maths.js b/src/utils/maths.js index 216def07e..264b69fc7 100644 --- a/src/utils/maths.js +++ b/src/utils/maths.js @@ -88,15 +88,15 @@ export function interpolate_data(input, [in_channels, in_height, in_width], [out /** - * Helper method to transpose a `AnyTypedArray` directly + * Helper method to permute a `AnyTypedArray` directly * @template {AnyTypedArray} T * @param {T} array * @param {number[]} dims * @param {number[]} axes - * @returns {[T, number[]]} The transposed array and the new shape. + * @returns {[T, number[]]} The permuted array and the new shape. */ -export function transpose_data(array, dims, axes) { - // Calculate the new shape of the transposed array +export function permute_data(array, dims, axes) { + // Calculate the new shape of the permuted array // and the stride of the original array const shape = new Array(axes.length); const stride = new Array(axes.length); @@ -110,21 +110,21 @@ export function transpose_data(array, dims, axes) { // Precompute inverse mapping of stride const invStride = axes.map((_, i) => stride[axes.indexOf(i)]); - // Create the transposed array with the new shape + // Create the permuted array with the new shape // @ts-ignore - const transposedData = new array.constructor(array.length); + const permutedData = new array.constructor(array.length); - // Transpose the original array to the new array + // Permute the original array to the new array for (let i = 0; i < array.length; ++i) { let newIndex = 0; for (let j = dims.length - 1, k = i; j >= 0; --j) { newIndex += (k % dims[j]) * invStride[j]; k = Math.floor(k / dims[j]); } - transposedData[newIndex] = array[i]; + permutedData[newIndex] = array[i]; } - return [transposedData, shape]; + return [permutedData, shape]; } @@ -952,3 +952,17 @@ export function round(num, decimals) { const pow = Math.pow(10, decimals); return Math.round(num * pow) / pow; } + +/** + * Helper function to round a number to the nearest integer, with ties rounded to the nearest even number. + * Also known as "bankers' rounding". This is the default rounding mode in python. For example: + * 1.5 rounds to 2 and 2.5 rounds to 2. + * + * @param {number} x The number to round + * @returns {number} The rounded number + */ +export function bankers_round(x) { + const r = Math.round(x); + const br = Math.abs(x) % 1 === 0.5 ? (r % 2 === 0 ? r : r - 1) : r; + return br; +} diff --git a/src/utils/tensor.js b/src/utils/tensor.js index 819c2dbb6..ccdf781be 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -11,7 +11,7 @@ import { ONNX } from '../backends/onnx.js'; import { interpolate_data, - transpose_data + permute_data } from './maths.js'; @@ -309,16 +309,18 @@ export class Tensor { } /** - * Return a transposed version of this Tensor, according to the provided dimensions. - * @param {...number} dims Dimensions to transpose. - * @returns {Tensor} The transposed tensor. + * Return a permuted version of this Tensor, according to the provided dimensions. + * @param {...number} dims Dimensions to permute. + * @returns {Tensor} The permuted tensor. */ - transpose(...dims) { - return transpose(this, dims); + permute(...dims) { + return permute(this, dims); } - // TODO: rename transpose to permute - // TODO: implement transpose + // TODO: implement transpose. For now (backwards compatibility), it's just an alias for permute() + transpose(...dims) { + return this.permute(...dims); + } // TODO add .max() and .min() methods @@ -680,14 +682,14 @@ function reshape(data, dimensions) { } /** - * Transposes a tensor according to the provided axes. - * @param {any} tensor The input tensor to transpose. - * @param {Array} axes The axes to transpose the tensor along. - * @returns {Tensor} The transposed tensor. + * Permutes a tensor according to the provided axes. + * @param {any} tensor The input tensor to permute. + * @param {Array} axes The axes to permute the tensor along. + * @returns {Tensor} The permuted tensor. */ -export function transpose(tensor, axes) { - const [transposedData, shape] = transpose_data(tensor.data, tensor.dims, axes); - return new Tensor(tensor.type, transposedData, shape); +export function permute(tensor, axes) { + const [permutedData, shape] = permute_data(tensor.data, tensor.dims, axes); + return new Tensor(tensor.type, permutedData, shape); } diff --git a/tests/maths.test.js b/tests/maths.test.js index 9a7d3dc3c..788ae5b02 100644 --- a/tests/maths.test.js +++ b/tests/maths.test.js @@ -2,7 +2,7 @@ import { compare } from './test_utils.js'; import { getFile } from '../src/utils/hub.js'; -import { FFT, medianFilter } from '../src/utils/maths.js'; +import { FFT, medianFilter, bankers_round } from '../src/utils/maths.js'; const fft = (arr, complex = false) => { @@ -27,6 +27,19 @@ const fftTestsData = await (await getFile('./tests/data/fft_tests.json')).json() describe('Mathematical operations', () => { + describe('bankers rounding', () => { + it('should round up to nearest even', () => { + expect(bankers_round(-0.5)).toBeCloseTo(0); + expect(bankers_round(1.5)).toBeCloseTo(2); + expect(bankers_round(19.5)).toBeCloseTo(20); + }); + it('should round down to nearest even', () => { + expect(bankers_round(-1.5)).toBeCloseTo(-2); + expect(bankers_round(2.5)).toBeCloseTo(2); + expect(bankers_round(18.5)).toBeCloseTo(18); + }); + }); + describe('median filtering', () => { diff --git a/tests/processors.test.js b/tests/processors.test.js index 0ebaec9ab..c9ab33982 100644 --- a/tests/processors.test.js +++ b/tests/processors.test.js @@ -50,7 +50,9 @@ describe('Processors', () => { const TEST_IMAGES = { pattern_3x3: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/pattern_3x3.png', + pattern_3x5: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/pattern_3x5.png', checkerboard_8x8: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/checkerboard_8x8.png', + checkerboard_64x32: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/checkerboard_64x32.png', receipt: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/receipt.png', tiger: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg', paper: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/nougat_paper.png', @@ -369,6 +371,7 @@ describe('Processors', () => { // - tests custom overrides // - tests multiple inputs // - tests `size_divisibility` and no size (size_divisibility=32) + // - tests do_pad and `size_divisibility` it(MODELS.vitmatte, async () => { const processor = await AutoProcessor.from_pretrained(m(MODELS.vitmatte)) @@ -391,6 +394,25 @@ describe('Processors', () => { compare(original_sizes, [[640, 960]]); compare(reshaped_input_sizes, [[640, 960]]); } + + + { + const image = await load_image(TEST_IMAGES.pattern_3x5); + const image2 = await load_image(TEST_IMAGES.pattern_3x5); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image, image2); + + compare(pixel_values.dims, [1, 4, 32, 32]); + expect(avg(pixel_values.data)).toBeCloseTo(-0.00867417361587286); + expect(pixel_values.data[0]).toBeCloseTo(-0.9921568632125854); + expect(pixel_values.data[1]).toBeCloseTo(-0.9686274528503418); + expect(pixel_values.data[5]).toBeCloseTo(0.0); + expect(pixel_values.data[32]).toBeCloseTo(-0.9215686321258545); + expect(pixel_values.data[33]).toBeCloseTo(-0.8980392217636108); + expect(pixel_values.data.at(-1)).toBeCloseTo(0.0); + + compare(original_sizes, [[5, 3]]); + compare(reshaped_input_sizes, [[5, 3]]); + } }, MAX_TEST_EXECUTION_TIME); // BitImageProcessor @@ -412,6 +434,7 @@ describe('Processors', () => { // DPTImageProcessor // - tests ensure_multiple_of // - tests keep_aspect_ratio + // - tests bankers rounding it(MODELS.dpt_2, async () => { const processor = await AutoProcessor.from_pretrained(m(MODELS.dpt_2)) @@ -425,6 +448,18 @@ describe('Processors', () => { compare(original_sizes, [[480, 640]]); compare(reshaped_input_sizes, [[518, 686]]); } + + { + const image = await load_image(TEST_IMAGES.checkerboard_64x32); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + // NOTE: without bankers rounding, this would be [1, 3, 266, 518] + compare(pixel_values.dims, [1, 3, 252, 518]); + compare(avg(pixel_values.data), 0.2267402559518814); + + compare(original_sizes, [[32, 64]]); + compare(reshaped_input_sizes, [[252, 518]]); + } }, MAX_TEST_EXECUTION_TIME); // EfficientNetImageProcessor diff --git a/tests/tensor.test.js b/tests/tensor.test.js index de9ffac30..bc056b9c8 100644 --- a/tests/tensor.test.js +++ b/tests/tensor.test.js @@ -103,6 +103,65 @@ describe('Tensor operations', () => { }); }); + describe('permute', () => { + it('should permute', async () => { + const x = new Tensor( + 'float32', + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23], + [2, 3, 4], + ); + // Permute axes to (0, 1, 2) - No change + const permuted_1 = x.permute(0, 1, 2); + const target_1 = x; + compare(permuted_1, target_1, 1e-3); + + // Permute axes to (0, 2, 1) + const permuted_2 = x.permute(0, 2, 1); + const target_2 = new Tensor( + 'float32', + [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11, 12, 16, 20, 13, 17, 21, 14, 18, 22, 15, 19, 23], + [2, 4, 3], + ); + compare(permuted_2, target_2, 1e-3); + + // Permute axes to (1, 0, 2) + const permuted_3 = x.permute(1, 0, 2); + const target_3 = new Tensor( + 'float32', + [0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23], + [3, 2, 4], + ); + compare(permuted_3, target_3, 1e-3); + + // Permute axes to (1, 2, 0) + const permuted_4 = x.permute(1, 2, 0); + const target_4 = new Tensor( + 'float32', + [0, 12, 1, 13, 2, 14, 3, 15, 4, 16, 5, 17, 6, 18, 7, 19, 8, 20, 9, 21, 10, 22, 11, 23], + [3, 4, 2], + ); + compare(permuted_4, target_4, 1e-3); + + // Permute axes to (2, 0, 1) + const permuted_5 = x.permute(2, 0, 1); + const target_5 = new Tensor( + 'float32', + [0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23], + [4, 2, 3], + ); + compare(permuted_5, target_5, 1e-3); + + // Permute axes to (2, 1, 0) + const permuted_6 = x.permute(2, 1, 0); + const target_6 = new Tensor( + 'float32', + [0, 12, 4, 16, 8, 20, 1, 13, 5, 17, 9, 21, 2, 14, 6, 18, 10, 22, 3, 15, 7, 19, 11, 23], + [4, 3, 2], + ); + compare(permuted_6, target_6, 1e-3); + }); + }); + describe('mean', () => { it('should calculate mean', async () => { const t1 = new Tensor('float32', [1, 2, 3, 4, 5, 6], [2, 3, 1]);