Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multimodal support #1216

Merged
merged 11 commits into from Dec 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 30 additions & 17 deletions api/types.go
Expand Up @@ -31,15 +31,18 @@ func (e StatusError) Error() string {
}
}

type ImageData []byte

type GenerateRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
System string `json:"system"`
Template string `json:"template"`
Context []int `json:"context,omitempty"`
Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"`
Format string `json:"format"`
Model string `json:"model"`
Prompt string `json:"prompt"`
System string `json:"system"`
Template string `json:"template"`
Context []int `json:"context,omitempty"`
Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"`
Format string `json:"format"`
Images []ImageData `json:"images,omitempty"`

Options map[string]interface{} `json:"options"`
}
Expand Down Expand Up @@ -148,11 +151,12 @@ type ShowRequest struct {
}

type ShowResponse struct {
License string `json:"license,omitempty"`
Modelfile string `json:"modelfile,omitempty"`
Parameters string `json:"parameters,omitempty"`
Template string `json:"template,omitempty"`
System string `json:"system,omitempty"`
License string `json:"license,omitempty"`
Modelfile string `json:"modelfile,omitempty"`
Parameters string `json:"parameters,omitempty"`
Template string `json:"template,omitempty"`
System string `json:"system,omitempty"`
Details ModelDetails `json:"details,omitempty"`
}

type CopyRequest struct {
Expand Down Expand Up @@ -188,10 +192,11 @@ type ListResponse struct {
}

type ModelResponse struct {
Name string `json:"name"`
ModifiedAt time.Time `json:"modified_at"`
Size int64 `json:"size"`
Digest string `json:"digest"`
Name string `json:"name"`
ModifiedAt time.Time `json:"modified_at"`
Size int64 `json:"size"`
Digest string `json:"digest"`
Details ModelDetails `json:"details,omitempty"`
}

type TokenResponse struct {
Expand All @@ -209,6 +214,14 @@ type GenerateResponse struct {
Metrics
}

type ModelDetails struct {
Format string `json:"format"`
Family string `json:"family"`
Families []string `json:"families"`
ParameterSize string `json:"parameter_size"`
QuantizationLevel string `json:"quantization_level"`
}

func (m *Metrics) Summary() {
if m.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
Expand Down
151 changes: 149 additions & 2 deletions cmd/cmd.go
Expand Up @@ -17,7 +17,9 @@ import (
"os/exec"
"os/signal"
"path/filepath"
"regexp"
"runtime"
"slices"
"strings"
"syscall"
"time"
Expand All @@ -36,6 +38,8 @@ import (
"github.com/jmorganca/ollama/version"
)

type ImageData []byte

func CreateHandler(cmd *cobra.Command, args []string) error {
filename, _ := cmd.Flags().GetString("file")
filename, err := filepath.Abs(filename)
Expand Down Expand Up @@ -418,6 +422,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{},
Images: []ImageData{},
}

format, err := cmd.Flags().GetString("format")
Expand All @@ -427,7 +432,6 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
opts.Format = format

prompts := args[1:]

// prepend stdin to the prompt if provided
if !term.IsTerminal(int(os.Stdin.Fd())) {
in, err := io.ReadAll(os.Stdin)
Expand Down Expand Up @@ -466,6 +470,7 @@ type generateOptions struct {
Format string
System string
Template string
Images []ImageData
Options map[string]interface{}
}

Expand Down Expand Up @@ -551,6 +556,10 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
return nil
}

images := make([]api.ImageData, 0)
for _, i := range opts.Images {
images = append(images, api.ImageData(i))
}
request := api.GenerateRequest{
Model: opts.Model,
Prompt: opts.Prompt,
Expand All @@ -559,6 +568,7 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
System: opts.System,
Template: opts.Template,
Options: opts.Options,
Images: images,
}

if err := client.Generate(ctx, &request, fn); err != nil {
Expand All @@ -585,7 +595,9 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
latest.Summary()
}

cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context))
ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
cmd.SetContext(ctx)

return nil
}

Expand All @@ -598,11 +610,31 @@ const (
MultilineTemplate
)

func modelIsMultiModal(cmd *cobra.Command, name string) bool {
// get model details
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return false
}

req := api.ShowRequest{Name: name}
resp, err := client.Show(cmd.Context(), &req)
if err != nil {
return false
}

return slices.Contains(resp.Details.Families, "clip")
}

func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
multiModal := modelIsMultiModal(cmd, opts.Model)

