@@ -12,6 +12,22 @@ output = query({
1212 "parameters": {"candidate_labels": ["refund", "legal", "faq"]},
1313})` ;
1414
15+ export const snippetZeroShotImageClassification = ( model : ModelData ) : string =>
16+ `def query(data):
17+ with open(data["image_path"], "rb") as f:
18+ img = f.read()
19+ payload={
20+ "parameters": data["parameters"],
21+ "inputs": base64.b64encode(img).decode("utf-8")
22+ }
23+ response = requests.post(API_URL, headers=headers, json=payload)
24+ return response.json()
25+
26+ output = query({
27+ "image_path": ${ getModelInputSnippet ( model ) } ,
28+ "parameters": {"candidate_labels": ["cat", "dog", "llama"]},
29+ })` ;
30+
1531export const snippetBasic = ( model : ModelData ) : string =>
1632 `def query(payload):
1733 response = requests.post(API_URL, headers=headers, json=payload)
@@ -71,7 +87,7 @@ Audio(audio, rate=sampling_rate)`;
7187 }
7288} ;
7389export const pythonSnippets : Partial < Record < PipelineType , ( model : ModelData ) => string > > = {
74- // Same order as in js /src/lib/interfaces/Types .ts
90+ // Same order as in tasks /src/pipelines .ts
7591 "text-classification" : snippetBasic ,
7692 "token-classification" : snippetBasic ,
7793 "table-question-answering" : snippetBasic ,
@@ -92,9 +108,10 @@ export const pythonSnippets: Partial<Record<PipelineType, (model: ModelData) =>
92108 "audio-to-audio" : snippetFile ,
93109 "audio-classification" : snippetFile ,
94110 "image-classification" : snippetFile ,
95- "image-to-text" : snippetFile ,
96111 "object-detection" : snippetFile ,
97112 "image-segmentation" : snippetFile ,
113+ "image-to-text" : snippetFile ,
114+ "zero-shot-image-classification" : snippetZeroShotImageClassification ,
98115} ;
99116
100117export function getPythonInferenceSnippet ( model : ModelData , accessToken : string ) : string {
0 commit comments