Skip to content

Commit

Permalink
feat(lib): Column value mapping capability (#64)
Browse files Browse the repository at this point in the history
* feat(lib): Column value mapping capability

* test(lib): Tests for applyMappings

* chore(config): Ignore shrinkwrap.yaml

* fix(lib): Fix mappings implementation

Co-authored-by: Glitch (isair-tensorflow-load-csv) <none>
  • Loading branch information
isair committed Sep 15, 2020
1 parent 1ccb7df commit 7a03239
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ pnpm-debug.log*
dist
docs
coverage
shrinkwrap.yaml
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ const {
} = loadCsv('./data.csv', {
featureColumns: ['lat', 'lng', 'height'],
labelColumns: ['temperature'],
mappings: {
height: (ft) => ft * 0.3048, // feet to meters
temperature: (f) => (f - 32) / 1.8, // fahrenheit to celsius
}, // Map values based on which column they are in before they are loaded into tensors.
shuffle: true, // Pass true to shuffle with a fixed seed, or a string to use it as a seed for the shuffling.
splitTest: true, // Splits your data in half. You can also provide a certain row count for the test data.
prependOnes: true, // Prepends a column of 1s to your features and testFeatures tensors, useful for linear regression.
Expand Down
22 changes: 22 additions & 0 deletions src/applyMappings.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import { CsvTable, CsvReadOptions } from './loadCsv.models';

const applyMappings = (
table: CsvTable,
mappings: NonNullable<CsvReadOptions['mappings']>
) => {
if (table.length < 2) {
return table;
}
const mappingsByIndex = table[0].map((columnName) => mappings[columnName]);
return table.map((row, index) =>
index === 0
? row
: row.map((value, columnIndex) =>
mappingsByIndex[columnIndex]
? mappingsByIndex[columnIndex](value)
: value
)
);
};

export default applyMappings;
7 changes: 7 additions & 0 deletions src/loadCsv.models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ export interface CsvReadOptions {
* Names of the columns to be included in the labels and testLabels tensors.
*/
labelColumns: string[];
/**
* Used for transforming values of entire columns. Key is column label, value is transformer function. Each value belonging to
* that column will be put through the transformer function and be overwritten with the return value of it.
*/
mappings?: {
[columnName: string]: (value: string | number) => string | number;
};
/**
* If true, shuffles all rows with a fixed seed, meaning that shuffling the same data will always result in the same shuffled data.
*
Expand Down
6 changes: 6 additions & 0 deletions src/loadCsv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ import { shuffle } from 'shuffle-seed';
import { CsvReadOptions, CsvTable } from './loadCsv.models';
import filterColumns from './filterColumns';
import splitTestData from './splitTestData';
import applyMappings from './applyMappings';

const defaultShuffleSeed = 'mncv9340ur';

const loadCsv = (filename: string, options: CsvReadOptions) => {
const {
featureColumns,
labelColumns,
mappings = {},
shuffle: shouldShuffle = false,
splitTest = false,
prependOnes = false,
Expand Down Expand Up @@ -43,6 +45,10 @@ const loadCsv = (filename: string, options: CsvReadOptions) => {
tables.labels.shift();
tables.features.shift();

for (const key of Object.keys(tables)) {
tables[key] = applyMappings(tables[key], mappings);
}

if (shouldShuffle) {
const seed =
typeof shouldShuffle === 'string' ? shouldShuffle : defaultShuffleSeed;
Expand Down
38 changes: 38 additions & 0 deletions tests/applyMappings.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/* eslint-disable @typescript-eslint/ban-ts-comment */
import { CsvReadOptions } from '../src/loadCsv.models';
import applyMappings from '../src/applyMappings';

const data = [
['lat', 'lng', 'height', 'temperature'],
[0.234, 1.47, 849.7, 64.4],
[-293.2, 103.34, 715.2, 73.4],
];

const mappings: NonNullable<CsvReadOptions['mappings']> = {
height: (ft) => Number(ft) * 0.3048, // feet to meters
temperature: (f) => (Number(f) - 32) / 1.8, // fahrenheit to celsius
};

test('Applying mappings works correctly', () => {
const mappedData = applyMappings(data, mappings);
// @ts-ignore
expect(mappedData).toBeDeepCloseTo(
[
['lat', 'lng', 'height', 'temperature'],
[0.234, 1.47, 258.98856, 18],
[-293.2, 103.34, 217.99296, 23],
],
3
);
});

test('Applying mappings does not break with a table with just headers', () => {
const tableOnlyHeaders = [['lat', 'lng', 'height', 'temperature']];
const mappedData = applyMappings(tableOnlyHeaders, mappings);
expect(mappedData).toMatchObject(tableOnlyHeaders);
});

test('Applying mappings does not break with an empty table', () => {
const mappedData = applyMappings([], mappings);
expect(mappedData).toMatchObject([]);
});

0 comments on commit 7a03239

Please sign in to comment.