// load the model
loadOpts := generateOptions{
Model: opts.Model,
Prompt: "",
Images: []ImageData{},
}
if err := generate(cmd, loadOpts); err != nil {
return err
Expand Down Expand Up @@ -902,6 +934,26 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {

if len(prompt) > 0 && multiline == MultilineNone {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can abstract this if conditions to a function so it can be easy to test. nice work btw.

opts.Prompt = prompt
if multiModal {
newPrompt, images, err := extractFileNames(prompt)
if err != nil {
return err
}
opts.Prompt = newPrompt

// reset the context if we find another image
if len(images) > 0 {
opts.Images = images
ctx := cmd.Context()
ctx = context.WithValue(ctx, generateContextKey("context"), []int{})
cmd.SetContext(ctx)
}
if len(opts.Images) == 0 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if this is a good idea. At its core, a llava model is still a language model and it's possible to interact with it as just a text completion or chat model

Copy link
Member

@jmorganca jmorganca Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was thinking this too! It works very well sans-image, and modern LLMs seem to blend both together so the user can do either with a single model

Copy link

@igorschlum igorschlum Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it would be transparent for the LLM as the text from the file at the path would be transmeeted to the LLM. I see in the code that the image is transmitted in Base64, so it's why you don't want to use this syntax for XML, csv, json or txt files. For mp3, Base64 is also used.

I'm dreaming to be able to use one way whisper with Ollama and be able to provide mp3 for conversion.
https://huggingface.co/openai/whisper-large-v3

For txt or json, I know I can type in terminal:
ollama run llama2 --verbose please translate in spanish "$(cat /Users/igor/song.txt)"
but I'm obliged to return to terminal to prompt and it's not simple for non Unix users.

When I use this command inside Ollama, I get an error:

ollama run llama2
))) please translate in spanish "$(cat /Users/igor/song.txt)"

The command $(cat /Users/igor/song.txt) is a Unix or Linux command that
uses the cat command to display the contents of a file located at
/Users/igor/song.txt.

It would be nice to prompt
))) please translate in spanish this text /Users/igor/song.txt
I would drag and drop the txt, json or csv file on terminal window. Usage would be much more simple as you do on the video up for jpg.

Anyway, you do a super job and I really enjoy Ollama as you do and improve it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm extremely reluctant to take this out because with the llava model it's pretty much useless until an image is added. I get that you can get it to answer a question, but that feels like a degenerate usecase for the current model. The kicker for me is that there is no indication to the user that they can even add an image.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LLaVA is better at understanding the context of text and images and using that information to answer questions or generate text than Llama2.

if my sample.txt file is:

C'est l'histoire d'un petit chat qui chantait "miaou miaou" tout le temps.

And translate it also in Italian

if my prompt is :

ollama run llama2 --verbose please translate in spanish "$(cat /Users/igor/sample.txt)"

The answer will be:

Spanish: "Es la historia de un pequeño gato que cantaba 'miau miau' todo el tiempo."

Italian: "Questa è la storia di un gattino che cantava 'miao miao' sempre."

The text in the sample.txt file is sent to Ollama and is interpreted as being part of the prompt, not part of the file that must be processed.

By separating the text from the prompt, LLaVA can focus on processing the text content independently, leading to more efficient and accurate responses. This approach could be particularly useful for tasks that require extensive text processing, such as translation or summarization.

fmt.Println("This model requires you to add a jpeg, png, or svg image.\n")
prompt = ""
continue
}
}
if err := generate(cmd, opts); err != nil {
return err
}
Expand All @@ -911,6 +963,57 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
}
}

func normalizeFilePath(fp string) string {
// Define a map of escaped characters and their replacements
replacements := map[string]string{
"\\ ": " ", // Escaped space
"\\(": "(", // Escaped left parenthesis
"\\)": ")", // Escaped right parenthesis
"\\[": "[", // Escaped left square bracket
"\\]": "]", // Escaped right square bracket
"\\{": "{", // Escaped left curly brace
"\\}": "}", // Escaped right curly brace
"\\$": "$", // Escaped dollar sign
"\\&": "&", // Escaped ampersand
"\\;": ";", // Escaped semicolon
"\\'": "'", // Escaped single quote
"\\\\": "\\", // Escaped backslash
"\\*": "*", // Escaped asterisk
"\\?": "?", // Escaped question mark
}

for escaped, actual := range replacements {
fp = strings.ReplaceAll(fp, escaped, actual)
}
return fp
}

func extractFileNames(input string) (string, []ImageData, error) {
// Regex to match file paths starting with / or ./ and include escaped spaces (\ or %20)
// and followed by more characters and a file extension
regexPattern := `(?:\./|/)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b`
re := regexp.MustCompile(regexPattern)

filePaths := re.FindAllString(input, -1)
var imgs []ImageData

for _, fp := range filePaths {
nfp := normalizeFilePath(fp)
data, err := getImageData(nfp)
if err != nil {
if os.IsNotExist(err) {
continue
}
fmt.Printf("Couldn't process image: %q\n", err)
return "", imgs, err
}
fmt.Printf("Added image '%s'\n", nfp)
input = strings.ReplaceAll(input, fp, "")
imgs = append(imgs, data)
}
return input, imgs, nil
}

