diff --git a/.gitignore b/.gitignore index b26e5ec..4c310f8 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,9 @@ .vscode/ .idea/ +# IDL +idl/ + # Binaries for programs and plugins bubblecopy output/ diff --git a/Makefile b/Makefile index 8c09dd1..52a0bdd 100644 --- a/Makefile +++ b/Makefile @@ -2,13 +2,15 @@ define DEBUG_SETTINGS { "interface": "lo0", - "service_port": "8080" + "service_port": "8080", + "replay_svr_addr": "127.0.0.1:6789" } endef export DEBUG_SETTINGS # todo: cross-compile for linux and windows to releases. build: clean + @echo "go build project..." @mkdir -p output @go build -o output/bubblecopy @@ -23,4 +25,14 @@ run-debug: build clean: @echo "clean output directory..." - @rm -rf output/ \ No newline at end of file + @rm -rf output/ + +update-idl: + @rm -rf ./idl + @echo "step1> fetching idl repo..." + @git clone --depth=1 https://github.com/bubble-diff/IDL.git idl + @rm -rf ./idl/.git + @rm ./idl/.gitignore + @echo "step2> compile idl..." + @protoc --go_out=. idl/*.proto + @go mod tidy diff --git a/buffer.go b/buffer.go index 68cbf0d..df42196 100644 --- a/buffer.go +++ b/buffer.go @@ -1,6 +1,8 @@ package main -import "io" +import ( + "io" +) type buffer struct { bytes chan []byte @@ -8,7 +10,9 @@ type buffer struct { } func NewBuffer() *buffer { - // todo: bytes可以做成带缓存的channel吗? + // 这里,必须是无缓存的channel,因为channel是stream进行close的。 + // 如果带缓存,stream关掉channel后,consumer会消费失败。 + // todo: 将close交给consume去做?这样就可以带缓存了 return &buffer{bytes: make(chan []byte)} } diff --git a/config.go b/config.go index b85fe5b..5b48b8d 100644 --- a/config.go +++ b/config.go @@ -10,13 +10,15 @@ import ( type config struct { // Taskid Diff任务ID - Taskid int `json:"taskid"` + Taskid int64 `json:"taskid"` // Secret 访问对应id任务配置的密钥 Secret string `json:"secret"` // Device 网卡名称 Device string `json:"interface"` // Port 被测服务端口 Port string `json:"service_port"` + // ReplaySvrAddr bubblereplay服务地址 + ReplaySvrAddr string `json:"replay_svr_addr"` // DeviceIPv4 网卡ipv4地址 DeviceIPv4 string @@ -35,6 +37,10 @@ func (c *config) init() { logrus.Fatal(err) } + if configuration.ReplaySvrAddr == "" { + logrus.Fatal("bubblereplay server addr not set") + } + c.DeviceIPv4, err = getDeviceIpv4(c.Device) if err != nil { logrus.Error(err) diff --git a/const.go b/const.go index 6f68982..4974ee8 100644 --- a/const.go +++ b/const.go @@ -11,3 +11,7 @@ const ( UnknownType = "unknown" HttpType = "http" ) + +const ( + ApiAddRecord = "/record/add" +) \ No newline at end of file diff --git a/go.mod b/go.mod index ebbe31a..fbbe12f 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.17 require ( github.com/google/gopacket v1.1.19 github.com/sirupsen/logrus v1.8.1 + google.golang.org/protobuf v1.27.1 ) require golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27 // indirect diff --git a/go.sum b/go.sum index 964afff..2755b82 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -24,3 +27,8 @@ golang.org/x/sys v0.0.0-20220128215802-99c3d69c2c27/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= +google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= diff --git a/main.go b/main.go index f845de3..4b44878 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "flag" "fmt" "os" + "os/signal" "time" "github.com/google/gopacket" @@ -61,14 +62,19 @@ func main() { streamPool := reassembly.NewStreamPool(streamFactory) assembler := reassembly.NewAssembler(streamPool) + signalChan := make(chan os.Signal, 1) + signal.Notify(signalChan, os.Interrupt) + ticker := time.NewTicker(time.Second * 30) - defer ticker.Stop() - defer streamFactory.WaitConsumers() for { // todo: 等待Diff任务启动,若未启动,请勿进行抓包消耗CPU // your code here... + done := false select { + case <-signalChan: + logrus.Info("Caught SIGINT: aborting") + done = true case <-ticker.C: // 停止监听30秒内无数据传输的连接 assembler.FlushCloseOlderThan(time.Now().Add(time.Second * -30)) @@ -79,5 +85,16 @@ func main() { assembler.Assemble(packet.NetworkLayer().NetworkFlow(), tcp) } } + if done { + break + } } + + ticker.Stop() + // Important! Please flush all connection before waiting consumers. + closed := assembler.FlushAll() + logrus.Debugf("Final flush: %d closed", closed) + + streamFactory.WaitConsumers() + logrus.Info("Bye~") } diff --git a/pb/replay.pb.go b/pb/replay.pb.go new file mode 100644 index 0000000..1657940 --- /dev/null +++ b/pb/replay.pb.go @@ -0,0 +1,306 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.27.1 +// protoc v3.19.4 +// source: idl/replay.proto + +package pb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type Record struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + TaskId int64 `protobuf:"varint,1,opt,name=task_id,json=taskId,proto3" json:"task_id,omitempty"` + OldReq []byte `protobuf:"bytes,2,opt,name=old_req,json=oldReq,proto3" json:"old_req,omitempty"` + OldResp []byte `protobuf:"bytes,3,opt,name=old_resp,json=oldResp,proto3" json:"old_resp,omitempty"` + NewResp []byte `protobuf:"bytes,4,opt,name=new_resp,json=newResp,proto3" json:"new_resp,omitempty"` +} + +func (x *Record) Reset() { + *x = Record{} + if protoimpl.UnsafeEnabled { + mi := &file_idl_replay_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *Record) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Record) ProtoMessage() {} + +func (x *Record) ProtoReflect() protoreflect.Message { + mi := &file_idl_replay_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Record.ProtoReflect.Descriptor instead. +func (*Record) Descriptor() ([]byte, []int) { + return file_idl_replay_proto_rawDescGZIP(), []int{0} +} + +func (x *Record) GetTaskId() int64 { + if x != nil { + return x.TaskId + } + return 0 +} + +func (x *Record) GetOldReq() []byte { + if x != nil { + return x.OldReq + } + return nil +} + +func (x *Record) GetOldResp() []byte { + if x != nil { + return x.OldResp + } + return nil +} + +func (x *Record) GetNewResp() []byte { + if x != nil { + return x.NewResp + } + return nil +} + +type AddRecordReq struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Record *Record `protobuf:"bytes,1,opt,name=record,proto3" json:"record,omitempty"` +} + +func (x *AddRecordReq) Reset() { + *x = AddRecordReq{} + if protoimpl.UnsafeEnabled { + mi := &file_idl_replay_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *AddRecordReq) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AddRecordReq) ProtoMessage() {} + +func (x *AddRecordReq) ProtoReflect() protoreflect.Message { + mi := &file_idl_replay_proto_msgTypes[1] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AddRecordReq.ProtoReflect.Descriptor instead. +func (*AddRecordReq) Descriptor() ([]byte, []int) { + return file_idl_replay_proto_rawDescGZIP(), []int{1} +} + +func (x *AddRecordReq) GetRecord() *Record { + if x != nil { + return x.Record + } + return nil +} + +type AddRecordResp struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Code int64 `protobuf:"varint,1,opt,name=code,proto3" json:"code,omitempty"` + Msg string `protobuf:"bytes,2,opt,name=msg,proto3" json:"msg,omitempty"` +} + +func (x *AddRecordResp) Reset() { + *x = AddRecordResp{} + if protoimpl.UnsafeEnabled { + mi := &file_idl_replay_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *AddRecordResp) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AddRecordResp) ProtoMessage() {} + +func (x *AddRecordResp) ProtoReflect() protoreflect.Message { + mi := &file_idl_replay_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AddRecordResp.ProtoReflect.Descriptor instead. +func (*AddRecordResp) Descriptor() ([]byte, []int) { + return file_idl_replay_proto_rawDescGZIP(), []int{2} +} + +func (x *AddRecordResp) GetCode() int64 { + if x != nil { + return x.Code + } + return 0 +} + +func (x *AddRecordResp) GetMsg() string { + if x != nil { + return x.Msg + } + return "" +} + +var File_idl_replay_proto protoreflect.FileDescriptor + +var file_idl_replay_proto_rawDesc = []byte{ + 0x0a, 0x10, 0x69, 0x64, 0x6c, 0x2f, 0x72, 0x65, 0x70, 0x6c, 0x61, 0x79, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x12, 0x06, 0x72, 0x65, 0x70, 0x6c, 0x61, 0x79, 0x22, 0x70, 0x0a, 0x06, 0x52, 0x65, + 0x63, 0x6f, 0x72, 0x64, 0x12, 0x17, 0x0a, 0x07, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x69, 0x64, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x74, 0x61, 0x73, 0x6b, 0x49, 0x64, 0x12, 0x17, 0x0a, + 0x07, 0x6f, 0x6c, 0x64, 0x5f, 0x72, 0x65, 0x71, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x06, + 0x6f, 0x6c, 0x64, 0x52, 0x65, 0x71, 0x12, 0x19, 0x0a, 0x08, 0x6f, 0x6c, 0x64, 0x5f, 0x72, 0x65, + 0x73, 0x70, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x07, 0x6f, 0x6c, 0x64, 0x52, 0x65, 0x73, + 0x70, 0x12, 0x19, 0x0a, 0x08, 0x6e, 0x65, 0x77, 0x5f, 0x72, 0x65, 0x73, 0x70, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0c, 0x52, 0x07, 0x6e, 0x65, 0x77, 0x52, 0x65, 0x73, 0x70, 0x22, 0x36, 0x0a, 0x0c, + 0x41, 0x64, 0x64, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x65, 0x71, 0x12, 0x26, 0x0a, 0x06, + 0x72, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x72, + 0x65, 0x70, 0x6c, 0x61, 0x79, 0x2e, 0x52, 0x65, 0x63, 0x6f, 0x72, 0x64, 0x52, 0x06, 0x72, 0x65, + 0x63, 0x6f, 0x72, 0x64, 0x22, 0x35, 0x0a, 0x0d, 0x41, 0x64, 0x64, 0x52, 0x65, 0x63, 0x6f, 0x72, + 0x64, 0x52, 0x65, 0x73, 0x70, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x03, 0x52, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x73, 0x67, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x73, 0x67, 0x42, 0x06, 0x5a, 0x04, 0x2e, + 0x2f, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_idl_replay_proto_rawDescOnce sync.Once + file_idl_replay_proto_rawDescData = file_idl_replay_proto_rawDesc +) + +func file_idl_replay_proto_rawDescGZIP() []byte { + file_idl_replay_proto_rawDescOnce.Do(func() { + file_idl_replay_proto_rawDescData = protoimpl.X.CompressGZIP(file_idl_replay_proto_rawDescData) + }) + return file_idl_replay_proto_rawDescData +} + +var file_idl_replay_proto_msgTypes = make([]protoimpl.MessageInfo, 3) +var file_idl_replay_proto_goTypes = []interface{}{ + (*Record)(nil), // 0: replay.Record + (*AddRecordReq)(nil), // 1: replay.AddRecordReq + (*AddRecordResp)(nil), // 2: replay.AddRecordResp +} +var file_idl_replay_proto_depIdxs = []int32{ + 0, // 0: replay.AddRecordReq.record:type_name -> replay.Record + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_idl_replay_proto_init() } +func file_idl_replay_proto_init() { + if File_idl_replay_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_idl_replay_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*Record); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_idl_replay_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*AddRecordReq); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_idl_replay_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*AddRecordResp); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_idl_replay_proto_rawDesc, + NumEnums: 0, + NumMessages: 3, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_idl_replay_proto_goTypes, + DependencyIndexes: file_idl_replay_proto_depIdxs, + MessageInfos: file_idl_replay_proto_msgTypes, + }.Build() + File_idl_replay_proto = out.File + file_idl_replay_proto_rawDesc = nil + file_idl_replay_proto_goTypes = nil + file_idl_replay_proto_depIdxs = nil +} diff --git a/stream.go b/stream.go index f2820b1..fff68a5 100644 --- a/stream.go +++ b/stream.go @@ -2,8 +2,10 @@ package main import ( "bufio" + "bytes" + "errors" + "fmt" "io" - "log" "net/http" "net/http/httputil" "sync" @@ -12,6 +14,9 @@ import ( "github.com/google/gopacket/layers" "github.com/google/gopacket/reassembly" "github.com/sirupsen/logrus" + "google.golang.org/protobuf/proto" + + "github.com/bubble-diff/bubblecopy/pb" ) type tcpStream struct { @@ -58,7 +63,8 @@ func (s *tcpStream) ReassembledSG(sg reassembly.ScatterGather, ac reassembly.Ass func (s *tcpStream) ReassemblyComplete(ac reassembly.AssemblerContext) bool { close(s.c2sBuf.bytes) close(s.s2cBuf.bytes) - return true + // false to receive last ack for avoiding New tcpStream. + return false } // consume 消费两个缓存中的数据进行下一步处理 @@ -79,31 +85,75 @@ func handleHttp(c2s, s2c io.Reader) { if err == io.EOF || err == io.ErrUnexpectedEOF { break } else if err != nil { - log.Println(err) + logrus.Error(err) continue } resp, err := http.ReadResponse(s2cReader, nil) if err == io.EOF || err == io.ErrUnexpectedEOF { break } else if err != nil { - log.Println(err) + logrus.Error(err) continue } - // todo: 我们目前只是将http req/resp以日志的形式打印下来 - // 后期我们在这里需要添加过滤http流量,以及转发至replayer的能力 - bytes, err := httputil.DumpRequest(req, true) + err = sendReqResp(req, resp) if err != nil { - log.Println(err) + logrus.Errorf("send old req/resp failed, %s", err) + } else { + logrus.Info("send old req/resp ok") } - req.Body.Close() - logrus.Debug(string(bytes)) - body, err := io.ReadAll(resp.Body) - if err != nil { - log.Println(err) - } + req.Body.Close() resp.Body.Close() - logrus.Debug(string(body)) } } + +// sendReqResp 将old req/resp发送至replay服务进行进一步处理 +// todo: 这个函数应该是协议无关的,现在参数为http协议。 +func sendReqResp(req *http.Request, resp *http.Response) (err error) { + // 序列化req/resp + rawReq, err := httputil.DumpRequest(req, true) + if err != nil { + return err + } + rawResp, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + + // send them + request := &pb.AddRecordReq{ + Record: &pb.Record{ + TaskId: configuration.Taskid, + OldReq: rawReq, + OldResp: rawResp, + NewResp: nil, + }, + } + rawpb, err := proto.Marshal(request) + if err != nil { + return err + } + + api := fmt.Sprintf("http://%s%s", configuration.ReplaySvrAddr, ApiAddRecord) + apiResp, err := http.Post(api, "application/octet-stream", bytes.NewReader(rawpb)) + defer apiResp.Body.Close() + if err != nil { + return err + } + + // parse api response + var response pb.AddRecordResp + rawApiResp, err := io.ReadAll(apiResp.Body) + if err != nil { + return err + } + err = proto.Unmarshal(rawApiResp, &response) + if err != nil { + return err + } else if response.Code != 0 { + return errors.New(response.Msg) + } + + return nil +} diff --git a/stream_factory.go b/stream_factory.go index a8b0864..44564b5 100644 --- a/stream_factory.go +++ b/stream_factory.go @@ -16,9 +16,9 @@ type tcpStreamFactory struct { func (f *tcpStreamFactory) New(netFlow, tcpFlow gopacket.Flow, tcp *layers.TCP, ac reassembly.AssemblerContext) reassembly.Stream { s := &tcpStream{ factoryWg: &f.wg, - isDetect: false, - c2sBuf: NewBuffer(), - s2cBuf: NewBuffer(), + isDetect: false, + c2sBuf: NewBuffer(), + s2cBuf: NewBuffer(), } return s }