diff --git a/pkg/login1/login1.go b/pkg/login1/login1.go new file mode 100644 index 000000000..bc82bb73a --- /dev/null +++ b/pkg/login1/login1.go @@ -0,0 +1,61 @@ +// Package login1 is a small subset of github.com/coreos/go-systemd/v22/login1 package with +// ability to use shared D-Bus connection and with proper error handling for Reboot method call, which +// is not yet provided by the upstream. +package login1 + +import ( + "context" + "fmt" + + godbus "github.com/godbus/dbus/v5" +) + +const ( + // DBusDest is an object path used by systemd-logind. + DBusDest = "org.freedesktop.login1" + // DBusInterface is an systemd-logind intefrace name. + DBusInterface = "org.freedesktop.login1.Manager" + // DBusPath is a standard path to systemd-logind interface. + DBusPath = "/org/freedesktop/login1" + // DBusMethodNameReboot is a login1 manager interface method name responsible for rebooting. + DBusMethodNameReboot = "Reboot" +) + +// Client describes functionality of provided login1 client. +type Client interface { + Reboot(context.Context) error +} + +// Objector describes functionality required from given D-Bus connection. +type Objector interface { + Object(string, godbus.ObjectPath) godbus.BusObject +} + +// Caller describes required functionality from D-Bus object. +type Caller interface { + CallWithContext(ctx context.Context, method string, flags godbus.Flags, args ...interface{}) *godbus.Call +} + +type rebooter struct { + caller Caller +} + +// New creates new login1 client using given D-Bus connection. +func New(objector Objector) (Client, error) { + if objector == nil { + return nil, fmt.Errorf("no objector given") + } + + return &rebooter{ + caller: objector.Object(DBusDest, DBusPath), + }, nil +} + +// Reboot reboots machine on which it's called. +func (r *rebooter) Reboot(ctx context.Context) error { + if call := r.caller.CallWithContext(ctx, DBusInterface+"."+DBusMethodNameReboot, 0, false); call.Err != nil { + return fmt.Errorf("calling reboot: %w", call.Err) + } + + return nil +} diff --git a/pkg/login1/login1_test.go b/pkg/login1/login1_test.go new file mode 100644 index 000000000..5e0f97b7e --- /dev/null +++ b/pkg/login1/login1_test.go @@ -0,0 +1,164 @@ +package login1_test + +import ( + "context" + "errors" + "fmt" + "testing" + + godbus "github.com/godbus/dbus/v5" + + "github.com/flatcar-linux/flatcar-linux-update-operator/pkg/dbus" + "github.com/flatcar-linux/flatcar-linux-update-operator/pkg/login1" +) + +func Test_Creating_new_client(t *testing.T) { + t.Parallel() + + t.Run("connects_to_global_login1_path_and_interface", func(t *testing.T) { + t.Parallel() + + objectConstructorCalled := false + + connectionWithContextCheck := &dbus.MockConnection{ + ObjectF: func(dest string, path godbus.ObjectPath) godbus.BusObject { + objectConstructorCalled = true + + expectedDest := "org.freedesktop.login1" + + if dest != expectedDest { + t.Fatalf("Expected D-Bus destination %q, got %q", expectedDest, dest) + } + + expectedPath := godbus.ObjectPath("/org/freedesktop/login1") + + if path != expectedPath { + t.Fatalf("Expected D-Bus path %q, got %q", expectedPath, path) + } + + return nil + }, + } + + if _, err := login1.New(connectionWithContextCheck); err != nil { + t.Fatalf("Unexpected error creating client: %v", err) + } + + if !objectConstructorCalled { + t.Fatalf("Expected object constructor to be called") + } + }) + + t.Run("returns_error_when_no_objector_is_given", func(t *testing.T) { + t.Parallel() + + client, err := login1.New(nil) + if err == nil { + t.Fatalf("Expected error creating client with no connector") + } + + if client != nil { + t.Fatalf("Expected client to be nil when New returns error") + } + }) +} + +func Test_Rebooting(t *testing.T) { + t.Parallel() + + t.Run("calls_login1_reboot_method_on_manager_interface", func(t *testing.T) { + t.Parallel() + + rebootCalled := false + + connectionWithContextCheck := &dbus.MockConnection{ + ObjectF: func(string, godbus.ObjectPath) godbus.BusObject { + return &dbus.MockObject{ + CallWithContextF: func(ctx context.Context, method string, flags godbus.Flags, args ...interface{}) *godbus.Call { + rebootCalled = true + + expectedMethodName := "org.freedesktop.login1.Manager.Reboot" + + if method != expectedMethodName { + t.Fatalf("Expected method %q being called, got %q", expectedMethodName, method) + } + + return &godbus.Call{} + }, + } + }, + } + + client, err := login1.New(connectionWithContextCheck) + if err != nil { + t.Fatalf("Unexpected error creating client: %v", err) + } + + if err := client.Reboot(context.Background()); err != nil { + t.Fatalf("Unexpected error rebooting: %v", err) + } + + if !rebootCalled { + t.Fatalf("Expected reboot method call on given D-Bus connection") + } + }) + + t.Run("use_given_context_for_D-Bus_call", func(t *testing.T) { + t.Parallel() + + testKey := struct{}{} + expectedValue := "bar" + + ctx := context.WithValue(context.Background(), testKey, expectedValue) + + connectionWithContextCheck := &dbus.MockConnection{ + ObjectF: func(string, godbus.ObjectPath) godbus.BusObject { + return &dbus.MockObject{ + CallWithContextF: func(ctx context.Context, method string, flags godbus.Flags, args ...interface{}) *godbus.Call { + if val := ctx.Value(testKey); val != expectedValue { + t.Fatalf("Got unexpected context on call") + } + + return &godbus.Call{} + }, + } + }, + } + + client, err := login1.New(connectionWithContextCheck) + if err != nil { + t.Fatalf("Unexpected error creating client: %v", err) + } + + if err := client.Reboot(ctx); err != nil { + t.Fatalf("Unexpected error rebooting: %v", err) + } + }) + + t.Run("returns_error_when_D-Bus_call_fails", func(t *testing.T) { + t.Parallel() + + expectedError := fmt.Errorf("reboot error") + + connectionWithFailingObjectCall := &dbus.MockConnection{ + ObjectF: func(string, godbus.ObjectPath) godbus.BusObject { + return &dbus.MockObject{ + CallWithContextF: func(ctx context.Context, method string, flags godbus.Flags, args ...interface{}) *godbus.Call { + return &godbus.Call{ + Err: expectedError, + } + }, + } + }, + } + + client, err := login1.New(connectionWithFailingObjectCall) + if err != nil { + t.Fatalf("Unexpected error creating client: %v", err) + } + + if err := client.Reboot(context.Background()); !errors.Is(err, expectedError) { + t.Fatalf("Unexpected error rebooting: %v", err) + } + }) +}