diff --git a/broker/api/api-handler.go b/broker/api/api-handler.go index 3aeca7d9..c50c0c97 100644 --- a/broker/api/api-handler.go +++ b/broker/api/api-handler.go @@ -11,6 +11,7 @@ import ( "github.com/indexdata/crosslink/broker/adapter" "github.com/indexdata/crosslink/broker/service" + "github.com/indexdata/crosslink/broker/tenant" "github.com/indexdata/crosslink/directory" "github.com/google/uuid" @@ -34,36 +35,23 @@ var LIMIT_DEFAULT int32 = 10 var ARCHIVE_PROCESS_STARTED = "Archive process started" type ApiHandler struct { - limitDefault int32 - eventRepo events.EventRepo - illRepo ill_db.IllRepo - tenant common.Tenant + limitDefault int32 + eventRepo events.EventRepo + illRepo ill_db.IllRepo + tenantContext tenant.TenantContext } -func NewApiHandler(eventRepo events.EventRepo, illRepo ill_db.IllRepo, tenant common.Tenant, limitDefault int32) ApiHandler { +func NewApiHandler(eventRepo events.EventRepo, illRepo ill_db.IllRepo, tenantContext tenant.TenantContext, limitDefault int32) ApiHandler { return ApiHandler{ - eventRepo: eventRepo, - illRepo: illRepo, - tenant: tenant, - limitDefault: limitDefault, + eventRepo: eventRepo, + illRepo: illRepo, + tenantContext: tenantContext, + limitDefault: limitDefault, } } -func (a *ApiHandler) isOwner(trans *ill_db.IllTransaction, tenant *string, requesterSymbol *string) bool { - if tenant == nil && requesterSymbol != nil { - return trans.RequesterSymbol.String == *requesterSymbol - } - if !a.tenant.IsSpecified() { - return true - } - if tenant == nil { - return false - } - return trans.RequesterSymbol.String == a.tenant.GetSymbol(*tenant) -} - func (a *ApiHandler) getIllTranFromParams(ctx common.ExtendedContext, w http.ResponseWriter, - okapiTenant *string, requesterSymbol *string, requesterReqId *oapi.RequesterRequestId, + r *http.Request, requesterSymbol *string, requesterReqId *oapi.RequesterRequestId, illTransactionId *oapi.IllTransactionId) (*ill_db.IllTransaction, error) { var tran ill_db.IllTransaction var err error @@ -86,10 +74,21 @@ func (a *ApiHandler) getIllTranFromParams(ctx common.ExtendedContext, w http.Res addInternalError(ctx, w, err) return nil, err } - if !a.isOwner(&tran, okapiTenant, requesterSymbol) { - return nil, nil + tenant := a.tenantContext.WithRequest(ctx, r, requesterSymbol) + syms, err := tenant.GetSymbols() + if err != nil { + addBadRequestError(ctx, w, err) + return nil, err + } + if syms == nil { + return &tran, nil } - return &tran, nil + for _, s := range syms { + if s == tran.RequesterSymbol.String { + return &tran, nil + } + } + return nil, nil } func (a *ApiHandler) Get(w http.ResponseWriter, r *http.Request) { @@ -109,10 +108,10 @@ func (a *ApiHandler) GetEvents(w http.ResponseWriter, r *http.Request, params oa if params.IllTransactionId != nil { logParams["IllTransactionId"] = *params.IllTransactionId } - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{ + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{ Other: logParams, }) - tran, err := a.getIllTranFromParams(ctx, w, params.XOkapiTenant, params.RequesterSymbol, + tran, err := a.getIllTranFromParams(ctx, w, r, params.RequesterSymbol, params.RequesterReqId, params.IllTransactionId) if err != nil { return @@ -138,7 +137,7 @@ func (a *ApiHandler) GetEvents(w http.ResponseWriter, r *http.Request, params oa } func (a *ApiHandler) GetIllTransactions(w http.ResponseWriter, r *http.Request, params oapi.GetIllTransactionsParams) { - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{ + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{ Other: map[string]string{"method": "GetIllTransactions"}, }) var resp oapi.IllTransactions @@ -156,7 +155,7 @@ func (a *ApiHandler) GetIllTransactions(w http.ResponseWriter, r *http.Request, } var fullCount int64 if params.RequesterReqId != nil { - tran, err := a.getIllTranFromParams(ctx, w, params.XOkapiTenant, params.RequesterSymbol, + tran, err := a.getIllTranFromParams(ctx, w, r, params.RequesterSymbol, params.RequesterReqId, nil) if err != nil { return @@ -165,49 +164,28 @@ func (a *ApiHandler) GetIllTransactions(w http.ResponseWriter, r *http.Request, fullCount = 1 resp.Items = append(resp.Items, toApiIllTransaction(r, *tran)) } - } else if a.tenant.IsSpecified() { - var symbol string - if params.XOkapiTenant != nil { - symbol = a.tenant.GetSymbol(*params.XOkapiTenant) - } else if params.RequesterSymbol != nil { - symbol = *params.RequesterSymbol - } - if symbol == "" { - writeJsonResponse(w, resp) - return - } - dbparams := ill_db.GetIllTransactionsByRequesterSymbolParams{ - Limit: limit, - Offset: offset, - RequesterSymbol: pgtype.Text{ - String: symbol, - Valid: true, - }, - } - var trans []ill_db.IllTransaction - var err error - trans, fullCount, err = a.illRepo.GetIllTransactionsByRequesterSymbol(ctx, dbparams, cql) - if err != nil { //DB error - addInternalError(ctx, w, err) - return - } - for _, t := range trans { - resp.Items = append(resp.Items, toApiIllTransaction(r, t)) - } } else { - dbparams := ill_db.ListIllTransactionsParams{ - Limit: limit, - Offset: offset, - } - var trans []ill_db.IllTransaction - var err error - trans, fullCount, err = a.illRepo.ListIllTransactions(ctx, dbparams, cql) - if err != nil { //DB error - addInternalError(ctx, w, err) + tenant := a.tenantContext.WithRequest(ctx, r, params.RequesterSymbol) + symbols, err := tenant.GetSymbols() + if err != nil { + addBadRequestError(ctx, w, err) return } - for _, t := range trans { - resp.Items = append(resp.Items, toApiIllTransaction(r, t)) + if symbols == nil || len(symbols) > 0 { + dbparams := ill_db.ListIllTransactionsParams{ + Limit: limit, + Offset: offset, + } + var trans []ill_db.IllTransaction + var err error + trans, fullCount, err = a.illRepo.ListIllTransactions(ctx, dbparams, cql, symbols) + if err != nil { //DB error + addInternalError(ctx, w, err) + return + } + for _, t := range trans { + resp.Items = append(resp.Items, toApiIllTransaction(r, t)) + } } } resp.About = CollectAboutData(fullCount, offset, limit, r) @@ -215,10 +193,10 @@ func (a *ApiHandler) GetIllTransactions(w http.ResponseWriter, r *http.Request, } func (a *ApiHandler) GetIllTransactionsId(w http.ResponseWriter, r *http.Request, id string, params oapi.GetIllTransactionsIdParams) { - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{ + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{ Other: map[string]string{"method": "GetIllTransactionsId", "id": id}, }) - tran, err := a.getIllTranFromParams(ctx, w, params.XOkapiTenant, params.RequesterSymbol, + tran, err := a.getIllTranFromParams(ctx, w, r, params.RequesterSymbol, nil, &id) if err != nil { return @@ -231,7 +209,7 @@ func (a *ApiHandler) GetIllTransactionsId(w http.ResponseWriter, r *http.Request } func (a *ApiHandler) DeleteIllTransactionsId(w http.ResponseWriter, r *http.Request, id string) { - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{ + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{ Other: map[string]string{"method": "DeleteIllTransactionsId", "id": id}, }) trans, err := a.illRepo.GetIllTransactionById(ctx, id) @@ -269,7 +247,7 @@ func (a *ApiHandler) returnHttpError(ctx common.ExtendedContext, w http.Response } func (a *ApiHandler) GetPeers(w http.ResponseWriter, r *http.Request, params oapi.GetPeersParams) { - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{ + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{ Other: map[string]string{"method": "GetPeers"}, }) dbparams := ill_db.ListPeersParams{ @@ -308,7 +286,7 @@ func (a *ApiHandler) GetPeers(w http.ResponseWriter, r *http.Request, params oap } func (a *ApiHandler) PostPeers(w http.ResponseWriter, r *http.Request) { - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{ + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{ Other: map[string]string{"method": "PostPeers"}, }) var newPeer oapi.Peer @@ -376,7 +354,7 @@ func (a *ApiHandler) PostPeers(w http.ResponseWriter, r *http.Request) { } func (a *ApiHandler) DeletePeersId(w http.ResponseWriter, r *http.Request, id string) { - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{ + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{ Other: map[string]string{"method": "DeletePeersSymbol", "id": id}, }) err := a.illRepo.WithTxFunc(ctx, func(repo ill_db.IllRepo) error { @@ -434,7 +412,7 @@ func (a *ApiHandler) DeletePeersId(w http.ResponseWriter, r *http.Request, id st } func (a *ApiHandler) GetPeersId(w http.ResponseWriter, r *http.Request, id string) { - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{ + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{ Other: map[string]string{"method": "GetPeersSymbol", "id": id}, }) peer, err := a.illRepo.GetPeerById(ctx, id) @@ -462,7 +440,7 @@ func (a *ApiHandler) GetPeersId(w http.ResponseWriter, r *http.Request, id strin } func (a *ApiHandler) PutPeersId(w http.ResponseWriter, r *http.Request, id string) { - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{ + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{ Other: map[string]string{"method": "PutPeersSymbol", "id": id}, }) peer, err := a.illRepo.GetPeerById(ctx, id) @@ -568,10 +546,10 @@ func (a *ApiHandler) GetLocatedSuppliers(w http.ResponseWriter, r *http.Request, if params.IllTransactionId != nil { logParams["IllTransactionId"] = *params.IllTransactionId } - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{ + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{ Other: logParams, }) - tran, err := a.getIllTranFromParams(ctx, w, params.XOkapiTenant, params.RequesterSymbol, + tran, err := a.getIllTranFromParams(ctx, w, r, params.RequesterSymbol, params.RequesterReqId, params.IllTransactionId) if err != nil { return @@ -598,6 +576,7 @@ func (a *ApiHandler) GetLocatedSuppliers(w http.ResponseWriter, r *http.Request, func (a *ApiHandler) PostArchiveIllTransactions(w http.ResponseWriter, r *http.Request, params oapi.PostArchiveIllTransactionsParams) { logParams := map[string]string{"method": "PostArchiveIllTransactions", "ArchiveDelay": params.ArchiveDelay, "ArchiveStatus": params.ArchiveStatus} + // a background process so use background context instead of request context to avoid cancellation when request is finished ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{ Other: logParams, }) diff --git a/broker/api/common.go b/broker/api/common.go index 2cacdacc..c07bfc36 100644 --- a/broker/api/common.go +++ b/broker/api/common.go @@ -1,13 +1,11 @@ package api import ( - "errors" "net/http" "net/url" "strconv" "strings" - "github.com/indexdata/crosslink/broker/common" "github.com/indexdata/crosslink/broker/oapi" ) @@ -134,23 +132,3 @@ func CollectAboutData(fullCount int64, offset int32, limit int32, r *http.Reques } return about } - -func GetSymbolForRequest(r *http.Request, tenantResolver common.Tenant, tenant *string, symbol *string) (string, error) { - if IsBrokerRequest(r) { - if tenantResolver.IsSpecified() { - if tenant == nil { - return "", errors.New("X-Okapi-Tenant must be specified") - } else { - return tenantResolver.GetSymbol(*tenant), nil - } - } else { - return "", errors.New("tenant mapping must be specified") - } - } else { - if symbol == nil || *symbol == "" { - return "", errors.New("symbol must be specified") - } else { - return *symbol, nil - } - } -} diff --git a/broker/api/common_test.go b/broker/api/common_test.go index 8f325dc8..c95ff56b 100644 --- a/broker/api/common_test.go +++ b/broker/api/common_test.go @@ -1,82 +1,12 @@ package api import ( - "net/http" "net/http/httptest" - "strings" "testing" - "github.com/indexdata/crosslink/broker/common" "github.com/stretchr/testify/assert" ) -func TestGetSymbolForRequest(t *testing.T) { - req, _ := http.NewRequest("GET", "/broker/patron_request", strings.NewReader("{")) - req.RequestURI = "/broker/patron_request" - tenant := "req" - resolved, err := GetSymbolForRequest(req, common.NewTenant("ISIL:{tenant}"), &tenant, nil) - assert.NoError(t, err) - assert.Equal(t, "ISIL:REQ", resolved) - - resolved, err = GetSymbolForRequest(req, common.NewTenant("ISIL:{tenant}"), nil, nil) - assert.Equal(t, "X-Okapi-Tenant must be specified", err.Error()) - assert.Equal(t, "", resolved) - - resolved, err = GetSymbolForRequest(req, common.NewTenant(""), &tenant, nil) - assert.Equal(t, "tenant mapping must be specified", err.Error()) - assert.Equal(t, "", resolved) -} - -func TestWithBrokerPrefix(t *testing.T) { - brokerReq, _ := http.NewRequest("GET", "/broker/patron_request", strings.NewReader("{")) - brokerReq.RequestURI = "/broker/patron_request" - assert.True(t, IsBrokerRequest(brokerReq)) - assert.Equal(t, "/broker/patron_requests/1", WithBrokerPrefix(brokerReq, "/patron_requests/1")) - assert.Equal(t, "/broker/", WithBrokerPrefix(brokerReq, "")) - assert.Equal(t, "/broker/", WithBrokerPrefix(brokerReq, "/")) - - regularReq, _ := http.NewRequest("GET", "/patron_request", strings.NewReader("{")) - regularReq.RequestURI = "/patron_request" - assert.False(t, IsBrokerRequest(regularReq)) - assert.Equal(t, "/patron_requests/1", WithBrokerPrefix(regularReq, "/patron_requests/1")) - assert.Equal(t, "/", WithBrokerPrefix(regularReq, "")) - assert.Equal(t, "/", WithBrokerPrefix(regularReq, "/")) -} - -func TestPathAndQuery(t *testing.T) { - assert.Equal(t, "/patron_requests/1/items", Path("patron_requests", "1", "items")) - assert.Equal(t, "/patron_requests/1/items", Path("/patron_requests/", "/1/", "/items/")) - assert.Equal(t, "/patron%20requests/a%2Fb%3Fx/items", Path("patron requests", "a/b?x", "items")) - - values := Query("symbol", "ISIL:REQ", "offset", "10", "dangling") - assert.Equal(t, "ISIL:REQ", values.Get("symbol")) - assert.Equal(t, "10", values.Get("offset")) - assert.Empty(t, values.Get("dangling")) -} - -func TestLink(t *testing.T) { - regularReq := httptest.NewRequest("GET", "https://example.org/patron_requests", nil) - regularReq.RequestURI = "/patron_requests" - link := Link(regularReq, Path("patron_requests", "1", "items"), Query("symbol", "ISIL:REQ", "q", "a b")) - assert.Equal(t, "https://example.org/patron_requests/1/items?q=a+b&symbol=ISIL%3AREQ", link) - - brokerReq := httptest.NewRequest("GET", "https://example.org/broker/patron_requests", nil) - brokerReq.RequestURI = "/broker/patron_requests" - brokerLink := Link(brokerReq, Path("patron_requests", "1", "items"), Query("symbol", "ISIL:REQ")) - assert.Equal(t, "https://example.org/broker/patron_requests/1/items?symbol=ISIL%3AREQ", brokerLink) -} - -func TestLinkRel(t *testing.T) { - req := httptest.NewRequest("GET", "https://example.org/patron_requests/1", nil) - req.RequestURI = "/patron_requests/1" - - currentLink := LinkRel(req, "", Query("symbol", "ISIL:REQ")) - assert.Equal(t, "https://example.org/patron_requests/1?symbol=ISIL%3AREQ", currentLink) - - relativeLink := LinkRel(req, "items", Query("symbol", "ISIL:REQ")) - assert.Equal(t, "https://example.org/patron_requests/1/items?symbol=ISIL%3AREQ", relativeLink) -} - func TestCollectAboutDataLastLink(t *testing.T) { reqOffset0 := httptest.NewRequest("GET", "http://localhost/ill_transactions?symbol=ISIL:DK-BIB1&offset=0", nil) reqOffset10 := httptest.NewRequest("GET", "http://localhost/ill_transactions?symbol=ISIL:DK-BIB1&offset=10", nil) diff --git a/broker/api/sse_broker.go b/broker/api/sse_broker.go index 7b084f10..e6cacad9 100644 --- a/broker/api/sse_broker.go +++ b/broker/api/sse_broker.go @@ -10,23 +10,24 @@ import ( "github.com/indexdata/crosslink/broker/events" pr_db "github.com/indexdata/crosslink/broker/patron_request/db" prservice "github.com/indexdata/crosslink/broker/patron_request/service" + "github.com/indexdata/crosslink/broker/tenant" "github.com/indexdata/crosslink/iso18626" ) type SseBroker struct { - input chan SseMessage - clients map[string]map[chan string]bool - mu sync.Mutex - ctx common.ExtendedContext - tenant common.Tenant + input chan SseMessage + clients map[string]map[chan string]bool + mu sync.Mutex + ctx common.ExtendedContext + tenantContext tenant.TenantContext } -func NewSseBroker(ctx common.ExtendedContext, tenant common.Tenant) (broker *SseBroker) { +func NewSseBroker(ctx common.ExtendedContext, tenantContext tenant.TenantContext) (broker *SseBroker) { broker = &SseBroker{ - input: make(chan SseMessage), - clients: make(map[string]map[chan string]bool), - ctx: ctx, - tenant: tenant, + input: make(chan SseMessage), + clients: make(map[string]map[chan string]bool), + ctx: ctx, + tenantContext: tenantContext, } // Start the single broadcaster goroutine @@ -67,24 +68,26 @@ func (b *SseBroker) removeClient(receiver string, clientChannel chan string) { // ServeHTTP implements the http.Handler interface for the SSE endpoint. func (b *SseBroker) ServeHTTP(w http.ResponseWriter, r *http.Request) { - clientChannel := make(chan string, 10) - tenant := r.Header.Get("X-Okapi-Tenant") - var symbol string - if b.tenant.IsSpecified() && tenant != "" { - symbol = b.tenant.GetSymbol(tenant) - } else { - symbol = r.URL.Query().Get("symbol") - } + logParams := map[string]string{"method": "ServeHTTP"} + ectx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) + + suppliedSymbol := r.URL.Query().Get("symbol") + symbol, err := b.tenantContext.WithRequest(ectx, r, &suppliedSymbol).GetSymbol() + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } side := r.URL.Query().Get("side") - if side == "" || symbol == "" { - http.Error(w, "query parameter 'side' and 'symbol' must be specified", http.StatusBadRequest) + if side == "" { + http.Error(w, "query parameter 'side' must be specified", http.StatusBadRequest) return } if side != string(prservice.SideBorrowing) && side != string(prservice.SideLending) { http.Error(w, fmt.Sprintf("query parameter 'side' must be %s or %s", prservice.SideBorrowing, prservice.SideLending), http.StatusBadRequest) return } + clientChannel := make(chan string, 10) b.mu.Lock() receiver := side + symbol clients := b.clients[receiver] diff --git a/broker/app/app.go b/broker/app/app.go index 6f968742..74350d5e 100644 --- a/broker/app/app.go +++ b/broker/app/app.go @@ -17,6 +17,7 @@ import ( pr_db "github.com/indexdata/crosslink/broker/patron_request/db" "github.com/indexdata/crosslink/broker/patron_request/proapi" prservice "github.com/indexdata/crosslink/broker/patron_request/service" + "github.com/indexdata/crosslink/broker/tenant" "github.com/dustin/go-humanize" "github.com/indexdata/crosslink/broker/adapter" @@ -166,11 +167,12 @@ func Init(ctx context.Context) (Context, error) { iso18626Client := client.CreateIso18626Client(eventBus, illRepo, prMessageHandler, MAX_MESSAGE_SIZE, delay) supplierLocator := service.CreateSupplierLocator(eventBus, illRepo, dirAdapter, holdingsAdapter) workflowManager := service.CreateWorkflowManager(eventBus, illRepo, service.WorkflowConfig{}) - prApiHandler := prapi.NewPrApiHandler(prRepo, eventBus, eventRepo, common.NewTenant(TENANT_TO_SYMBOL), &iso18626Handler, API_PAGE_SIZE) + tenantContext := tenant.NewContext().WithIllRepo(illRepo).WithLookupAdapter(dirAdapter).WithTenantSymbol(TENANT_TO_SYMBOL) + prApiHandler := prapi.NewPrApiHandler(prRepo, eventBus, eventRepo, *tenantContext, &iso18626Handler, API_PAGE_SIZE) prApiHandler.SetAutoActionRunner(prActionService) prApiHandler.SetActionTaskProcessor(prActionService) - sseBroker := api.NewSseBroker(appCtx, common.NewTenant(TENANT_TO_SYMBOL)) + sseBroker := api.NewSseBroker(appCtx, *tenantContext) AddDefaultHandlers(eventBus, iso18626Client, supplierLocator, workflowManager, iso18626Handler, *prActionService, prApiHandler, sseBroker) err = StartEventBus(ctx, eventBus) @@ -209,11 +211,16 @@ func StartServer(ctx Context) error { _, _ = w.Write(oapi.OpenAPISpecYAML) }) - apiHandler := api.NewApiHandler(ctx.EventRepo, ctx.IllRepo, common.NewTenant(""), API_PAGE_SIZE) + tenantContext := tenant.NewContext().WithIllRepo(ctx.IllRepo).WithLookupAdapter(ctx.DirAdapter) + + apiHandler := api.NewApiHandler(ctx.EventRepo, ctx.IllRepo, *tenantContext, API_PAGE_SIZE) oapi.HandlerFromMux(&apiHandler, ServeMux) proapi.HandlerFromMux(&ctx.PrApiHandler, ServeMux) if TENANT_TO_SYMBOL != "" { - apiHandler := api.NewApiHandler(ctx.EventRepo, ctx.IllRepo, common.NewTenant(TENANT_TO_SYMBOL), API_PAGE_SIZE) + tenantContext = tenant.NewContext().WithIllRepo(ctx.IllRepo).WithLookupAdapter(ctx.DirAdapter).WithTenantSymbol(TENANT_TO_SYMBOL) + ServeMux.HandleFunc("/broker/sse/events", ctx.SseBroker.ServeHTTP) + + apiHandler := api.NewApiHandler(ctx.EventRepo, ctx.IllRepo, *tenantContext, API_PAGE_SIZE) oapi.HandlerFromMuxWithBaseURL(&apiHandler, ServeMux, "/broker") proapi.HandlerFromMuxWithBaseURL(&ctx.PrApiHandler, ServeMux, "/broker") } diff --git a/broker/common/tenant.go b/broker/common/tenant.go deleted file mode 100644 index 9e17acc9..00000000 --- a/broker/common/tenant.go +++ /dev/null @@ -1,21 +0,0 @@ -package common - -import ( - "strings" -) - -type Tenant struct { - mapping string -} - -func NewTenant(tenantSymbol string) Tenant { - return Tenant{mapping: tenantSymbol} -} - -func (t *Tenant) IsSpecified() bool { - return t.mapping != "" -} - -func (t *Tenant) GetSymbol(tenant string) string { - return strings.ReplaceAll(t.mapping, "{tenant}", strings.ToUpper(tenant)) -} diff --git a/broker/ill_db/ill_cql.go b/broker/ill_db/ill_cql.go index 3976cef4..aeaf1403 100644 --- a/broker/ill_db/ill_cql.go +++ b/broker/ill_db/ill_cql.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/indexdata/cql-go/cql" + "github.com/indexdata/cql-go/cqlbuilder" "github.com/indexdata/cql-go/pgcql" ) @@ -55,12 +56,34 @@ func handlePeersQuery(cqlString string, noBaseArgs int) (pgcql.Query, error) { } func (q *Queries) ListIllTransactionsCql(ctx context.Context, db DBTX, arg ListIllTransactionsParams, - cqlString *string) ([]ListIllTransactionsRow, error) { - if cqlString == nil { + cqlString *string, symbols []string) ([]ListIllTransactionsRow, error) { + var cql strings.Builder + for _, symbol := range symbols { + if cql.Len() > 0 { + cql.WriteString(" OR ") + } else { + cql.WriteString("(") + } + comp, err := cqlbuilder.NewQuery().Search("requester_symbol").Term(symbol).Build() + if err != nil { + return nil, fmt.Errorf("failed to build CQL query: %w", err) + } + cql.WriteString(comp.String()) + } + if cql.Len() > 0 { + cql.WriteString(")") + } + if cqlString != nil { + if cql.Len() > 0 { + cql.WriteString(" AND ") + } + cql.WriteString("(" + *cqlString + ")") + } + if cql.Len() == 0 { return q.ListIllTransactions(ctx, db, arg) } - noBaseArgs := 2 // weh have two base arguments: limit and offset - res, err := handleIllTransactionsQuery(*cqlString, noBaseArgs) + noBaseArgs := 2 // we have two base arguments: limit and offset + res, err := handleIllTransactionsQuery(cql.String(), noBaseArgs) if err != nil { return nil, err } @@ -111,64 +134,6 @@ func (q *Queries) ListIllTransactionsCql(ctx context.Context, db DBTX, arg ListI return items, nil } -func (q *Queries) GetIllTransactionsByRequesterSymbolCql(ctx context.Context, db DBTX, arg GetIllTransactionsByRequesterSymbolParams, - cqlString *string) ([]GetIllTransactionsByRequesterSymbolRow, error) { - if cqlString == nil { - return q.GetIllTransactionsByRequesterSymbol(ctx, db, arg) - } - noBaseArgs := 3 // we have three base arguments: requester_symbol, limit and offset - res, err := handleIllTransactionsQuery(*cqlString, noBaseArgs) - if err != nil { - return nil, err - } - whereClause := "" - if res.GetWhereClause() != "" { - whereClause = "AND (" + res.GetWhereClause() + ") " - } - orgSql := getIllTransactionsByRequesterSymbol - pos := strings.Index(orgSql, "ORDER BY") - if pos == -1 { - return nil, fmt.Errorf("CQL query must contain an ORDER BY clause") - } - sql := orgSql[:pos] + whereClause + orgSql[pos:] - sqlArguments := make([]interface{}, 0, noBaseArgs+len(res.GetQueryArguments())) - sqlArguments = append(sqlArguments, arg.RequesterSymbol, arg.Limit, arg.Offset) - sqlArguments = append(sqlArguments, res.GetQueryArguments()...) - - rows, err := db.Query(ctx, sql, sqlArguments...) - if err != nil { - return nil, err - } - defer rows.Close() - var items []GetIllTransactionsByRequesterSymbolRow - for rows.Next() { - var i GetIllTransactionsByRequesterSymbolRow - if err := rows.Scan( - &i.IllTransaction.ID, - &i.IllTransaction.Timestamp, - &i.IllTransaction.RequesterSymbol, - &i.IllTransaction.RequesterID, - &i.IllTransaction.LastRequesterAction, - &i.IllTransaction.PrevRequesterAction, - &i.IllTransaction.SupplierSymbol, - &i.IllTransaction.RequesterRequestID, - &i.IllTransaction.PrevRequesterRequestID, - &i.IllTransaction.SupplierRequestID, - &i.IllTransaction.LastSupplierStatus, - &i.IllTransaction.PrevSupplierStatus, - &i.IllTransaction.IllTransactionData, - &i.FullCount, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - func (q *Queries) ListPeersCql(ctx context.Context, db DBTX, arg ListPeersParams, cqlString *string) ([]ListPeersRow, error) { if cqlString == nil { return q.ListPeers(ctx, db, arg) diff --git a/broker/ill_db/illrepo.go b/broker/ill_db/illrepo.go index 797f7bc3..5f1e9b67 100644 --- a/broker/ill_db/illrepo.go +++ b/broker/ill_db/illrepo.go @@ -22,8 +22,7 @@ type IllRepo interface { GetIllTransactionByRequesterRequestIdForUpdate(ctx common.ExtendedContext, requesterRequestID pgtype.Text) (IllTransaction, error) GetIllTransactionById(ctx common.ExtendedContext, id string) (IllTransaction, error) GetIllTransactionByIdForUpdate(ctx common.ExtendedContext, id string) (IllTransaction, error) - ListIllTransactions(ctx common.ExtendedContext, params ListIllTransactionsParams, cql *string) ([]IllTransaction, int64, error) - GetIllTransactionsByRequesterSymbol(ctx common.ExtendedContext, params GetIllTransactionsByRequesterSymbolParams, cql *string) ([]IllTransaction, int64, error) + ListIllTransactions(ctx common.ExtendedContext, params ListIllTransactionsParams, cql *string, symbols []string) ([]IllTransaction, int64, error) DeleteIllTransaction(ctx common.ExtendedContext, id string) error SavePeer(ctx common.ExtendedContext, params SavePeerParams) (Peer, error) GetPeerById(ctx common.ExtendedContext, id string) (Peer, error) @@ -101,8 +100,8 @@ func (r *PgIllRepo) GetIllTransactionByIdForUpdate(ctx common.ExtendedContext, i return row.IllTransaction, err } -func (r *PgIllRepo) ListIllTransactions(ctx common.ExtendedContext, params ListIllTransactionsParams, cql *string) ([]IllTransaction, int64, error) { - rows, err := r.queries.ListIllTransactionsCql(ctx, r.GetConnOrTx(), params, cql) +func (r *PgIllRepo) ListIllTransactions(ctx common.ExtendedContext, params ListIllTransactionsParams, cql *string, symbols []string) ([]IllTransaction, int64, error) { + rows, err := r.queries.ListIllTransactionsCql(ctx, r.GetConnOrTx(), params, cql, symbols) var transactions []IllTransaction var fullCount int64 if err == nil { @@ -115,7 +114,7 @@ func (r *PgIllRepo) ListIllTransactions(ctx common.ExtendedContext, params ListI } else { params.Limit = 1 params.Offset = 0 - rows, err = r.queries.ListIllTransactionsCql(ctx, r.GetConnOrTx(), params, cql) + rows, err = r.queries.ListIllTransactionsCql(ctx, r.GetConnOrTx(), params, cql, symbols) if err == nil && len(rows) > 0 { fullCount = rows[0].FullCount } @@ -124,19 +123,6 @@ func (r *PgIllRepo) ListIllTransactions(ctx common.ExtendedContext, params ListI return transactions, fullCount, err } -func (r *PgIllRepo) GetIllTransactionsByRequesterSymbol(ctx common.ExtendedContext, params GetIllTransactionsByRequesterSymbolParams, cql *string) ([]IllTransaction, int64, error) { - rows, err := r.queries.GetIllTransactionsByRequesterSymbolCql(ctx, r.GetConnOrTx(), params, cql) - var transactions []IllTransaction - var fullCount int64 - if err == nil { - for _, r := range rows { - fullCount = r.FullCount - transactions = append(transactions, r.IllTransaction) - } - } - return transactions, fullCount, err -} - func (r *PgIllRepo) DeleteIllTransaction(ctx common.ExtendedContext, id string) error { return r.queries.DeleteIllTransaction(ctx, r.GetConnOrTx(), id) } diff --git a/broker/patron_request/api/api-handler.go b/broker/patron_request/api/api-handler.go index 13668e04..4b538b47 100644 --- a/broker/patron_request/api/api-handler.go +++ b/broker/patron_request/api/api-handler.go @@ -1,7 +1,6 @@ package prapi import ( - "context" "encoding/json" "errors" "fmt" @@ -21,6 +20,7 @@ import ( pr_db "github.com/indexdata/crosslink/broker/patron_request/db" "github.com/indexdata/crosslink/broker/patron_request/proapi" prservice "github.com/indexdata/crosslink/broker/patron_request/service" + "github.com/indexdata/crosslink/broker/tenant" "github.com/indexdata/crosslink/iso18626" "github.com/indexdata/go-utils/utils" "github.com/jackc/pgerrcode" @@ -45,19 +45,19 @@ type PatronRequestApiHandler struct { actionMappingService prservice.ActionMappingService autoActionRunner prservice.AutoActionRunner actionTaskProcessor ActionTaskProcessor - tenant common.Tenant + tenantContext tenant.TenantContext notificationSender prservice.PatronRequestNotificationService } func NewPrApiHandler(prRepo pr_db.PrRepo, eventBus events.EventBus, - eventRepo events.EventRepo, tenant common.Tenant, iso18626Handler handler.Iso18626HandlerInterface, limitDefault int32) PatronRequestApiHandler { + eventRepo events.EventRepo, tenantContext tenant.TenantContext, iso18626Handler handler.Iso18626HandlerInterface, limitDefault int32) PatronRequestApiHandler { return PatronRequestApiHandler{ limitDefault: limitDefault, prRepo: prRepo, eventBus: eventBus, eventRepo: eventRepo, actionMappingService: prservice.ActionMappingService{SMService: &prservice.StateModelService{}}, - tenant: tenant, + tenantContext: tenantContext, notificationSender: *prservice.CreatePatronRequestNotificationService(prRepo, eventBus, iso18626Handler), } } @@ -73,7 +73,7 @@ func (a *PatronRequestApiHandler) SetActionTaskProcessor(actionTaskProcessor Act func (a *PatronRequestApiHandler) GetStateModelModelsModel(w http.ResponseWriter, r *http.Request, model string, params proapi.GetStateModelModelsModelParams) { stateModel, err := a.actionMappingService.GetStateModel(model) if err != nil { - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{ + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{ Other: map[string]string{"method": "GetStateModelModelsModel", "model": model}, }) addInternalError(ctx, w, err) @@ -91,18 +91,19 @@ func (a *PatronRequestApiHandler) GetStateModelCapabilities(w http.ResponseWrite } func (a *PatronRequestApiHandler) GetPatronRequests(w http.ResponseWriter, r *http.Request, params proapi.GetPatronRequestsParams) { - symbol, err := api.GetSymbolForRequest(r, a.tenant, params.XOkapiTenant, params.Symbol) - logParams := map[string]string{"method": "GetPatronRequests", "symbol": symbol} + logParams := map[string]string{"method": "GetPatronRequests"} if params.Side != nil { logParams["side"] = *params.Side } - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{ - Other: logParams, - }) + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) + symbol, err := a.tenantContext.WithRequest(ctx, r, params.Symbol).GetSymbol() if err != nil { addBadRequestError(ctx, w, err) return } + logParams["symbol"] = symbol + ctx = common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) + limit := a.limitDefault if params.Limit != nil { limit = *params.Limit @@ -193,20 +194,21 @@ func addOwnerRestriction(queryBuilder *cqlbuilder.QueryBuilder, symbol string, s } func (a *PatronRequestApiHandler) PostPatronRequests(w http.ResponseWriter, r *http.Request, params proapi.PostPatronRequestsParams) { - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{ - Other: map[string]string{"method": "PostPatronRequests"}, - }) + logParams := map[string]string{"method": "PostPatronRequests"} + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) var newPr proapi.CreatePatronRequest err := json.NewDecoder(r.Body).Decode(&newPr) if err != nil { addBadRequestError(ctx, w, err) return } - symbol, err := api.GetSymbolForRequest(r, a.tenant, params.XOkapiTenant, newPr.RequesterSymbol) + symbol, err := a.tenantContext.WithRequest(ctx, r, newPr.RequesterSymbol).GetSymbol() if err != nil { addBadRequestError(ctx, w, err) return } + logParams["symbol"] = symbol + ctx = common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) newPr.RequesterSymbol = &symbol creationTime := pgtype.Timestamp{Valid: true, Time: time.Now()} illRequest, requesterReqId, err := a.parseAndValidateIllRequest(ctx, &newPr, creationTime.Time) @@ -248,17 +250,18 @@ func (a *PatronRequestApiHandler) PostPatronRequests(w http.ResponseWriter, r *h } func (a *PatronRequestApiHandler) DeletePatronRequestsId(w http.ResponseWriter, r *http.Request, id string, params proapi.DeletePatronRequestsIdParams) { - symbol, err := api.GetSymbolForRequest(r, a.tenant, params.XOkapiTenant, params.Symbol) - logParams := map[string]string{"method": "DeletePatronRequestsId", "id": id, "symbol": symbol} + logParams := map[string]string{"method": "DeletePatronRequestsId", "id": id} if params.Side != nil { logParams["side"] = *params.Side } - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{Other: logParams}) - + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) + symbol, err := a.tenantContext.WithRequest(ctx, r, params.Symbol).GetSymbol() if err != nil { addBadRequestError(ctx, w, err) return } + logParams["symbol"] = symbol + ctx = common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) pr := a.getPatronRequestById(w, ctx, id, params.Side, symbol) if pr == nil { return @@ -300,17 +303,18 @@ func isOwner(pr pr_db.PatronRequest, symbol string) bool { } func (a *PatronRequestApiHandler) GetPatronRequestsId(w http.ResponseWriter, r *http.Request, id string, params proapi.GetPatronRequestsIdParams) { - symbol, err := api.GetSymbolForRequest(r, a.tenant, params.XOkapiTenant, params.Symbol) - logParams := map[string]string{"method": "GetPatronRequestsId", "id": id, "symbol": symbol} + logParams := map[string]string{"method": "GetPatronRequestsId", "id": id} if params.Side != nil { logParams["side"] = *params.Side } - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{Other: logParams}) - + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) + symbol, err := a.tenantContext.WithRequest(ctx, r, params.Symbol).GetSymbol() if err != nil { addBadRequestError(ctx, w, err) return } + logParams["symbol"] = symbol + ctx = common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) pr := a.getPatronRequestById(w, ctx, id, params.Side, symbol) if pr == nil { return @@ -319,17 +323,19 @@ func (a *PatronRequestApiHandler) GetPatronRequestsId(w http.ResponseWriter, r * } func (a *PatronRequestApiHandler) GetPatronRequestsIdActions(w http.ResponseWriter, r *http.Request, id string, params proapi.GetPatronRequestsIdActionsParams) { - symbol, err := api.GetSymbolForRequest(r, a.tenant, params.XOkapiTenant, params.Symbol) - logParams := map[string]string{"method": "GetPatronRequestsIdActions", "id": id, "symbol": symbol} + logParams := map[string]string{"method": "GetPatronRequestsIdActions", "id": id} if params.Side != nil { logParams["side"] = *params.Side } - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{Other: logParams}) + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) + symbol, err := a.tenantContext.WithRequest(ctx, r, params.Symbol).GetSymbol() if err != nil { addBadRequestError(ctx, w, err) return } + logParams["symbol"] = symbol + ctx = common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) pr := a.getPatronRequestById(w, ctx, id, params.Side, symbol) if pr == nil { return @@ -344,17 +350,19 @@ func (a *PatronRequestApiHandler) GetPatronRequestsIdActions(w http.ResponseWrit } func (a *PatronRequestApiHandler) PostPatronRequestsIdAction(w http.ResponseWriter, r *http.Request, id string, params proapi.PostPatronRequestsIdActionParams) { - symbol, err := api.GetSymbolForRequest(r, a.tenant, params.XOkapiTenant, params.Symbol) - logParams := map[string]string{"method": "PostPatronRequestsIdAction", "id": id, "symbol": symbol} + logParams := map[string]string{"method": "PostPatronRequestsIdAction", "id": id} if params.Side != nil { logParams["side"] = *params.Side } - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{Other: logParams}) + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) + symbol, err := a.tenantContext.WithRequest(ctx, r, params.Symbol).GetSymbol() if err != nil { addBadRequestError(ctx, w, err) return } + logParams["symbol"] = symbol + ctx = common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) pr := a.getPatronRequestById(w, ctx, id, params.Side, symbol) if pr == nil { return @@ -419,18 +427,19 @@ func (a *PatronRequestApiHandler) PostPatronRequestsIdAction(w http.ResponseWrit } func (a *PatronRequestApiHandler) GetPatronRequestsIdEvents(w http.ResponseWriter, r *http.Request, id string, params proapi.GetPatronRequestsIdEventsParams) { - symbol, err := api.GetSymbolForRequest(r, a.tenant, params.XOkapiTenant, params.Symbol) - logParams := map[string]string{"method": "GetPatronRequestsIdEvents", "id": id, "symbol": symbol} - + logParams := map[string]string{"method": "GetPatronRequestsIdEvents", "id": id} if params.Side != nil { logParams["side"] = *params.Side } - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{Other: logParams}) + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) + symbol, err := a.tenantContext.WithRequest(ctx, r, params.Symbol).GetSymbol() if err != nil { addBadRequestError(ctx, w, err) return } + logParams["symbol"] = symbol + ctx = common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) pr := a.getPatronRequestById(w, ctx, id, params.Side, symbol) if pr == nil { return @@ -451,18 +460,19 @@ func (a *PatronRequestApiHandler) GetPatronRequestsIdEvents(w http.ResponseWrite } func (a *PatronRequestApiHandler) GetPatronRequestsIdItems(w http.ResponseWriter, r *http.Request, id string, params proapi.GetPatronRequestsIdItemsParams) { - symbol, err := api.GetSymbolForRequest(r, a.tenant, params.XOkapiTenant, params.Symbol) - logParams := map[string]string{"method": "GetPatronRequestsIdItems", "id": id, "symbol": symbol} - + logParams := map[string]string{"method": "GetPatronRequestsIdItems", "id": id} if params.Side != nil { logParams["side"] = *params.Side } - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{Other: logParams}) + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) + symbol, err := a.tenantContext.WithRequest(ctx, r, params.Symbol).GetSymbol() if err != nil { addBadRequestError(ctx, w, err) return } + logParams["symbol"] = symbol + ctx = common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) pr := a.getPatronRequestById(w, ctx, id, params.Side, symbol) if pr == nil { return @@ -483,18 +493,19 @@ func (a *PatronRequestApiHandler) GetPatronRequestsIdItems(w http.ResponseWriter } func (a *PatronRequestApiHandler) GetPatronRequestsIdNotifications(w http.ResponseWriter, r *http.Request, id string, params proapi.GetPatronRequestsIdNotificationsParams) { - symbol, err := api.GetSymbolForRequest(r, a.tenant, params.XOkapiTenant, params.Symbol) - logParams := map[string]string{"method": "GetPatronRequestsIdNotifications", "id": id, "symbol": symbol} - + logParams := map[string]string{"method": "GetPatronRequestsIdNotifications", "id": id} if params.Side != nil { logParams["side"] = *params.Side } - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{Other: logParams}) + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) + symbol, err := a.tenantContext.WithRequest(ctx, r, params.Symbol).GetSymbol() if err != nil { addBadRequestError(ctx, w, err) return } + logParams["symbol"] = symbol + ctx = common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) limit := a.limitDefault if params.Limit != nil { limit = *params.Limit @@ -533,19 +544,18 @@ func (a *PatronRequestApiHandler) GetPatronRequestsIdNotifications(w http.Respon } func (a *PatronRequestApiHandler) PostPatronRequestsIdNotifications(w http.ResponseWriter, r *http.Request, id string, params proapi.PostPatronRequestsIdNotificationsParams) { - symbol, err := api.GetSymbolForRequest(r, a.tenant, params.XOkapiTenant, params.Symbol) - logParams := map[string]string{"method": "PostPatronRequestsIdNotifications", "id": id, "symbol": symbol} - + logParams := map[string]string{"method": "PostPatronRequestsIdNotifications", "id": id} if params.Side != nil { logParams["side"] = *params.Side } - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{Other: logParams}) - + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) + symbol, err := a.tenantContext.WithRequest(ctx, r, params.Symbol).GetSymbol() if err != nil { addBadRequestError(ctx, w, err) return } - + logParams["symbol"] = symbol + ctx = common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) if r.Body == nil { addBadRequestError(ctx, w, errors.New("body is required")) return @@ -591,19 +601,18 @@ func (a *PatronRequestApiHandler) PostPatronRequestsIdNotifications(w http.Respo } func (a *PatronRequestApiHandler) PutPatronRequestsIdNotificationsNotificationIdReceipt(w http.ResponseWriter, r *http.Request, id string, notificationId string, params proapi.PutPatronRequestsIdNotificationsNotificationIdReceiptParams) { - symbol, err := api.GetSymbolForRequest(r, a.tenant, params.XOkapiTenant, params.Symbol) - logParams := map[string]string{"method": "PutPatronRequestsIdNotificationsNotificationIdReceipt", "id": id, "symbol": symbol} - + logParams := map[string]string{"method": "PutPatronRequestsIdNotificationsNotificationIdReceipt", "id": id} if params.Side != nil { logParams["side"] = *params.Side } - ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{Other: logParams}) - + ctx := common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) + symbol, err := a.tenantContext.WithRequest(ctx, r, params.Symbol).GetSymbol() if err != nil { addBadRequestError(ctx, w, err) return } - + logParams["symbol"] = symbol + ctx = common.CreateExtCtxWithArgs(r.Context(), &common.LoggerArgs{Other: logParams}) if r.Body == nil { addBadRequestError(ctx, w, errors.New("body is required")) return diff --git a/broker/patron_request/api/api-handler_test.go b/broker/patron_request/api/api-handler_test.go index bc11ada2..3ce0184b 100644 --- a/broker/patron_request/api/api-handler_test.go +++ b/broker/patron_request/api/api-handler_test.go @@ -19,6 +19,7 @@ import ( pr_db "github.com/indexdata/crosslink/broker/patron_request/db" "github.com/indexdata/crosslink/broker/patron_request/proapi" prservice "github.com/indexdata/crosslink/broker/patron_request/service" + "github.com/indexdata/crosslink/broker/tenant" "github.com/indexdata/crosslink/broker/test/mocks" "github.com/indexdata/crosslink/iso18626" "github.com/jackc/pgx/v5" @@ -113,7 +114,7 @@ func TestToApiPatronRequestOmitsIllTransactionLinkWithoutRequesterReqID(t *testi } func TestGetPatronRequests(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() params := proapi.GetPatronRequestsParams{ @@ -126,7 +127,7 @@ func TestGetPatronRequests(t *testing.T) { } func TestGetPatronRequestsNoSymbol(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() params := proapi.GetPatronRequestsParams{ @@ -138,7 +139,7 @@ func TestGetPatronRequestsNoSymbol(t *testing.T) { } func TestGetPatronRequestsWithLimits(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() offset := proapi.Offset(10) @@ -158,7 +159,7 @@ func TestGetPatronRequestsWithLimits(t *testing.T) { func TestGetPatronRequestsWithRequesterReqId(t *testing.T) { repo := new(PrRepoCapture) - handler := NewPrApiHandler(repo, mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(repo, mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() requesterReqID := "req-123" @@ -177,7 +178,7 @@ func TestGetPatronRequestsWithRequesterReqId(t *testing.T) { } func TestPostPatronRequests(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) id := "1" toCreate := proapi.CreatePatronRequest{ Id: &id, @@ -196,7 +197,7 @@ func TestPostPatronRequests(t *testing.T) { } func TestPostPatronRequestsMissingSymbol(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) toCreate := proapi.PatronRequest{Id: "1"} jsonBytes, err := json.Marshal(toCreate) assert.NoError(t, err, "failed to marshal patron request") @@ -210,7 +211,7 @@ func TestPostPatronRequestsMissingSymbol(t *testing.T) { } func TestPostPatronRequestsInvalidJson(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("POST", "/", bytes.NewBuffer([]byte("a\": v\""))) rr := httptest.NewRecorder() tenant := proapi.Tenant("test-lib") @@ -220,7 +221,7 @@ func TestPostPatronRequestsInvalidJson(t *testing.T) { } func TestPostPatronRequestsInvalidIllRequestShape(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) jsonBytes := []byte(`{ "id":"1", "requesterSymbol":"` + symbol + `", @@ -236,7 +237,7 @@ func TestPostPatronRequestsInvalidIllRequestShape(t *testing.T) { } func TestDeletePatronRequestsIdNotFound(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("POST", "/", nil) rr := httptest.NewRecorder() handler.DeletePatronRequestsId(rr, req, "2", proapi.DeletePatronRequestsIdParams{Symbol: &symbol}) @@ -244,7 +245,7 @@ func TestDeletePatronRequestsIdNotFound(t *testing.T) { } func TestDeletePatronRequestsIdMissingSymbol(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("POST", "/", nil) rr := httptest.NewRecorder() handler.DeletePatronRequestsId(rr, req, "2", proapi.DeletePatronRequestsIdParams{}) @@ -253,7 +254,7 @@ func TestDeletePatronRequestsIdMissingSymbol(t *testing.T) { } func TestDeletePatronRequestsIdError(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("POST", "/", nil) rr := httptest.NewRecorder() handler.DeletePatronRequestsId(rr, req, "1", proapi.DeletePatronRequestsIdParams{Symbol: &symbol}) @@ -262,7 +263,7 @@ func TestDeletePatronRequestsIdError(t *testing.T) { } func TestDeletePatronRequestsId(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("POST", "/", nil) rr := httptest.NewRecorder() handler.DeletePatronRequestsId(rr, req, "3", proapi.DeletePatronRequestsIdParams{Symbol: &symbol, Side: &proapiBorrowingSide}) @@ -271,7 +272,7 @@ func TestDeletePatronRequestsId(t *testing.T) { } func TestDeletePatronRequestsIdDeleted(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("POST", "/", nil) rr := httptest.NewRecorder() handler.DeletePatronRequestsId(rr, req, "4", proapi.DeletePatronRequestsIdParams{Symbol: &symbol, Side: &proapiBorrowingSide}) @@ -279,7 +280,7 @@ func TestDeletePatronRequestsIdDeleted(t *testing.T) { } func TestGetPatronRequestsIdMissingSymbol(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("POST", "/", nil) rr := httptest.NewRecorder() handler.GetPatronRequestsId(rr, req, "2", proapi.GetPatronRequestsIdParams{}) @@ -288,7 +289,7 @@ func TestGetPatronRequestsIdMissingSymbol(t *testing.T) { } func TestGetPatronRequestsIdNotFound(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("POST", "/", nil) rr := httptest.NewRecorder() handler.GetPatronRequestsId(rr, req, "2", proapi.GetPatronRequestsIdParams{Symbol: &symbol}) @@ -296,7 +297,7 @@ func TestGetPatronRequestsIdNotFound(t *testing.T) { } func TestGetPatronRequestsId(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("POST", "/", nil) rr := httptest.NewRecorder() handler.GetPatronRequestsId(rr, req, "1", proapi.GetPatronRequestsIdParams{Symbol: &symbol}) @@ -305,7 +306,7 @@ func TestGetPatronRequestsId(t *testing.T) { } func TestGetPatronRequestsIdActions(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.GetPatronRequestsIdActions(rr, req, "3", proapi.GetPatronRequestsIdActionsParams{Symbol: &symbol, Side: &proapiBorrowingSide}) @@ -314,7 +315,7 @@ func TestGetPatronRequestsIdActions(t *testing.T) { } func TestGetPatronRequestsIdActionsNoSymbol(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.GetPatronRequestsIdActions(rr, req, "3", proapi.GetPatronRequestsIdActionsParams{Side: &proapiBorrowingSide}) @@ -323,7 +324,7 @@ func TestGetPatronRequestsIdActionsNoSymbol(t *testing.T) { } func TestGetPatronRequestsIdActionsDbError(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.GetPatronRequestsIdActions(rr, req, "1", proapi.GetPatronRequestsIdActionsParams{Symbol: &symbol, Side: &proapiBorrowingSide}) @@ -332,7 +333,7 @@ func TestGetPatronRequestsIdActionsDbError(t *testing.T) { } func TestGetPatronRequestsIdActionsNotFoundBecauseOfSide(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.GetPatronRequestsIdActions(rr, req, "3", proapi.GetPatronRequestsIdActionsParams{Symbol: &symbol, Side: &proapiLendingSide}) @@ -341,7 +342,7 @@ func TestGetPatronRequestsIdActionsNotFoundBecauseOfSide(t *testing.T) { } func TestPostPatronRequestsIdActionNoSymbol(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.PostPatronRequestsIdAction(rr, req, "3", proapi.PostPatronRequestsIdActionParams{Side: &proapiBorrowingSide}) @@ -350,7 +351,7 @@ func TestPostPatronRequestsIdActionNoSymbol(t *testing.T) { } func TestPostPatronRequestsIdActionDbError(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.PostPatronRequestsIdAction(rr, req, "1", proapi.PostPatronRequestsIdActionParams{Symbol: &symbol, Side: &proapiBorrowingSide}) @@ -359,7 +360,7 @@ func TestPostPatronRequestsIdActionDbError(t *testing.T) { } func TestPostPatronRequestsIdActionNotFoundBecauseOfSide(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.PostPatronRequestsIdAction(rr, req, "3", proapi.PostPatronRequestsIdActionParams{Symbol: &symbol, Side: &proapiLendingSide}) @@ -368,7 +369,7 @@ func TestPostPatronRequestsIdActionNotFoundBecauseOfSide(t *testing.T) { } func TestPostPatronRequestsIdActionErrorParsing(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", strings.NewReader("{")) rr := httptest.NewRecorder() handler.PostPatronRequestsIdAction(rr, req, "3", proapi.PostPatronRequestsIdActionParams{Symbol: &symbol, Side: &proapiBorrowingSide}) @@ -377,7 +378,7 @@ func TestPostPatronRequestsIdActionErrorParsing(t *testing.T) { } func TestGetPatronRequestsIdEventsNoSymbol(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.GetPatronRequestsIdEvents(rr, req, "3", proapi.GetPatronRequestsIdEventsParams{Side: &proapiBorrowingSide}) @@ -386,7 +387,7 @@ func TestGetPatronRequestsIdEventsNoSymbol(t *testing.T) { } func TestGetPatronRequestsIdEventsDbError(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.GetPatronRequestsIdEvents(rr, req, "1", proapi.GetPatronRequestsIdEventsParams{Symbol: &symbol, Side: &proapiBorrowingSide}) @@ -395,7 +396,7 @@ func TestGetPatronRequestsIdEventsDbError(t *testing.T) { } func TestGetPatronRequestsIdEventsNotFoundBecauseOfSide(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.GetPatronRequestsIdEvents(rr, req, "3", proapi.GetPatronRequestsIdEventsParams{Symbol: &symbol, Side: &proapiLendingSide}) @@ -404,7 +405,7 @@ func TestGetPatronRequestsIdEventsNotFoundBecauseOfSide(t *testing.T) { } func TestGetPatronRequestsIdEventsErrorGettingEvents(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.GetPatronRequestsIdEvents(rr, req, "3", proapi.GetPatronRequestsIdEventsParams{Symbol: &symbol, Side: &proapiBorrowingSide}) @@ -413,7 +414,7 @@ func TestGetPatronRequestsIdEventsErrorGettingEvents(t *testing.T) { } func TestGetPatronRequestsIdNotificationsNoSymbol(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.GetPatronRequestsIdNotifications(rr, req, "3", proapi.GetPatronRequestsIdNotificationsParams{Side: &proapiBorrowingSide}) @@ -422,7 +423,7 @@ func TestGetPatronRequestsIdNotificationsNoSymbol(t *testing.T) { } func TestGetPatronRequestsIdNotificationsDbError(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.GetPatronRequestsIdNotifications(rr, req, "1", proapi.GetPatronRequestsIdNotificationsParams{Symbol: &symbol, Side: &proapiBorrowingSide}) @@ -431,7 +432,7 @@ func TestGetPatronRequestsIdNotificationsDbError(t *testing.T) { } func TestGetPatronRequestsIdNotificationsNotFoundBecauseOfSide(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.GetPatronRequestsIdNotifications(rr, req, "3", proapi.GetPatronRequestsIdNotificationsParams{Symbol: &symbol, Side: &proapiLendingSide}) @@ -440,7 +441,7 @@ func TestGetPatronRequestsIdNotificationsNotFoundBecauseOfSide(t *testing.T) { } func TestGetPatronRequestsIdNotificationsErrorGettingEvents(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() handler.GetPatronRequestsIdNotifications(rr, req, "3", proapi.GetPatronRequestsIdNotificationsParams{Symbol: &symbol, Side: &proapiBorrowingSide}) @@ -474,7 +475,7 @@ func TestGetPatronRequestsIdNotificationsWithKindFilter(t *testing.T) { }, fullCount: 1, } - handler := NewPrApiHandler(repo, mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(repo, mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("GET", "/", nil) rr := httptest.NewRecorder() @@ -509,7 +510,7 @@ func TestGetPatronRequestsIdNotificationsWithKindFilter(t *testing.T) { } func TestParseAndValidateIllRequestAndBuildDbPatronRequest(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{}) creationTime := time.Now() id := uuid.NewString() @@ -536,7 +537,7 @@ func TestParseAndValidateIllRequestAndBuildDbPatronRequest(t *testing.T) { } func TestParseAndValidateIllRequestInvalidRequesterSymbol(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{}) invalidSymbol := "REQ" @@ -546,7 +547,7 @@ func TestParseAndValidateIllRequestInvalidRequesterSymbol(t *testing.T) { } func TestParseAndValidateIllRequestInvalidBrokerSymbol(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{}) previousBrokerSymbol := brokerSymbol brokerSymbol = "BROKER" @@ -563,7 +564,7 @@ func TestParseAndValidateIllRequestInvalidBrokerSymbol(t *testing.T) { } func TestPostPatronRequestsIdNotificationsNoSymbol(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("POST", "/", nil) rr := httptest.NewRecorder() handler.PostPatronRequestsIdNotifications(rr, req, "3", proapi.PostPatronRequestsIdNotificationsParams{Side: &proapiBorrowingSide}) @@ -572,7 +573,7 @@ func TestPostPatronRequestsIdNotificationsNoSymbol(t *testing.T) { } func TestPostPatronRequestsIdNotificationsDbError(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) body := "{\"note\": \"Say hello\"}" req, _ := http.NewRequest("POST", "/", bytes.NewBufferString(body)) rr := httptest.NewRecorder() @@ -582,7 +583,7 @@ func TestPostPatronRequestsIdNotificationsDbError(t *testing.T) { } func TestPostPatronRequestsIdNotificationsNotFoundBecauseOfSide(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) body := "{\"note\": \"Say hello\"}" req, _ := http.NewRequest("POST", "/", bytes.NewBufferString(body)) rr := httptest.NewRecorder() @@ -592,7 +593,7 @@ func TestPostPatronRequestsIdNotificationsNotFoundBecauseOfSide(t *testing.T) { } func TestPostPatronRequestsIdNotificationsErrorSavingNotification(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), *tenant.NewContext(), nil, 10) body := "{\"note\": \"Say hello\"}" req, _ := http.NewRequest("POST", "/", bytes.NewBufferString(body)) rr := httptest.NewRecorder() @@ -602,7 +603,7 @@ func TestPostPatronRequestsIdNotificationsErrorSavingNotification(t *testing.T) } func TestPostPatronRequestsIdNotificationsErrorBecauseOfBodyMissing(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("POST", "/", nil) rr := httptest.NewRecorder() handler.PostPatronRequestsIdNotifications(rr, req, "3", proapi.PostPatronRequestsIdNotificationsParams{Symbol: &symbol, Side: &proapiBorrowingSide}) @@ -611,7 +612,7 @@ func TestPostPatronRequestsIdNotificationsErrorBecauseOfBodyMissing(t *testing.T } func TestPostPatronRequestsIdNotificationsErrorBecauseOfBody(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), *tenant.NewContext(), nil, 10) body := "{\"note" req, _ := http.NewRequest("POST", "/", bytes.NewBufferString(body)) rr := httptest.NewRecorder() @@ -621,7 +622,7 @@ func TestPostPatronRequestsIdNotificationsErrorBecauseOfBody(t *testing.T) { } func TestPostPatronRequestsIdNotificationsErrorBecauseOfMissingNote(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), *tenant.NewContext(), nil, 10) body := "{}" req, _ := http.NewRequest("POST", "/", bytes.NewBufferString(body)) rr := httptest.NewRecorder() @@ -631,7 +632,7 @@ func TestPostPatronRequestsIdNotificationsErrorBecauseOfMissingNote(t *testing.T } func TestPostPatronRequestsIdNotificationsErrorFailedSendOnlyLogged(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositorySuccess), common.NewTenant(""), new(MockIso18626Handler), 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositorySuccess), *tenant.NewContext(), new(MockIso18626Handler), 10) body := "{\"note\": \"Say hello\"}" req, _ := http.NewRequest("POST", "/", bytes.NewBufferString(body)) rr := httptest.NewRecorder() @@ -641,7 +642,7 @@ func TestPostPatronRequestsIdNotificationsErrorFailedSendOnlyLogged(t *testing.T } func TestPutPatronRequestsIdNotificationsNotificationIdReceiptNoSymbol(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("PUT", "/", nil) rr := httptest.NewRecorder() handler.PutPatronRequestsIdNotificationsNotificationIdReceipt(rr, req, "3", "n1", proapi.PutPatronRequestsIdNotificationsNotificationIdReceiptParams{Side: &proapiBorrowingSide}) @@ -650,7 +651,7 @@ func TestPutPatronRequestsIdNotificationsNotificationIdReceiptNoSymbol(t *testin } func TestPutPatronRequestsIdNotificationsNotificationIdReceiptDbError(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) body := "{\"receipt\": \"SEEN\"}" req, _ := http.NewRequest("PUT", "/", bytes.NewBufferString(body)) rr := httptest.NewRecorder() @@ -660,7 +661,7 @@ func TestPutPatronRequestsIdNotificationsNotificationIdReceiptDbError(t *testing } func TestPutPatronRequestsIdNotificationsNotificationIdReceiptNotFoundBecauseOfSide(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, mockEventRepo, *tenant.NewContext(), nil, 10) body := "{\"receipt\": \"SEEN\"}" req, _ := http.NewRequest("PUT", "/", bytes.NewBufferString(body)) rr := httptest.NewRecorder() @@ -670,7 +671,7 @@ func TestPutPatronRequestsIdNotificationsNotificationIdReceiptNotFoundBecauseOfS } func TestPutPatronRequestsIdNotificationsNotificationIdReceiptErrorReadingNotification(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), *tenant.NewContext(), nil, 10) body := "{\"receipt\": \"SEEN\"}" req, _ := http.NewRequest("PUT", "/", bytes.NewBufferString(body)) rr := httptest.NewRecorder() @@ -680,7 +681,7 @@ func TestPutPatronRequestsIdNotificationsNotificationIdReceiptErrorReadingNotifi } func TestPutPatronRequestsIdNotificationsNotificationIdReceiptNotFound(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), *tenant.NewContext(), nil, 10) body := "{\"receipt\": \"SEEN\"}" req, _ := http.NewRequest("PUT", "/", bytes.NewBufferString(body)) rr := httptest.NewRecorder() @@ -690,7 +691,7 @@ func TestPutPatronRequestsIdNotificationsNotificationIdReceiptNotFound(t *testin } func TestPutPatronRequestsIdNotificationsNotificationIdReceiptErrorBecauseOfBodyMissing(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), *tenant.NewContext(), nil, 10) req, _ := http.NewRequest("PUT", "/", nil) rr := httptest.NewRecorder() handler.PutPatronRequestsIdNotificationsNotificationIdReceipt(rr, req, "3", "n1", proapi.PutPatronRequestsIdNotificationsNotificationIdReceiptParams{Symbol: &symbol, Side: &proapiBorrowingSide}) @@ -699,7 +700,7 @@ func TestPutPatronRequestsIdNotificationsNotificationIdReceiptErrorBecauseOfBody } func TestPutPatronRequestsIdNotificationsNotificationIdReceiptErrorBecauseOfBody(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), *tenant.NewContext(), nil, 10) body := "{\"receipt" req, _ := http.NewRequest("PUT", "/", bytes.NewBufferString(body)) rr := httptest.NewRecorder() @@ -709,7 +710,7 @@ func TestPutPatronRequestsIdNotificationsNotificationIdReceiptErrorBecauseOfBody } func TestPutPatronRequestsIdNotificationsNotificationIdReceiptPrDoesNotOwn(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), *tenant.NewContext(), nil, 10) body := "{\"receipt\": \"SEEN\"}" req, _ := http.NewRequest("PUT", "/", bytes.NewBufferString(body)) rr := httptest.NewRecorder() @@ -719,7 +720,7 @@ func TestPutPatronRequestsIdNotificationsNotificationIdReceiptPrDoesNotOwn(t *te } func TestPutPatronRequestsIdNotificationsNotificationIdReceiptFailedToSave(t *testing.T) { - handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), common.NewTenant(""), nil, 10) + handler := NewPrApiHandler(new(PrRepoError), mockEventBus, new(mocks.MockEventRepositoryError), *tenant.NewContext(), nil, 10) body := "{\"receipt\": \"SEEN\"}" req, _ := http.NewRequest("PUT", "/", bytes.NewBufferString(body)) rr := httptest.NewRecorder() diff --git a/broker/sqlc/ill_query.sql b/broker/sqlc/ill_query.sql index d8755c39..7bb7d2ce 100644 --- a/broker/sqlc/ill_query.sql +++ b/broker/sqlc/ill_query.sql @@ -100,13 +100,6 @@ FROM ill_transaction ORDER BY timestamp LIMIT $1 OFFSET $2; --- name: GetIllTransactionsByRequesterSymbol :many -SELECT sqlc.embed(ill_transaction), COUNT(*) OVER () as full_count -FROM ill_transaction -WHERE requester_symbol = $1 -ORDER BY timestamp -LIMIT $2 OFFSET $3; - -- name: SaveIllTransaction :one INSERT INTO ill_transaction (id, timestamp, requester_symbol, requester_id, last_requester_action, prev_requester_action, supplier_symbol, requester_request_id, diff --git a/broker/tenant/tenant.go b/broker/tenant/tenant.go new file mode 100644 index 00000000..73135fbe --- /dev/null +++ b/broker/tenant/tenant.go @@ -0,0 +1,153 @@ +package tenant + +import ( + "errors" + "net/http" + "strings" + + "github.com/indexdata/crosslink/broker/adapter" + "github.com/indexdata/crosslink/broker/common" + "github.com/indexdata/crosslink/broker/ill_db" +) + +type TenantContext struct { + illRepo ill_db.IllRepo + directoryLookupAdapter adapter.DirectoryLookupAdapter + tenantSymbolMap string +} + +func NewContext() *TenantContext { + return &TenantContext{} +} + +func (s *TenantContext) WithTenantSymbol(tenantSymbol string) *TenantContext { + s.tenantSymbolMap = tenantSymbol + return s +} + +func (s *TenantContext) WithIllRepo(illRepo ill_db.IllRepo) *TenantContext { + s.illRepo = illRepo + return s +} + +func (s *TenantContext) WithLookupAdapter(directoryLookupAdapter adapter.DirectoryLookupAdapter) *TenantContext { + s.directoryLookupAdapter = directoryLookupAdapter + return s +} + +func (s *TenantContext) isSpecified() bool { + return s.tenantSymbolMap != "" +} + +func (s *TenantContext) getSymbol(tenant string) string { + return strings.ReplaceAll(s.tenantSymbolMap, "{tenant}", strings.ToUpper(tenant)) +} + +type Tenant struct { + tenantContext *TenantContext + ctx common.ExtendedContext + okapiEndpoint bool + tenant string + symbol string +} + +func (s *TenantContext) WithRequest(ctx common.ExtendedContext, r *http.Request, symbol *string) *Tenant { + var pSymbol string + if symbol != nil { + pSymbol = *symbol + } + t := &Tenant{ + tenantContext: s, + tenant: r.Header.Get("X-Okapi-Tenant"), + ctx: ctx, + okapiEndpoint: strings.HasPrefix(r.URL.Path, "/broker/"), + symbol: pSymbol, + } + return t +} + +func (t *Tenant) GetSymbol() (string, error) { + var mainSymbol string + if t.okapiEndpoint { + if !t.tenantContext.isSpecified() { + return "", errors.New("tenant mapping must be specified") + } + if t.tenant == "" { + return "", errors.New("header X-Okapi-Tenant must be specified") + } + mainSymbol = t.tenantContext.getSymbol(t.tenant) + } else { + if t.symbol == "" { + return "", errors.New("symbol must be specified") + } + mainSymbol = t.symbol + } + if t.tenantContext.illRepo == nil { + return mainSymbol, nil + } + // if supplied symbol is the same as main symbol, we can skip the check against branch symbols, since it's valid + // we do not check even if only one symbol because GetCachedPeersBySymbols() creates a peer for the main symbol if it does not exist, + // so we would not be able to distinguish between "symbol does not exist" and "symbol exists but has no peers" + if t.symbol == "" || t.symbol == mainSymbol { + return mainSymbol, nil + } + peers, _, err := t.tenantContext.illRepo.GetCachedPeersBySymbols(t.ctx, []string{mainSymbol}, t.tenantContext.directoryLookupAdapter) + if err != nil { + return "", err + } + found := false + for _, peer := range peers { + branchSymbols, err := t.tenantContext.illRepo.GetBranchSymbolsByPeerId(t.ctx, peer.ID) + if err != nil { + return "", err + } + for _, branchSymbol := range branchSymbols { + if t.symbol == branchSymbol.SymbolValue { + found = true + } + } + } + if !found { + return "", errors.New("symbol does not match any branch symbols for tenant") + } + return t.symbol, nil +} + +// GetSymbols returns the main symbol for the tenant and all branch symbols of peers associated with that symbol. +// A nil slice means no symbol filtering should be applied. Otherwise, the returned slice contains at least the +// main symbol and may include associated branch symbols. +func (t *Tenant) GetSymbols() ([]string, error) { + var mainSymbol string + if t.okapiEndpoint { + if !t.tenantContext.isSpecified() { + return nil, errors.New("tenant mapping must be specified") + } + if t.tenant == "" { + return nil, errors.New("header X-Okapi-Tenant must be specified") + } + mainSymbol = t.tenantContext.getSymbol(t.tenant) + } else { + if t.symbol == "" { + return nil, nil + } + mainSymbol = t.symbol + } + allSyms := []string{mainSymbol} + if t.tenantContext.illRepo == nil { + return allSyms, nil + } + peers, _, err := t.tenantContext.illRepo.GetCachedPeersBySymbols(t.ctx, []string{mainSymbol}, t.tenantContext.directoryLookupAdapter) + if err != nil { + return nil, err + } + for _, peer := range peers { + branchSymbols, err := t.tenantContext.illRepo.GetBranchSymbolsByPeerId(t.ctx, peer.ID) + if err != nil { + return nil, err + } + for _, branchSymbol := range branchSymbols { + allSyms = append(allSyms, branchSymbol.SymbolValue) + } + } + return allSyms, nil +} diff --git a/broker/tenant/tenant_test.go b/broker/tenant/tenant_test.go new file mode 100644 index 00000000..1657ae88 --- /dev/null +++ b/broker/tenant/tenant_test.go @@ -0,0 +1,271 @@ +package tenant + +import ( + "context" + "net/http" + "net/url" + "testing" + + "github.com/indexdata/crosslink/broker/adapter" + "github.com/indexdata/crosslink/broker/common" + "github.com/indexdata/crosslink/broker/ill_db" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +type MockDirectoryLookupAdapter struct { + mock.Mock + adapter.DirectoryLookupAdapter +} + +type MockIllRepo struct { + mock.Mock + ill_db.IllRepo +} + +func (r *MockIllRepo) GetCachedPeersBySymbols(ctx common.ExtendedContext, symbols []string, directoryAdapter adapter.DirectoryLookupAdapter) ([]ill_db.Peer, string, error) { + args := r.Called(ctx, symbols, directoryAdapter) + return args.Get(0).([]ill_db.Peer), args.String(1), args.Error(2) +} + +func (r *MockIllRepo) GetBranchSymbolsByPeerId(ctx common.ExtendedContext, peerId string) ([]ill_db.BranchSymbol, error) { + args := r.Called(ctx, peerId) + return args.Get(0).([]ill_db.BranchSymbol), args.Error(1) +} + +func TestTenantNoSymbol(t *testing.T) { + tenantContext := NewContext() + assert.False(t, tenantContext.isSpecified()) + + ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{}) + header := http.Header{} + turl := &url.URL{Path: "/test"} + httpRequest := &http.Request{Header: header, URL: turl} + + tenant := tenantContext.WithRequest(ctx, httpRequest, nil) + _, err := tenant.GetSymbol() + assert.Error(t, err) + assert.Equal(t, "symbol must be specified", err.Error()) + + symbols, err := tenant.GetSymbols() + assert.NoError(t, err) + assert.Nil(t, symbols) +} + +func TestTenantWithSymbol(t *testing.T) { + tenantContext := NewContext() + assert.False(t, tenantContext.isSpecified()) + + ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{}) + header := http.Header{} + turl := &url.URL{Path: "/test"} + httpRequest := &http.Request{Header: header, URL: turl} + + symbol := "LIB" + tenant := tenantContext.WithRequest(ctx, httpRequest, &symbol) + outputSymbol, err := tenant.GetSymbol() + assert.NoError(t, err) + assert.Equal(t, "LIB", outputSymbol) + + symbols, err := tenant.GetSymbols() + assert.NoError(t, err) + assert.Equal(t, []string{"LIB"}, symbols) +} + +func TestTenantNoMapping(t *testing.T) { + tenantContext := NewContext() + assert.False(t, tenantContext.isSpecified()) + + ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{}) + header := http.Header{} + turl := &url.URL{Path: "/broker/"} + httpRequest := &http.Request{Header: header, URL: turl} + + tenant := tenantContext.WithRequest(ctx, httpRequest, nil) + _, err := tenant.GetSymbol() + assert.Error(t, err) + assert.Equal(t, "tenant mapping must be specified", err.Error()) + + _, err = tenant.GetSymbols() + assert.Error(t, err) + assert.Equal(t, "tenant mapping must be specified", err.Error()) +} + +func TestTenantMissingTenant(t *testing.T) { + tenantContext := NewContext().WithTenantSymbol("ISIL:DK-{tenant}") + assert.True(t, tenantContext.isSpecified()) + + ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{}) + header := http.Header{} + turl := &url.URL{Path: "/broker/"} + httpRequest := &http.Request{Header: header, URL: turl} + + tenant := tenantContext.WithRequest(ctx, httpRequest, nil) + _, err := tenant.GetSymbol() + assert.Error(t, err) + assert.Equal(t, "header X-Okapi-Tenant must be specified", err.Error()) + + _, err = tenant.GetSymbols() + assert.Error(t, err) + assert.Equal(t, "header X-Okapi-Tenant must be specified", err.Error()) +} + +func TestTenantMapOK(t *testing.T) { + tenantContext := NewContext().WithTenantSymbol("ISIL:DK-{tenant}") + assert.True(t, tenantContext.isSpecified()) + + ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{}) + header := http.Header{} + header.Set("X-Okapi-Tenant", "tenant1") + turl := &url.URL{Path: "/broker/"} + httpRequest := &http.Request{Header: header, URL: turl} + + tenant := tenantContext.WithRequest(ctx, httpRequest, nil) + outputSymbol, err := tenant.GetSymbol() + assert.NoError(t, err) + assert.Equal(t, "ISIL:DK-TENANT1", outputSymbol) + + symbols, err := tenant.GetSymbols() + assert.NoError(t, err) + assert.Equal(t, []string{"ISIL:DK-TENANT1"}, symbols) +} + +func TestTenantRepo1(t *testing.T) { + mockIllRepo := new(MockIllRepo) + mockIllRepo.On("GetCachedPeersBySymbols", mock.Anything, mock.Anything, mock.Anything).Return([]ill_db.Peer{}, "", nil) + + tenantContext := NewContext().WithTenantSymbol("ISIL:DK-{tenant}").WithIllRepo(mockIllRepo) + assert.True(t, tenantContext.isSpecified()) + + ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{}) + header := http.Header{} + header.Set("X-Okapi-Tenant", "tenant1") + turl := &url.URL{Path: "/broker/"} + httpRequest := &http.Request{Header: header, URL: turl} + + tenant := tenantContext.WithRequest(ctx, httpRequest, nil) + outputSymbol, err := tenant.GetSymbol() + assert.NoError(t, err) + assert.Equal(t, "ISIL:DK-TENANT1", outputSymbol) + + symbols, err := tenant.GetSymbols() + assert.NoError(t, err) + assert.Equal(t, []string{"ISIL:DK-TENANT1"}, symbols) +} + +func TestTenantSymIdentical(t *testing.T) { + mockIllRepo := new(MockIllRepo) + mockIllRepo.On("GetCachedPeersBySymbols", mock.Anything, mock.Anything, mock.Anything).Return([]ill_db.Peer{}, "", nil) + + tenantContext := NewContext().WithTenantSymbol("ISIL:DK-{tenant}").WithIllRepo(mockIllRepo) + assert.True(t, tenantContext.isSpecified()) + + ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{}) + header := http.Header{} + header.Set("X-Okapi-Tenant", "tenant1") + turl := &url.URL{Path: "/broker/"} + httpRequest := &http.Request{Header: header, URL: turl} + symbol := "ISIL:DK-TENANT1" + tenant := tenantContext.WithRequest(ctx, httpRequest, &symbol) + outputSymbol, err := tenant.GetSymbol() + assert.NoError(t, err) + assert.Equal(t, "ISIL:DK-TENANT1", outputSymbol) + + symbols, err := tenant.GetSymbols() + assert.NoError(t, err) + assert.Equal(t, []string{"ISIL:DK-TENANT1"}, symbols) +} + +func TestTenantNoBranchMatch(t *testing.T) { + mockIllRepo := new(MockIllRepo) + mockIllRepo.On("GetCachedPeersBySymbols", mock.Anything, mock.Anything, mock.Anything).Return([]ill_db.Peer{}, "", nil) + + tenantContext := NewContext().WithTenantSymbol("ISIL:DK-{tenant}").WithIllRepo(mockIllRepo) + assert.True(t, tenantContext.isSpecified()) + + ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{}) + header := http.Header{} + header.Set("X-Okapi-Tenant", "tenant1") + turl := &url.URL{Path: "/broker/"} + httpRequest := &http.Request{Header: header, URL: turl} + symbol := "LIB" + tenant := tenantContext.WithRequest(ctx, httpRequest, &symbol) + _, err := tenant.GetSymbol() + assert.Error(t, err) + assert.Equal(t, "symbol does not match any branch symbols for tenant", err.Error()) + + symbols, err := tenant.GetSymbols() + assert.NoError(t, err) + assert.Equal(t, []string{"ISIL:DK-TENANT1"}, symbols) +} + +func TestTenantBranchMatch(t *testing.T) { + mockIllRepo := new(MockIllRepo) + mockIllRepo.On("GetCachedPeersBySymbols", mock.Anything, mock.Anything, mock.Anything).Return([]ill_db.Peer{{ID: "ISIL:DK-TENANT1"}}, "", nil) + mockIllRepo.On("GetBranchSymbolsByPeerId", mock.Anything, mock.Anything).Return([]ill_db.BranchSymbol{{SymbolValue: "ISIL:DK-DIKU"}, {SymbolValue: "ISIL:DK-LIB"}}, nil) + + tenantContext := NewContext().WithTenantSymbol("ISIL:DK-{tenant}").WithIllRepo(mockIllRepo) + assert.True(t, tenantContext.isSpecified()) + + ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{}) + header := http.Header{} + header.Set("X-Okapi-Tenant", "tenant1") + turl := &url.URL{Path: "/broker/"} + httpRequest := &http.Request{Header: header, URL: turl} + symbol := "ISIL:DK-LIB" + tenant := tenantContext.WithRequest(ctx, httpRequest, &symbol) + outputSymbol, err := tenant.GetSymbol() + assert.NoError(t, err) + assert.Equal(t, "ISIL:DK-LIB", outputSymbol) + + symbols, err := tenant.GetSymbols() + assert.NoError(t, err) + assert.Equal(t, []string{"ISIL:DK-TENANT1", "ISIL:DK-DIKU", "ISIL:DK-LIB"}, symbols) +} + +func TestTenantRepoError1(t *testing.T) { + mockIllRepo := new(MockIllRepo) + mockIllRepo.On("GetCachedPeersBySymbols", mock.Anything, mock.Anything, mock.Anything).Return([]ill_db.Peer{}, "", assert.AnError) + + tenantContext := NewContext().WithTenantSymbol("ISIL:DK-{tenant}").WithIllRepo(mockIllRepo) + assert.True(t, tenantContext.isSpecified()) + + ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{}) + header := http.Header{} + header.Set("X-Okapi-Tenant", "tenant1") + turl := &url.URL{Path: "/broker/"} + httpRequest := &http.Request{Header: header, URL: turl} + symbol := "ISIL:DK-LIB" + tenant := tenantContext.WithRequest(ctx, httpRequest, &symbol) + _, err := tenant.GetSymbol() + assert.Error(t, err) + assert.Equal(t, "assert.AnError general error for testing", err.Error()) + + _, err = tenant.GetSymbols() + assert.Error(t, err) + assert.Equal(t, "assert.AnError general error for testing", err.Error()) +} + +func TestTenantRepoError2(t *testing.T) { + mockIllRepo := new(MockIllRepo) + mockIllRepo.On("GetCachedPeersBySymbols", mock.Anything, mock.Anything, mock.Anything).Return([]ill_db.Peer{{ID: "ISIL:DK-TENANT1"}}, "", nil) + mockIllRepo.On("GetBranchSymbolsByPeerId", mock.Anything, mock.Anything).Return([]ill_db.BranchSymbol{}, assert.AnError) + + tenantContext := NewContext().WithTenantSymbol("ISIL:DK-{tenant}").WithIllRepo(mockIllRepo) + assert.True(t, tenantContext.isSpecified()) + + ctx := common.CreateExtCtxWithArgs(context.Background(), &common.LoggerArgs{}) + header := http.Header{} + header.Set("X-Okapi-Tenant", "tenant1") + turl := &url.URL{Path: "/broker/"} + httpRequest := &http.Request{Header: header, URL: turl} + symbol := "ISIL:DK-LIB" + tenant := tenantContext.WithRequest(ctx, httpRequest, &symbol) + _, err := tenant.GetSymbol() + assert.Error(t, err) + assert.Equal(t, "assert.AnError general error for testing", err.Error()) + + _, err = tenant.GetSymbols() + assert.Error(t, err) + assert.Equal(t, "assert.AnError general error for testing", err.Error()) +} diff --git a/broker/test/api/api-handler_test.go b/broker/test/api/api-handler_test.go index 4feb244a..ea3ac8da 100644 --- a/broker/test/api/api-handler_test.go +++ b/broker/test/api/api-handler_test.go @@ -18,6 +18,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/indexdata/crosslink/broker/common" + "github.com/indexdata/crosslink/broker/tenant" "github.com/indexdata/crosslink/broker/vcs" "github.com/indexdata/crosslink/iso18626" "github.com/jackc/pgx/v5/pgtype" @@ -44,7 +45,7 @@ var eventRepo events.EventRepo var sseBroker *api.SseBroker var mockIllRepoError = new(mocks.MockIllRepositoryError) var mockEventRepoError = new(mocks.MockEventRepositoryError) -var handlerMock = api.NewApiHandler(mockEventRepoError, mockIllRepoError, common.NewTenant(""), api.LIMIT_DEFAULT) +var handlerMock = api.NewApiHandler(mockEventRepoError, mockIllRepoError, *tenant.NewContext(), api.LIMIT_DEFAULT) func TestMain(m *testing.M) { app.TENANT_TO_SYMBOL = "ISIL:DK-{tenant}" @@ -202,7 +203,8 @@ func TestGetIllTransactions(t *testing.T) { prevLink := *resp.About.PrevLink assert.Contains(t, prevLink, "offset=0") - body = getResponseBody(t, "/broker/ill_transactions?requester_symbol="+url.QueryEscape("ISIL:DK-BIB1")) + body = httpGet(t, "/broker/ill_transactions?requester_symbol="+url.QueryEscape("ISIL:DK-BIB1"), "bib1", http.StatusOK) + resp.About.NextLink = nil resp.About.PrevLink = nil err = json.Unmarshal(body, &resp) @@ -221,12 +223,12 @@ func TestGetIllTransactions(t *testing.T) { assert.True(t, strings.HasPrefix(lastLink, getLocalhostWithPort()+"/broker/ill_transactions?")) assert.Contains(t, lastLink, "requester_symbol="+url.QueryEscape("ISIL:DK-BIB1")) assert.Contains(t, lastLink, "offset=10") - // we have estblished that the next link is correct, now we will check if it works - hres, err := http.Get(nextLink) // nolint:gosec - assert.NoError(t, err) - defer hres.Body.Close() - body, err = io.ReadAll(hres.Body) - assert.NoError(t, err) + + // we still need the tenant + body = httpGet(t, nextLink, "", http.StatusBadRequest) + assert.Contains(t, string(body), "header X-Okapi-Tenant must be specified") + // supply tenant now + body = httpGet(t, nextLink, "bib1", http.StatusOK) err = json.Unmarshal(body, &resp) assert.NoError(t, err) assert.NotNil(t, resp.About.PrevLink) @@ -324,9 +326,10 @@ func TestBrokerCRUD(t *testing.T) { httpGet(t, "/broker/ill_transactions/"+illId+"?requester_symbol="+url.QueryEscape("ISIL:DK-DIKU"), "diku", http.StatusOK) httpGet(t, "/broker/ill_transactions/"+illId+"?requester_symbol="+url.QueryEscape("ISIL:DK-DIKU"), "ruc", http.StatusNotFound) httpGet(t, "/broker/ill_transactions/"+illId, "ruc", http.StatusNotFound) - httpGet(t, "/broker/ill_transactions/"+illId, "", http.StatusNotFound) + httpGet(t, "/broker/ill_transactions/"+illId, "diku", http.StatusOK) + httpGet(t, "/broker/ill_transactions/"+illId, "", http.StatusBadRequest) - body = httpGet(t, "/broker/ill_transactions/"+illId+"?requester_symbol="+url.QueryEscape("ISIL:DK-DIKU"), "", http.StatusOK) + body = httpGet(t, "/broker/ill_transactions/"+illId+"?requester_symbol="+url.QueryEscape("ISIL:DK-DIKU"), "diku", http.StatusOK) err = json.Unmarshal(body, &tran) assert.NoError(t, err) assert.Equal(t, illId, tran.Id) @@ -339,7 +342,11 @@ func TestBrokerCRUD(t *testing.T) { assert.Equal(t, 0, len(httpGetTrans(t, "/broker/ill_transactions", "ruc", http.StatusOK))) - assert.Equal(t, 0, len(httpGetTrans(t, "/broker/ill_transactions", "", http.StatusOK))) + assert.Equal(t, 0, len(httpGetTrans(t, "/broker/ill_transactions", "magtic", http.StatusOK))) + + assert.Equal(t, 0, len(httpGetTrans(t, "/broker/ill_transactions", "", http.StatusBadRequest))) + + assert.True(t, len(httpGetTrans(t, "/ill_transactions", "", http.StatusOK)) >= 2) body = httpGet(t, "/broker/ill_transactions?requester_req_id="+url.QueryEscape(reqReqId), "diku", http.StatusOK) var resp oapi.IllTransactions @@ -383,7 +390,7 @@ func TestBrokerCRUD(t *testing.T) { assert.Len(t, events.Items, 1) assert.Equal(t, eventId, events.Items[0].Id) - body = httpGet(t, "/broker/events?requester_req_id="+url.QueryEscape(reqReqId)+"&requester_symbol="+url.QueryEscape("ISIL:DK-DIKU"), "", http.StatusOK) + body = httpGet(t, "/broker/events?requester_req_id="+url.QueryEscape(reqReqId)+"&requester_symbol="+url.QueryEscape("ISIL:DK-DIKU"), "diku", http.StatusOK) err = json.Unmarshal(body, &events) assert.NoError(t, err) assert.Len(t, events.Items, 1) @@ -846,7 +853,10 @@ func getResponseBody(t *testing.T, endpoint string) []byte { func httpRequest(t *testing.T, method string, uriPath string, reqbytes []byte, tenant string, expectStatus int) []byte { client := http.DefaultClient - hreq, err := http.NewRequest(method, getLocalhostWithPort()+uriPath, bytes.NewBuffer(reqbytes)) + if strings.HasPrefix(uriPath, "/") { + uriPath = getLocalhostWithPort() + uriPath + } + hreq, err := http.NewRequest(method, uriPath, bytes.NewBuffer(reqbytes)) assert.NoError(t, err) if tenant != "" { hreq.Header.Set("X-Okapi-Tenant", tenant) diff --git a/broker/test/api/sse_broker_test.go b/broker/test/api/sse_broker_test.go index 629b9e60..047720b0 100644 --- a/broker/test/api/sse_broker_test.go +++ b/broker/test/api/sse_broker_test.go @@ -101,16 +101,25 @@ func TestSseEndpointNoSide(t *testing.T) { bodyBytes, err := io.ReadAll(resp.Body) assert.NoError(t, err) assert.Equal(t, 400, resp.StatusCode) - assert.Equal(t, "query parameter 'side' and 'symbol' must be specified\n", string(bodyBytes)) + assert.Equal(t, "query parameter 'side' must be specified\n", string(bodyBytes)) } func TestSseEndpointNoSymbol(t *testing.T) { - resp, err := http.Get(getLocalhostWithPort() + "/sse/events?side=borrowing") + resp, err := http.Get(getLocalhostWithPort() + "/sse/events?side=borrowing&other=/broker/") assert.NoError(t, err) bodyBytes, err := io.ReadAll(resp.Body) assert.NoError(t, err) assert.Equal(t, 400, resp.StatusCode) - assert.Equal(t, "query parameter 'side' and 'symbol' must be specified\n", string(bodyBytes)) + assert.Equal(t, "symbol must be specified\n", string(bodyBytes)) +} + +func TestSseEndpointNoTenant(t *testing.T) { + resp, err := http.Get(getLocalhostWithPort() + "/broker/sse/events?side=borrowing") + assert.NoError(t, err) + bodyBytes, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, 400, resp.StatusCode) + assert.Equal(t, "header X-Okapi-Tenant must be specified\n", string(bodyBytes)) } func executeTask(t time.Time) { diff --git a/broker/test/mocks/mock_illrepo.go b/broker/test/mocks/mock_illrepo.go index 65b43d42..f441a5e8 100644 --- a/broker/test/mocks/mock_illrepo.go +++ b/broker/test/mocks/mock_illrepo.go @@ -101,13 +101,7 @@ func (r *MockIllRepositorySuccess) GetIllTransactionByRequesterRequestIdForUpdat }, nil } -func (r *MockIllRepositorySuccess) ListIllTransactions(ctx common.ExtendedContext, params ill_db.ListIllTransactionsParams, cql *string) ([]ill_db.IllTransaction, int64, error) { - return []ill_db.IllTransaction{{ - ID: "id", - }}, 0, nil -} - -func (r *MockIllRepositorySuccess) GetIllTransactionsByRequesterSymbol(ctx common.ExtendedContext, params ill_db.GetIllTransactionsByRequesterSymbolParams, cql *string) ([]ill_db.IllTransaction, int64, error) { +func (r *MockIllRepositorySuccess) ListIllTransactions(ctx common.ExtendedContext, params ill_db.ListIllTransactionsParams, cql *string, symbols []string) ([]ill_db.IllTransaction, int64, error) { return []ill_db.IllTransaction{{ ID: "id", }}, 0, nil @@ -263,11 +257,7 @@ func (r *MockIllRepositoryError) GetIllTransactionByRequesterRequestIdForUpdate( return ill_db.IllTransaction{}, errors.New("DB error") } -func (r *MockIllRepositoryError) ListIllTransactions(ctx common.ExtendedContext, params ill_db.ListIllTransactionsParams, cql *string) ([]ill_db.IllTransaction, int64, error) { - return []ill_db.IllTransaction{}, 0, errors.New("DB error") -} - -func (r *MockIllRepositoryError) GetIllTransactionsByRequesterSymbol(ctx common.ExtendedContext, params ill_db.GetIllTransactionsByRequesterSymbolParams, cql *string) ([]ill_db.IllTransaction, int64, error) { +func (r *MockIllRepositoryError) ListIllTransactions(ctx common.ExtendedContext, params ill_db.ListIllTransactionsParams, cql *string, symbols []string) ([]ill_db.IllTransaction, int64, error) { return []ill_db.IllTransaction{}, 0, errors.New("DB error") }