Skip to content

Commit

Permalink
feat: Implement hugging face definition (#138)
Browse files Browse the repository at this point in the history
Because

- Support Hugging face model definition

This commit

- Implement hugging face model definition flow
  • Loading branch information
EiffelFly committed Jul 11, 2022
1 parent 3d10302 commit 9415cbe
Show file tree
Hide file tree
Showing 12 changed files with 295 additions and 28 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
},
"dependencies": {
"@amplitude/analytics-browser": "^0.4.1",
"@instill-ai/design-system": "^0.0.85",
"@instill-ai/design-system": "^0.0.88",
"@types/json-schema": "^7.0.11",
"axios": "^0.27.2",
"clsx": "^1.1.1",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ const ConfigureModelForm: FC<ConfigureModelFormProps> = ({
modalState.setModalIsOpen(false);
},
});
}, [model, amplitudeIsInit, router, deleteModel]);
}, [model, amplitudeIsInit, router, deleteModel, modalState]);

return (
<>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import {
ProgressMessageBoxState,
} from "@instill-ai/design-system";
import { TestModelInstanceSection } from "@/components/sections";
import { useModelDefinition } from "@/services/model";

export type ConfigureModelInstanceFormProps = {
modelInstance: ModelInstance;
Expand Down
120 changes: 109 additions & 11 deletions src/components/forms/model/CreateModelForm/CreateModelForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
import {
useCreateArtivcModel,
useCreateGithubModel,
useCreateHuggingFaceModel,
useCreateLocalModel,
useDeployModelInstance,
useModelDefinitions,
Expand All @@ -27,12 +28,14 @@ import { ModelDefinitionIcon, PrimaryButton } from "@/components/ui";
import {
CreateArtivcModelPayload,
CreateGithubModelPayload,
CreateHuggingFaceModelPayload,
CreateLocalModelPayload,
Model,
} from "@/lib/instill";
import { Nullable } from "@/types/general";
import { useAmplitudeCtx } from "context/AmplitudeContext";
import { sendAmplitudeData } from "@/lib/amplitude";
import { AxiosError } from "axios";

export type CreateModelFormValue = {
id: Nullable<string>;
Expand All @@ -43,6 +46,8 @@ export type CreateModelFormValue = {
description: Nullable<string>;
gcsBucketPath: Nullable<string>;
credentials: Nullable<string>;
huggingFaceRepo: Nullable<string>;
huggingFaceUrl: Nullable<string>;
};

const CreateNewModelFlow: FC = () => {
Expand Down Expand Up @@ -121,7 +126,15 @@ const CreateNewModelFlow: FC = () => {
return true;
}

if (!values.gcsBucketPath || !values.id) {
if (values.modelDefinition === "artivc") {
if (!values.gcsBucketPath || !values.id) {
return false;
}

return true;
}

if (!values.huggingFaceRepo) {
return false;
}

Expand All @@ -133,6 +146,7 @@ const CreateNewModelFlow: FC = () => {
const createGithubModel = useCreateGithubModel();
const createLocalModel = useCreateLocalModel();
const createArtivcModel = useCreateArtivcModel();
const createHuggingFaceModel = useCreateHuggingFaceModel();

const handelCreateModel = useCallback(
async (values: CreateModelFormValue) => {
Expand Down Expand Up @@ -176,12 +190,12 @@ const CreateNewModelFlow: FC = () => {
}
},
onError: (error) => {
if (error instanceof Error) {
if (error instanceof AxiosError) {
setCreateModelMessageBoxState(() => ({
activate: true,
status: "error",
description: null,
message: error.message,
message: error.response?.data.message ?? error.message,
}));
} else {
setCreateModelMessageBoxState(() => ({
Expand Down Expand Up @@ -224,12 +238,12 @@ const CreateNewModelFlow: FC = () => {
}
},
onError: (error) => {
if (error instanceof Error) {
if (error instanceof AxiosError) {
setCreateModelMessageBoxState(() => ({
activate: true,
status: "error",
description: null,
message: error.message,
message: error.response?.data.message ?? error.message,
}));
} else {
setCreateModelMessageBoxState(() => ({
Expand All @@ -241,7 +255,7 @@ const CreateNewModelFlow: FC = () => {
}
},
});
} else {
} else if (values.modelDefinition === "artivc") {
if (!values.gcsBucketPath) return;

const payload: CreateArtivcModelPayload = {
Expand Down Expand Up @@ -270,12 +284,12 @@ const CreateNewModelFlow: FC = () => {
}
},
onError: (error) => {
if (error instanceof Error) {
if (error instanceof AxiosError) {
setCreateModelMessageBoxState(() => ({
activate: true,
status: "error",
description: null,
message: error.message,
message: error.response?.data.message ?? error.message,
}));
} else {
setCreateModelMessageBoxState(() => ({
Expand All @@ -287,9 +301,62 @@ const CreateNewModelFlow: FC = () => {
}
},
});
} else {
if (!values.huggingFaceRepo) return;

const payload: CreateHuggingFaceModelPayload = {
id: values.id,
model_definition: "model-definitions/huggingface",
desctiption: values.description ?? null,
configuration: {
repo_id: values.huggingFaceRepo,
},
};

createHuggingFaceModel.mutate(payload, {
onSuccess: (newModel) => {
setModelCreated(true);
setNewModel(newModel);
setCreateModelMessageBoxState(() => ({
activate: true,
status: "success",
description: null,
message: "Create succeeded",
}));
if (amplitudeIsInit) {
sendAmplitudeData("create_artivc_model", {
type: "critical_action",
});
}
},
onError: (error) => {
if (error instanceof AxiosError) {
setCreateModelMessageBoxState(() => ({
activate: true,
status: "error",
description: null,
message: error.response?.data.message ?? error.message,
}));
} else {
setCreateModelMessageBoxState(() => ({
activate: true,
status: "error",
description: null,
message:
"Something went wrong when create the HuggingFace model",
}));
}
},
});
}
},
[amplitudeIsInit, createArtivcModel, createGithubModel, createLocalModel]
[
amplitudeIsInit,
createArtivcModel,
createGithubModel,
createLocalModel,
createHuggingFaceModel,
]
);

// ###################################################################
Expand Down Expand Up @@ -377,12 +444,12 @@ const CreateNewModelFlow: FC = () => {
router.push("/models");
},
onError: (error) => {
if (error instanceof Error) {
if (error instanceof AxiosError) {
setDeployModelInstanceMessageBoxState(() => ({
activate: true,
status: "error",
description: null,
message: error.message,
message: error.response?.data.message ?? error.message,
}));
} else {
setDeployModelInstanceMessageBoxState(() => ({
Expand Down Expand Up @@ -518,6 +585,37 @@ const CreateNewModelFlow: FC = () => {
/>
</>
) : null}
{values.modelDefinition === "huggingface" ? (
<>
<TextField
id="huggingFaceRepo"
name="huggingFaceRepo"
label="HuggingFace model ID"
additionalMessageOnLabel={null}
description="The name of a public HuggingFace model ID, e.g. `google/vit-base-patch16-224`."
value={values.huggingFaceRepo}
error={errors.huggingFaceRepo || null}
additionalOnChangeCb={null}
disabled={modelCreated ? true : false}
readOnly={false}
required={true}
placeholder=""
type="text"
autoComplete="off"
/>
<TextArea
id="description"
name="description"
label="Description"
description="Fill with a short description of your new model"
value={values.description}
error={errors.description || null}
disabled={modelCreated ? true : false}
enableCounter={true}
counterWordLimit={1023}
/>
</>
) : null}
<div className="flex flex-row">
<BasicProgressMessageBox
state={createModelMessageBoxState}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ type NewModel = {
description: Nullable<string>;
gcsBucketPath: Nullable<string>;
credentials: Nullable<string>;
huggingFaceRepo: Nullable<string>;
};

type Model = {
Expand Down Expand Up @@ -106,6 +107,7 @@ const CreatePipelineDataSourceForm: FC<StepNumberState> = (props) => {
repo: null,
gcsBucketPath: null,
credentials: null,
huggingFaceRepo: null,
},
existing: {
id: null,
Expand Down
Loading

0 comments on commit 9415cbe

Please sign in to comment.