Skip to content

Commit

Permalink
Add new features to production (#19)
Browse files Browse the repository at this point in the history
* chore: add prettier and format

* chore: bump version

* [Feature] Remove necessity for no / at end of MODEL_DIR_PATH (#18)

* feat: add playground for trying out the package

* chore: add playground to .npmignore

* chore: remove playground

* feat: set up filter to remove '/' at end if it exists

* test: add test for model_dir_path with/without slash at end

* chore: changeset

* docs: update readme

* chore: format
  • Loading branch information
kevinanielsen authored Nov 28, 2023
1 parent b588449 commit 84c2bb7
Show file tree
Hide file tree
Showing 9 changed files with 260 additions and 240 deletions.
21 changes: 9 additions & 12 deletions .changeset/config.json
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
{
"$schema": "https://unpkg.com/@changesets/config@3.0.0/schema.json",
"changelog": [
"@changesets/changelog-github",
{ "repo": "kevinanielsen/tfjs-image-node" }
],
"commit": false,
"fixed": [],
"linked": [],
"access": "public",
"baseBranch": "main",
"updateInternalDependencies": "patch",
"ignore": []
"$schema": "https://unpkg.com/@changesets/config@3.0.0/schema.json",
"changelog": ["@changesets/changelog-github", { "repo": "kevinanielsen/tfjs-image-node" }],
"commit": false,
"fixed": [],
"linked": [],
"access": "public",
"baseBranch": "main",
"updateInternalDependencies": "patch",
"ignore": []
}
5 changes: 5 additions & 0 deletions .changeset/heavy-readers-fold.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"tfjs-image-node": patch
---

Add filter to remove necessity for no "/" at end of MODEL_DIR_PATH
2 changes: 1 addition & 1 deletion .prettierrc
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
"printWidth": 100,
"tabWidth": 2,
"useTabs": true
}
}
26 changes: 13 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,9 @@ pnpm add tfjs-image-node
tfjs-image-node has two different exports. One for the tfjs-node platform and one for the regular tfjs package - one package is preferred for operations in the node runtime, the other is preferred for regular javascript.

```typescript
// Using Node Platform
const classifyImage = require("tfjs-image-node/node");
const classifyImage = require("tfjs-image-node");
// or
import classifyImage from "tfjs-image-node/node";

// Using JS Platform
const classifyImage = require("tfjs-image-node/js");
// or
import classifyImage from "tfjs-image-node/js";
import classifyImage from "tfjs-image-node";
```

## Example
Expand All @@ -36,12 +30,18 @@ import classifyImage from "tfjs-image-node/js";
import classifyImage from "tfjs-image-node/node";

const model = "https://teachablemachine.withgoogle.com/models/jAIOHvmge";
const image =
"https://www.stgeorges.nhs.uk/wp-content/uploads/2014/03/hand-2.jpeg";
const image = "https://www.stgeorges.nhs.uk/wp-content/uploads/2014/03/hand-2.jpeg";

// With tfjs-node as the platform
(async () => {
const prediction = await classifyImage(model, image);
console.log(prediction[0]);
})();

// With classic tfjs as the platform
(async () => {
const prediction = await classifyImage(model, image);
console.log(prediction[0]);
const prediction = await classifyImage(model, image, "classic");
console.log(prediction[0]);
})();

