Skip to content

Commit

Permalink
Add new parameters to embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
hauselin committed May 14, 2024
1 parent b9eec50 commit d1a9518
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 3 deletions.
3 changes: 3 additions & 0 deletions R/model_options.R
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ search_options <- function(query) {
#' validate_options(mirostat = 1, mirostat_eta = 0.2, invalid_opt = 1024)
validate_options <- function(...) {
opts <- list(...)
if (length(opts) == 0) {
return(TRUE)
}
opts_validity <- check_options(names(opts))
if (length(opts_validity$invalid_options > 0)) {
invalid <- opts_validity$invalid_options
Expand Down
17 changes: 15 additions & 2 deletions R/ollama.R
Original file line number Diff line number Diff line change
Expand Up @@ -291,18 +291,31 @@ normalize <- function(x) {
#' @param model A character string of the model name such as "llama3".
#' @param prompt A character string of the prompt that you want to get the vector embedding for.
#' @param normalize Normalize the vector to length 1. Default is TRUE.
#' @param keep_alive The time to keep the connection alive. Default is "5m" (5 minutes).
#' @param endpoint The endpoint to get the vector embedding. Default is "/api/embeddings".
#' @param ... Additional options to pass to the model.
#'
#' @return A numeric vector of the embedding.
#' @export
#'
#' @examplesIf test_connection()$status_code == 200
#' embeddings("nomic-embed-text:latest", "The quick brown fox jumps over the lazy dog.")
embeddings <- function(model, prompt, normalize = TRUE, endpoint = "/api/embeddings") {
embeddings <- function(model, prompt, normalize = TRUE, keep_alive = "5m", endpoint = "/api/embeddings", ...) {
req <- create_request(endpoint)
req <- httr2::req_method(req, "POST")

body_json <- list(model = model, prompt = prompt)
opts <- list(...)
if (length(opts) == 0) {
body_json <- list(model = model, prompt = prompt, keep_alive = keep_alive)
} else {
if (validate_options(...)) {
body_json <- list(model = model, prompt = prompt, keep_alive = keep_alive, options = opts)
} else {
stop("Invalid model options passed to ... argument. Please check the model options and try again.")
}
}

# body_json <- list(model = model, prompt = prompt, keep_alive = keep_alive)
req <- httr2::req_body_json(req, body_json)

tryCatch({
Expand Down
13 changes: 12 additions & 1 deletion man/embeddings.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit d1a9518

Please sign in to comment.