diff --git a/internal/native/grpc_client.go b/internal/native/grpc_client.go index 300a22848..85a3201bd 100644 --- a/internal/native/grpc_client.go +++ b/internal/native/grpc_client.go @@ -79,6 +79,18 @@ func NewGRPCClient(opts grpcClientOptions) (*GRPCClient, error) { // Start event stream go grpcClient.startEventStream() + // Start event handler to process events from the channel + go func() { + for { + select { + case event := <-grpcClient.eventCh: + grpcClient.handleEvent(event) + case <-grpcClient.eventDone: + return + } + } + }() + return grpcClient, nil } @@ -234,20 +246,6 @@ func (c *GRPCClient) handleEvent(event *pb.Event) { } } -// OnEvent registers an event handler -func (c *GRPCClient) OnEvent(eventType string, handler func(data interface{})) { - go func() { - for { - select { - case event := <-c.eventCh: - c.handleEvent(event) - case <-c.eventDone: - return - } - } - }() -} - // Close closes the gRPC client func (c *GRPCClient) Close() error { c.closeM.Lock() diff --git a/internal/native/grpc_server.go b/internal/native/grpc_server.go index 304203ced..9b54fb5b7 100644 --- a/internal/native/grpc_server.go +++ b/internal/native/grpc_server.go @@ -15,18 +15,20 @@ import ( // grpcServer wraps the Native instance and implements the gRPC service type grpcServer struct { pb.UnimplementedNativeServiceServer - native *Native - logger *zerolog.Logger - eventChs []chan *pb.Event - eventM sync.Mutex + native *Native + logger *zerolog.Logger + eventStreamChan chan *pb.Event + eventStreamMu sync.Mutex + eventStreamCtx context.Context + eventStreamCancel context.CancelFunc } // NewGRPCServer creates a new gRPC server for the native service func NewGRPCServer(n *Native, logger *zerolog.Logger) *grpcServer { s := &grpcServer{ - native: n, - logger: logger, - eventChs: make([]chan *pb.Event, 0), + native: n, + logger: logger, + eventStreamChan: make(chan *pb.Event, 100), } // Store original callbacks and wrap them to also broadcast events @@ -82,16 +84,7 @@ func NewGRPCServer(n *Native, logger *zerolog.Logger) *grpcServer { } func (s *grpcServer) broadcastEvent(event *pb.Event) { - s.eventM.Lock() - defer s.eventM.Unlock() - - for _, ch := range s.eventChs { - select { - case ch <- event: - default: - // Channel full, skip - } - } + s.eventStreamChan <- event } func (s *grpcServer) IsReady(ctx context.Context, req *pb.IsReadyRequest) (*pb.IsReadyResponse, error) { @@ -103,35 +96,49 @@ func (s *grpcServer) StreamEvents(req *pb.Empty, stream pb.NativeService_StreamE setProcTitle("connected") defer setProcTitle("waiting") - eventCh := make(chan *pb.Event, 100) + // Cancel previous stream if exists + s.eventStreamMu.Lock() + if s.eventStreamCancel != nil { + s.logger.Debug().Msg("cancelling previous StreamEvents call") + s.eventStreamCancel() + } - // Register this channel for events - s.eventM.Lock() - s.eventChs = append(s.eventChs, eventCh) - s.eventM.Unlock() + // Create a cancellable context for this stream + ctx, cancel := context.WithCancel(stream.Context()) + s.eventStreamCtx = ctx + s.eventStreamCancel = cancel + s.eventStreamMu.Unlock() - // Unregister on exit + // Clean up when this stream ends defer func() { - s.eventM.Lock() - defer s.eventM.Unlock() - for i, ch := range s.eventChs { - if ch == eventCh { - s.eventChs = append(s.eventChs[:i], s.eventChs[i+1:]...) - break - } + s.eventStreamMu.Lock() + defer s.eventStreamMu.Unlock() + if s.eventStreamCtx == ctx { + s.eventStreamCancel = nil + s.eventStreamCtx = nil } - close(eventCh) + cancel() }() // Stream events for { select { - case event := <-eventCh: + case event := <-s.eventStreamChan: + // Check if this stream is still the active one + s.eventStreamMu.Lock() + isActive := s.eventStreamCtx == ctx + s.eventStreamMu.Unlock() + + if !isActive { + s.logger.Debug().Msg("stream replaced by new call, exiting") + return context.Canceled + } + if err := stream.Send(event); err != nil { return err } - case <-stream.Context().Done(): - return stream.Context().Err() + case <-ctx.Done(): + return ctx.Err() } } } diff --git a/main.go b/main.go index a4d80fb74..88d2dec7c 100644 --- a/main.go +++ b/main.go @@ -70,9 +70,6 @@ func Main() { initOta() - initNative(systemVersionLocal, appVersionLocal) - initDisplay() - http.DefaultClient.Timeout = 1 * time.Minute // Initialize network