Skip to content
This repository has been archived by the owner on Jan 26, 2022. It is now read-only.

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
vito committed Jul 21, 2013
0 parents commit ec48c66
Show file tree
Hide file tree
Showing 23 changed files with 2,659 additions and 0 deletions.
316 changes: 316 additions & 0 deletions connection.go
@@ -0,0 +1,316 @@
package warden

import (
"bufio"
"code.google.com/p/goprotobuf/proto"
"errors"
"fmt"
"net"
"strconv"
"sync"
)

type Connection struct {
conn net.Conn
read *bufio.Reader
writeLock sync.Mutex

disconnected chan bool
}

type WardenError struct {
Message string
Data string
Backtrace []string
}

func (e *WardenError) Error() string {
return e.Message
}

func Connect(socket_path string) (*Connection, error) {
conn, err := net.Dial("unix", socket_path)
if err != nil {
return nil, err
}

return &Connection{
conn: conn,
read: bufio.NewReader(conn),
}, nil
}

func (c *Connection) Close() {
c.conn.Close()
}

func (c *Connection) Create() (*CreateResponse, error) {
res, err := c.roundTrip(&CreateRequest{}, &CreateResponse{})
if err != nil {
return nil, err
}

return res.(*CreateResponse), nil
}

func (c *Connection) Destroy(handle string) (*DestroyResponse, error) {
res, err := c.roundTrip(
&DestroyRequest{Handle: proto.String(handle)},
&DestroyResponse{},
)

if err != nil {
return nil, err
}

return res.(*DestroyResponse), nil
}

func (c *Connection) Spawn(handle, script string) (*SpawnResponse, error) {
res, err := c.roundTrip(
&SpawnRequest{
Handle: proto.String(handle),
Script: proto.String(script),
},
&SpawnResponse{},
)

if err != nil {
return nil, err
}

return res.(*SpawnResponse), nil
}

func (c *Connection) Run(handle, script string) (*RunResponse, error) {
res, err := c.roundTrip(
&RunRequest{
Handle: proto.String(handle),
Script: proto.String(script),
},
&RunResponse{},
)

if err != nil {
return nil, err
}

return res.(*RunResponse), nil
}

func (c *Connection) Stream(handle string, jobId uint32) (chan *StreamResponse, error) {
err := c.sendMessage(
&StreamRequest{
Handle: proto.String(handle),
JobId: proto.Uint32(jobId),
},
)

if err != nil {
return nil, err
}

responses := make(chan *StreamResponse)

go func() {
for {
resMsg, err := c.readResponse(&StreamResponse{})
if err != nil {
close(responses)
break
}

response := resMsg.(*StreamResponse)

responses <- response

if response.ExitStatus != nil {
close(responses)
break
}
}
}()

return responses, nil
}

func (c *Connection) roundTrip(request proto.Message, response proto.Message) (proto.Message, error) {
err := c.sendMessage(request)
if err != nil {
return nil, err
}

resp, err := c.readResponse(response)
if err != nil {
return nil, err
}

return resp, nil
}

func (c *Connection) sendMessage(req proto.Message) error {
c.writeLock.Lock()
defer c.writeLock.Unlock()

request, err := proto.Marshal(req)
if err != nil {
return err
}

msg := &Message{
Type: Message_Type(message2type(req)).Enum(),
Payload: request,
}

data, err := proto.Marshal(msg)
if err != nil {
return err
}

_, err = c.conn.Write(
[]byte(
fmt.Sprintf(
"%d\r\n%s\r\n",
len(data),
data,
),
),
)

if err != nil {
c.disconnected <- true
return err
}

return nil
}

func (c *Connection) readResponse(response proto.Message) (proto.Message, error) {
payload, err := c.readPayload()
if err != nil {
fmt.Println("error reading payload: ", err)
c.disconnected <- true
return nil, err
}

message := &Message{}
err = proto.Unmarshal(payload, message)
if err != nil {
println("failed to unmarshal message :(")
return nil, err
}

// error response
if message.GetType() == Message_Type(1) {
errorResponse := &ErrorResponse{}
err = proto.Unmarshal(message.Payload, errorResponse)
if err != nil {
return nil, errors.New("error unmarshalling error!")
}

return nil, &WardenError{
Message: errorResponse.GetMessage(),
Data: errorResponse.GetData(),
Backtrace: errorResponse.GetBacktrace(),
}
}

if message.GetType() != Message_Type(message2type(response)) {
fmt.Printf("expected %d, got %d\n", message2type(response), *message.Type)
return nil, errors.New("response message type mismatch!")
}

err = proto.Unmarshal(message.GetPayload(), response)
return response, err
}

func (c *Connection) readPayload() ([]byte, error) {
msgHeader, err := c.read.ReadBytes('\n')
if err != nil {
return nil, err
}

msgLen, err := strconv.ParseUint(string(msgHeader[0:len(msgHeader)-2]), 10, 0)
if err != nil {
return nil, err
}

payload, err := readNBytes(int(msgLen), c.read)
if err != nil {
return nil, err
}

_, err = readNBytes(2, c.read) // CRLN
if err != nil {
return nil, err
}

return payload, err
}

func message2type(msg proto.Message) int32 {
switch msg.(type) {
case *ErrorResponse:
return 1

case *CreateRequest, *CreateResponse:
return 11
case *StopRequest, *StopResponse:
return 12
case *DestroyRequest, *DestroyResponse:
return 13
case *InfoRequest, *InfoResponse:
return 13

case *SpawnRequest, *SpawnResponse:
return 21
case *LinkRequest, *LinkResponse:
return 22
case *RunRequest, *RunResponse:
return 23
case *StreamRequest, *StreamResponse:
return 24

case *NetInRequest, *NetInResponse:
return 31
case *NetOutRequest, *NetOutResponse:
return 32

case *CopyInRequest, *CopyInResponse:
return 41
case *CopyOutRequest, *CopyOutResponse:
return 42

case *LimitMemoryRequest, *LimitMemoryResponse:
return 51
case *LimitDiskRequest, *LimitDiskResponse:
return 52
case *LimitBandwidthRequest, *LimitBandwidthResponse:
return 53

case *PingRequest, *PingResponse:
return 91
case *ListRequest, *ListResponse:
return 92
case *EchoRequest, *EchoResponse:
return 93
}

fmt.Printf("wat?!?! %#v\n", msg)
panic("unknown message type")
}

func readNBytes(payloadLen int, io *bufio.Reader) ([]byte, error) {
payload := make([]byte, payloadLen)

for readCount := 0; readCount < payloadLen; {
n, err := io.Read(payload[readCount:])
if err != nil {
return nil, err
}

readCount += n
}

return payload, nil
}

0 comments on commit ec48c66

Please sign in to comment.