Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -3779,11 +3779,7 @@ export class VitMattePreTrainedModel extends PreTrainedModel { }
* import { Tensor, cat } from '@xenova/transformers';
*
* // Visualize predicted alpha matte
* const imageTensor = new Tensor(
* 'uint8',
* new Uint8Array(image.data),
* [image.height, image.width, image.channels]
* ).transpose(2, 0, 1);
* const imageTensor = image.toTensor();
*
* // Convert float (0-1) alpha matte to uint8 (0-255)
* const alphaChannel = alphas
Expand Down
53 changes: 30 additions & 23 deletions src/processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ import {
min,
max,
softmax,
bankers_round,
} from './utils/maths.js';


import { Tensor, transpose, cat, interpolate, stack } from './utils/tensor.js';
import { Tensor, permute, cat, interpolate, stack } from './utils/tensor.js';

import { RawImage } from './utils/image.js';
import {
Expand Down Expand Up @@ -174,14 +175,15 @@ function validate_audio_inputs(audio, feature_extractor) {
* @private
*/
function constraint_to_multiple_of(val, multiple, minVal = 0, maxVal = null) {
let x = Math.round(val / multiple) * multiple;
const a = val / multiple;
let x = bankers_round(a) * multiple;

if (maxVal !== null && x > maxVal) {
x = Math.floor(val / multiple) * multiple;
x = Math.floor(a) * multiple;
}

if (x < minVal) {
x = Math.ceil(val / multiple) * multiple;
x = Math.ceil(a) * multiple;
}

return x;
Expand All @@ -195,8 +197,8 @@ function constraint_to_multiple_of(val, multiple, minVal = 0, maxVal = null) {
*/
function enforce_size_divisibility([width, height], divisor) {
return [
Math.floor(width / divisor) * divisor,
Math.floor(height / divisor) * divisor
Math.max(Math.floor(width / divisor), 1) * divisor,
Math.max(Math.floor(height / divisor), 1) * divisor
];
}

Expand Down Expand Up @@ -348,7 +350,7 @@ export class ImageFeatureExtractor extends FeatureExtractor {
/**
* Pad the image by a certain amount.
* @param {Float32Array} pixelData The pixel data to pad.
* @param {number[]} imgDims The dimensions of the image.
* @param {number[]} imgDims The dimensions of the image (height, width, channels).
* @param {{width:number; height:number}|number} padSize The dimensions of the padded image.
* @param {Object} options The options for padding.
* @param {'constant'|'symmetric'} [options.mode='constant'] The type of padding to add.
Expand All @@ -361,7 +363,7 @@ export class ImageFeatureExtractor extends FeatureExtractor {
center = false,
constant_values = 0,
} = {}) {
const [imageWidth, imageHeight, imageChannels] = imgDims;
const [imageHeight, imageWidth, imageChannels] = imgDims;

let paddedImageWidth, paddedImageHeight;
if (typeof padSize === 'number') {
Expand Down Expand Up @@ -513,8 +515,8 @@ export class ImageFeatureExtractor extends FeatureExtractor {
if (this.config.keep_aspect_ratio && this.config.ensure_multiple_of) {

// determine new height and width
let scale_height = size.height / srcHeight;
let scale_width = size.width / srcWidth;
let scale_height = newHeight / srcHeight;
let scale_width = newWidth / srcWidth;

// scale as little as possible
if (Math.abs(1 - scale_width) < Math.abs(1 - scale_height)) {
Expand Down Expand Up @@ -616,6 +618,9 @@ export class ImageFeatureExtractor extends FeatureExtractor {
/** @type {HeightWidth} */
const reshaped_input_size = [image.height, image.width];

// NOTE: All pixel-level manipulation (i.e., modifying `pixelData`)
// occurs with data in the hwc format (height, width, channels),
// to emulate the behavior of the original Python code (w/ numpy).
let pixelData = Float32Array.from(image.data);
let imgDims = [image.height, image.width, image.channels];

Expand Down Expand Up @@ -646,21 +651,23 @@ export class ImageFeatureExtractor extends FeatureExtractor {
}

// do padding after rescaling/normalizing
if (do_pad ?? (this.do_pad && this.pad_size)) {
const padded = this.pad_image(pixelData, [image.width, image.height, image.channels], this.pad_size);
[pixelData, imgDims] = padded; // Update pixel data and image dimensions
if (do_pad ?? this.do_pad) {
if (this.pad_size) {
const padded = this.pad_image(pixelData, [image.height, image.width, image.channels], this.pad_size);
[pixelData, imgDims] = padded; // Update pixel data and image dimensions
} else if (this.size_divisibility) {
const [paddedWidth, paddedHeight] = enforce_size_divisibility([imgDims[1], imgDims[0]], this.size_divisibility);
[pixelData, imgDims] = this.pad_image(pixelData, imgDims, { width: paddedWidth, height: paddedHeight });
}
}

// Create HWC tensor
const img = new Tensor('float32', pixelData, imgDims);

// convert to channel dimension format:
const transposed = transpose(img, [2, 0, 1]); // hwc -> chw
const pixel_values = new Tensor('float32', pixelData, imgDims)
.permute(2, 0, 1); // convert to channel dimension format (hwc -> chw)

return {
original_size: [srcHeight, srcWidth],
reshaped_input_size: reshaped_input_size,
pixel_values: transposed,
pixel_values: pixel_values,
}
}

Expand Down Expand Up @@ -760,9 +767,9 @@ export class SegformerFeatureExtractor extends ImageFeatureExtractor {
return toReturn;
}
}
export class DPTImageProcessor extends ImageFeatureExtractor { }
export class BitImageProcessor extends ImageFeatureExtractor { }
export class DPTFeatureExtractor extends ImageFeatureExtractor { }
export class DPTImageProcessor extends DPTFeatureExtractor { } // NOTE: extends DPTFeatureExtractor
export class BitImageProcessor extends ImageFeatureExtractor { }
export class GLPNFeatureExtractor extends ImageFeatureExtractor { }
export class CLIPFeatureExtractor extends ImageFeatureExtractor { }
export class ChineseCLIPFeatureExtractor extends ImageFeatureExtractor { }
Expand Down Expand Up @@ -835,7 +842,7 @@ export class DeiTFeatureExtractor extends ImageFeatureExtractor { }
export class BeitFeatureExtractor extends ImageFeatureExtractor { }
export class DonutFeatureExtractor extends ImageFeatureExtractor {
pad_image(pixelData, imgDims, padSize, options = {}) {
const [imageWidth, imageHeight, imageChannels] = imgDims;
const [imageHeight, imageWidth, imageChannels] = imgDims;

let image_mean = this.image_mean;
if (!Array.isArray(this.image_mean)) {
Expand Down Expand Up @@ -1382,7 +1389,7 @@ export class Swin2SRImageProcessor extends ImageFeatureExtractor {
pad_image(pixelData, imgDims, padSize, options = {}) {
// NOTE: In this case, `padSize` represents the size of the sliding window for the local attention.
// In other words, the image is padded so that its width and height are multiples of `padSize`.
const [imageWidth, imageHeight, imageChannels] = imgDims;
const [imageHeight, imageWidth, imageChannels] = imgDims;

return super.pad_image(pixelData, imgDims, {
// NOTE: For Swin2SR models, the original python implementation adds padding even when the image's width/height is already
Expand Down
20 changes: 19 additions & 1 deletion src/utils/image.js
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import { getFile } from './hub.js';
import { env } from '../env.js';
import { Tensor } from './tensor.js';

// Will be empty (or not used) if running in browser or web-worker
import sharp from 'sharp';
Expand Down Expand Up @@ -166,7 +167,7 @@ export class RawImage {

/**
* Helper method to create a new Image from a tensor
* @param {import('./tensor.js').Tensor} tensor
* @param {Tensor} tensor
*/
static fromTensor(tensor, channel_format = 'CHW') {
if (tensor.dims.length !== 3) {
Expand Down Expand Up @@ -586,6 +587,23 @@ export class RawImage {
return await canvas.convertToBlob({ type, quality });
}

toTensor(channel_format = 'CHW') {
let tensor = new Tensor(
'uint8',
new Uint8Array(this.data),
[this.height, this.width, this.channels]
);

if (channel_format === 'HWC') {
// Do nothing
} else if (channel_format === 'CHW') { // hwc -> chw
tensor = tensor.permute(2, 0, 1);
} else {
throw new Error(`Unsupported channel format: ${channel_format}`);
}
return tensor;
}

toCanvas() {
if (!BROWSER_ENV) {
throw new Error('toCanvas() is only supported in browser environments.')
Expand Down
32 changes: 23 additions & 9 deletions src/utils/maths.js
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ export function interpolate_data(input, [in_channels, in_height, in_width], [out


/**
* Helper method to transpose a `AnyTypedArray` directly
* Helper method to permute a `AnyTypedArray` directly
* @template {AnyTypedArray} T
* @param {T} array
* @param {number[]} dims
* @param {number[]} axes
* @returns {[T, number[]]} The transposed array and the new shape.
* @returns {[T, number[]]} The permuted array and the new shape.
*/
export function transpose_data(array, dims, axes) {
// Calculate the new shape of the transposed array
export function permute_data(array, dims, axes) {
// Calculate the new shape of the permuted array
// and the stride of the original array
const shape = new Array(axes.length);
const stride = new Array(axes.length);
Expand All @@ -110,21 +110,21 @@ export function transpose_data(array, dims, axes) {
// Precompute inverse mapping of stride
const invStride = axes.map((_, i) => stride[axes.indexOf(i)]);

// Create the transposed array with the new shape
// Create the permuted array with the new shape
// @ts-ignore
const transposedData = new array.constructor(array.length);
const permutedData = new array.constructor(array.length);

// Transpose the original array to the new array
// Permute the original array to the new array
for (let i = 0; i < array.length; ++i) {
let newIndex = 0;
for (let j = dims.length - 1, k = i; j >= 0; --j) {
newIndex += (k % dims[j]) * invStride[j];
k = Math.floor(k / dims[j]);
}
transposedData[newIndex] = array[i];
permutedData[newIndex] = array[i];
}

return [transposedData, shape];
return [permutedData, shape];
}


Expand Down Expand Up @@ -952,3 +952,17 @@ export function round(num, decimals) {
const pow = Math.pow(10, decimals);
return Math.round(num * pow) / pow;
}

/**
* Helper function to round a number to the nearest integer, with ties rounded to the nearest even number.
* Also known as "bankers' rounding". This is the default rounding mode in python. For example:
* 1.5 rounds to 2 and 2.5 rounds to 2.
*
* @param {number} x The number to round
* @returns {number} The rounded number
*/
export function bankers_round(x) {
const r = Math.round(x);
const br = Math.abs(x) % 1 === 0.5 ? (r % 2 === 0 ? r : r - 1) : r;
return br;
}
32 changes: 17 additions & 15 deletions src/utils/tensor.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { ONNX } from '../backends/onnx.js';

import {
interpolate_data,
transpose_data
permute_data
} from './maths.js';


Expand Down Expand Up @@ -309,16 +309,18 @@ export class Tensor {
}

/**
* Return a transposed version of this Tensor, according to the provided dimensions.
* @param {...number} dims Dimensions to transpose.
* @returns {Tensor} The transposed tensor.
* Return a permuted version of this Tensor, according to the provided dimensions.
* @param {...number} dims Dimensions to permute.
* @returns {Tensor} The permuted tensor.
*/
transpose(...dims) {
return transpose(this, dims);
permute(...dims) {
return permute(this, dims);
}

// TODO: rename transpose to permute
// TODO: implement transpose
// TODO: implement transpose. For now (backwards compatibility), it's just an alias for permute()
transpose(...dims) {
return this.permute(...dims);
}

// TODO add .max() and .min() methods

Expand Down Expand Up @@ -680,14 +682,14 @@ function reshape(data, dimensions) {
}

/**
* Transposes a tensor according to the provided axes.
* @param {any} tensor The input tensor to transpose.
* @param {Array} axes The axes to transpose the tensor along.
* @returns {Tensor} The transposed tensor.
* Permutes a tensor according to the provided axes.
* @param {any} tensor The input tensor to permute.
* @param {Array} axes The axes to permute the tensor along.
* @returns {Tensor} The permuted tensor.
*/
export function transpose(tensor, axes) {
const [transposedData, shape] = transpose_data(tensor.data, tensor.dims, axes);
return new Tensor(tensor.type, transposedData, shape);
export function permute(tensor, axes) {
const [permutedData, shape] = permute_data(tensor.data, tensor.dims, axes);
return new Tensor(tensor.type, permutedData, shape);
}


Expand Down
15 changes: 14 additions & 1 deletion tests/maths.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import { compare } from './test_utils.js';

import { getFile } from '../src/utils/hub.js';
import { FFT, medianFilter } from '../src/utils/maths.js';
import { FFT, medianFilter, bankers_round } from '../src/utils/maths.js';


const fft = (arr, complex = false) => {
Expand All @@ -27,6 +27,19 @@ const fftTestsData = await (await getFile('./tests/data/fft_tests.json')).json()

describe('Mathematical operations', () => {

describe('bankers rounding', () => {
it('should round up to nearest even', () => {
expect(bankers_round(-0.5)).toBeCloseTo(0);
expect(bankers_round(1.5)).toBeCloseTo(2);
expect(bankers_round(19.5)).toBeCloseTo(20);
});
it('should round down to nearest even', () => {
expect(bankers_round(-1.5)).toBeCloseTo(-2);
expect(bankers_round(2.5)).toBeCloseTo(2);
expect(bankers_round(18.5)).toBeCloseTo(18);
});
});

describe('median filtering', () => {


Expand Down
Loading