This repository has been archived by the owner on Jan 26, 2022. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit ec48c66
Showing
23 changed files
with
2,659 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,316 @@ | ||
package warden | ||
|
||
import ( | ||
"bufio" | ||
"code.google.com/p/goprotobuf/proto" | ||
"errors" | ||
"fmt" | ||
"net" | ||
"strconv" | ||
"sync" | ||
) | ||
|
||
type Connection struct { | ||
conn net.Conn | ||
read *bufio.Reader | ||
writeLock sync.Mutex | ||
|
||
disconnected chan bool | ||
} | ||
|
||
type WardenError struct { | ||
Message string | ||
Data string | ||
Backtrace []string | ||
} | ||
|
||
func (e *WardenError) Error() string { | ||
return e.Message | ||
} | ||
|
||
func Connect(socket_path string) (*Connection, error) { | ||
conn, err := net.Dial("unix", socket_path) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return &Connection{ | ||
conn: conn, | ||
read: bufio.NewReader(conn), | ||
}, nil | ||
} | ||
|
||
func (c *Connection) Close() { | ||
c.conn.Close() | ||
} | ||
|
||
func (c *Connection) Create() (*CreateResponse, error) { | ||
res, err := c.roundTrip(&CreateRequest{}, &CreateResponse{}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return res.(*CreateResponse), nil | ||
} | ||
|
||
func (c *Connection) Destroy(handle string) (*DestroyResponse, error) { | ||
res, err := c.roundTrip( | ||
&DestroyRequest{Handle: proto.String(handle)}, | ||
&DestroyResponse{}, | ||
) | ||
|
||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return res.(*DestroyResponse), nil | ||
} | ||
|
||
func (c *Connection) Spawn(handle, script string) (*SpawnResponse, error) { | ||
res, err := c.roundTrip( | ||
&SpawnRequest{ | ||
Handle: proto.String(handle), | ||
Script: proto.String(script), | ||
}, | ||
&SpawnResponse{}, | ||
) | ||
|
||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return res.(*SpawnResponse), nil | ||
} | ||
|
||
func (c *Connection) Run(handle, script string) (*RunResponse, error) { | ||
res, err := c.roundTrip( | ||
&RunRequest{ | ||
Handle: proto.String(handle), | ||
Script: proto.String(script), | ||
}, | ||
&RunResponse{}, | ||
) | ||
|
||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return res.(*RunResponse), nil | ||
} | ||
|
||
func (c *Connection) Stream(handle string, jobId uint32) (chan *StreamResponse, error) { | ||
err := c.sendMessage( | ||
&StreamRequest{ | ||
Handle: proto.String(handle), | ||
JobId: proto.Uint32(jobId), | ||
}, | ||
) | ||
|
||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
responses := make(chan *StreamResponse) | ||
|
||
go func() { | ||
for { | ||
resMsg, err := c.readResponse(&StreamResponse{}) | ||
if err != nil { | ||
close(responses) | ||
break | ||
} | ||
|
||
response := resMsg.(*StreamResponse) | ||
|
||
responses <- response | ||
|
||
if response.ExitStatus != nil { | ||
close(responses) | ||
break | ||
} | ||
} | ||
}() | ||
|
||
return responses, nil | ||
} | ||
|
||
func (c *Connection) roundTrip(request proto.Message, response proto.Message) (proto.Message, error) { | ||
err := c.sendMessage(request) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
resp, err := c.readResponse(response) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return resp, nil | ||
} | ||
|
||
func (c *Connection) sendMessage(req proto.Message) error { | ||
c.writeLock.Lock() | ||
defer c.writeLock.Unlock() | ||
|
||
request, err := proto.Marshal(req) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
msg := &Message{ | ||
Type: Message_Type(message2type(req)).Enum(), | ||
Payload: request, | ||
} | ||
|
||
data, err := proto.Marshal(msg) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
_, err = c.conn.Write( | ||
[]byte( | ||
fmt.Sprintf( | ||
"%d\r\n%s\r\n", | ||
len(data), | ||
data, | ||
), | ||
), | ||
) | ||
|
||
if err != nil { | ||
c.disconnected <- true | ||
return err | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func (c *Connection) readResponse(response proto.Message) (proto.Message, error) { | ||
payload, err := c.readPayload() | ||
if err != nil { | ||
fmt.Println("error reading payload: ", err) | ||
c.disconnected <- true | ||
return nil, err | ||
} | ||
|
||
message := &Message{} | ||
err = proto.Unmarshal(payload, message) | ||
if err != nil { | ||
println("failed to unmarshal message :(") | ||
return nil, err | ||
} | ||
|
||
// error response | ||
if message.GetType() == Message_Type(1) { | ||
errorResponse := &ErrorResponse{} | ||
err = proto.Unmarshal(message.Payload, errorResponse) | ||
if err != nil { | ||
return nil, errors.New("error unmarshalling error!") | ||
} | ||
|
||
return nil, &WardenError{ | ||
Message: errorResponse.GetMessage(), | ||
Data: errorResponse.GetData(), | ||
Backtrace: errorResponse.GetBacktrace(), | ||
} | ||
} | ||
|
||
if message.GetType() != Message_Type(message2type(response)) { | ||
fmt.Printf("expected %d, got %d\n", message2type(response), *message.Type) | ||
return nil, errors.New("response message type mismatch!") | ||
} | ||
|
||
err = proto.Unmarshal(message.GetPayload(), response) | ||
return response, err | ||
} | ||
|
||
func (c *Connection) readPayload() ([]byte, error) { | ||
msgHeader, err := c.read.ReadBytes('\n') | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
msgLen, err := strconv.ParseUint(string(msgHeader[0:len(msgHeader)-2]), 10, 0) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
payload, err := readNBytes(int(msgLen), c.read) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
_, err = readNBytes(2, c.read) // CRLN | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return payload, err | ||
} | ||
|
||
func message2type(msg proto.Message) int32 { | ||
switch msg.(type) { | ||
case *ErrorResponse: | ||
return 1 | ||
|
||
case *CreateRequest, *CreateResponse: | ||
return 11 | ||
case *StopRequest, *StopResponse: | ||
return 12 | ||
case *DestroyRequest, *DestroyResponse: | ||
return 13 | ||
case *InfoRequest, *InfoResponse: | ||
return 13 | ||
|
||
case *SpawnRequest, *SpawnResponse: | ||
return 21 | ||
case *LinkRequest, *LinkResponse: | ||
return 22 | ||
case *RunRequest, *RunResponse: | ||
return 23 | ||
case *StreamRequest, *StreamResponse: | ||
return 24 | ||
|
||
case *NetInRequest, *NetInResponse: | ||
return 31 | ||
case *NetOutRequest, *NetOutResponse: | ||
return 32 | ||
|
||
case *CopyInRequest, *CopyInResponse: | ||
return 41 | ||
case *CopyOutRequest, *CopyOutResponse: | ||
return 42 | ||
|
||
case *LimitMemoryRequest, *LimitMemoryResponse: | ||
return 51 | ||
case *LimitDiskRequest, *LimitDiskResponse: | ||
return 52 | ||
case *LimitBandwidthRequest, *LimitBandwidthResponse: | ||
return 53 | ||
|
||
case *PingRequest, *PingResponse: | ||
return 91 | ||
case *ListRequest, *ListResponse: | ||
return 92 | ||
case *EchoRequest, *EchoResponse: | ||
return 93 | ||
} | ||
|
||
fmt.Printf("wat?!?! %#v\n", msg) | ||
panic("unknown message type") | ||
} | ||
|
||
func readNBytes(payloadLen int, io *bufio.Reader) ([]byte, error) { | ||
payload := make([]byte, payloadLen) | ||
|
||
for readCount := 0; readCount < payloadLen; { | ||
n, err := io.Read(payload[readCount:]) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
readCount += n | ||
} | ||
|
||
return payload, nil | ||
} |
Oops, something went wrong.