Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] add mistral api support #20

Merged
merged 5 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Unlike [gp.nvim](https://github.com/Robitx/gp.nvim), [parrot.nvim](https://githu
+ [Anthropic API](https://www.anthropic.com/api)
+ [perplexity.ai API](https://blog.perplexity.ai/blog/introducing-pplx-api)
+ [OpenAI API](https://platform.openai.com/)
+ [Mistral API](https://docs.mistral.ai/api/)
+ Local and offline serving via [ollama](https://github.com/ollama/ollama)
- Custom agent definitions to determine specific prompt and API parameter combinations, similar to [GPTs](https://openai.com/index/introducing-gpts/)
- Flexible support for providing API credentials from various sources, such as environment variables, bash commands, and your favorite password manager CLI
Expand Down Expand Up @@ -84,6 +85,9 @@ Let the parrot fix your bugs.
anthropic = {
api_key = os.getenv "ANTHROPIC_API_KEY",
},
mistral = {
api_key = os.getenv "MISTRAL_API_KEY",
},
},
}
end,
Expand Down Expand Up @@ -203,16 +207,16 @@ require("parrot").setup {
CompleteFullContext = function(prt, params)
local template = [[
I have the following code from {{filename}}:

```{{filetype}}
{{filecontent}}
```

Please look at the following section specifically:
```{{filetype}}
{{selection}}
```

Please finish the code above carefully and logically.
Respond just with the snippet of code that should be inserted.
]]
Expand Down
106 changes: 106 additions & 0 deletions lua/parrot/agents.lua
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,56 @@ local anthropic_chat_agents = {
provider = "anthropic",
},
}
local mistral_chat_agents = {
{
name = "Codestral",
model = { model = "codestral-latest", temperature = 1.5, top_p = 1 },
system_prompt = system_chat_prompt,
provider = "mistral",
},
{
name = "Mistral-Tiny",
model = { model = "mistral-tiny", temperature = 1.5, top_p = 1 },
system_prompt = system_chat_prompt,
provider = "mistral",
},
{
name = "Mistral-Small",
model = { model = "mistral-small-latest", temperature = 1.5, top_p = 1 },
system_prompt = system_chat_prompt,
provider = "mistral",
},
{
name = "Mistral-Medium",
model = { model = "mistral-medium-latest", temperature = 1.5, top_p = 1 },
system_prompt = system_chat_prompt,
provider = "mistral",
},
{
name = "Mistral-Large",
model = { model = "mistral-large-latest", temperature = 1.5, top_p = 1 },
system_prompt = system_chat_prompt,
provider = "mistral",
},
{
name = "Open-Mistral-7B",
model = { model = "open-mistral-7b", temperature = 1.5, top_p = 1 },
system_prompt = system_chat_prompt,
provider = "mistral",
},
{
name = "Open-Mixtral-8x7B",
model = { model = "open-mixtral-8x7b", temperature = 1.5, top_p = 1 },
system_prompt = system_chat_prompt,
provider = "mistral",
},
{
name = "Open-Mixtral-8x22B",
model = { model = "open-mixtral-8x22b", temperature = 1.5, top_p = 1 },
system_prompt = system_chat_prompt,
provider = "mistral",
},
}

local ollama_command_agents = {
{
Expand Down Expand Up @@ -253,6 +303,56 @@ local anthropic_command_agents = {
provider = "anthropic",
},
}
local mistral_command_agents = {
{
name = "Codestral",
model = { model = "codestral-latest", temperature = 1.5, top_p = 1 },
system_prompt = system_code_prompt,
provider = "mistral",
},
{
name = "Mistral-Tiny",
model = { model = "mistral-tiny", temperature = 1.5, top_p = 1 },
system_prompt = system_code_prompt,
provider = "mistral",
},
{
name = "Mistral-Small",
model = { model = "mistral-small-latest", temperature = 1.5, top_p = 1 },
system_prompt = system_code_prompt,
provider = "mistral",
},
{
name = "Mistral-Medium",
model = { model = "mistral-medium-latest", temperature = 1.5, top_p = 1 },
system_prompt = system_code_prompt,
provider = "mistral",
},
{
name = "Mistral-Large",
model = { model = "mistral-large-latest", temperature = 1.5, top_p = 1 },
system_prompt = system_code_prompt,
provider = "mistral",
},
{
name = "Open-Mistral-7B",
model = { model = "open-mistral-7b", temperature = 1.5, top_p = 1 },
system_prompt = system_code_prompt,
provider = "mistral",
},
{
name = "Open-Mixtral-8x7B",
model = { model = "open-mixtral-8x7b", temperature = 1.5, top_p = 1 },
system_prompt = system_code_prompt,
provider = "mistral",
},
{
name = "Open-Mixtral-8x22B",
model = { model = "open-mixtral-8x22b", temperature = 1.5, top_p = 1 },
system_prompt = system_code_prompt,
provider = "mistral",
},
}

local M = {}

Expand All @@ -269,6 +369,9 @@ end
for _, agent in ipairs(anthropic_chat_agents) do
table.insert(M.chat_agents, agent)
end
for _, agent in ipairs(mistral_chat_agents) do
table.insert(M.chat_agents, agent)
end

M.command_agents = {}
for _, agent in ipairs(ollama_command_agents) do
Expand All @@ -283,5 +386,8 @@ end
for _, agent in ipairs(anthropic_command_agents) do
table.insert(M.command_agents, agent)
end
for _, agent in ipairs(mistral_command_agents) do
table.insert(M.command_agents, agent)
end

return M
10 changes: 10 additions & 0 deletions lua/parrot/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ local config = {
topic_prompt = "You only respond with 2 to 3 words to summarize the past conversation.",
topic_model = { model = "claude-3-sonnet-20240229", max_tokens = 32, system = topic_prompt },
},
mistral = {
api_key = "",
endpoint = "https://api.mistral.ai/v1/chat/completions",
topic_prompt = [[
Summarize the chat above and only provide a short headline of 2 to 3
words without any opening phrase like "Sure, here is the summary",
"Sure! Here's a shortheadline summarizing the chat" or anything similar.
]],
topic_model = "mistral-medium-latest",
},
},
-- prefix for all commands
cmd_prefix = "Prt",
Expand Down
1 change: 1 addition & 0 deletions lua/parrot/health.lua
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ function M.check()
check_provider(parrot, "openai")
check_provider(parrot, "ollama")
check_provider(parrot, "pplx")
check_provider(parrot, "mistral")
end

for _, name in ipairs({ "curl", "grep", "rg", "ln" }) do
Expand Down
2 changes: 1 addition & 1 deletion lua/parrot/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -1704,7 +1704,7 @@ M.Prompt = function(params, target, prompt, model, template, system_template, ag
sys_prompt = sys_prompt or ""
local prov = M.get_provider()
if prov.name ~= agent_provider then
M.logger.error("Missmatch of agent and current provider " .. prov.name .. " and " .. agent_provider)
M.logger.error("Mismatch of agent and current provider " .. prov.name .. " and " .. agent_provider)
return
end
messages = prov:add_system_prompt(messages, sys_prompt)
Expand Down
2 changes: 2 additions & 0 deletions lua/parrot/provider/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ local Ollama = require("parrot.provider.ollama")
local OpenAI = require("parrot.provider.openai")
local Anthropic = require("parrot.provider.anthropic")
local Perplexity = require("parrot.provider.perplexity")
local Mistral = require("parrot.provider.mistral")

local M = {
logger = require("parrot.logger"),
Expand All @@ -17,6 +18,7 @@ M.init_provider = function(prov_name, endpoint, api_key)
openai = OpenAI,
anthropic = Anthropic,
pplx = Perplexity,
mistral = Mistral,
}

local ProviderClass = providers[prov_name]
Expand Down
68 changes: 68 additions & 0 deletions lua/parrot/provider/mistral.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
local logger = require("parrot.logger")

local Mistral = {}
Mistral.__index = Mistral

local available_model_set = {
["codestral-latest"] = true,
["mistral-tiny"] = true,
["mistral-small-latest"] = true,
["mistral-medium-latest"] = true,
["mistral-large-latest"] = true,
["open-mistral-7b"] = true,
["open-mixtral-8x7b"] = true,
["open-mixtral-8x22b"] = true,
}

function Mistral:new(endpoint, api_key)
return setmetatable({
endpoint = endpoint,
api_key = api_key,
name = "mistral",
}, self)
end

function Mistral:curl_params()
return {
self.endpoint,
"-H",
"authorization: Bearer " .. self.api_key,
}
end

function Mistral:verify()
if type(self.api_key) == "table" then
logger.error("api_key is still an unresolved command: " .. vim.inspect(self.api_key))
return false
elseif self.api_key and string.match(self.api_key, "%S") then
return true
else
logger.error("Error with api key " .. self.name .. " " .. vim.inspect(self.api_key) .. " run :checkhealth parrot")
return false
end
end

function Mistral:preprocess_messages(messages)
return messages
end

function Mistral:add_system_prompt(messages, sys_prompt)
if sys_prompt ~= "" then
table.insert(messages, { role = "system", content = sys_prompt })
end
return messages
end

function Mistral:process(line)
if line:match("chat%.completion%.chunk") or line:match("chat%.completion") then
line = vim.json.decode(line)
return line.choices[1].delta.content
end
end

function Mistral:check(agent)
local model = type(agent.model) == "string" and agent.model or agent.model.model
return available_model_set[model]
end

return Mistral