Skip to content

Commit

Permalink
chore: moving the reshape method to ops/tensor_ops.ts; adding more ex…
Browse files Browse the repository at this point in the history
…ception tests
  • Loading branch information
JasonShin committed Nov 29, 2018
1 parent 1d3ced0 commit 8d9b585
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 69 deletions.
59 changes: 59 additions & 0 deletions src/lib/ops/tensor_ops.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import * as tf from '@tensorflow/tfjs';
import { flattenDeep } from 'lodash';
import { Type1DMatrix, Type2DMatrix, TypeMatrix } from '../types';

/**
Expand Down Expand Up @@ -84,3 +85,61 @@ export function validateMatrix2D(X: Type2DMatrix<any>): number[] {
}
return shape;
}

/**
* Reshapes any size of array into a new shape. The code was copied from
* math.js, https://github.com/josdejong/mathjs/blob/5750a1845442946d236822505c607a522be23474/src/utils/array.js#L258
* in order to use specific method from Math.js instead of install an entire library.
*
* @example
* reshape([1, 2, 3, 4, 5, 6], [2, 3]); // [[1, 2, 3], [4, 5, 6]]
*
* @param array - Target array
* @param sizes - New array shape to resize into
* @ignore
*/
export function reshape(
array: TypeMatrix<any>,
sizes: number[]
): TypeMatrix<any> {
// Initial validations
if (!Array.isArray(array)) {
throw new TypeError('The input array must be an array!');
}

if (!Array.isArray(sizes)) {
throw new TypeError('The sizes must be an array!');
}

const deepFlatArray = flattenDeep(array);
// If the reshaping is to single dimensional
if (sizes.length === 1 && deepFlatArray.length === sizes[0]) {
return deepFlatArray;
} else if (sizes.length === 1 && deepFlatArray.length !== sizes[0]) {
throw new TypeError(
`Target array shape [${
deepFlatArray.length
}] cannot be reshaped into ${sizes}`
);
}

// testing if there are enough elements for the requested shape
let tmpArray = deepFlatArray;
let tmpArray2;
// for each dimensions starting by the last one and ignoring the first one
for (let sizeIndex = sizes.length - 1; sizeIndex > 0; sizeIndex--) {
const size = sizes[sizeIndex];

tmpArray2 = [];

// aggregate the elements of the current tmpArray in elements of the requested size
const length = tmpArray.length / size;
for (let i = 0; i < length; i++) {
tmpArray2.push(tmpArray.slice(i * size, (i + 1) * size));
}
// set it as the new tmpArray for the next loop turn or for return
tmpArray = tmpArray2;
}

return tmpArray;
}
49 changes: 0 additions & 49 deletions src/lib/utils/tensor.ts

This file was deleted.

39 changes: 39 additions & 0 deletions test/ops/tensor_ops.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Iris } from '../../src/lib/datasets';
import {
inferShape,
reshape,
validateFitInputs,
validateMatrix1D,
validateMatrix2D
Expand Down Expand Up @@ -157,3 +158,41 @@ describe('ops', () => {
});
});
});

describe('utils.reshape', () => {
it('should reshape an array of shape [1] into [2, 3]', () => {
const result = reshape([1, 2, 3, 4, 5, 6], [2, 3]);
expect(result).toEqual([[1, 2, 3], [4, 5, 6]]);
});

it('should reshape an array of shape [2, 3] into [1]', () => {
// console.log(tf.tensor1d([1, 2, 3]).shape);
const result = reshape([[1, 2, 3], [4, 5, 6]], [6]);
expect(result).toEqual(result);
});

it('should reshape an array of shape [1] into [2, 3, 1]', () => {
const result = reshape([1, 2, 3, 4, 5, 6], [2, 3, 1]);
expect(result).toEqual([[[1], [2], [3]], [[4], [5], [6]]]);
});

it('should reshape an array of shape [2, 3] into [2, 3, 1]', () => {
const result = reshape([[1, 2, 3], [4, 5, 6]], [2, 3, 1]);
expect(result).toEqual([[[1], [2], [3]], [[4], [5], [6]]]);
});

it('should not reshape invalid inputs', () => {
expect(() => reshape(null, [1])).toThrow(
'The input array must be an array!'
);
expect(() => reshape([], [1])).toThrow(
'Target array shape [0] cannot be reshaped into 1'
);
expect(() => reshape([[1, 2, 3]], null)).toThrow(
'The sizes must be an array!'
);
expect(() => reshape([[1, 2, 3]], 1)).toThrow(
'The sizes must be an array!'
);
});
});
20 changes: 0 additions & 20 deletions test/utils/tensor.test.ts

This file was deleted.

0 comments on commit 8d9b585

Please sign in to comment.