Skip to content

Commit

Permalink
Multimodal support (#1216)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Matt Apperson <mattapperson@Matts-MacBook-Pro.local>
  • Loading branch information
pdevine and Matt Apperson committed Dec 11, 2023
1 parent 7a1b37a commit 910e940
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 28 deletions.
47 changes: 30 additions & 17 deletions api/types.go
Expand Up @@ -31,15 +31,18 @@ func (e StatusError) Error() string {
}
}

type ImageData []byte

type GenerateRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
System string `json:"system"`
Template string `json:"template"`
Context []int `json:"context,omitempty"`
Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"`
Format string `json:"format"`
Model string `json:"model"`
Prompt string `json:"prompt"`
System string `json:"system"`
Template string `json:"template"`
Context []int `json:"context,omitempty"`
Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"`
Format string `json:"format"`
Images []ImageData `json:"images,omitempty"`

Options map[string]interface{} `json:"options"`
}
Expand Down Expand Up @@ -148,11 +151,12 @@ type ShowRequest struct {
}

type ShowResponse struct {
License string `json:"license,omitempty"`
Modelfile string `json:"modelfile,omitempty"`
Parameters string `json:"parameters,omitempty"`
Template string `json:"template,omitempty"`
System string `json:"system,omitempty"`
License string `json:"license,omitempty"`
Modelfile string `json:"modelfile,omitempty"`
Parameters string `json:"parameters,omitempty"`
Template string `json:"template,omitempty"`
System string `json:"system,omitempty"`
Details ModelDetails `json:"details,omitempty"`
}

