Skip to content

Commit

Permalink
Merge pull request #14 from aberic/master
Browse files Browse the repository at this point in the history
new feature
  • Loading branch information
aberic committed May 13, 2019
2 parents 465d779 + a366c61 commit 466a6cf
Show file tree
Hide file tree
Showing 15 changed files with 164 additions and 98 deletions.
27 changes: 15 additions & 12 deletions bow/bow.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,24 +87,27 @@ func (s *Bow) register(routeService *RouteService) {
}

// RunBow 开启路由
func RunBow(context *gin.Context, serviceName string) {
routeService, ok := instance.AllWay[serviceName]
if !ok {
err := fmt.Errorf("routeService not fount")
log.Shunt.Error(err.Error(), zap.String("serviceName", serviceName))
context.JSON(http.StatusOK, err.Error())
} else {
request.SyncPoolGetRequest().Call(context, context.Request.Method, routeService.OutRemote, routeService.OutURI)
}
func RunBow(context *gin.Context, serviceName string, filter func(context *gin.Context, result *response.Result) bool) {
RunBowCallback(context, serviceName, filter, nil)
}

// RunBowCallback 开启路由并处理降级
func RunBowCallback(context *gin.Context, serviceName string, f func() *response.Result) {
func RunBowCallback(context *gin.Context, serviceName string, filter func(context *gin.Context, result *response.Result) bool, f func() *response.Result) {
routeService, ok := instance.AllWay[serviceName]
result := response.Result{}
if !ok {
err := fmt.Errorf("service not fount")
err := fmt.Errorf("routeService not fount")
log.Shunt.Error(err.Error(), zap.String("serviceName", serviceName))
context.JSON(http.StatusOK, err.Error())
result.Fail(err.Error())
context.JSON(http.StatusOK, result)
return
}
if !filter(context, &result) {
context.JSON(http.StatusOK, result)
return
}
if nil == f {
request.SyncPoolGetRequest().Call(context, context.Request.Method, routeService.OutRemote, routeService.OutURI)
} else {
request.SyncPoolGetRequest().Callback(context, context.Request.Method, routeService.OutRemote, routeService.OutURI, f)
}
Expand Down
5 changes: 3 additions & 2 deletions bow/bow_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,18 @@
package bow

import (
"github.com/ennoo/rivet/trans/response"
"github.com/gin-gonic/gin"
)

// Route 网关服务路由
func Route(engine *gin.Engine) {
func Route(engine *gin.Engine, filter func(context *gin.Context, result *response.Result) bool) {
// 仓库相关路由设置
vRepo := engine.Group("/")
for x := range routeServices {
bowService := routeServices[x]
vRepo.Any(bowService.InURI, func(context *gin.Context) {
RunBow(context, bowService.Name)
RunBow(context, bowService.Name, filter)
})
}
}
8 changes: 7 additions & 1 deletion examples/bow1/bow1.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@ package main

import (
"github.com/ennoo/rivet/rivet"
"github.com/ennoo/rivet/trans/response"
"github.com/ennoo/rivet/utils/env"
"github.com/gin-gonic/gin"
"strings"
)

func main() {
rivet.Initialize(true, false, true, false)
rivet.Initialize(false, true, false)
rivet.UserBow(func(context *gin.Context, result *response.Result) bool {
result.Fail("test fail")
return false
})
rivet.Bow().AddService("test1", "hello1", "http://localhost:8081", "rivet/shunt")
rivet.Bow().AddService("test2", "hello2", "https://localhost:8092", "rivet/shunt")
rivet.ListenAndServe(&rivet.ListenServe{
Expand Down
7 changes: 6 additions & 1 deletion examples/bow2/bow2.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,17 @@ package main
import (
"github.com/ennoo/rivet/bow"
"github.com/ennoo/rivet/rivet"
"github.com/ennoo/rivet/trans/response"
"github.com/ennoo/rivet/utils/env"
"github.com/gin-gonic/gin"
"strings"
)

func main() {
rivet.Initialize(true, false, true, false)
rivet.Initialize(false, true, false)
rivet.UserBow(func(context *gin.Context, result *response.Result) bool {
return true
})
rivet.Bow().Add(
&bow.RouteService{
Name: "test1",
Expand Down
2 changes: 1 addition & 1 deletion examples/shunt1/shunt1.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
)

func main() {
rivet.Initialize(false, true, true, true)
rivet.Initialize(true, true, true)
//rivet.Log().Conf(&log.Config{
// FilePath: strings.Join([]string{"./logs/rivet.log"}, ""),
// Level: zapcore.DebugLevel,
Expand Down
4 changes: 3 additions & 1 deletion examples/trans1/trans1.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
)

func main() {
rivet.Initialize(false, true, false, false)
rivet.Initialize(true, false, false)
rivet.UseDiscovery(discovery.ComponentConsul, "127.0.0.1:8500", "test", "127.0.0.1", 8081)
rivet.ListenAndServe(&rivet.ListenServe{
Engine: rivet.SetupRouter(testRouter1),
Expand Down Expand Up @@ -64,6 +64,8 @@ func shunt1(context *gin.Context) {
result.SayFail(context, err.Error())
}
test.Name = "trans1"
context.Writer.Header().Add("trans1Token15", "trans1Test15")
context.SetCookie("trans1Token16", "trans1Test16", 10, "/", "localhost", false, true)
result.SaySuccess(context, test)
})
}
Expand Down
2 changes: 1 addition & 1 deletion examples/trans1/trans_tls1.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
)

func main() {
rivet.Initialize(false, true, false, false)
rivet.Initialize(true, false, false)
rivet.ListenAndServeTLS(&rivet.ListenServe{
Engine: rivet.SetupRouter(testRouterTLS1),
DefaultPort: "8091",
Expand Down
2 changes: 1 addition & 1 deletion examples/trans2/trans2.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
)

func main() {
rivet.Initialize(false, true, false, false)
rivet.Initialize(true, false, false)
rivet.UseDiscovery(discovery.ComponentConsul, "127.0.0.1:8500", "test", "127.0.0.1", 8082)
rivet.ListenAndServe(&rivet.ListenServe{
Engine: rivet.SetupRouter(testRouter2),
Expand Down
2 changes: 1 addition & 1 deletion examples/trans2/trans_tls2.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import (
)

func main() {
rivet.Initialize(false, true, false, false)
rivet.Initialize(true, false, false)

rivet.ListenAndServeTLS(&rivet.ListenServe{
Engine: rivet.SetupRouter(testRouterTLS2),
Expand Down
18 changes: 14 additions & 4 deletions rivet/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"github.com/ennoo/rivet/scheduled"
"github.com/ennoo/rivet/server"
"github.com/ennoo/rivet/trans/request"
"github.com/ennoo/rivet/trans/response"
"github.com/ennoo/rivet/utils/env"
"github.com/ennoo/rivet/utils/log"
"github.com/ennoo/rivet/utils/string"
Expand All @@ -37,6 +38,7 @@ var (
ud = false // 是否启用发现服务
cp string // 启用的发现服务组件类型
sn string // 注册到发现服务的服务名称(优先通过环境变量 SERVICE_NAME 获取)
f func(context *gin.Context, result *response.Result) bool
)

// ListenServe 启动监听端口服务对象
Expand All @@ -49,8 +51,10 @@ type ListenServe struct {
// keepAlive 指定保持活动网络连接的时间,如果为0,则不启用keep-alive,默认30s
KeepAlive time.Duration

// TLS 服务端私钥
CertFile string
KeyFile string
// TLS 服务端的数字证书
KeyFile string
}

// Initialize rivet 初始化方法,必须最先调用
Expand All @@ -62,13 +66,19 @@ type ListenServe struct {
// serverManager:是否开启外界服务管理功能
//
// loadBalance:是否开启负载均衡
func Initialize(bow bool, healthCheck bool, serverManager bool, loadBalance bool) {
func Initialize(healthCheck bool, serverManager bool, loadBalance bool) {
runtime.GOMAXPROCS(runtime.NumCPU())
route = bow
hc = healthCheck
sm = serverManager
request.LB = loadBalance
}

// UserBow 开启网关路由
//
// filter 自定义过滤方案
func UserBow(filter func(context *gin.Context, result *response.Result) bool) {
route = true
f = filter
}

// UseDiscovery 启用指定的发现服务
Expand Down Expand Up @@ -103,7 +113,7 @@ func UseDiscovery(component, url, serviceName, hostname string, port int) {
func SetupRouter(routes ...func(*gin.Engine)) *gin.Engine {
engine := gin.Default()
if route {
bow.Route(engine)
bow.Route(engine, f)
}
if hc {
Health(engine)
Expand Down
103 changes: 73 additions & 30 deletions trans/request/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ package request

import (
"encoding/json"
"errors"
"github.com/ennoo/rivet/server"
"github.com/ennoo/rivet/shunt"
"github.com/ennoo/rivet/trans/response"
"github.com/ennoo/rivet/utils/log"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"io/ioutil"
"net/http"
"net/url"
"strconv"
Expand Down Expand Up @@ -80,29 +82,30 @@ func (request *Request) RestJSON(method string, remote string, uri string, param
Param: param,
}
var body []byte
var resp *http.Response
var err error

switch method {
case http.MethodGet:
body, err = restJSONHandler.Get(DirectJSONRequest)
resp, err = restJSONHandler.Get(DirectJSONRequest)
case http.MethodHead:
body, err = restJSONHandler.Head(DirectJSONRequest)
resp, err = restJSONHandler.Head(DirectJSONRequest)
case http.MethodPost:
body, err = restJSONHandler.Post(DirectJSONRequest)
resp, err = restJSONHandler.Post(DirectJSONRequest)
case http.MethodPut:
body, err = restJSONHandler.Put(DirectJSONRequest)
resp, err = restJSONHandler.Put(DirectJSONRequest)
case http.MethodPatch:
body, err = restJSONHandler.Patch(DirectJSONRequest)
resp, err = restJSONHandler.Patch(DirectJSONRequest)
case http.MethodDelete:
body, err = restJSONHandler.Delete(DirectJSONRequest)
resp, err = restJSONHandler.Delete(DirectJSONRequest)
case http.MethodConnect:
body, err = restJSONHandler.Connect(DirectJSONRequest)
resp, err = restJSONHandler.Connect(DirectJSONRequest)
case http.MethodOptions:
body, err = restJSONHandler.Options(DirectJSONRequest)
resp, err = restJSONHandler.Options(DirectJSONRequest)
case http.MethodTrace:
body, err = restJSONHandler.Trace(DirectJSONRequest)
resp, err = restJSONHandler.Trace(DirectJSONRequest)
}
return body, err
return restDone(body, resp, err)
}

// RestTextByURL TEXT 请求
Expand Down Expand Up @@ -131,27 +134,41 @@ func (request *Request) RestText(method string, remote string, uri string, value
Values: values,
}
var body []byte
var resp *http.Response
var err error

switch method {
case http.MethodGet:
body, err = restTextHandler.Get(DirectTextRequest)
resp, err = restTextHandler.Get(DirectTextRequest)
case http.MethodHead:
body, err = restTextHandler.Head(DirectTextRequest)
resp, err = restTextHandler.Head(DirectTextRequest)
case http.MethodPost:
body, err = restTextHandler.Post(DirectTextRequest)
resp, err = restTextHandler.Post(DirectTextRequest)
case http.MethodPut:
body, err = restTextHandler.Put(DirectTextRequest)
resp, err = restTextHandler.Put(DirectTextRequest)
case http.MethodPatch:
body, err = restTextHandler.Patch(DirectTextRequest)
resp, err = restTextHandler.Patch(DirectTextRequest)
case http.MethodDelete:
body, err = restTextHandler.Delete(DirectTextRequest)
resp, err = restTextHandler.Delete(DirectTextRequest)
case http.MethodConnect:
body, err = restTextHandler.Connect(DirectTextRequest)
resp, err = restTextHandler.Connect(DirectTextRequest)
case http.MethodOptions:
body, err = restTextHandler.Options(DirectTextRequest)
resp, err = restTextHandler.Options(DirectTextRequest)
case http.MethodTrace:
body, err = restTextHandler.Trace(DirectTextRequest)
resp, err = restTextHandler.Trace(DirectTextRequest)
}
return restDone(body, resp, err)
}

func restDone(body []byte, resp *http.Response, err error) ([]byte, error) {
if err != nil {
return nil, err
}
if nil != resp {
defer resp.Body.Close()
body, err = ioutil.ReadAll(resp.Body)
} else {
err = errors.New("response is nil")
}
return body, err
}
Expand Down Expand Up @@ -265,35 +282,53 @@ func (request *Request) callReal(context *gin.Context, method string, remote str
Header: req.Header,
Cookies: cookies}}
var body []byte
var resp *http.Response
var err error

switch method {
case http.MethodGet:
body, err = restTransHandler.Get(TransCallbackRequest)
resp, err = restTransHandler.Get(TransCallbackRequest)
case http.MethodHead:
body, err = restTransHandler.Head(TransCallbackRequest)
resp, err = restTransHandler.Head(TransCallbackRequest)
case http.MethodPost:
body, err = restTransHandler.Post(TransCallbackRequest)
resp, err = restTransHandler.Post(TransCallbackRequest)
case http.MethodPut:
body, err = restTransHandler.Put(TransCallbackRequest)
resp, err = restTransHandler.Put(TransCallbackRequest)
case http.MethodPatch:
body, err = restTransHandler.Patch(TransCallbackRequest)
resp, err = restTransHandler.Patch(TransCallbackRequest)
case http.MethodDelete:
body, err = restTransHandler.Delete(TransCallbackRequest)
resp, err = restTransHandler.Delete(TransCallbackRequest)
case http.MethodConnect:
body, err = restTransHandler.Connect(TransCallbackRequest)
resp, err = restTransHandler.Connect(TransCallbackRequest)
case http.MethodOptions:
body, err = restTransHandler.Options(TransCallbackRequest)
resp, err = restTransHandler.Options(TransCallbackRequest)
case http.MethodTrace:
body, err = restTransHandler.Trace(TransCallbackRequest)
resp, err = restTransHandler.Trace(TransCallbackRequest)
}
request.callDone(context, body, resp, err, callback)
}

func (request *Request) callDone(context *gin.Context, body []byte, resp *http.Response, err error, callback func() *response.Result) {
if err != nil {
request.result.Fail(err.Error())
context.JSON(http.StatusOK, request.result)
return
}

if nil != resp {
defer resp.Body.Close()
bodyRead, err := ioutil.ReadAll(resp.Body)
done(context, resp, request, bodyRead, err, callback)
} else {
request.result.Fail("Response is nil")
context.JSON(http.StatusOK, request.result)
}
done(context, request, body, err, callback)
}

// done 请求转发处理结果
//
// 转发请求或降级回调
func done(context *gin.Context, request *Request, body []byte, err error, callback func() *response.Result) {
func done(context *gin.Context, resp *http.Response, request *Request, body []byte, err error, callback func() *response.Result) {
if err != nil {
request.result.Callback(callback, err)
} else {
Expand All @@ -303,6 +338,14 @@ func done(context *gin.Context, request *Request, body []byte, err error, callba
request.result.Fail(err.Error())
}
}

//for k := range resp.Header {
// context.Writer.Header().Add(k, resp.Header.Get(k))
//}
for index := range resp.Cookies() {
cookie := resp.Cookies()[index]
context.SetCookie(cookie.Name, cookie.Value, cookie.MaxAge, cookie.Path, cookie.Domain, cookie.Secure, cookie.HttpOnly)
}
context.JSON(http.StatusOK, request.result)
}

Expand Down

0 comments on commit 466a6cf

Please sign in to comment.