Skip to content

Commit

Permalink
Implement ZeroShotImageClassificationWidget (#322)
Browse files Browse the repository at this point in the history
* Implement `ZeroShotImageClassificationWidget`

* Fix outout parsing

* Add model example
  • Loading branch information
mishig25 committed Oct 4, 2022
1 parent d63e4cd commit 495282b
Show file tree
Hide file tree
Showing 3 changed files with 218 additions and 0 deletions.
2 changes: 2 additions & 0 deletions js/src/lib/components/InferenceWidget/InferenceWidget.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import TabularDataWidget from "./widgets/TabularDataWidget/TabularDataWidget.svelte";
import ReinforcementLearningWidget from "./widgets/ReinforcementLearningWidget/ReinforcementLearningWidget.svelte";
import ZeroShotClassificationWidget from "./widgets/ZeroShowClassificationWidget/ZeroShotClassificationWidget.svelte";
import ZeroShotImageClassificationWidget from "./widgets/ZeroShotImageClassificationWidget/ZeroShotImageClassificationWidget.svelte";
export let apiToken: WidgetProps["apiToken"] = undefined;
export let callApiOnMount = false;
Expand Down Expand Up @@ -70,6 +71,7 @@
"reinforcement-learning": ReinforcementLearningWidget,
"zero-shot-classification": ZeroShotClassificationWidget,
"document-question-answering": VisualQuestionAnsweringWidget,
"zero-shot-image-classification": ZeroShotImageClassificationWidget,
};
$: widgetComponent = WIDGET_COMPONENTS[model.pipeline_tag ?? ""];
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
<script lang="ts">
import type { WidgetProps } from "../../shared/types";
import WidgetFileInput from "../../shared/WidgetFileInput/WidgetFileInput.svelte";
import WidgetDropzone from "../../shared/WidgetDropzone/WidgetDropzone.svelte";
import WidgetTextInput from "../../shared/WidgetTextInput/WidgetTextInput.svelte";
import WidgetSubmitBtn from "../../shared/WidgetSubmitBtn/WidgetSubmitBtn.svelte";
import WidgetWrapper from "../../shared/WidgetWrapper/WidgetWrapper.svelte";
import WidgetOutputChart from "../../shared/WidgetOutputChart/WidgetOutputChart.svelte";
import { addInferenceParameters, getResponse } from "../../shared/helpers";
export let apiToken: WidgetProps["apiToken"];
export let apiUrl: WidgetProps["apiUrl"];
export let model: WidgetProps["model"];
export let noTitle: WidgetProps["noTitle"];
export let includeCredentials: WidgetProps["includeCredentials"];
let candidateLabels = "";
let computeTime = "";
let error: string = "";
let isLoading = false;
let modelLoading = {
isLoading: false,
estimatedTime: 0,
};
let output: Array<{ label: string; score: number }> = [];
let outputJson: string;
let imgSrc = "";
let imageBase64 = "";
async function onSelectFile(file: File | Blob) {
imgSrc = URL.createObjectURL(file);
await updateImageBase64(file);
}
function updateImageBase64(file: File | Blob): Promise<void> {
return new Promise((resolve, reject) => {
let fileReader: FileReader = new FileReader();
fileReader.onload = async () => {
try {
const imageBase64WithPrefix: string = fileReader.result as string;
imageBase64 = imageBase64WithPrefix.split(",")[1]; // remove prefix
isLoading = false;
resolve();
} catch (err) {
reject(err);
}
};
fileReader.onerror = (e) => reject(e);
isLoading = true;
fileReader.readAsDataURL(file);
});
}
function isValidOutput(arg: any): arg is { label: string; score: number }[] {
return (
Array.isArray(arg) &&
arg.every(
(x) => typeof x.label === "string" && typeof x.score === "number"
)
);
}
function parseOutput(body: unknown): Array<{ label: string; score: number }> {
if (isValidOutput(body)) {
return body;
}
throw new TypeError(
"Invalid output: output must be of type <labels:Array; scores:Array>"
);
}
function previewInputSample(sample: Record<string, any>) {
candidateLabels = sample.candidate_labels;
imgSrc = sample.src;
}
async function applyInputSample(sample: Record<string, any>) {
candidateLabels = sample.candidate_labels;
imgSrc = sample.src;
const res = await fetch(imgSrc);
const blob = await res.blob();
await updateImageBase64(blob);
getOutput();
}
async function getOutput(withModelLoading = false) {
const trimmedCandidateLabels = candidateLabels.trim().split(",").join(",");
if (!trimmedCandidateLabels) {
error = "You need to input at least one label";
output = [];
outputJson = "";
return;
}
if (!imageBase64) {
error = "You need to upload an image";
output = [];
outputJson = "";
return;
}
const requestBody = {
image: imageBase64,
parameters: {
candidate_labels: trimmedCandidateLabels,
},
};
addInferenceParameters(requestBody, model);
isLoading = true;
const res = await getResponse(
apiUrl,
model.id,
requestBody,
apiToken,
parseOutput,
withModelLoading,
includeCredentials
);
isLoading = false;
// Reset values
computeTime = "";
error = "";
modelLoading = { isLoading: false, estimatedTime: 0 };
output = [];
outputJson = "";
if (res.status === "success") {
computeTime = res.computeTime;
output = res.output;
outputJson = res.outputJson;
} else if (res.status === "loading-model") {
modelLoading = {
isLoading: true,
estimatedTime: res.estimatedTime,
};
getOutput(true);
} else if (res.status === "error") {
error = res.error;
}
}
</script>

<WidgetWrapper
{apiUrl}
{applyInputSample}
{computeTime}
{error}
{isLoading}
{model}
{modelLoading}
{noTitle}
{outputJson}
{previewInputSample}
>
<svelte:fragment slot="top">
<form class="space-y-2">
<WidgetDropzone
classNames="no-hover:hidden"
{isLoading}
{imgSrc}
{onSelectFile}
onError={(e) => (error = e)}
>
{#if imgSrc}
<img
src={imgSrc}
class="pointer-events-none shadow mx-auto max-h-44"
alt=""
/>
{/if}
</WidgetDropzone>
<!-- Better UX for mobile/table through CSS breakpoints -->
{#if imgSrc}
{#if imgSrc}
<div
class="mb-2 flex justify-center bg-gray-50 dark:bg-gray-900 with-hover:hidden"
>
<img src={imgSrc} class="pointer-events-none max-h-44" alt="" />
</div>
{/if}
{/if}
<WidgetFileInput
accept="image/*"
classNames="mr-2 with-hover:hidden"
{isLoading}
label="Browse for image"
{onSelectFile}
/>
<WidgetTextInput
bind:value={candidateLabels}
label="Possible class names (comma-separated)"
placeholder="Possible class names..."
/>
<WidgetSubmitBtn
{isLoading}
onClick={() => {
getOutput();
}}
/>
</form>
</svelte:fragment>
<svelte:fragment slot="bottom">
{#if output.length}
<WidgetOutputChart classNames="pt-4" {output} />
{/if}
</svelte:fragment>
</WidgetWrapper>
4 changes: 4 additions & 0 deletions js/src/routes/index.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import type { ModelData } from "../lib/interfaces/Types";
const models: ModelData[] = [
{
id: "openai/clip-vit-base-patch16",
pipeline_tag: "zero-shot-image-classification",
},
{
id: "ydshieh/vit-gpt2-coco-en",
pipeline_tag: "image-to-text",
Expand Down

0 comments on commit 495282b

Please sign in to comment.