Skip to content

Commit

Permalink
Rewrite the protocol to be pure-JSON.
Browse files Browse the repository at this point in the history
Among other things, this lets use json.Encoder and Decoder directly, for
less copying and no arbitrary line-length limits.
  • Loading branch information
nelhage committed Apr 6, 2014
1 parent 5c0f4ba commit 02ede8f
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 119 deletions.
96 changes: 43 additions & 53 deletions client/client.go
@@ -1,15 +1,23 @@
package client

import (
"bufio"
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/nelhage/livegrep/jsonframe"
"net"
"strings"
)

var ops jsonframe.Marshaler

func init() {
ops.Register(new(Result))
ops.Register(new(ReplyError))
ops.Register(new(ServerInfo))
ops.Register(new(Stats))
ops.Register(new(Query))
}

type client struct {
conn net.Conn
queries chan *search
Expand Down Expand Up @@ -86,86 +94,68 @@ func (c *client) Close() {
func (c *client) loop() {
defer c.conn.Close()
defer close(c.errors)
scan := bufio.NewScanner(c.conn)
encoder := json.NewEncoder(c.conn)
decoder := json.NewDecoder(c.conn)

for {
if !scan.Scan() {
e := scan.Err()
if e == nil {
e = errors.New("connection closed unexpectedly")
}
c.errors <- e
return
}
if !bytes.HasPrefix(scan.Bytes(), []byte("READY ")) {
c.errors <- fmt.Errorf("Expected READY, got: %s", scan.Text())
return
}

info := &ServerInfo{}
if err := json.Unmarshal(scan.Bytes()[len("READY "):], &info); err != nil {
op, err := ops.Decode(decoder)
if err != nil {
c.errors <- err
return
}

select {
case c.ready <- info:
default:
if info, ok := op.(*ServerInfo); !ok {
c.errors <- fmt.Errorf("Expected op: '%s', got: %s",
new(ServerInfo).Opcode(), op.Opcode())
return
} else {
select {
case c.ready <- info:
default:
}
}

q, ok := <-c.queries
if !ok {
break
}
if e := encoder.Encode(q.query); e != nil {
if e := ops.Encode(encoder, q.query); e != nil {
q.errors <- e
close(q.errors)
close(q.results)
close(q.stats)
continue
}
done := false
for scan.Scan() {
line := scan.Text()
if strings.HasPrefix(line, "FATAL ") {
q.errors <- QueryError{q.query, strings.TrimPrefix(line, "FATAL ")}
done = true
ResultLoop:
for {
op, err = ops.Decode(decoder)
if err != nil {
break
} else if strings.HasPrefix(line, "DONE ") {
stats := &Stats{}
if e := json.Unmarshal(scan.Bytes()[len("DONE "):], stats); e != nil {
q.errors <- e
} else {
q.stats <- stats
}
}
switch concrete := op.(type) {
case *ReplyError:
q.errors <- QueryError{q.query, string(*concrete)}
done = true
break
} else {
r := &Result{}
if e := json.Unmarshal(scan.Bytes(), r); e != nil {
q.errors <- e
break
}
q.results <- r
break ResultLoop
case *Stats:
q.stats <- concrete
done = true
break ResultLoop
case *Result:
q.results <- concrete
}
}

if !done {
e := scan.Err()
if e == nil {
e = errors.New("connection closed unexpectedly")
}
q.errors <- e
if err != nil {
q.errors <- err
} else if !done {
q.errors <- errors.New("connection closed unexpectedly")
}

close(q.errors)
close(q.results)
close(q.stats)
}
if e := scan.Err(); e != nil {
c.errors <- e
}
}

func (s *search) Results() <-chan *Result {
Expand Down
104 changes: 56 additions & 48 deletions client/client_test.go
@@ -1,56 +1,63 @@
package client_test
package client

import (
"encoding/json"
"github.com/nelhage/livegrep/client"
"io"
. "launchpad.net/gocheck"
"launchpad.net/gocheck"
"net"
"strings"
"testing"
)

func Test(t *testing.T) { TestingT(t) }
func Test(t *testing.T) { gocheck.TestingT(t) }

type ClientSuite struct {
client client.Client
client Client
}

func (c *ClientSuite) TearDownTest(*C) {
func (c *ClientSuite) TearDownTest(*gocheck.C) {
if c.client != nil {
c.client.Close()
}
}

var _ = Suite(&ClientSuite{})
var _ = gocheck.Suite(&ClientSuite{})

var (
Equals = gocheck.Equals
IsNil = gocheck.IsNil
Not = gocheck.Not
)

type MockServer struct {
Info *client.ServerInfo
Results []*client.Result
Info *ServerInfo
Results []*Result
}

func (m *MockServer) handle(conn net.Conn) {
defer conn.Close()

encoder := json.NewEncoder(conn)
reader := json.NewDecoder(conn)
decoder := json.NewDecoder(conn)
for {
io.WriteString(conn, "READY ")
encoder.Encode(m.Info)
ops.Encode(encoder, m.Info)

var q client.Query
if err := reader.Decode(&q); err != nil {
if op, err := ops.Decode(decoder); err != nil {
if err == io.EOF {
return
}
panic(err.Error())
} else {
if op.(*Query) == nil {
panic("expected query")
}
}

for _, r := range m.Results {
encoder.Encode(r)
ops.Encode(encoder, r)
}
io.WriteString(conn, "DONE ")
encoder.Encode(&client.Stats{})

ops.Encode(encoder, &Stats{})
}
}

Expand All @@ -72,46 +79,46 @@ func runMockServer(handle func(net.Conn)) <-chan string {
return ready
}

func (s *ClientSuite) connect(c *C, addr string) {
func (s *ClientSuite) connect(c *gocheck.C, addr string) {
var err error
s.client, err = client.Dial("tcp", addr)
s.client, err = Dial("tcp", addr)
if err != nil {
c.Fatalf("connecting to %s: %s", addr, err.Error())
}
}

func (s *ClientSuite) TestQuery(c *C) {
func (s *ClientSuite) TestQuery(c *gocheck.C) {
s.connect(c, <-runMockServer((&MockServer{
Results: []*client.Result{
Results: []*Result{
{Line: "match line 1"},
},
}).handle))
search, err := s.client.Query(&client.Query{".", "", ""})
c.Assert(err, IsNil)
search, err := s.client.Query(&Query{".", "", ""})
c.Assert(err, gocheck.IsNil)
var n int
for r := range search.Results() {
n++
c.Assert(r.Line, Not(Equals), "")
c.Assert(r.Line, gocheck.Not(Equals), "")
}
c.Assert(n, Equals, 1)
st, e := search.Close()
c.Assert(e, IsNil)
c.Assert(st, Not(IsNil))
}

func (s *ClientSuite) TestTwoQueries(c *C) {
func (s *ClientSuite) TestTwoQueries(c *gocheck.C) {
s.connect(c, <-runMockServer((&MockServer{
Results: []*client.Result{
Results: []*Result{
{Line: "match line 1"},
},
}).handle))

search, err := s.client.Query(&client.Query{".", "", ""})
search, err := s.client.Query(&Query{".", "", ""})
c.Assert(err, IsNil)
_, err = search.Close()
c.Assert(err, IsNil)

search, err = s.client.Query(&client.Query{".", "", ""})
search, err = s.client.Query(&Query{".", "", ""})
c.Assert(err, IsNil)
n := 0
for _ = range search.Results() {
Expand All @@ -124,16 +131,16 @@ func (s *ClientSuite) TestTwoQueries(c *C) {
c.Assert(n, Not(Equals), 0)
}

func (s *ClientSuite) TestLongLine(c *C) {
func (s *ClientSuite) TestLongLine(c *gocheck.C) {
longLine := strings.Repeat("X", 1<<16)
s.connect(c, <-runMockServer((&MockServer{
Results: []*client.Result{
Results: []*Result{
{Line: longLine},
},
}).handle))
search, err := s.client.Query(&client.Query{".", "", ""})
search, err := s.client.Query(&Query{".", "", ""})
c.Assert(err, IsNil)
var rs []*client.Result
var rs []*Result
for r := range search.Results() {
rs = append(rs, r)
}
Expand All @@ -143,39 +150,40 @@ func (s *ClientSuite) TestLongLine(c *C) {
}

type MockServerQueryError struct {
Info *client.ServerInfo
Info *ServerInfo
Err string
}

func (m *MockServerQueryError) handle(conn net.Conn) {
defer conn.Close()
encoder := json.NewEncoder(conn)
reader := json.NewDecoder(conn)
decoder := json.NewDecoder(conn)
for {
io.WriteString(conn, "READY ")
encoder.Encode(m.Info)
ops.Encode(encoder, m.Info)

var q client.Query
if err := reader.Decode(&q); err != nil {
if op, err := ops.Decode(decoder); err != nil {
if err == io.EOF {
return
}
panic(err.Error())
} else {
if op.(*Query) == nil {
panic("expected query")
}
}

io.WriteString(conn, "FATAL ")
io.WriteString(conn, m.Err)
io.WriteString(conn, "\n")
re := ReplyError(m.Err)
ops.Encode(encoder, &re)
}
}

func (s *ClientSuite) TestBadRegex(c *C) {
func (s *ClientSuite) TestBadRegex(c *gocheck.C) {
errStr := "Invalid query: ("
s.connect(c, <-runMockServer((&MockServerQueryError{
Err: errStr,
}).handle))

search, err := s.client.Query(&client.Query{"(", "", ""})
search, err := s.client.Query(&Query{"(", "", ""})
c.Assert(err, IsNil)
for _ = range search.Results() {
c.Fatal("Got back a result from an erroneous query!")
Expand All @@ -185,7 +193,7 @@ func (s *ClientSuite) TestBadRegex(c *C) {
if e == nil {
c.Fatal("Didn't get back an error")
}
if q, ok := e.(client.QueryError); ok {
if q, ok := e.(QueryError); ok {
c.Assert(q.Query.Line, Equals, "(")
c.Assert(q.Err, Equals, errStr)
} else {
Expand All @@ -195,17 +203,17 @@ func (s *ClientSuite) TestBadRegex(c *C) {

func mockServerShutdown() <-chan string {
return runMockServer(func(conn net.Conn) {
conn.Write([]byte("READY {}\n"))
ops.Encode(json.NewEncoder(conn), &ServerInfo{})
conn.Close()
})
}

func (s *ClientSuite) TestShutdown(c *C) {
func (s *ClientSuite) TestShutdown(c *gocheck.C) {
addr := <-mockServerShutdown()

s.connect(c, addr)

search, err := s.client.Query(&client.Query{Line: "l"})
search, err := s.client.Query(&Query{Line: "l"})
c.Assert(err, IsNil)
c.Assert(search, Not(IsNil))

Expand All @@ -218,7 +226,7 @@ func (s *ClientSuite) TestShutdown(c *C) {
c.Assert(st, IsNil)
c.Assert(err, Not(IsNil))

search, err = s.client.Query(&client.Query{Line: "l"})
search, err = s.client.Query(&Query{Line: "l"})
c.Assert(err, Not(IsNil))
c.Assert(search, IsNil)
}

0 comments on commit 02ede8f

Please sign in to comment.