Skip to content

Commit 41ddbb0

Browse files
authored
Add API snippet zero-shot img classification (#363)
1 parent 008d2a0 commit 41ddbb0

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

packages/tasks/src/snippets/inputs.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ const inputsTextToAudio = () => `"liquid drum and bass, atmospheric synths, airy
7979

8080
const inputsAutomaticSpeechRecognition = () => `"sample1.flac"`;
8181

82+
const inputsZeroShotImageClassification = () => `"cats.jpg"`;
83+
8284
const modelInputSnippets: {
8385
[key in PipelineType]?: (model: ModelData) => string;
8486
} = {
@@ -105,6 +107,7 @@ const modelInputSnippets: {
105107
"token-classification": inputsTokenClassification,
106108
translation: inputsTranslation,
107109
"zero-shot-classification": inputsZeroShotClassification,
110+
"zero-shot-image-classification": inputsZeroShotImageClassification,
108111
};
109112

110113
// Use noWrap to put the whole snippet on a single line (removing new lines and tabulations)

packages/tasks/src/snippets/python.ts

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,22 @@ output = query({
1212
"parameters": {"candidate_labels": ["refund", "legal", "faq"]},
1313
})`;
1414

15+
export const snippetZeroShotImageClassification = (model: ModelData): string =>
16+
`def query(data):
17+
with open(data["image_path"], "rb") as f:
18+
img = f.read()
19+
payload={
20+
"parameters": data["parameters"],
21+
"inputs": base64.b64encode(img).decode("utf-8")
22+
}
23+
response = requests.post(API_URL, headers=headers, json=payload)
24+
return response.json()
25+
26+
output = query({
27+
"image_path": ${getModelInputSnippet(model)},
28+
"parameters": {"candidate_labels": ["cat", "dog", "llama"]},
29+
})`;
30+
1531
export const snippetBasic = (model: ModelData): string =>
1632
`def query(payload):
1733
response = requests.post(API_URL, headers=headers, json=payload)
@@ -71,7 +87,7 @@ Audio(audio, rate=sampling_rate)`;
7187
}
7288
};
7389
export const pythonSnippets: Partial<Record<PipelineType, (model: ModelData) => string>> = {
74-
// Same order as in js/src/lib/interfaces/Types.ts
90+
// Same order as in tasks/src/pipelines.ts
7591
"text-classification": snippetBasic,
7692
"token-classification": snippetBasic,
7793
"table-question-answering": snippetBasic,
@@ -92,9 +108,10 @@ export const pythonSnippets: Partial<Record<PipelineType, (model: ModelData) =>
92108
"audio-to-audio": snippetFile,
93109
"audio-classification": snippetFile,
94110
"image-classification": snippetFile,
95-
"image-to-text": snippetFile,
96111
"object-detection": snippetFile,
97112
"image-segmentation": snippetFile,
113+
"image-to-text": snippetFile,
114+
"zero-shot-image-classification": snippetZeroShotImageClassification,
98115
};
99116

100117
export function getPythonInferenceSnippet(model: ModelData, accessToken: string): string {

0 commit comments

Comments
 (0)