func RunServer(cmd *cobra.Command, _ []string) error {
host, port, err := net.SplitHostPort(os.Getenv("OLLAMA_HOST"))
if err != nil {
Expand All @@ -937,6 +1040,50 @@ func RunServer(cmd *cobra.Command, _ []string) error {
return server.Serve(ln, origins)
}

func getImageData(filePath string) ([]byte, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, err
}
defer file.Close()

buf := make([]byte, 512)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bytes.Buffer is probably more appropriate. you can do something like this where it just appends to the 512

var b bytes.Buffer

if _, err := io.CopyN(&b, file); err != nil {
  // return err
}

contentType := http.DetectContentType(b.Bytes())
if !slices.Contains(types, contentType) {
  // return err
}

if _, err := io.Copy(&b, file); err != nil {
  // return err
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still have to stat the file and only read if the file is < 100MB. I feel like we're splitting hairs here.

_, err = file.Read(buf)
if err != nil {
return nil, err
}

contentType := http.DetectContentType(buf)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/svg+xml", "image/png"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this defined by the model? Can it recognize other image types, e.g. bmps, gifs

Copy link
Contributor

@mxyng mxyng Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list of possible values returned by DetectContentType is very short but contain some values not seen here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally there would be metadata from the model which told us the list of supported mimetypes. I made the list a few weeks ago from the llava documentation, but we can explore changing this in the future.

if !slices.Contains(allowedTypes, contentType) {
return nil, fmt.Errorf("invalid image type: %s", contentType)
}

info, err := file.Stat()
if err != nil {
return nil, err
}

// Check if the file size exceeds 100MB
var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
if info.Size() > maxSize {
return nil, fmt.Errorf("file size exceeds maximum limit (100MB).")
}

buf = make([]byte, info.Size())
_, err = file.Seek(0, 0)
if err != nil {
return nil, err
}

_, err = io.ReadFull(file, buf)
if err != nil {
return nil, err
}

return buf, nil
}

func initializeKeypair() error {
home, err := os.UserHomeDir()
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions docs/modelfile.md
Expand Up @@ -150,6 +150,7 @@ PARAMETER <parameter> <parametervalue>
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be in a follow up PR, we should add a section in api.md for this 😃

| top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |


### TEMPLATE

`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system prompt and a user's prompt. This is used to create a full custom prompt, and syntax may be model specific. You can usually find the template for a given model in the readme for that model.
Expand Down
16 changes: 16 additions & 0 deletions llm/llama.go
Expand Up @@ -223,8 +223,14 @@ type Running struct {
*StatusWriter // captures error messages from the llama runner process
}

type ImageData struct {
Data []byte `json:"data"`
ID int `json:"id"`
}

type llama struct {
api.Options
ImageData []ImageData
Running
}

Expand Down Expand Up @@ -547,6 +553,7 @@ const maxBufferSize = 512 * format.KiloByte
type PredictOpts struct {
Prompt string
Format string
Images []api.ImageData
CheckpointStart time.Time
CheckpointLoaded time.Time
}
Expand All @@ -564,6 +571,14 @@ type PredictResult struct {
}

func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
imageData := llm.ImageData
if len(predict.Images) > 0 {
for cnt, i := range predict.Images {
imageData = append(imageData, ImageData{Data: i, ID: cnt})
}
}
log.Printf("loaded %d images", len(imageData))

request := map[string]any{
"prompt": predict.Prompt,
"stream": true,
Expand All @@ -585,6 +600,7 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
"penalize_nl": llm.PenalizeNewline,
"seed": llm.Seed,
"stop": llm.Stop,
"image_data": imageData,
}

if predict.Format == "json" {
Expand Down
3 changes: 3 additions & 0 deletions server/images.go
Expand Up @@ -46,6 +46,7 @@ type Model struct {
System string
License []string
Digest string
Size int64
Options map[string]interface{}
}

Expand Down Expand Up @@ -242,6 +243,7 @@ func GetModel(name string) (*Model, error) {
Digest: digest,
Template: "{{ .Prompt }}",
License: []string{},
Size: manifest.GetTotalSize(),
}

filename, err := GetBlobsPath(manifest.Config.Digest)
Expand Down Expand Up @@ -545,6 +547,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
}

// xxx - can this be removed?
if config.ModelType == "65B" {
if gqa, ok := formattedParams["gqa"].(int); ok && gqa == 8 {
config.ModelType = "70B"
Expand Down