/
manifold.go
134 lines (118 loc) · 3.35 KB
/
manifold.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// Copyright 2018 Canonical Ltd.
// Licensed under the AGPLv3, see LICENCE file for details.
package auditconfigupdater
import (
"github.com/juju/errors"
"github.com/juju/worker/v3"
"github.com/juju/worker/v3/dependency"
jujuagent "github.com/juju/juju/agent"
"github.com/juju/juju/core/auditlog"
"github.com/juju/juju/worker/common"
workerstate "github.com/juju/juju/worker/state"
)
// ManifoldConfig holds the information needed to run an
// auditconfigupdater in a dependency.Engine.
type ManifoldConfig struct {
AgentName string
StateName string
NewWorker func(ConfigSource, auditlog.Config, AuditLogFactory) (worker.Worker, error)
}
// Validate validates the manifold configuration.
func (config ManifoldConfig) Validate() error {
if config.AgentName == "" {
return errors.NotValidf("empty AgentName")
}
if config.StateName == "" {
return errors.NotValidf("empty StateName")
}
if config.NewWorker == nil {
return errors.NotValidf("nil NewWorker")
}
return nil
}
// Manifold returns a dependency.Manifold to run an
// auditconfigupdater.
func Manifold(config ManifoldConfig) dependency.Manifold {
return dependency.Manifold{
Inputs: []string{
config.AgentName,
config.StateName,
},
Start: config.start,
Output: output,
}
}
func (config ManifoldConfig) start(context dependency.Context) (_ worker.Worker, err error) {
if err := config.Validate(); err != nil {
return nil, errors.Trace(err)
}
var agent jujuagent.Agent
if err := context.Get(config.AgentName, &agent); err != nil {
return nil, errors.Trace(err)
}
var stTracker workerstate.StateTracker
if err := context.Get(config.StateName, &stTracker); err != nil {
return nil, errors.Trace(err)
}
statePool, err := stTracker.Use()
if err != nil {
return nil, errors.Trace(err)
}
defer func() {
if err != nil {
_ = stTracker.Done()
}
}()
logDir := agent.CurrentConfig().LogDir()
st, err := statePool.SystemState()
if err != nil {
return nil, errors.Trace(err)
}
logFactory := func(cfg auditlog.Config) auditlog.AuditLog {
return auditlog.NewLogFile(logDir, cfg.MaxSizeMB, cfg.MaxBackups)
}
auditConfig, err := initialConfig(st)
if err != nil {
return nil, errors.Trace(err)
}
if auditConfig.Enabled {
auditConfig.Target = logFactory(auditConfig)
}
w, err := config.NewWorker(st, auditConfig, logFactory)
if err != nil {
return nil, errors.Trace(err)
}
return common.NewCleanupWorker(w, func() { _ = stTracker.Done() }), nil
}
type withCurrentConfig interface {
CurrentConfig() auditlog.Config
}
func output(in worker.Worker, out interface{}) error {
if w, ok := in.(*common.CleanupWorker); ok {
in = w.Worker
}
w, ok := in.(withCurrentConfig)
if !ok {
return errors.Errorf("expected worker implementing CurrentConfig(), got %T", in)
}
target, ok := out.(*func() auditlog.Config)
if !ok {
return errors.Errorf("out should be *func() auditlog.Config; got %T", out)
}
*target = w.CurrentConfig
return nil
}
func initialConfig(source ConfigSource) (auditlog.Config, error) {
cfg, err := source.ControllerConfig()
if err != nil {
return auditlog.Config{}, errors.Trace(err)
}
result := auditlog.Config{
Enabled: cfg.AuditingEnabled(),
CaptureAPIArgs: cfg.AuditLogCaptureArgs(),
MaxSizeMB: cfg.AuditLogMaxSizeMB(),
MaxBackups: cfg.AuditLogMaxBackups(),
ExcludeMethods: cfg.AuditLogExcludeMethods(),
}
return result, nil
}