Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement hugging face definition #138

Merged
merged 10 commits into from
Jul 11, 2022
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
},
"dependencies": {
"@amplitude/analytics-browser": "^0.4.1",
"@instill-ai/design-system": "^0.0.85",
"@instill-ai/design-system": "^0.0.88",
"@types/json-schema": "^7.0.11",
"axios": "^0.27.2",
"clsx": "^1.1.1",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ const ConfigureModelForm: FC<ConfigureModelFormProps> = ({
modalState.setModalIsOpen(false);
},
});
}, [model, amplitudeIsInit, router, deleteModel]);
}, [model, amplitudeIsInit, router, deleteModel, modalState]);

return (
<>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import {
ProgressMessageBoxState,
} from "@instill-ai/design-system";
import { TestModelInstanceSection } from "@/components/sections";
import { useModelDefinition } from "@/services/model";

export type ConfigureModelInstanceFormProps = {
modelInstance: ModelInstance;
Expand Down
120 changes: 109 additions & 11 deletions src/components/forms/model/CreateModelForm/CreateModelForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
import {
useCreateArtivcModel,
useCreateGithubModel,
useCreateHuggingFaceModel,
useCreateLocalModel,
useDeployModelInstance,
useModelDefinitions,
Expand All @@ -27,12 +28,14 @@ import { ModelDefinitionIcon, PrimaryButton } from "@/components/ui";
import {
CreateArtivcModelPayload,
CreateGithubModelPayload,
CreateHuggingFaceModelPayload,
CreateLocalModelPayload,
Model,
} from "@/lib/instill";
import { Nullable } from "@/types/general";
import { useAmplitudeCtx } from "context/AmplitudeContext";
import { sendAmplitudeData } from "@/lib/amplitude";
import { AxiosError } from "axios";

export type CreateModelFormValue = {
id: Nullable<string>;
Expand All @@ -43,6 +46,8 @@ export type CreateModelFormValue = {
description: Nullable<string>;
gcsBucketPath: Nullable<string>;
credentials: Nullable<string>;
huggingFaceRepo: Nullable<string>;
huggingFaceUrl: Nullable<string>;
};

const CreateNewModelFlow: FC = () => {
Expand Down Expand Up @@ -121,7 +126,15 @@ const CreateNewModelFlow: FC = () => {
return true;
}

if (!values.gcsBucketPath || !values.id) {
if (values.modelDefinition === "artivc") {
if (!values.gcsBucketPath || !values.id) {
return false;
}

return true;
}

if (!values.huggingFaceRepo) {
return false;
}

Expand All @@ -133,6 +146,7 @@ const CreateNewModelFlow: FC = () => {
const createGithubModel = useCreateGithubModel();
const createLocalModel = useCreateLocalModel();
const createArtivcModel = useCreateArtivcModel();
const createHuggingFaceModel = useCreateHuggingFaceModel();

const handelCreateModel = useCallback(
async (values: CreateModelFormValue) => {
Expand Down Expand Up @@ -176,12 +190,12 @@ const CreateNewModelFlow: FC = () => {
}
},
onError: (error) => {
if (error instanceof Error) {
if (error instanceof AxiosError) {
setCreateModelMessageBoxState(() => ({
activate: true,
status: "error",
description: null,
message: error.message,
message: error.response?.data.message ?? error.message,
}));
} else {
setCreateModelMessageBoxState(() => ({
Expand Down Expand Up @@ -224,12 +238,12 @@ const CreateNewModelFlow: FC = () => {
}
},
onError: (error) => {
if (error instanceof Error) {
if (error instanceof AxiosError) {
setCreateModelMessageBoxState(() => ({
activate: true,
status: "error",
description: null,
message: error.message,
message: error.response?.data.message ?? error.message,
}));
} else {
setCreateModelMessageBoxState(() => ({
Expand All @@ -241,7 +255,7 @@ const CreateNewModelFlow: FC = () => {
}
},
});
} else {
} else if (values.modelDefinition === "artivc") {
if (!values.gcsBucketPath) return;

const payload: CreateArtivcModelPayload = {
Expand Down Expand Up @@ -270,12 +284,12 @@ const CreateNewModelFlow: FC = () => {
}
},
onError: (error) => {
if (error instanceof Error) {
if (error instanceof AxiosError) {
setCreateModelMessageBoxState(() => ({
activate: true,
status: "error",
description: null,
message: error.message,
message: error.response?.data.message ?? error.message,
}));
} else {
setCreateModelMessageBoxState(() => ({
Expand All @@ -287,9 +301,62 @@ const CreateNewModelFlow: FC = () => {
}
},
});
} else {
if (!values.huggingFaceRepo) return;

const payload: CreateHuggingFaceModelPayload = {
id: values.id,
model_definition: "model-definitions/huggingface",
desctiption: values.description ?? null,
configuration: {
repo_id: values.huggingFaceRepo,
},
};

createHuggingFaceModel.mutate(payload, {
onSuccess: (newModel) => {
setModelCreated(true);
setNewModel(newModel);
setCreateModelMessageBoxState(() => ({
activate: true,
status: "success",
description: null,
message: "Create succeeded",
}));
if (amplitudeIsInit) {
sendAmplitudeData("create_artivc_model", {
type: "critical_action",
});
}
},
onError: (error) => {
if (error instanceof AxiosError) {
setCreateModelMessageBoxState(() => ({
activate: true,
status: "error",
description: null,
message: error.response?.data.message ?? error.message,
}));
} else {
setCreateModelMessageBoxState(() => ({
activate: true,
status: "error",
description: null,
message:
"Something went wrong when create the HuggingFace model",
}));
}
},
});
}
},
[amplitudeIsInit, createArtivcModel, createGithubModel, createLocalModel]
[
amplitudeIsInit,
createArtivcModel,
createGithubModel,
createLocalModel,
createHuggingFaceModel,
]
);

// ###################################################################
Expand Down Expand Up @@ -377,12 +444,12 @@ const CreateNewModelFlow: FC = () => {
router.push("/models");
},
onError: (error) => {
if (error instanceof Error) {
if (error instanceof AxiosError) {
setDeployModelInstanceMessageBoxState(() => ({
activate: true,
status: "error",
description: null,
message: error.message,
message: error.response?.data.message ?? error.message,
}));
} else {
setDeployModelInstanceMessageBoxState(() => ({
Expand Down Expand Up @@ -518,6 +585,37 @@ const CreateNewModelFlow: FC = () => {
/>
</>
) : null}
{values.modelDefinition === "huggingface" ? (
<>
<TextField
id="huggingFaceRepo"
name="huggingFaceRepo"
label="HuggingFace model ID"
additionalMessageOnLabel={null}
description="The name of a public HuggingFace model ID, e.g. `google/vit-base-patch16-224`."
value={values.huggingFaceRepo}
error={errors.huggingFaceRepo || null}
additionalOnChangeCb={null}
disabled={modelCreated ? true : false}
readOnly={false}
required={true}
placeholder=""
type="text"
autoComplete="off"
/>
<TextArea
id="description"
name="description"
label="Description"
description="Fill with a short description of your new model"
value={values.description}
error={errors.description || null}
disabled={modelCreated ? true : false}
enableCounter={true}
counterWordLimit={1023}
/>
</>
) : null}
<div className="flex flex-row">
<BasicProgressMessageBox
state={createModelMessageBoxState}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ type NewModel = {
description: Nullable<string>;
gcsBucketPath: Nullable<string>;
credentials: Nullable<string>;
huggingFaceRepo: Nullable<string>;
};

type Model = {
Expand Down Expand Up @@ -106,6 +107,7 @@ const CreatePipelineDataSourceForm: FC<StepNumberState> = (props) => {
repo: null,
gcsBucketPath: null,
credentials: null,
huggingFaceRepo: null,
},
existing: {
id: null,
Expand Down
Loading