Skip to content

Commit

Permalink
feat: Add support for new API types in SaveModelFromInputedUrlRequest
Browse files Browse the repository at this point in the history
  • Loading branch information
n4ze3m committed May 16, 2024
1 parent be8b848 commit ddd6590
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 124 deletions.
169 changes: 46 additions & 123 deletions server/src/handlers/api/v1/admin/model.handler.ts
Original file line number Diff line number Diff line change
@@ -1,131 +1,19 @@
// this the dedicated file for the model handler for the admin route
// why i do this? because i want to make the code more readable and easy to maintain
import { FastifyReply, FastifyRequest } from "fastify";
import {
FetchModelFromInputedUrlRequest,
SaveEmbeddingModelRequest,
SaveModelFromInputedUrlRequest,
ToogleModelRequest,
} from "./type";
import axios from "axios";
import { removeTrailingSlash } from "../../../../utils/url";
import { getSettings } from "../../../../utils/common";

const _getModelFromUrl = async (url: string, apiKey?: string) => {
try {
const response = await axios.get(`${url}/models`, {
headers: {
"HTTP-Referer":
process.env.LOCAL_REFER_URL || "https://dialoqbase.n4ze3m.com/",
"X-Title": process.env.LOCAL_TITLE || "Dialoqbase",
Authorization: apiKey && `Bearer ${apiKey}`,
},
});
return response.data;
} catch (error) {
console.error(error);
return null;
}
};

const _getOllamaModels = async (url: string) => {
try {
const response = await axios.get(`${url}/api/tags`);
const { models } = response.data as {
models: {
name: string;
}[];
};
return models.map((data) => {
return {
id: data.name,
object: data.name,
};
});
} catch (error) {
return null;
}
};

const _isReplicateModelExist = async (model_id: string, token: string) => {
try {
const url = "https://api.replicate.com/v1/models/";
const isVersionModel = model_id.split(":").length > 1;
if (!isVersionModel) {
const res = await axios.get(`${url}${model_id}`, {
headers: {
Authorization: `Token ${token}`,
},
});
const data = res.data;
return {
success: true,
message: "Model found",
name: data.name,
};
} else {
const [owner, model_name] = model_id.split("/");
const version = model_name.split(":")[1];
const res = await axios.get(
`${url}${owner}/${model_name.split(":")[0]}/versions/${version}`,
{
headers: {
Authorization: `Token ${token}`,
},
}
);

const data = res.data;

return {
success: true,
message: "Model found",
name: data.name,
};
}
} catch (error) {
if (axios.isAxiosError(error)) {
console.error(error.response?.data);
if (error.response?.status === 404) {
return {
success: false,
message: "Model not found",
name: undefined,
};
} else if (error.response?.status === 401) {
return {
success: false,
message: "Unauthorized",
name: undefined,
};
} else if (error.response?.status === 403) {
return {
success: false,
message: "Forbidden",
name: undefined,
};
} else if (error.response?.status === 500) {
return {
success: false,
message: "Internal Server Error",
name: undefined,
};
} else {
return {
success: false,
message: "Internal Server Error",
name: undefined,
};
}
} else {
return {
success: false,
message: "Internal Server Error",
name: undefined,
};
}
}
};
import {
getModelFromUrl,
getOllamaModels,
isReplicateModelExist,
isValidModel,
modelProvider,
} from "./utils";

