Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(diffusers): be consistent with pipelines, support also depthimg2img #926

Merged
merged 1 commit into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 14 additions & 21 deletions api/openai/image.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
package openai

import (
"bufio"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"os"
"path"
"path/filepath"
"strconv"
"strings"
Expand Down Expand Up @@ -52,34 +51,28 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
}

src := ""
// retrieve the file data from the request
file, err := c.FormFile("src")
if err == nil {

f, err := file.Open()
if input.File != "" {
//base 64 decode the file and write it somewhere
// that we will cleanup
decoded, err := base64.StdEncoding.DecodeString(input.File)
if err != nil {
return err
}
defer f.Close()

dir, err := os.MkdirTemp("", "img2img")

// Create a temporary file
outputFile, err := os.CreateTemp(o.ImageDir, "b64")
if err != nil {
return err
}
defer os.RemoveAll(dir)

dst := filepath.Join(dir, path.Base(file.Filename))
dstFile, err := os.Create(dst)
// write the base64 result
writer := bufio.NewWriter(outputFile)
_, err = writer.Write(decoded)
if err != nil {
outputFile.Close()
return err
}

if _, err := io.Copy(dstFile, f); err != nil {
log.Debug().Msgf("Image file copying error %+v - %+v - err %+v", file.Filename, dst, err)
return err
}
src = dst
outputFile.Close()
src = outputFile.Name()
defer os.RemoveAll(src)
}

log.Debug().Msgf("Parameter Config: %+v", config)
Expand Down
18 changes: 15 additions & 3 deletions extra/grpc/autogptq/backend_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

66 changes: 66 additions & 0 deletions extra/grpc/autogptq/backend_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ def __init__(self, channel):
request_serializer=backend__pb2.TTSRequest.SerializeToString,
response_deserializer=backend__pb2.Result.FromString,
)
self.TokenizeString = channel.unary_unary(
'/backend.Backend/TokenizeString',
request_serializer=backend__pb2.PredictOptions.SerializeToString,
response_deserializer=backend__pb2.TokenizationResponse.FromString,
)
self.Status = channel.unary_unary(
'/backend.Backend/Status',
request_serializer=backend__pb2.HealthMessage.SerializeToString,
response_deserializer=backend__pb2.StatusResponse.FromString,
)


class BackendServicer(object):
Expand Down Expand Up @@ -107,6 +117,18 @@ def TTS(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def TokenizeString(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def Status(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_BackendServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand Down Expand Up @@ -150,6 +172,16 @@ def add_BackendServicer_to_server(servicer, server):
request_deserializer=backend__pb2.TTSRequest.FromString,
response_serializer=backend__pb2.Result.SerializeToString,
),
'TokenizeString': grpc.unary_unary_rpc_method_handler(
servicer.TokenizeString,
request_deserializer=backend__pb2.PredictOptions.FromString,
response_serializer=backend__pb2.TokenizationResponse.SerializeToString,
),
'Status': grpc.unary_unary_rpc_method_handler(
servicer.Status,
request_deserializer=backend__pb2.HealthMessage.FromString,
response_serializer=backend__pb2.StatusResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'backend.Backend', rpc_method_handlers)
Expand Down Expand Up @@ -295,3 +327,37 @@ def TTS(request,
backend__pb2.Result.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def TokenizeString(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/TokenizeString',
backend__pb2.PredictOptions.SerializeToString,
backend__pb2.TokenizationResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def Status(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/backend.Backend/Status',
backend__pb2.HealthMessage.SerializeToString,
backend__pb2.StatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)