Skip to content

Commit

Permalink
feat: add connect stream package to create server stream
Browse files Browse the repository at this point in the history
  • Loading branch information
pyadav committed Feb 1, 2024
1 parent e30ed0b commit 52532c8
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 14 deletions.
1 change: 1 addition & 0 deletions mobius/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ require (
connectrpc.com/vanguard v0.1.0
github.com/MakeNowJust/heredoc v1.0.0
github.com/go-playground/validator/v10 v10.17.0
github.com/google/go-cmp v0.6.0
github.com/mcuadros/go-defaults v1.2.0
github.com/missingstudio/studio/common v0.0.0-00010101000000-000000000000
github.com/missingstudio/studio/protos v0.0.0-00010101000000-000000000000
Expand Down
75 changes: 75 additions & 0 deletions mobius/internal/mock/mock_stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package mock

import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
)

type MockStream[T any] struct {
mu sync.Mutex
ctx context.Context
ch chan *T
closed bool
counter *uint32
Messages []*T
MessageMap map[string]int
}

func NewMockStream[T any](t *testing.T) *MockStream[T] {
t.Helper()
var counter uint32
return &MockStream[T]{
ctx: context.Background(),
ch: make(chan *T),
closed: false,
counter: &counter,
Messages: make([]*T, 0),
MessageMap: make(map[string]int),
}
}

func (m *MockStream[T]) Run() error {
for {
select {
case data, ok := <-m.ch:
if !ok {
return fmt.Errorf("stream closed")
}
atomic.AddUint32(m.counter, 1)
m.Messages = append(m.Messages, data)
case <-m.ctx.Done():
return m.ctx.Err()
}
}
}

func (m *MockStream[T]) GetChannel() chan *T {
return m.ch
}

func (m *MockStream[T]) Send(data *T) {
m.mu.Lock()
defer m.mu.Unlock()
if !m.closed {
m.ch <- data
}
}

func (m *MockStream[T]) Close() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.closed {
close(m.ch)
}
m.closed = true
}

// SetCounter sets the counter used to count the number of messages sent.
// Multiple streams can share the same counter to count the total number of
// messages sent across all streams.
func (m *MockStream[T]) SetCounter(counter *uint32) {
m.counter = counter
}
79 changes: 79 additions & 0 deletions mobius/internal/stream/stream.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package stream

import (
"context"
"fmt"
"sync"

"connectrpc.com/connect"
)

type StreamInterface[T any] interface {
Send(data *T)
Run() error
Close()
}

// Stream wraps a connect.ServerStream.
type Stream[T any] struct {
mu sync.Mutex
// stream is the underlying connect stream
// that does the actual transfer of data
// between the server and a client
stream *connect.ServerStream[T]
// context is the context of the stream
ctx context.Context
// The channel that we listen to for any
// new data that we need to send to the client.
ch chan *T
// closed is a flag that indicates whether
// the stream has been closed.
closed bool
}

// newStream creates a new stream.
func NewStream[T any](ctx context.Context, st *connect.ServerStream[T]) *Stream[T] {
return &Stream[T]{
stream: st,
ctx: ctx,
ch: make(chan *T),
}
}

// Close closes the stream.
func (s *Stream[T]) Close() {
s.mu.Lock()
defer s.mu.Unlock()
if !s.closed {
close(s.ch)
}
s.closed = true
}

// Run runs the stream.
// Run will block until the stream is closed.
func (s *Stream[T]) Run() error {
defer s.Close()
for {
select {
case <-s.ctx.Done():
return s.ctx.Err()
case data, ok := <-s.ch:
if !ok {
return connect.NewError(connect.CodeCanceled, fmt.Errorf("stream closed"))
}
if err := s.stream.Send(data); err != nil {
return err
}
}
}
}

// Send sends data to this stream's connected client.
func (s *Stream[T]) Send(data *T) {
s.mu.Lock()
defer s.mu.Unlock()
if !s.closed {
s.ch <- data
}
}
66 changes: 66 additions & 0 deletions mobius/internal/stream/stream_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package stream

import (
"fmt"
"sync"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/missingstudio/studio/backend/internal/mock"
)

type Data struct {
Msg string
}

var messages = []*Data{
{Msg: "Hello"},
{Msg: "World"},
{Msg: "Foo"},
{Msg: "Bar"},
{Msg: "Gandalf"},
{Msg: "Frodo"},
{Msg: "Bilbo"},
{Msg: "Radagast"},
{Msg: "Sauron"},
{Msg: "Gollum"},
}

func TestStream(t *testing.T) {
var counter uint32
stream := mock.NewMockStream[Data](t)
stream.SetCounter(&counter)
wg := sync.WaitGroup{}
wg.Add(1)

go func() {
defer wg.Done()
err := stream.Run()
t.Log(err)
}()

for _, data := range messages {
stream.Send(data)
}

stream.Close()
wg.Wait()

// A total of 10 messages should have been sent.
if counter != 10 {
fmt.Println(counter)
t.Errorf("expected 10, got %d", counter)
}

msgMsp := make(map[string]int)
for _, data := range stream.Messages {
msgMsp[data.Msg]++
}

if len(stream.Messages) != 10 {
t.Errorf("expected 10 messages, got %d", len(stream.Messages))
}
if diff := cmp.Diff(messages, stream.Messages); diff != "" {
t.Errorf("expected %v, got %v: %s", messages, stream.Messages, diff)
}
}
28 changes: 15 additions & 13 deletions protos/pkg/llm/service.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion protos/proto/llm/service.proto
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ service LLMService {
option (google.api.http).post = "/v1/chat/completions";
option (google.api.http).body = "*";
}
rpc StreamChatCompletions(CompletionRequest) returns (stream CompletionResponse) {}
rpc StreamChatCompletions(CompletionRequest) returns (stream CompletionResponse) {
option (google.api.http).post = "/v1/chat/completions:stream";
option (google.api.http).body = "*";
}
}

enum FinishReason {
Expand Down

0 comments on commit 52532c8

Please sign in to comment.