export const getAllModelsHandler = async (
request: FastifyRequest,
Expand All @@ -143,7 +31,7 @@ export const getAllModelsHandler = async (
const settings = await getSettings(prisma);

const not_to_hide_providers = settings?.hideDefaultModels
? [ "Local", "local", "ollama", "transformer", "Transformer"]
? ["Local", "local", "ollama", "transformer", "Transformer"]
: undefined;
const allModels = await prisma.dialoqbaseModels.findMany({
where: {
Expand Down Expand Up @@ -180,7 +68,7 @@ export const fetchModelFromInputedUrlHandler = async (
}

if (api_type === "ollama") {
const models = await _getOllamaModels(removeTrailingSlash(ollama_url!));
const models = await getOllamaModels(removeTrailingSlash(ollama_url!));

if (!models) {
return reply.status(404).send({
Expand All @@ -193,7 +81,7 @@ export const fetchModelFromInputedUrlHandler = async (
data: models,
};
} else if (api_type === "openai") {
const model = await _getModelFromUrl(removeTrailingSlash(url!), api_key);
const model = await getModelFromUrl(removeTrailingSlash(url!), api_key);

if (!model) {
return reply.status(404).send({
Expand All @@ -205,6 +93,10 @@ export const fetchModelFromInputedUrlHandler = async (
return {
data: Array.isArray(model) ? model : model.data,
};
} else {
return reply.status(400).send({
message: "Invalid api type",
});
}
} catch (error) {
console.error(error);
Expand Down Expand Up @@ -232,7 +124,7 @@ export const saveModelFromInputedUrlHandler = async (
request.body;

if (api_type === "replicate") {
const isModelExist = await _isReplicateModelExist(model_id, api_key!);
const isModelExist = await isReplicateModelExist(model_id, api_key!);

if (!isModelExist.success) {
return reply.status(404).send({
Expand Down Expand Up @@ -269,6 +161,37 @@ export const saveModelFromInputedUrlHandler = async (
},
});

return {
message: "success",
};
} else if (
api_type === "openai-api" ||
api_type === "google" ||
api_type === "anthropic" ||
api_type === "groq"
) {
const provider = modelProvider[api_type];
console.log(provider, "provider");
const validModel = await isValidModel(model_id.trim(), provider);

if (!validModel) {
return reply.status(404).send({
message: `Model not found for the given model_id ${model_id}`,
});
}

let newModelId = model_id.trim() + `_dialoqbase_${new Date().getTime()}`;

await prisma.dialoqbaseModels.create({
data: {
name: name,
model_id: newModelId,
stream_available: true,
local_model: true,
model_provider: provider,
},
});

return {
message: "success",
};
Expand Down
10 changes: 9 additions & 1 deletion server/src/handlers/api/v1/admin/type.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,15 @@ export type SaveModelFromInputedUrlRequest = {
name: string;
stream_available: boolean;
api_key?: string;
api_type: "openai" | "ollama" | "replicate";
api_type:
| "openai"
| "ollama"
| "replicate"
| "openai-api"
| "google"
| "anthropic"
| "groq"
;
};
};

Expand Down
142 changes: 142 additions & 0 deletions server/src/handlers/api/v1/admin/utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import axios from "axios";
import { chatModelProvider } from "../../../../utils/models";

export const isReplicateModelExist = async (
model_id: string,
token: string
) => {
try {
const url = "https://api.replicate.com/v1/models/";
const isVersionModel = model_id.split(":").length > 1;
if (!isVersionModel) {
const res = await axios.get(`${url}${model_id}`, {
headers: {
Authorization: `Token ${token}`,
},
});
const data = res.data;
return {
success: true,
message: "Model found",
name: data.name,
};
} else {
const [owner, model_name] = model_id.split("/");
const version = model_name.split(":")[1];
const res = await axios.get(
`${url}${owner}/${model_name.split(":")[0]}/versions/${version}`,
{
headers: {
Authorization: `Token ${token}`,
},
}
);

const data = res.data;

return {
success: true,
message: "Model found",
name: data.name,
};
}
} catch (error) {
if (axios.isAxiosError(error)) {
console.error(error.response?.data);
if (error.response?.status === 404) {
return {
success: false,
message: "Model not found",
name: undefined,
};
} else if (error.response?.status === 401) {
return {
success: false,
message: "Unauthorized",
name: undefined,
};
} else if (error.response?.status === 403) {
return {
success: false,
message: "Forbidden",
name: undefined,
};
} else if (error.response?.status === 500) {
return {
success: false,
message: "Internal Server Error",
name: undefined,
};
} else {
return {
success: false,
message: "Internal Server Error",
name: undefined,
};
}
} else {
return {
success: false,
message: "Internal Server Error",
name: undefined,
};
}
}
};

export const getModelFromUrl = async (url: string, apiKey?: string) => {
try {
const response = await axios.get(`${url}/models`, {
headers: {
"HTTP-Referer":
process.env.LOCAL_REFER_URL || "https://dialoqbase.n4ze3m.com/",
"X-Title": process.env.LOCAL_TITLE || "Dialoqbase",
Authorization: apiKey && `Bearer ${apiKey}`,
},
});
return response.data;
} catch (error) {
console.error(error);
return null;
}
};

export const getOllamaModels = async (url: string) => {
try {
const response = await axios.get(`${url}/api/tags`);
const { models } = response.data as {
models: {
name: string;
}[];
};
return models.map((data) => {
return {
id: data.name,
object: data.name,
};
});
} catch (error) {
return null;
}
};

export const modelProvider = {
"openai-api": "OpenAI",
anthropic: "Anthropic",
google: "Google",
groq: "Groq",
};

export const isValidModel = async (
model_id: string,
model_provider: string
) => {
try {
const model = chatModelProvider(model_provider, model_id, 0.7, {});
const chat = await model.invoke("Hello");
return chat !== null;
} catch (error) {
console.error(error);
return false;
}
};

0 comments on commit ddd6590

Please sign in to comment.