// expected output:
Expand All @@ -67,7 +67,7 @@ const image =
string
</td>
<td>
The URL to your AI model (currently only supports teachable machine URLs like "https://teachablemachine.withgoogle.com/models/{model_id}" <u><b>with no "/" at the end!</b></u>
The URL to your AI model (currently only supports teachable machine URLs like "https://teachablemachine.withgoogle.com/models/{model_id}".
</td>
</tr>
<tr>
Expand Down
147 changes: 73 additions & 74 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,90 +4,89 @@ const tfJs = require("@tensorflow/tfjs");
let tf: any;

interface IMetadata extends JSON {
labels: string[];
labels: string[];
}

type ResultType = {
label: string;
probability: string;
label: string;
probability: string;
};

type ClassifyImageType = (
MODEL_DIR_PATH: string,
IMAGE_FILE_PATH: string,
PLATFORM?: "node" | "classic",
MODEL_DIR_PATH: string,
IMAGE_FILE_PATH: string,
PLATFORM?: "node" | "classic"
) => Promise<ResultType[] | Error>;

const filterInputPath = (inputPath: string) => {
if (inputPath.endsWith("/")) {
return inputPath.slice(0, -1);
}
return inputPath;
};

const classifyImage: ClassifyImageType = async (
MODEL_DIR_PATH,
IMAGE_FILE_PATH,
PLATFORM = "node",
MODEL_DIR_PATH,
IMAGE_FILE_PATH,
PLATFORM = "node"
) => {
PLATFORM === "node" ? (tf = tfNode) : (tf = tfJs);

if (!MODEL_DIR_PATH || !IMAGE_FILE_PATH) {
return new Error("MISSING_PARAMETER");
}

const res = await fetch(`${MODEL_DIR_PATH}/metadata.json`);
if (res.status !== 200) {
return new Error("METADATA_NOT_FOUND");
}

const METADATA: IMetadata = await res.json();

if (METADATA["labels"].length === 0 || METADATA["labels"]! instanceof Array) {
return new Error("NO_METADATA_LABELS");
}

let labels: string[] = METADATA["labels"];

const model = await tf.loadLayersModel(`${MODEL_DIR_PATH}/model.json`);

const image = await Jimp.read(IMAGE_FILE_PATH);
image.cover(
224,
224,
Jimp.HORIZONTAL_ALIGN_CENTER | Jimp.VERTICAL_ALIGN_MIDDLE,
);

const NUM_OF_CHANNELS = 3;
let values = new Float32Array(224 * 224 * NUM_OF_CHANNELS);

let i = 0;
image.scan(
0,
0,
image.bitmap.width,
image.bitmap.height,
(x: number, y: number) => {
const pixel = Jimp.intToRGBA(image.getPixelColor(x, y));
pixel.r = pixel.r / 127.0 - 1;
pixel.g = pixel.g / 127.0 - 1;
pixel.b = pixel.b / 127.0 - 1;
pixel.a = pixel.a / 127.0 - 1;
values[i * NUM_OF_CHANNELS + 0] = pixel.r;
values[i * NUM_OF_CHANNELS + 1] = pixel.g;
values[i * NUM_OF_CHANNELS + 2] = pixel.b;
i++;
},
);

const outShape = [224, 224, NUM_OF_CHANNELS];
let img_tensor = tf.tensor3d(values, outShape, "float32");
img_tensor = img_tensor.expandDims(0);

const predictions = await model.predict(img_tensor).dataSync();

let result: ResultType[] = [];

for (let i = 0; i < predictions.length; i++) {
const label = labels[i];
const probability = predictions[i];
result.push({ label: label, probability: probability });
}

return result.sort((a, b) => Number(b.probability) - Number(a.probability));
PLATFORM === "node" ? (tf = tfNode) : (tf = tfJs);

if (!MODEL_DIR_PATH || !IMAGE_FILE_PATH) {
return new Error("MISSING_PARAMETER");
}

MODEL_DIR_PATH = filterInputPath(MODEL_DIR_PATH);

const res = await fetch(`${MODEL_DIR_PATH}/metadata.json`);
if (res.status !== 200) {
return new Error("METADATA_NOT_FOUND");
}

const METADATA: IMetadata = await res.json();

if (METADATA["labels"].length === 0 || METADATA["labels"]! instanceof Array) {
return new Error("NO_METADATA_LABELS");
}

let labels: string[] = METADATA["labels"];

const model = await tf.loadLayersModel(`${MODEL_DIR_PATH}/model.json`);

const image = await Jimp.read(IMAGE_FILE_PATH);
image.cover(224, 224, Jimp.HORIZONTAL_ALIGN_CENTER | Jimp.VERTICAL_ALIGN_MIDDLE);

const NUM_OF_CHANNELS = 3;
let values = new Float32Array(224 * 224 * NUM_OF_CHANNELS);

let i = 0;
image.scan(0, 0, image.bitmap.width, image.bitmap.height, (x: number, y: number) => {
const pixel = Jimp.intToRGBA(image.getPixelColor(x, y));
pixel.r = pixel.r / 127.0 - 1;
pixel.g = pixel.g / 127.0 - 1;
pixel.b = pixel.b / 127.0 - 1;
pixel.a = pixel.a / 127.0 - 1;
values[i * NUM_OF_CHANNELS + 0] = pixel.r;
values[i * NUM_OF_CHANNELS + 1] = pixel.g;
values[i * NUM_OF_CHANNELS + 2] = pixel.b;
i++;
});

const outShape = [224, 224, NUM_OF_CHANNELS];
let img_tensor = tf.tensor3d(values, outShape, "float32");
img_tensor = img_tensor.expandDims(0);

const predictions = await model.predict(img_tensor).dataSync();

let result: ResultType[] = [];

for (let i = 0; i < predictions.length; i++) {
const label = labels[i];
const probability = predictions[i];
result.push({ label: label, probability: probability });
}

return result.sort((a, b) => Number(b.probability) - Number(a.probability));
};

export default classifyImage;
135 changes: 73 additions & 62 deletions test/classifyImage.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,75 +2,86 @@ import { describe, it, expect } from "vitest";
import classifyImage from "../src";

const model = "https://teachablemachine.withgoogle.com/models/jAIOHvmge";
const imageHand =
"https://www.stgeorges.nhs.uk/wp-content/uploads/2014/03/hand-2.jpeg";
const imageHand = "https://www.stgeorges.nhs.uk/wp-content/uploads/2014/03/hand-2.jpeg";
const imageNoHand =
"https://upload.wikimedia.org/wikipedia/commons/thumb/5/50/Black_colour.jpg/640px-Black_colour.jpg";
"https://upload.wikimedia.org/wikipedia/commons/thumb/5/50/Black_colour.jpg/640px-Black_colour.jpg";
const imageHandJPEG = "./test/images/hand.jpeg";
const imageHandPNG = "./test/images/hand.png";

describe("classifyImage function - Node", async () => {
describe("Result returns", async () => {
it("returns hand when shown a picture of a hand", async () => {
const result = await classifyImage(model, imageHand);
if (result instanceof Error) {
return new Error();
} else {
expect(result[0].label).toBe("Hand");
}
});
describe("Result returns", async () => {
it("returns hand when shown a picture of a hand", async () => {
const result = await classifyImage(model, imageHand);
if (result instanceof Error) {
return new Error();
} else {
expect(result[0].label).toBe("Hand");
}
});

it("returns 'No hand' when shown a picture not including hand", async () => {
const result = await classifyImage(model, imageNoHand);
it("returns 'No hand' when shown a picture not including hand", async () => {
const result = await classifyImage(model, imageNoHand);

if (result instanceof Error) {
return new Error();
} else {
expect(result[0].label).toBe("No hand");
}
});
if (result instanceof Error) {
return new Error();
} else {
expect(result[0].label).toBe("No hand");
}
});

it("returns a probability level", async () => {
const result = await classifyImage(model, imageNoHand);
if (result instanceof Error) {
return new Error();
} else {
expect(result[0].probability).not.toBe(null);
}
});
});
describe("Error boundries", async () => {
it("returns an error when missing a parameter", async () => {
//@ts-expect-error
const result = await classifyImage(imageNoHand);
it("returns a probability level", async () => {
const result = await classifyImage(model, imageNoHand);
if (result instanceof Error) {
return new Error();
} else {
expect(result[0].probability).not.toBe(null);
}
});

expect(result).toBeInstanceOf(Error);
});
});
describe("Image types", async () => {
it("returns a result on url image-input", async () => {
const result = await classifyImage(model, imageHand);
if (result instanceof Error) {
return new Error();
} else {
expect(result[0].label).toBe("Hand");
}
});
it("returns a result on JPEG image-input", async () => {
const result = await classifyImage(model, imageHandJPEG);
if (result instanceof Error) {
return new Error();
} else {
expect(result[0].label).toBe("Hand");
}
});
it("returns a result on PNG image-input", async () => {
const result = await classifyImage(model, imageHandPNG);
if (result instanceof Error) {
return new Error();
} else {
expect(result[0].label).toBe("Hand");
}
});
});
it("returns when MODEL_DIR_PATH ends with slash", async () => {
const result = await classifyImage(model + "/", imageHand);
if (result instanceof Error) {
return new Error();
} else {
expect(result[0].label).toBe("Hand");
}
});
});
describe("Error boundries", async () => {
it("returns an error when missing a parameter", async () => {
//@ts-expect-error
const result = await classifyImage(imageNoHand);

expect(result).toBeInstanceOf(Error);
});
});

describe("Image types", async () => {
it("returns a result on url image-input", async () => {
const result = await classifyImage(model, imageHand);
if (result instanceof Error) {
return new Error();
} else {
expect(result[0].label).toBe("Hand");
}
});

it("returns a result on JPEG image-input", async () => {
const result = await classifyImage(model, imageHandJPEG);
if (result instanceof Error) {
return new Error();
} else {
expect(result[0].label).toBe("Hand");
}
});

it("returns a result on PNG image-input", async () => {
const result = await classifyImage(model, imageHandPNG);
if (result instanceof Error) {
return new Error();
} else {
expect(result[0].label).toBe("Hand");
}
});
});
});
Loading

0 comments on commit 84c2bb7

Please sign in to comment.