type CopyRequest struct {
Expand Down Expand Up @@ -188,10 +192,11 @@ type ListResponse struct {
}

type ModelResponse struct {
Name string `json:"name"`
ModifiedAt time.Time `json:"modified_at"`
Size int64 `json:"size"`
Digest string `json:"digest"`
Name string `json:"name"`
ModifiedAt time.Time `json:"modified_at"`
Size int64 `json:"size"`
Digest string `json:"digest"`
Details ModelDetails `json:"details,omitempty"`
}

type TokenResponse struct {
Expand All @@ -209,6 +214,14 @@ type GenerateResponse struct {
Metrics
}

type ModelDetails struct {
Format string `json:"format"`
Family string `json:"family"`
Families []string `json:"families"`
ParameterSize string `json:"parameter_size"`
QuantizationLevel string `json:"quantization_level"`
}

func (m *Metrics) Summary() {
if m.TotalDuration > 0 {
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
Expand Down
151 changes: 149 additions & 2 deletions cmd/cmd.go
Expand Up @@ -17,7 +17,9 @@ import (
"os/exec"
"os/signal"
"path/filepath"
"regexp"
"runtime"
"slices"
"strings"
"syscall"
"time"
Expand All @@ -36,6 +38,8 @@ import (
"github.com/jmorganca/ollama/version"
)

type ImageData []byte

func CreateHandler(cmd *cobra.Command, args []string) error {
filename, _ := cmd.Flags().GetString("file")
filename, err := filepath.Abs(filename)
Expand Down Expand Up @@ -418,6 +422,7 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{},
Images: []ImageData{},
}

format, err := cmd.Flags().GetString("format")
Expand All @@ -427,7 +432,6 @@ func RunGenerate(cmd *cobra.Command, args []string) error {
opts.Format = format

prompts := args[1:]

// prepend stdin to the prompt if provided
if !term.IsTerminal(int(os.Stdin.Fd())) {
in, err := io.ReadAll(os.Stdin)
Expand Down Expand Up @@ -466,6 +470,7 @@ type generateOptions struct {
Format string
System string
Template string
Images []ImageData
Options map[string]interface{}
}

Expand Down Expand Up @@ -551,6 +556,10 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
return nil
}

images := make([]api.ImageData, 0)
for _, i := range opts.Images {
images = append(images, api.ImageData(i))
}
request := api.GenerateRequest{
Model: opts.Model,
Prompt: opts.Prompt,
Expand All @@ -559,6 +568,7 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
System: opts.System,
Template: opts.Template,
Options: opts.Options,
Images: images,
}

if err := client.Generate(ctx, &request, fn); err != nil {
Expand All @@ -585,7 +595,9 @@ func generate(cmd *cobra.Command, opts generateOptions) error {
latest.Summary()
}

cmd.SetContext(context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context))
ctx = context.WithValue(cmd.Context(), generateContextKey("context"), latest.Context)
cmd.SetContext(ctx)

return nil
}

Expand All @@ -598,11 +610,31 @@ const (
MultilineTemplate
)

func modelIsMultiModal(cmd *cobra.Command, name string) bool {
// get model details
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return false
}

req := api.ShowRequest{Name: name}
resp, err := client.Show(cmd.Context(), &req)
if err != nil {
return false
}

return slices.Contains(resp.Details.Families, "clip")
}

func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
multiModal := modelIsMultiModal(cmd, opts.Model)

// load the model
loadOpts := generateOptions{
Model: opts.Model,
Prompt: "",
Images: []ImageData{},
}
if err := generate(cmd, loadOpts); err != nil {
return err
Expand Down Expand Up @@ -902,6 +934,26 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {

if len(prompt) > 0 && multiline == MultilineNone {
opts.Prompt = prompt
if multiModal {
newPrompt, images, err := extractFileNames(prompt)
if err != nil {
return err
}
opts.Prompt = newPrompt

// reset the context if we find another image
if len(images) > 0 {
opts.Images = images
ctx := cmd.Context()
ctx = context.WithValue(ctx, generateContextKey("context"), []int{})
cmd.SetContext(ctx)
}
if len(opts.Images) == 0 {
fmt.Println("This model requires you to add a jpeg, png, or svg image.\n")
prompt = ""
continue
}
}
if err := generate(cmd, opts); err != nil {
return err
}
Expand All @@ -911,6 +963,57 @@ func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
}
}

func normalizeFilePath(fp string) string {
// Define a map of escaped characters and their replacements
replacements := map[string]string{
"\\ ": " ", // Escaped space
"\\(": "(", // Escaped left parenthesis
"\\)": ")", // Escaped right parenthesis
"\\[": "[", // Escaped left square bracket
"\\]": "]", // Escaped right square bracket
"\\{": "{", // Escaped left curly brace
"\\}": "}", // Escaped right curly brace
"\\$": "$", // Escaped dollar sign
"\\&": "&", // Escaped ampersand
"\\;": ";", // Escaped semicolon
"\\'": "'", // Escaped single quote
"\\\\": "\\", // Escaped backslash
"\\*": "*", // Escaped asterisk
"\\?": "?", // Escaped question mark
}

for escaped, actual := range replacements {
fp = strings.ReplaceAll(fp, escaped, actual)
}
return fp
}

func extractFileNames(input string) (string, []ImageData, error) {
// Regex to match file paths starting with / or ./ and include escaped spaces (\ or %20)
// and followed by more characters and a file extension
regexPattern := `(?:\./|/)[\S\\ ]+?\.(?i:jpg|jpeg|png|svg)\b`
re := regexp.MustCompile(regexPattern)

filePaths := re.FindAllString(input, -1)
var imgs []ImageData

for _, fp := range filePaths {
nfp := normalizeFilePath(fp)
data, err := getImageData(nfp)
if err != nil {
if os.IsNotExist(err) {
continue
}
fmt.Printf("Couldn't process image: %q\n", err)
return "", imgs, err
}
fmt.Printf("Added image '%s'\n", nfp)
input = strings.ReplaceAll(input, fp, "")
imgs = append(imgs, data)
}
return input, imgs, nil
}

func RunServer(cmd *cobra.Command, _ []string) error {
host, port, err := net.SplitHostPort(os.Getenv("OLLAMA_HOST"))
if err != nil {
Expand All @@ -937,6 +1040,50 @@ func RunServer(cmd *cobra.Command, _ []string) error {
return server.Serve(ln, origins)
}

func getImageData(filePath string) ([]byte, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, err
}
defer file.Close()

buf := make([]byte, 512)
_, err = file.Read(buf)
if err != nil {
return nil, err
}

contentType := http.DetectContentType(buf)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/svg+xml", "image/png"}
if !slices.Contains(allowedTypes, contentType) {
return nil, fmt.Errorf("invalid image type: %s", contentType)
}

info, err := file.Stat()
if err != nil {
return nil, err
}

// Check if the file size exceeds 100MB
var maxSize int64 = 100 * 1024 * 1024 // 100MB in bytes
if info.Size() > maxSize {
return nil, fmt.Errorf("file size exceeds maximum limit (100MB).")
}

buf = make([]byte, info.Size())
_, err = file.Seek(0, 0)
if err != nil {
return nil, err
}

_, err = io.ReadFull(file, buf)
if err != nil {
return nil, err
}

return buf, nil
}

func initializeKeypair() error {
home, err := os.UserHomeDir()
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions docs/modelfile.md
Expand Up @@ -150,6 +150,7 @@ PARAMETER <parameter> <parametervalue>
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
| top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |
### TEMPLATE
`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system prompt and a user's prompt. This is used to create a full custom prompt, and syntax may be model specific. You can usually find the template for a given model in the readme for that model.
Expand Down
16 changes: 16 additions & 0 deletions llm/llama.go
Expand Up @@ -223,8 +223,14 @@ type Running struct {
*StatusWriter // captures error messages from the llama runner process
}

type ImageData struct {
Data []byte `json:"data"`
ID int `json:"id"`
}

type llama struct {
api.Options
ImageData []ImageData
Running
}

Expand Down Expand Up @@ -547,6 +553,7 @@ const maxBufferSize = 512 * format.KiloByte
type PredictOpts struct {
Prompt string
Format string
Images []api.ImageData
CheckpointStart time.Time
CheckpointLoaded time.Time
}
Expand All @@ -564,6 +571,14 @@ type PredictResult struct {
}

func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
imageData := llm.ImageData
if len(predict.Images) > 0 {
for cnt, i := range predict.Images {
imageData = append(imageData, ImageData{Data: i, ID: cnt})
}
}
log.Printf("loaded %d images", len(imageData))

request := map[string]any{
"prompt": predict.Prompt,
"stream": true,
Expand All @@ -585,6 +600,7 @@ func (llm *llama) Predict(ctx context.Context, predict PredictOpts, fn func(Pred
"penalize_nl": llm.PenalizeNewline,
"seed": llm.Seed,
"stop": llm.Stop,
"image_data": imageData,
}

if predict.Format == "json" {
Expand Down
3 changes: 3 additions & 0 deletions server/images.go
Expand Up @@ -46,6 +46,7 @@ type Model struct {
System string
License []string
Digest string
Size int64
Options map[string]interface{}
}

Expand Down Expand Up @@ -242,6 +243,7 @@ func GetModel(name string) (*Model, error) {
Digest: digest,
Template: "{{ .Prompt }}",
License: []string{},
Size: manifest.GetTotalSize(),
}

filename, err := GetBlobsPath(manifest.Config.Digest)
Expand Down Expand Up @@ -545,6 +547,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
}

// xxx - can this be removed?
if config.ModelType == "65B" {
if gqa, ok := formattedParams["gqa"].(int); ok && gqa == 8 {
config.ModelType = "70B"
Expand Down

0 comments on commit 910e940

Please sign in to comment.