Skip to content

Commit

Permalink
chore: adding more tests against the reshape method
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonShin committed Nov 29, 2018
1 parent 9b64d9f commit 1d3ced0
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
24 changes: 21 additions & 3 deletions src/lib/utils/tensor.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
import { flattenDeep } from 'lodash';
import { TypeMatrix } from '../types/matrix.types';

/**
* Reshapes any size of array into a new shape
* Reshapes any size of array into a new shape. The code was copied from
* math.js, https://github.com/josdejong/mathjs/blob/5750a1845442946d236822505c607a522be23474/src/utils/array.js#L258
* in order to use specific method from Math.js instead of install an entire library.
*
* @example
* reshape([1, 2, 3, 4, 5, 6], [2, 3]); // [[1, 2, 3], [4, 5, 6]]
*
* @param array
* @param sizes
* @param array - Target array
* @param sizes - New array shape to resize into
* @ignore
*/
export function reshape(array, sizes): TypeMatrix<any> {
// If the reshaping is to single dimensional
if (Array.isArray(array) && sizes.length === 1) {
const deepFlat = flattenDeep(array);
if (deepFlat.length === sizes[0]) {
return deepFlat;
} else {
throw new TypeError(
`Target array shape [${
deepFlat.length
}] cannot be reshaped into ${sizes}`
);
}
}

// testing if there are enough elements for the requested shape
let tmpArray = array;
let tmpArray2;
Expand Down
18 changes: 14 additions & 4 deletions test/utils/tensor.test.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import * as tf from '@tensorflow/tfjs';
import { reshape } from '../../src/lib/utils/tensor';

describe('utils.dataSyncRaw', () => {
it('tzz', () => {
const x = tf.tensor2d([[1, 2, 3], [4, 5, 6]]);
describe('utils.reshape', () => {
it('should reshape an array of shape [1] into [2, 3]', () => {
const result = reshape([1, 2, 3, 4, 5, 6], [2, 3]);
console.info(result);
expect(result).toEqual([[1, 2, 3], [4, 5, 6]]);
});

it('should reshape an array of shape [2, 3] into [1]', () => {
// console.log(tf.tensor1d([1, 2, 3]).shape);
const result = reshape([[1, 2, 3], [4, 5, 6]], [6]);
expect(result).toEqual(result);
});

it('should reshape an array of shape [1] into [2, 3, 1]', () => {
const result = reshape([1, 2, 3, 4, 5, 6], [2, 3, 1]);
expect(result).toEqual([[[1], [2], [3]], [[4], [5], [6]]]);
});
});

0 comments on commit 1d3ced0

Please sign in to comment.