Skip to content

Commit be9ce74

Browse files
authored
feat(enterprise): audit logs for alpha and zero (#7295)
* starting work on audit logs * making audit logs to work with alpha * making audit logs to work with zero too * adding endpoint to grpc audit * adding skip method for grpc audits * making zero and alpha logs from the start itself. * fixing zero init audit process * adding dgraph audit tool and logs encryption * adding logwriter to handle encryption and everything * adding test cases to check requests are getting logged into the audit logs * adding interceptor ee version * basic refactoring and log message truncate functionality * making log writer performant using buffered writer * gracefully closing all the go routines * fixing oss build * fixing failed test case
1 parent 1bfb25a commit be9ce74

36 files changed

+1605
-87
lines changed

dgraph/cmd/alpha/run.go

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ import (
3636
"time"
3737

3838
badgerpb "github.com/dgraph-io/badger/v3/pb"
39+
"github.com/dgraph-io/dgraph/ee/audit"
40+
3941
"github.com/dgraph-io/dgo/v200/protos/api"
4042
"github.com/dgraph-io/dgraph/edgraph"
4143
"github.com/dgraph-io/dgraph/ee/enc"
@@ -191,6 +193,14 @@ they form a Raft group and provide synchronous replication.
191193
`Cache percentages summing up to 100 for various caches (FORMAT:
192194
PostingListCache,PstoreBlockCache,PstoreIndexCache,WAL).`)
193195

196+
flag.String("audit", "",
197+
`Various audit options.
198+
dir=/path/to/audits to define the path where to store the audit logs.
199+
compress=true/false to enabled the compression of old audit logs (default behaviour is false).
200+
encrypt_file=enc/key/file enables the audit log encryption with the key path provided with the
201+
flag.
202+
Sample flag could look like --audit dir=aa;encrypt_file=/filepath;compress=true`)
203+
194204
// TLS configurations
195205
x.RegisterServerTLSFlags(flag)
196206
}
@@ -379,6 +389,7 @@ func serveGRPC(l net.Listener, tlsCfg *tls.Config, closer *z.Closer) {
379389
grpc.MaxSendMsgSize(x.GrpcMaxSize),
380390
grpc.MaxConcurrentStreams(1000),
381391
grpc.StatsHandler(&ocgrpc.ServerHandler{}),
392+
grpc.UnaryInterceptor(audit.AuditRequestGRPC),
382393
}
383394
if tlsCfg != nil {
384395
opt = append(opt, grpc.Creds(credentials.NewTLS(tlsCfg)))
@@ -417,15 +428,18 @@ func setupServer(closer *z.Closer) {
417428
log.Fatal(err)
418429
}
419430

420-
http.HandleFunc("/query", queryHandler)
421-
http.HandleFunc("/query/", queryHandler)
422-
http.HandleFunc("/mutate", mutationHandler)
423-
http.HandleFunc("/mutate/", mutationHandler)
424-
http.HandleFunc("/commit", commitHandler)
425-
http.HandleFunc("/alter", alterHandler)
426-
http.HandleFunc("/health", healthCheck)
427-
http.HandleFunc("/state", stateHandler)
428-
http.HandleFunc("/jemalloc", x.JemallocHandler)
431+
baseMux := http.NewServeMux()
432+
http.Handle("/", audit.AuditRequestHttp(baseMux))
433+
434+
baseMux.HandleFunc("/query", queryHandler)
435+
baseMux.HandleFunc("/query/", queryHandler)
436+
baseMux.HandleFunc("/mutate", mutationHandler)
437+
baseMux.HandleFunc("/mutate/", mutationHandler)
438+
baseMux.HandleFunc("/commit", commitHandler)
439+
baseMux.HandleFunc("/alter", alterHandler)
440+
baseMux.HandleFunc("/health", healthCheck)
441+
baseMux.HandleFunc("/state", stateHandler)
442+
baseMux.HandleFunc("/jemalloc", x.JemallocHandler)
429443

430444
// TODO: Figure out what this is for?
431445
http.HandleFunc("/debug/store", storeStatsHandler)
@@ -451,8 +465,9 @@ func setupServer(closer *z.Closer) {
451465
var gqlHealthStore *admin.GraphQLHealthStore
452466
// Do not use := notation here because adminServer is a global variable.
453467
mainServer, adminServer, gqlHealthStore = admin.NewServers(introspection, &globalEpoch, closer)
454-
http.Handle("/graphql", mainServer.HTTPHandler())
455-
http.HandleFunc("/probe/graphql", func(w http.ResponseWriter, r *http.Request) {
468+
baseMux.Handle("/graphql", mainServer.HTTPHandler())
469+
baseMux.HandleFunc("/probe/graphql", func(w http.ResponseWriter,
470+
r *http.Request) {
456471
healthStatus := gqlHealthStore.GetHealth()
457472
httpStatusCode := http.StatusOK
458473
if !healthStatus.Healthy {
@@ -463,18 +478,19 @@ func setupServer(closer *z.Closer) {
463478
x.Check2(w.Write([]byte(fmt.Sprintf(`{"status":"%s","schemaUpdateCounter":%d}`,
464479
healthStatus.StatusMsg, atomic.LoadUint64(&globalEpoch)))))
465480
})
466-
http.Handle("/admin", allowedMethodsHandler(allowedMethods{
481+
baseMux.Handle("/admin", allowedMethodsHandler(allowedMethods{
467482
http.MethodGet: true,
468483
http.MethodPost: true,
469484
http.MethodOptions: true,
470485
}, adminAuthHandler(adminServer.HTTPHandler())))
471486

472-
http.Handle("/admin/schema", adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter,
487+
baseMux.Handle("/admin/schema", adminAuthHandler(http.HandlerFunc(func(
488+
w http.ResponseWriter,
473489
r *http.Request) {
474490
adminSchemaHandler(w, r, adminServer)
475491
})))
476492

477-
http.Handle("/admin/schema/validate", http.HandlerFunc(func(w http.ResponseWriter,
493+
baseMux.HandleFunc("/admin/schema/validate", func(w http.ResponseWriter,
478494
r *http.Request) {
479495
schema := readRequest(w, r)
480496
w.Header().Set("Content-Type", "application/json")
@@ -489,26 +505,28 @@ func setupServer(closer *z.Closer) {
489505
w.WriteHeader(http.StatusBadRequest)
490506
errs := strings.Split(strings.TrimSpace(err.Error()), "\n")
491507
x.SetStatusWithErrors(w, x.ErrorInvalidRequest, errs)
492-
}))
508+
})
493509

494-
http.Handle("/admin/shutdown", allowedMethodsHandler(allowedMethods{http.MethodGet: true},
510+
baseMux.Handle("/admin/shutdown", allowedMethodsHandler(allowedMethods{http.
511+
MethodGet: true},
495512
adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
496513
shutDownHandler(w, r, adminServer)
497514
}))))
498515

499-
http.Handle("/admin/draining", allowedMethodsHandler(allowedMethods{
516+
baseMux.Handle("/admin/draining", allowedMethodsHandler(allowedMethods{
500517
http.MethodPut: true,
501518
http.MethodPost: true,
502519
}, adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
503520
drainingHandler(w, r, adminServer)
504521
}))))
505522

506-
http.Handle("/admin/export", allowedMethodsHandler(allowedMethods{http.MethodGet: true},
523+
baseMux.Handle("/admin/export", allowedMethodsHandler(
524+
allowedMethods{http.MethodGet: true},
507525
adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
508526
exportHandler(w, r, adminServer)
509527
}))))
510528

511-
http.Handle("/admin/config/cache_mb", allowedMethodsHandler(allowedMethods{
529+
baseMux.Handle("/admin/config/cache_mb", allowedMethodsHandler(allowedMethods{
512530
http.MethodGet: true,
513531
http.MethodPut: true,
514532
}, adminAuthHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -520,10 +538,10 @@ func setupServer(closer *z.Closer) {
520538
glog.Infof("Bringing up GraphQL HTTP admin API at %s/admin", addr)
521539

522540
// Add OpenCensus z-pages.
523-
zpages.Handle(http.DefaultServeMux, "/z")
541+
zpages.Handle(baseMux, "/z")
524542

525-
http.HandleFunc("/", homeHandler)
526-
http.HandleFunc("/ui/keywords", keywordHandler)
543+
baseMux.Handle("/", http.HandlerFunc(homeHandler))
544+
baseMux.Handle("/ui/keywords", http.HandlerFunc(keywordHandler))
527545

528546
// Initialize the servers.
529547
admin.ServerCloser.AddRunning(3)
@@ -585,6 +603,8 @@ func run() {
585603
walCache := (cachePercent[3] * (totalCache << 20)) / 100
586604

587605
ctype, clevel := x.ParseCompression(Alpha.Conf.GetString("badger.compression"))
606+
607+
conf := audit.GetAuditConf(Alpha.Conf.GetString("audit"))
588608
opts := worker.Options{
589609
PostingDir: Alpha.Conf.GetString("postings"),
590610
WALDir: Alpha.Conf.GetString("wal"),
@@ -597,6 +617,7 @@ func run() {
597617

598618
MutationsMode: worker.AllowMutations,
599619
AuthToken: Alpha.Conf.GetString("auth_token"),
620+
Audit: conf,
600621
}
601622

602623
secretFile := Alpha.Conf.GetString("acl_secret_file")
@@ -658,6 +679,8 @@ func run() {
658679
LudicrousConcurrency: Alpha.Conf.GetInt("ludicrous_concurrency"),
659680
TLSClientConfig: tlsClientConf,
660681
TLSServerConfig: tlsServerConf,
682+
HmacSecret: opts.HmacSecret,
683+
Audit: opts.Audit != nil,
661684
}
662685
x.WorkerConfig.Parse(Alpha.Conf)
663686

@@ -699,6 +722,9 @@ func run() {
699722

700723
worker.InitServerState()
701724

725+
// Audit is enterprise feature.
726+
x.Check(audit.InitAuditorIfNecessary(opts.Audit, worker.EnterpriseEnabled))
727+
702728
if Alpha.Conf.GetBool("expose_trace") {
703729
// TODO: Remove this once we get rid of event logs.
704730
trace.AuthRequest = func(req *http.Request) (any, sensitive bool) {
@@ -792,6 +818,8 @@ func run() {
792818
adminCloser.SignalAndWait()
793819
glog.Infoln("adminCloser closed.")
794820

821+
audit.Close()
822+
795823
worker.State.Dispose()
796824
x.RemoveCidFile()
797825
glog.Info("worker.State disposed.")

dgraph/cmd/bulk/count_index.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ func (c *countIndexer) writeIndex(buf *z.Buffer) {
156156
encoder = codec.Encoder{BlockSize: 256, Alloc: alloc}
157157
pl.Reset()
158158

159-
// Flush out the buffer.
159+
// flush out the buffer.
160160
if outBuf.LenNoPadding() > 4<<20 {
161161
x.Check(c.writer.Write(outBuf))
162162
outBuf.Reset()

dgraph/cmd/bulk/reduce.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ func (r *reducer) writeTmpSplits(ci *countIndexer, wg *sync.WaitGroup) {
292292
}
293293

294294
for i := 0; i < len(kvs.Kv); i += maxSplitBatchLen {
295-
// Flush the write batch when the max batch length is reached to prevent the
295+
// flush the write batch when the max batch length is reached to prevent the
296296
// value log from growing over the allowed limit.
297297
if splitBatchLen >= maxSplitBatchLen {
298298
x.Check(ci.splitWriter.Flush())

dgraph/cmd/live/batch.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ import (
3434
"github.com/dgraph-io/badger/v3"
3535
"github.com/dgraph-io/dgo/v200"
3636
"github.com/dgraph-io/dgo/v200/protos/api"
37-
"github.com/dgraph-io/dgraph/dgraph/cmd/zero"
3837
"github.com/dgraph-io/dgraph/gql"
3938
"github.com/dgraph-io/dgraph/protos/pb"
4039
"github.com/dgraph-io/dgraph/tok"
@@ -132,7 +131,7 @@ func handleError(err error, isRetry bool) {
132131
dur := time.Duration(1+rand.Intn(10)) * time.Minute
133132
fmt.Printf("Server is overloaded. Will retry after %s.\n", dur.Round(time.Minute))
134133
time.Sleep(dur)
135-
case err != zero.ErrConflict && err != dgo.ErrAborted:
134+
case err != x.ErrConflict && err != dgo.ErrAborted:
136135
fmt.Printf("Error while mutating: %v s.Code %v\n", s.Message(), s.Code())
137136
}
138137
}

dgraph/cmd/root_ee.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ package cmd
1414

1515
import (
1616
acl "github.com/dgraph-io/dgraph/ee/acl"
17+
"github.com/dgraph-io/dgraph/ee/audit"
1718
"github.com/dgraph-io/dgraph/ee/backup"
1819
)
1920

@@ -24,5 +25,6 @@ func init() {
2425
&backup.LsBackup,
2526
&backup.ExportBackup,
2627
&acl.CmdAcl,
28+
&audit.CmdAudit,
2729
)
2830
}

dgraph/cmd/zero/license_ee.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ import (
2020
"net/http"
2121
"time"
2222

23+
"github.com/dgraph-io/dgraph/ee/audit"
24+
2325
"github.com/dgraph-io/dgraph/protos/pb"
2426
"github.com/dgraph-io/dgraph/x"
2527
"github.com/dgraph-io/ristretto/z"
@@ -91,6 +93,7 @@ func (n *node) updateEnterpriseState(closer *z.Closer) {
9193
active := time.Now().UTC().Before(expiry)
9294
if !active {
9395
n.server.expireLicense()
96+
audit.Close()
9497
glog.Warningf("Your enterprise license has expired and enterprise features are " +
9598
"disabled. To continue using enterprise features, apply a valid license. To receive " +
9699
"a new license, contact us at https://dgraph.io/contact.")

dgraph/cmd/zero/oracle.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ func (o *Oracle) commit(src *api.TxnContext) error {
134134
defer o.Unlock()
135135

136136
if o.hasConflict(src) {
137-
return ErrConflict
137+
return x.ErrConflict
138138
}
139139
// We store src.Keys as string to ensure compatibility with all the various language clients we
140140
// have. But, really they are just uint64s encoded as strings. We use base 36 during creation of
@@ -310,9 +310,6 @@ func (o *Oracle) MaxPending() uint64 {
310310
return o.maxAssigned
311311
}
312312

313-
// ErrConflict is returned when commit couldn't succeed due to conflicts.
314-
var ErrConflict = errors.New("Transaction conflict")
315-
316313
// proposeTxn proposes a txn update, and then updates src to reflect the state
317314
// of the commit after proposal is run.
318315
func (s *Server) proposeTxn(ctx context.Context, src *api.TxnContext) error {

dgraph/cmd/zero/raft.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"time"
2929

3030
"github.com/dgraph-io/dgraph/conn"
31+
"github.com/dgraph-io/dgraph/ee/audit"
3132
"github.com/dgraph-io/dgraph/protos/pb"
3233
"github.com/dgraph-io/dgraph/x"
3334
"github.com/dgraph-io/ristretto/z"
@@ -404,6 +405,11 @@ func (n *node) applyProposal(e raftpb.Entry) (uint64, error) {
404405
// Check expiry and set enabled accordingly.
405406
expiry := time.Unix(state.License.ExpiryTs, 0).UTC()
406407
state.License.Enabled = time.Now().UTC().Before(expiry)
408+
if state.License.Enabled && opts.audit != nil {
409+
if err := audit.InitAuditor(opts.audit); err != nil {
410+
glog.Errorf("error while initializing audit logs %+v", err)
411+
}
412+
}
407413
}
408414
if p.Snapshot != nil {
409415
if err := n.applySnapshot(p.Snapshot); err != nil {

0 commit comments

Comments
 (0)