diff --git a/.vscode/settings.json b/.vscode/settings.json index c405326..45ab0ed 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,4 @@ { - "go.formatTool": "goformat", "go.inferGopath": false, "go.autocompleteUnimportedPackages": true, "go.delveConfig": { diff --git a/dev/fmt b/dev/fmt index 977f880..69c24ee 100755 --- a/dev/fmt +++ b/dev/fmt @@ -2,8 +2,8 @@ cd "$(dirname "$0")" if [ "$(docker images -q neko_server_build 2> /dev/null)" == "" ]; then - echo "Image 'neko_server_build' not found. Run ./build first." - exit 1 + echo "Image 'neko_server_build' not found. Run ./build first." + exit 1 fi docker run -it --rm \ diff --git a/dev/go b/dev/go index 731d1d1..ae0acf1 100755 --- a/dev/go +++ b/dev/go @@ -2,8 +2,8 @@ cd "$(dirname "$0")" if [ "$(docker images -q neko_server_build 2> /dev/null)" == "" ]; then - echo "Image 'neko_server_build' not found. Run ./build first." - exit 1 + echo "Image 'neko_server_build' not found. Run ./build first." + exit 1 fi docker run -it \ diff --git a/dev/lint b/dev/lint index a7aa620..4124fda 100755 --- a/dev/lint +++ b/dev/lint @@ -2,8 +2,8 @@ cd "$(dirname "$0")" if [ "$(docker images -q neko_server_build 2> /dev/null)" == "" ]; then - echo "Image 'neko_server_build' not found. Run ./build first." - exit 1 + echo "Image 'neko_server_build' not found. Run ./build first." + exit 1 fi # diff --git a/dev/start b/dev/start index 514bd2f..210e61c 100755 --- a/dev/start +++ b/dev/start @@ -2,8 +2,8 @@ cd "$(dirname "$0")" if [ -z "$(docker images -q neko_server_app 2> /dev/null)" ]; then - echo "Image 'neko_server_app' not found. Running ./build first." - ./build + echo "Image 'neko_server_app' not found. Running ./build first." + ./build fi if [ -z $NEKO_PORT ]; then @@ -22,6 +22,10 @@ if [ -z $NEKO_NAT1TO1 ]; then fi done + if [ -z $NEKO_NAT1TO1 ]; then + NEKO_NAT1TO1=$(hostname -I 2>/dev/null | awk '{print $1}') + fi + if [ -z $NEKO_NAT1TO1 ]; then NEKO_NAT1TO1=$(hostname -i 2>/dev/null) fi diff --git a/internal/api/room/control.go b/internal/api/room/control.go index ddf0434..133af61 100644 --- a/internal/api/room/control.go +++ b/internal/api/room/control.go @@ -6,6 +6,8 @@ import ( "github.com/go-chi/chi" "github.com/demodesk/neko/pkg/auth" + "github.com/demodesk/neko/pkg/types/event" + "github.com/demodesk/neko/pkg/types/message" "github.com/demodesk/neko/pkg/utils" ) @@ -33,17 +35,26 @@ func (h *RoomHandler) controlStatus(w http.ResponseWriter, r *http.Request) erro } func (h *RoomHandler) controlRequest(w http.ResponseWriter, r *http.Request) error { - _, hasHost := h.sessions.GetHost() + session, _ := auth.GetSession(r) + host, hasHost := h.sessions.GetHost() if hasHost { - return utils.HttpUnprocessableEntity("there is already a host") + // TODO: Some throttling mechanism to prevent spamming. + + // let host know that someone wants to take control + host.Send( + event.CONTROL_REQUEST, + message.SessionID{ + ID: session.ID(), + }) + + return utils.HttpError(http.StatusAccepted, "control request sent") } - session, _ := auth.GetSession(r) if h.sessions.Settings().LockedControls && !session.Profile().IsAdmin { return utils.HttpForbidden("controls are locked") } - h.sessions.SetHost(session) + session.SetAsHost() return utils.HttpSuccess(w) } @@ -55,19 +66,20 @@ func (h *RoomHandler) controlRelease(w http.ResponseWriter, r *http.Request) err } h.desktop.ResetKeys() - h.sessions.ClearHost() + session.ClearHost() return utils.HttpSuccess(w) } func (h *RoomHandler) controlTake(w http.ResponseWriter, r *http.Request) error { session, _ := auth.GetSession(r) - h.sessions.SetHost(session) + session.SetAsHost() return utils.HttpSuccess(w) } func (h *RoomHandler) controlGive(w http.ResponseWriter, r *http.Request) error { + session, _ := auth.GetSession(r) sessionId := chi.URLParam(r, "sessionId") target, ok := h.sessions.Get(sessionId) @@ -79,17 +91,18 @@ func (h *RoomHandler) controlGive(w http.ResponseWriter, r *http.Request) error return utils.HttpBadRequest("target session is not allowed to host") } - h.sessions.SetHost(target) + target.SetAsHostBy(session) return utils.HttpSuccess(w) } func (h *RoomHandler) controlReset(w http.ResponseWriter, r *http.Request) error { + session, _ := auth.GetSession(r) _, hasHost := h.sessions.GetHost() if hasHost { h.desktop.ResetKeys() - h.sessions.ClearHost() + session.ClearHost() } return utils.HttpSuccess(w) diff --git a/internal/api/room/handler.go b/internal/api/room/handler.go index ba33d91..8ca7aea 100644 --- a/internal/api/room/handler.go +++ b/internal/api/room/handler.go @@ -31,7 +31,7 @@ func New( } // generate fallback image for private mode when needed - sessions.OnSettingsChanged(func(new types.Settings, old types.Settings) { + sessions.OnSettingsChanged(func(session types.Session, new, old types.Settings) { if old.PrivateMode && !new.PrivateMode { log.Debug().Msg("clearing private mode fallback image") h.privateModeImage = nil diff --git a/internal/api/room/keyboard.go b/internal/api/room/keyboard.go index a8115c0..f3765f5 100644 --- a/internal/api/room/keyboard.go +++ b/internal/api/room/keyboard.go @@ -7,21 +7,13 @@ import ( "github.com/demodesk/neko/pkg/utils" ) -type KeyboardMapData struct { - types.KeyboardMap -} - -type KeyboardModifiersData struct { - types.KeyboardModifiers -} - func (h *RoomHandler) keyboardMapSet(w http.ResponseWriter, r *http.Request) error { - data := &KeyboardMapData{} - if err := utils.HttpJsonRequest(w, r, data); err != nil { + keyboardMap := types.KeyboardMap{} + if err := utils.HttpJsonRequest(w, r, &keyboardMap); err != nil { return err } - err := h.desktop.SetKeyboardMap(data.KeyboardMap) + err := h.desktop.SetKeyboardMap(keyboardMap) if err != nil { return utils.HttpInternalServerError().WithInternalErr(err) } @@ -30,28 +22,26 @@ func (h *RoomHandler) keyboardMapSet(w http.ResponseWriter, r *http.Request) err } func (h *RoomHandler) keyboardMapGet(w http.ResponseWriter, r *http.Request) error { - data, err := h.desktop.GetKeyboardMap() + keyboardMap, err := h.desktop.GetKeyboardMap() if err != nil { return utils.HttpInternalServerError().WithInternalErr(err) } - return utils.HttpSuccess(w, KeyboardMapData{ - KeyboardMap: *data, - }) + return utils.HttpSuccess(w, keyboardMap) } func (h *RoomHandler) keyboardModifiersSet(w http.ResponseWriter, r *http.Request) error { - data := &KeyboardModifiersData{} - if err := utils.HttpJsonRequest(w, r, data); err != nil { + keyboardModifiers := types.KeyboardModifiers{} + if err := utils.HttpJsonRequest(w, r, &keyboardModifiers); err != nil { return err } - h.desktop.SetKeyboardModifiers(data.KeyboardModifiers) + h.desktop.SetKeyboardModifiers(keyboardModifiers) return utils.HttpSuccess(w) } func (h *RoomHandler) keyboardModifiersGet(w http.ResponseWriter, r *http.Request) error { - return utils.HttpSuccess(w, KeyboardModifiersData{ - KeyboardModifiers: h.desktop.GetKeyboardModifiers(), - }) + keyboardModifiers := h.desktop.GetKeyboardModifiers() + + return utils.HttpSuccess(w, keyboardModifiers) } diff --git a/internal/api/room/screen.go b/internal/api/room/screen.go index 174f347..12124bd 100644 --- a/internal/api/room/screen.go +++ b/internal/api/room/screen.go @@ -11,24 +11,16 @@ import ( "github.com/demodesk/neko/pkg/utils" ) -type ScreenConfigurationPayload struct { - Width int `json:"width"` - Height int `json:"height"` - Rate int16 `json:"rate"` -} - func (h *RoomHandler) screenConfiguration(w http.ResponseWriter, r *http.Request) error { - size := h.desktop.GetScreenSize() + screenSize := h.desktop.GetScreenSize() - return utils.HttpSuccess(w, ScreenConfigurationPayload{ - Width: size.Width, - Height: size.Height, - Rate: size.Rate, - }) + return utils.HttpSuccess(w, screenSize) } func (h *RoomHandler) screenConfigurationChange(w http.ResponseWriter, r *http.Request) error { - data := &ScreenConfigurationPayload{} + auth, _ := auth.GetSession(r) + + data := &types.ScreenSize{} if err := utils.HttpJsonRequest(w, r, data); err != nil { return err } @@ -43,10 +35,9 @@ func (h *RoomHandler) screenConfigurationChange(w http.ResponseWriter, r *http.R return utils.HttpUnprocessableEntity("cannot set screen size").WithInternalErr(err) } - h.sessions.Broadcast(event.SCREEN_UPDATED, message.ScreenSize{ - Width: size.Width, - Height: size.Height, - Rate: size.Rate, + h.sessions.Broadcast(event.SCREEN_UPDATED, message.ScreenSizeUpdate{ + ID: auth.ID(), + ScreenSize: size, }) return utils.HttpSuccess(w, data) @@ -56,16 +47,7 @@ func (h *RoomHandler) screenConfigurationChange(w http.ResponseWriter, r *http.R func (h *RoomHandler) screenConfigurationsList(w http.ResponseWriter, r *http.Request) error { configurations := h.desktop.ScreenConfigurations() - list := make([]ScreenConfigurationPayload, 0, len(configurations)) - for _, conf := range configurations { - list = append(list, ScreenConfigurationPayload{ - Width: conf.Width, - Height: conf.Height, - Rate: conf.Rate, - }) - } - - return utils.HttpSuccess(w, list) + return utils.HttpSuccess(w, configurations) } func (h *RoomHandler) screenShotGet(w http.ResponseWriter, r *http.Request) error { diff --git a/internal/api/room/settings.go b/internal/api/room/settings.go index 33e4155..c8e1816 100644 --- a/internal/api/room/settings.go +++ b/internal/api/room/settings.go @@ -1,8 +1,12 @@ package room import ( + "encoding/json" + "io" "net/http" + "github.com/demodesk/neko/pkg/auth" + "github.com/demodesk/neko/pkg/types" "github.com/demodesk/neko/pkg/utils" ) @@ -12,13 +16,23 @@ func (h *RoomHandler) settingsGet(w http.ResponseWriter, r *http.Request) error } func (h *RoomHandler) settingsSet(w http.ResponseWriter, r *http.Request) error { - settings := h.sessions.Settings() + session, _ := auth.GetSession(r) - if err := utils.HttpJsonRequest(w, r, &settings); err != nil { - return err + // We read the request body first and unmashal it inside the UpdateSettingsFunc + // to ensure atomicity of the operation. + body, err := io.ReadAll(r.Body) + if err != nil { + return utils.HttpBadRequest("unable to read request body").WithInternalErr(err) } - h.sessions.UpdateSettings(settings) + h.sessions.UpdateSettingsFunc(session, func(settings *types.Settings) bool { + err = json.Unmarshal(body, settings) + return err == nil + }) + + if err != nil { + return utils.HttpBadRequest("unable to parse provided data").WithInternalErr(err) + } return utils.HttpSuccess(w) } diff --git a/internal/api/router.go b/internal/api/router.go index dbe877a..575c08e 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -7,6 +7,7 @@ import ( "github.com/demodesk/neko/internal/api/members" "github.com/demodesk/neko/internal/api/room" + "github.com/demodesk/neko/internal/api/sessions" "github.com/demodesk/neko/pkg/auth" "github.com/demodesk/neko/pkg/types" "github.com/demodesk/neko/pkg/utils" @@ -45,7 +46,9 @@ func (api *ApiManagerCtx) Route(r types.Router) { r.Post("/logout", api.Logout) r.Get("/whoami", api.Whoami) - r.Get("/sessions", api.Sessions) + + sessionsHandler := sessions.New(api.sessions) + r.Route("/sessions", sessionsHandler.Route) membersHandler := members.New(api.members) r.Route("/members", membersHandler.Route) diff --git a/internal/api/session.go b/internal/api/session.go index e1991ad..5f15316 100644 --- a/internal/api/session.go +++ b/internal/api/session.go @@ -33,6 +33,8 @@ func (api *ApiManagerCtx) Login(w http.ResponseWriter, r *http.Request) error { return utils.HttpUnprocessableEntity("session already connected") } else if errors.Is(err, types.ErrMemberDoesNotExist) || errors.Is(err, types.ErrMemberInvalidPassword) { return utils.HttpUnauthorized().WithInternalErr(err) + } else if errors.Is(err, types.ErrSessionLoginsLocked) { + return utils.HttpForbidden("logins are locked").WithInternalErr(err) } else { return utils.HttpInternalServerError().WithInternalErr(err) } @@ -81,16 +83,3 @@ func (api *ApiManagerCtx) Whoami(w http.ResponseWriter, r *http.Request) error { State: session.State(), }) } - -func (api *ApiManagerCtx) Sessions(w http.ResponseWriter, r *http.Request) error { - sessions := []SessionDataPayload{} - for _, session := range api.sessions.List() { - sessions = append(sessions, SessionDataPayload{ - ID: session.ID(), - Profile: session.Profile(), - State: session.State(), - }) - } - - return utils.HttpSuccess(w, sessions) -} diff --git a/internal/api/sessions/controller.go b/internal/api/sessions/controller.go new file mode 100644 index 0000000..c4ce44c --- /dev/null +++ b/internal/api/sessions/controller.go @@ -0,0 +1,80 @@ +package sessions + +import ( + "errors" + "net/http" + + "github.com/demodesk/neko/pkg/auth" + "github.com/demodesk/neko/pkg/types" + "github.com/demodesk/neko/pkg/utils" + "github.com/go-chi/chi" +) + +type SessionDataPayload struct { + ID string `json:"id"` + Profile types.MemberProfile `json:"profile"` + State types.SessionState `json:"state"` +} + +func (h *SessionsHandler) sessionsList(w http.ResponseWriter, r *http.Request) error { + sessions := []SessionDataPayload{} + for _, session := range h.sessions.List() { + sessions = append(sessions, SessionDataPayload{ + ID: session.ID(), + Profile: session.Profile(), + State: session.State(), + }) + } + + return utils.HttpSuccess(w, sessions) +} + +func (h *SessionsHandler) sessionsRead(w http.ResponseWriter, r *http.Request) error { + sessionId := chi.URLParam(r, "sessionId") + + session, ok := h.sessions.Get(sessionId) + if !ok { + return utils.HttpNotFound("session not found") + } + + return utils.HttpSuccess(w, SessionDataPayload{ + ID: session.ID(), + Profile: session.Profile(), + State: session.State(), + }) +} + +func (h *SessionsHandler) sessionsDelete(w http.ResponseWriter, r *http.Request) error { + session, _ := auth.GetSession(r) + + sessionId := chi.URLParam(r, "sessionId") + if sessionId == session.ID() { + return utils.HttpBadRequest("cannot delete own session") + } + + err := h.sessions.Delete(sessionId) + if err != nil { + if errors.Is(err, types.ErrSessionNotFound) { + return utils.HttpBadRequest("session not found") + } else { + return utils.HttpInternalServerError().WithInternalErr(err) + } + } + + return utils.HttpSuccess(w) +} + +func (h *SessionsHandler) sessionsDisconnect(w http.ResponseWriter, r *http.Request) error { + sessionId := chi.URLParam(r, "sessionId") + + err := h.sessions.Disconnect(sessionId) + if err != nil { + if errors.Is(err, types.ErrSessionNotFound) { + return utils.HttpBadRequest("session not found") + } else { + return utils.HttpInternalServerError().WithInternalErr(err) + } + } + + return utils.HttpSuccess(w) +} diff --git a/internal/api/sessions/handler.go b/internal/api/sessions/handler.go new file mode 100644 index 0000000..5f5b771 --- /dev/null +++ b/internal/api/sessions/handler.go @@ -0,0 +1,30 @@ +package sessions + +import ( + "github.com/demodesk/neko/pkg/auth" + "github.com/demodesk/neko/pkg/types" +) + +type SessionsHandler struct { + sessions types.SessionManager +} + +func New( + sessions types.SessionManager, +) *SessionsHandler { + // Init + + return &SessionsHandler{ + sessions: sessions, + } +} + +func (h *SessionsHandler) Route(r types.Router) { + r.Get("/", h.sessionsList) + + r.With(auth.AdminsOnly).Route("/{sessionId}", func(r types.Router) { + r.Get("/", h.sessionsRead) + r.Delete("/", h.sessionsDelete) + r.Post("/disconnect", h.sessionsDisconnect) + }) +} diff --git a/internal/config/session.go b/internal/config/session.go index 6ab4776..443c02d 100644 --- a/internal/config/session.go +++ b/internal/config/session.go @@ -11,7 +11,9 @@ type Session struct { File string PrivateMode bool + LockedLogins bool LockedControls bool + ControlProtection bool ImplicitHosting bool InactiveCursors bool MercifulReconnect bool @@ -34,11 +36,21 @@ func (Session) Init(cmd *cobra.Command) error { return err } + cmd.PersistentFlags().Bool("session.locked_logins", false, "whether logins should be locked for users initially") + if err := viper.BindPFlag("session.locked_logins", cmd.PersistentFlags().Lookup("session.locked_logins")); err != nil { + return err + } + cmd.PersistentFlags().Bool("session.locked_controls", false, "whether controls should be locked for users initially") if err := viper.BindPFlag("session.locked_controls", cmd.PersistentFlags().Lookup("session.locked_controls")); err != nil { return err } + cmd.PersistentFlags().Bool("session.control_protection", false, "users can gain control only if at least one admin is in the room") + if err := viper.BindPFlag("session.control_protection", cmd.PersistentFlags().Lookup("session.control_protection")); err != nil { + return err + } + cmd.PersistentFlags().Bool("session.implicit_hosting", true, "allow implicit control switching") if err := viper.BindPFlag("session.implicit_hosting", cmd.PersistentFlags().Lookup("session.implicit_hosting")); err != nil { return err @@ -87,7 +99,9 @@ func (s *Session) Set() { s.File = viper.GetString("session.file") s.PrivateMode = viper.GetBool("session.private_mode") + s.LockedLogins = viper.GetBool("session.locked_logins") s.LockedControls = viper.GetBool("session.locked_controls") + s.ControlProtection = viper.GetBool("session.control_protection") s.ImplicitHosting = viper.GetBool("session.implicit_hosting") s.InactiveCursors = viper.GetBool("session.inactive_cursors") s.MercifulReconnect = viper.GetBool("session.merciful_reconnect") diff --git a/internal/member/manager.go b/internal/member/manager.go index 3538657..ccf32b2 100644 --- a/internal/member/manager.go +++ b/internal/member/manager.go @@ -141,6 +141,10 @@ func (manager *MemberManagerCtx) Login(username string, password string) (types. return nil, "", err } + if !profile.IsAdmin && manager.sessions.Settings().LockedLogins { + return nil, "", types.ErrSessionLoginsLocked + } + session, ok := manager.sessions.Get(id) if ok { if session.State().IsConnected { diff --git a/internal/plugins/chat/config.go b/internal/plugins/chat/config.go new file mode 100644 index 0000000..dd24835 --- /dev/null +++ b/internal/plugins/chat/config.go @@ -0,0 +1,23 @@ +package chat + +import ( + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +type Config struct { + Enabled bool +} + +func (Config) Init(cmd *cobra.Command) error { + cmd.PersistentFlags().Bool("chat.enabled", true, "whether to enable chat plugin") + if err := viper.BindPFlag("chat.enabled", cmd.PersistentFlags().Lookup("chat.enabled")); err != nil { + return err + } + + return nil +} + +func (s *Config) Set() { + s.Enabled = viper.GetBool("chat.enabled") +} diff --git a/internal/plugins/chat/manager.go b/internal/plugins/chat/manager.go new file mode 100644 index 0000000..93e3c1b --- /dev/null +++ b/internal/plugins/chat/manager.go @@ -0,0 +1,162 @@ +package chat + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + + "github.com/demodesk/neko/pkg/auth" + "github.com/demodesk/neko/pkg/types" + "github.com/demodesk/neko/pkg/utils" +) + +func NewManager( + sessions types.SessionManager, + config *Config, +) *Manager { + logger := log.With().Str("module", "chat").Logger() + + return &Manager{ + logger: logger, + config: config, + sessions: sessions, + } +} + +type Manager struct { + logger zerolog.Logger + config *Config + sessions types.SessionManager +} + +type Settings struct { + CanSend bool `json:"can_send" mapstructure:"can_send"` + CanReceive bool `json:"can_receive" mapstructure:"can_receive"` +} + +func (m *Manager) settingsForSession(session types.Session) (Settings, error) { + settings := Settings{ + CanSend: true, // defaults to true + CanReceive: true, // defaults to true + } + err := m.sessions.Settings().Plugins.Unmarshal(PluginName, &settings) + if err != nil && !errors.Is(err, types.ErrPluginSettingsNotFound) { + return Settings{}, fmt.Errorf("unable to unmarshal %s plugin settings from global settings: %w", PluginName, err) + } + + profile := Settings{ + CanSend: true, // defaults to true + CanReceive: true, // defaults to true + } + + err = session.Profile().Plugins.Unmarshal(PluginName, &profile) + if err != nil && !errors.Is(err, types.ErrPluginSettingsNotFound) { + return Settings{}, fmt.Errorf("unable to unmarshal %s plugin settings from profile: %w", PluginName, err) + } + + return Settings{ + CanSend: m.config.Enabled && (settings.CanSend || session.Profile().IsAdmin) && profile.CanSend, + CanReceive: m.config.Enabled && (settings.CanReceive || session.Profile().IsAdmin) && profile.CanReceive, + }, nil +} + +func (m *Manager) sendMessage(session types.Session, content Content) { + now := time.Now() + + // get all sessions that have chat enabled + var sessions []types.Session + m.sessions.Range(func(s types.Session) bool { + if settings, err := m.settingsForSession(s); err == nil && settings.CanReceive { + sessions = append(sessions, s) + } + // continue iteration over all sessions + return true + }) + + // send content to all sessions + for _, s := range sessions { + s.Send(CHAT_MESSAGE, Message{ + ID: session.ID(), + Created: now, + Content: content, + }) + } +} + +func (m *Manager) Start() error { + // send init message once a user connects + m.sessions.OnConnected(func(session types.Session) { + session.Send(CHAT_INIT, Init{ + Enabled: m.config.Enabled, + }) + }) + + return nil +} + +func (m *Manager) Shutdown() error { + return nil +} + +func (m *Manager) Route(r types.Router) { + r.With(auth.AdminsOnly).Post("/", m.sendMessageHandler) +} + +func (m *Manager) WebSocketHandler(session types.Session, msg types.WebSocketMessage) bool { + switch msg.Event { + case CHAT_MESSAGE: + var content Content + if err := json.Unmarshal(msg.Payload, &content); err != nil { + m.logger.Error().Err(err).Msg("failed to unmarshal chat message") + // we processed the message, return true + return true + } + + settings, err := m.settingsForSession(session) + if err != nil { + m.logger.Error().Err(err).Msg("error checking chat permissions for this session") + // we processed the message, return true + return true + } + if !settings.CanSend { + m.logger.Warn().Msg("not allowed to send chat messages") + // we processed the message, return true + return true + } + + m.sendMessage(session, content) + return true + } + return false +} + +func (m *Manager) sendMessageHandler(w http.ResponseWriter, r *http.Request) error { + session, ok := auth.GetSession(r) + if !ok { + return utils.HttpUnauthorized("session not found") + } + + settings, err := m.settingsForSession(session) + if err != nil { + return utils.HttpInternalServerError(). + WithInternalErr(err). + Msg("error checking chat permissions for this session") + } + + if !settings.CanSend { + return utils.HttpForbidden("not allowed to send chat messages") + } + + content := Content{} + if err := utils.HttpJsonRequest(w, r, &content); err != nil { + return err + } + + m.sendMessage(session, content) + return utils.HttpSuccess(w) +} diff --git a/internal/plugins/chat/plugin.go b/internal/plugins/chat/plugin.go new file mode 100644 index 0000000..d4bc946 --- /dev/null +++ b/internal/plugins/chat/plugin.go @@ -0,0 +1,35 @@ +package chat + +import ( + "github.com/demodesk/neko/pkg/types" +) + +type Plugin struct { + config *Config + manager *Manager +} + +func NewPlugin() *Plugin { + return &Plugin{ + config: &Config{}, + } +} + +func (p *Plugin) Name() string { + return PluginName +} + +func (p *Plugin) Config() types.PluginConfig { + return p.config +} + +func (p *Plugin) Start(m types.PluginManagers) error { + p.manager = NewManager(m.SessionManager, p.config) + m.ApiManager.AddRouter("/chat", p.manager.Route) + m.WebSocketManager.AddHandler(p.manager.WebSocketHandler) + return p.manager.Start() +} + +func (p *Plugin) Shutdown() error { + return p.manager.Shutdown() +} diff --git a/internal/plugins/chat/types.go b/internal/plugins/chat/types.go new file mode 100644 index 0000000..33e9d11 --- /dev/null +++ b/internal/plugins/chat/types.go @@ -0,0 +1,24 @@ +package chat + +import "time" + +const PluginName = "chat" + +const ( + CHAT_INIT = "chat/init" + CHAT_MESSAGE = "chat/message" +) + +type Init struct { + Enabled bool `json:"enabled"` +} + +type Content struct { + Text string `json:"text"` +} + +type Message struct { + ID string `json:"id"` + Created time.Time `json:"created"` + Content Content `json:"content"` +} diff --git a/internal/plugins/filetransfer/config.go b/internal/plugins/filetransfer/config.go new file mode 100644 index 0000000..593d31f --- /dev/null +++ b/internal/plugins/filetransfer/config.go @@ -0,0 +1,41 @@ +package filetransfer + +import ( + "path/filepath" + "time" + + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +type Config struct { + Enabled bool + RootDir string + RefreshInterval time.Duration +} + +func (Config) Init(cmd *cobra.Command) error { + cmd.PersistentFlags().Bool("filetransfer.enabled", false, "whether file transfer is enabled") + if err := viper.BindPFlag("filetransfer.enabled", cmd.PersistentFlags().Lookup("filetransfer.enabled")); err != nil { + return err + } + + cmd.PersistentFlags().String("filetransfer.dir", "/home/neko/Downloads", "root directory for file transfer") + if err := viper.BindPFlag("filetransfer.dir", cmd.PersistentFlags().Lookup("filetransfer.dir")); err != nil { + return err + } + + cmd.PersistentFlags().Duration("filetransfer.refresh_interval", 30*time.Second, "interval to refresh file list") + if err := viper.BindPFlag("filetransfer.refresh_interval", cmd.PersistentFlags().Lookup("filetransfer.refresh_interval")); err != nil { + return err + } + + return nil +} + +func (s *Config) Set() { + s.Enabled = viper.GetBool("filetransfer.enabled") + rootDir := viper.GetString("filetransfer.dir") + s.RootDir = filepath.Clean(rootDir) + s.RefreshInterval = viper.GetDuration("filetransfer.refresh_interval") +} diff --git a/internal/plugins/filetransfer/manager.go b/internal/plugins/filetransfer/manager.go new file mode 100644 index 0000000..0a6603e --- /dev/null +++ b/internal/plugins/filetransfer/manager.go @@ -0,0 +1,332 @@ +package filetransfer + +import ( + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "regexp" + "sync" + "time" + + "github.com/demodesk/neko/pkg/auth" + "github.com/demodesk/neko/pkg/types" + "github.com/demodesk/neko/pkg/utils" + "github.com/fsnotify/fsnotify" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" +) + +const MULTIPART_FORM_MAX_MEMORY = 32 << 20 + +func NewManager( + sessions types.SessionManager, + config *Config, +) *Manager { + logger := log.With().Str("module", "filetransfer").Logger() + + return &Manager{ + logger: logger, + config: config, + sessions: sessions, + shutdown: make(chan struct{}), + } +} + +type Manager struct { + logger zerolog.Logger + config *Config + sessions types.SessionManager + shutdown chan struct{} + mu sync.RWMutex + fileList []Item +} + +func (m *Manager) isEnabledForSession(session types.Session) (bool, error) { + settings := Settings{ + Enabled: true, // defaults to true + } + err := m.sessions.Settings().Plugins.Unmarshal(PluginName, &settings) + if err != nil && !errors.Is(err, types.ErrPluginSettingsNotFound) { + return false, fmt.Errorf("unable to unmarshal %s plugin settings from global settings: %w", PluginName, err) + } + + profile := Settings{ + Enabled: true, // defaults to true + } + + err = session.Profile().Plugins.Unmarshal(PluginName, &profile) + if err != nil && !errors.Is(err, types.ErrPluginSettingsNotFound) { + return false, fmt.Errorf("unable to unmarshal %s plugin settings from profile: %w", PluginName, err) + } + + return m.config.Enabled && (settings.Enabled || session.Profile().IsAdmin) && profile.Enabled, nil +} + +func (m *Manager) refresh() (error, bool) { + // if file transfer is disabled, return immediately without refreshing + if !m.config.Enabled { + return nil, false + } + + files, err := ListFiles(m.config.RootDir) + if err != nil { + return err, false + } + + m.mu.Lock() + defer m.mu.Unlock() + + // check if file list has changed (todo: use hash instead of comparing all fields) + changed := false + if len(files) == len(m.fileList) { + for i, file := range files { + if file.Name != m.fileList[i].Name || file.Size != m.fileList[i].Size { + changed = true + break + } + } + } else { + changed = true + } + + m.fileList = files + return nil, changed +} + +func (m *Manager) broadcastUpdate() { + m.mu.RLock() + fileList := m.fileList + m.mu.RUnlock() + + m.sessions.Broadcast(FILETRANSFER_UPDATE, Message{ + Enabled: m.config.Enabled, + RootDir: m.config.RootDir, + Files: fileList, + }) +} + +func (m *Manager) sendUpdate(session types.Session) { + m.mu.RLock() + fileList := m.fileList + m.mu.RUnlock() + + session.Send(FILETRANSFER_UPDATE, Message{ + Enabled: m.config.Enabled, + RootDir: m.config.RootDir, + Files: fileList, + }) +} + +func (m *Manager) Start() error { + // send init message once a user connects + m.sessions.OnConnected(func(session types.Session) { + m.sendUpdate(session) + }) + + // if file transfer is disabled, return immediately without starting the watcher + if !m.config.Enabled { + return nil + } + + if _, err := os.Stat(m.config.RootDir); os.IsNotExist(err) { + err = os.Mkdir(m.config.RootDir, os.ModePerm) + m.logger.Err(err).Msg("creating file transfer directory") + } + + watcher, err := fsnotify.NewWatcher() + if err != nil { + return fmt.Errorf("unable to start file transfer dir watcher: %w", err) + } + + go func() { + defer watcher.Close() + + // periodically refresh file list + ticker := time.NewTicker(m.config.RefreshInterval) + defer ticker.Stop() + + for { + select { + case <-m.shutdown: + m.logger.Info().Msg("shutting down file transfer manager") + return + case <-ticker.C: + err, changed := m.refresh() + if err != nil { + m.logger.Err(err).Msg("unable to refresh file transfer list") + } + if changed { + m.broadcastUpdate() + } + case e, ok := <-watcher.Events: + if !ok { + m.logger.Info().Msg("file transfer dir watcher closed") + return + } + + if e.Has(fsnotify.Create) || e.Has(fsnotify.Remove) || e.Has(fsnotify.Rename) { + m.logger.Debug().Str("event", e.String()).Msg("file transfer dir watcher event") + + err, changed := m.refresh() + if err != nil { + m.logger.Err(err).Msg("unable to refresh file transfer list") + } + + if changed { + m.broadcastUpdate() + } + } + case err := <-watcher.Errors: + m.logger.Err(err).Msg("error in file transfer dir watcher") + } + } + }() + + if err := watcher.Add(m.config.RootDir); err != nil { + return fmt.Errorf("unable to watch file transfer dir: %w", err) + } + + // initial refresh + err, changed := m.refresh() + if err != nil { + return fmt.Errorf("unable to refresh file transfer list: %w", err) + } + if changed { + m.broadcastUpdate() + } + + return nil +} + +func (m *Manager) Shutdown() error { + close(m.shutdown) + return nil +} + +func (m *Manager) Route(r types.Router) { + r.With(auth.AdminsOnly).Get("/", m.downloadFileHandler) + r.With(auth.AdminsOnly).Post("/", m.uploadFileHandler) +} + +func (m *Manager) WebSocketHandler(session types.Session, msg types.WebSocketMessage) bool { + switch msg.Event { + case FILETRANSFER_UPDATE: + err, changed := m.refresh() + if err != nil { + m.logger.Err(err).Msg("unable to refresh file transfer list") + } + + if changed { + // broadcast update message to all clients + m.broadcastUpdate() + } else { + // send update message to this client only + m.sendUpdate(session) + } + return true + } + + // not handled by this plugin + return false +} + +func (m *Manager) downloadFileHandler(w http.ResponseWriter, r *http.Request) error { + session, ok := auth.GetSession(r) + if !ok { + return utils.HttpUnauthorized("session not found") + } + + enabled, err := m.isEnabledForSession(session) + if err != nil { + return utils.HttpInternalServerError(). + WithInternalErr(err). + Msg("error checking file transfer permissions") + } + + if !enabled { + return utils.HttpForbidden("file transfer is disabled") + } + + filename := r.URL.Query().Get("filename") + badChars, err := regexp.MatchString(`(?m)\.\.(?:\/|$)`, filename) + if filename == "" || badChars || err != nil { + return utils.HttpBadRequest(). + WithInternalErr(err). + Msg("bad filename") + } + + // ensure filename is clean and only contains the basename + filename = filepath.Clean(filename) + filename = filepath.Base(filename) + filePath := filepath.Join(m.config.RootDir, filename) + + http.ServeFile(w, r, filePath) + return nil +} + +func (m *Manager) uploadFileHandler(w http.ResponseWriter, r *http.Request) error { + session, ok := auth.GetSession(r) + if !ok { + return utils.HttpUnauthorized("session not found") + } + + enabled, err := m.isEnabledForSession(session) + if err != nil { + return utils.HttpInternalServerError(). + WithInternalErr(err). + Msg("error checking file transfer permissions") + } + + if !enabled { + return utils.HttpForbidden("file transfer is disabled") + } + + err = r.ParseMultipartForm(MULTIPART_FORM_MAX_MEMORY) + if err != nil || r.MultipartForm == nil { + return utils.HttpBadRequest(). + WithInternalErr(err). + Msg("error parsing form") + } + + defer func() { + err = r.MultipartForm.RemoveAll() + if err != nil { + m.logger.Warn().Err(err).Msg("failed to clean up multipart form") + } + }() + + for _, formheader := range r.MultipartForm.File["files"] { + // ensure filename is clean and only contains the basename + filename := filepath.Clean(formheader.Filename) + filename = filepath.Base(filename) + filePath := filepath.Join(m.config.RootDir, filename) + + formfile, err := formheader.Open() + if err != nil { + return utils.HttpBadRequest(). + WithInternalErr(err). + Msg("error opening formdata file") + } + defer formfile.Close() + + f, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + return utils.HttpInternalServerError(). + WithInternalErr(err). + Msg("error opening file for writing") + } + defer f.Close() + + _, err = io.Copy(f, formfile) + if err != nil { + return utils.HttpInternalServerError(). + WithInternalErr(err). + Msg("error writing file") + } + } + + return nil +} diff --git a/internal/plugins/filetransfer/plugin.go b/internal/plugins/filetransfer/plugin.go new file mode 100644 index 0000000..a98672a --- /dev/null +++ b/internal/plugins/filetransfer/plugin.go @@ -0,0 +1,35 @@ +package filetransfer + +import ( + "github.com/demodesk/neko/pkg/types" +) + +type Plugin struct { + config *Config + manager *Manager +} + +func NewPlugin() *Plugin { + return &Plugin{ + config: &Config{}, + } +} + +func (p *Plugin) Name() string { + return PluginName +} + +func (p *Plugin) Config() types.PluginConfig { + return p.config +} + +func (p *Plugin) Start(m types.PluginManagers) error { + p.manager = NewManager(m.SessionManager, p.config) + m.ApiManager.AddRouter("/filetransfer", p.manager.Route) + m.WebSocketManager.AddHandler(p.manager.WebSocketHandler) + return p.manager.Start() +} + +func (p *Plugin) Shutdown() error { + return p.manager.Shutdown() +} diff --git a/internal/plugins/filetransfer/types.go b/internal/plugins/filetransfer/types.go new file mode 100644 index 0000000..32748d4 --- /dev/null +++ b/internal/plugins/filetransfer/types.go @@ -0,0 +1,30 @@ +package filetransfer + +const PluginName = "filetransfer" + +type Settings struct { + Enabled bool `json:"enabled" mapstructure:"enabled"` +} + +const ( + FILETRANSFER_UPDATE = "filetransfer/update" +) + +type Message struct { + Enabled bool `json:"enabled"` + RootDir string `json:"root_dir"` + Files []Item `json:"files"` +} + +type ItemType string + +const ( + ItemTypeFile ItemType = "file" + ItemTypeDir ItemType = "dir" +) + +type Item struct { + Name string `json:"name"` + Type ItemType `json:"type"` + Size int64 `json:"size,omitempty"` +} diff --git a/internal/plugins/filetransfer/utils.go b/internal/plugins/filetransfer/utils.go new file mode 100644 index 0000000..c4c828a --- /dev/null +++ b/internal/plugins/filetransfer/utils.go @@ -0,0 +1,32 @@ +package filetransfer + +import "os" + +func ListFiles(path string) ([]Item, error) { + items, err := os.ReadDir(path) + if err != nil { + return nil, err + } + + out := make([]Item, len(items)) + for i, item := range items { + var itemType ItemType + var size int64 = 0 + if item.IsDir() { + itemType = ItemTypeDir + } else { + itemType = ItemTypeFile + info, err := item.Info() + if err == nil { + size = info.Size() + } + } + out[i] = Item{ + Name: item.Name(), + Type: itemType, + Size: size, + } + } + + return out, nil +} diff --git a/internal/plugins/manager.go b/internal/plugins/manager.go index 79d8cc2..24ada08 100644 --- a/internal/plugins/manager.go +++ b/internal/plugins/manager.go @@ -11,6 +11,8 @@ import ( "github.com/spf13/cobra" "github.com/demodesk/neko/internal/config" + "github.com/demodesk/neko/internal/plugins/chat" + "github.com/demodesk/neko/internal/plugins/filetransfer" "github.com/demodesk/neko/pkg/types" ) @@ -42,6 +44,10 @@ func New(config *config.Plugins) *ManagerCtx { manager.logger.Info().Msgf("loading finished, total %d plugins", manager.plugins.len()) } + // add built-in plugins + manager.plugins.addPlugin(filetransfer.NewPlugin()) + manager.plugins.addPlugin(chat.NewPlugin()) + return manager } diff --git a/internal/session/manager.go b/internal/session/manager.go index df85345..db6aaa8 100644 --- a/internal/session/manager.go +++ b/internal/session/manager.go @@ -20,7 +20,9 @@ func New(config *config.Session) *SessionManagerCtx { config: config, settings: types.Settings{ PrivateMode: config.PrivateMode, - LockedControls: config.LockedControls, + LockedLogins: config.LockedLogins, + LockedControls: config.LockedControls || config.ControlProtection, + ControlProtection: config.ControlProtection, ImplicitHosting: config.ImplicitHosting, InactiveCursors: config.InactiveCursors, MercifulReconnect: config.MercifulReconnect, @@ -120,10 +122,11 @@ func (manager *SessionManagerCtx) Update(id string, profile types.MemberProfile) return types.ErrSessionNotFound } + old := session.profile session.profile = profile manager.sessionsMu.Unlock() - manager.emmiter.Emit("profile_changed", session) + manager.emmiter.Emit("profile_changed", session, profile, old) manager.save() session.profileChanged() @@ -156,6 +159,26 @@ func (manager *SessionManagerCtx) Delete(id string) error { return nil } +func (manager *SessionManagerCtx) Disconnect(id string) error { + manager.sessionsMu.Lock() + session, ok := manager.sessions[id] + if !ok { + manager.sessionsMu.Unlock() + return types.ErrSessionNotFound + } + manager.sessionsMu.Unlock() + + if session.State().IsConnected { + session.DestroyWebSocketPeer("session disconnected") + } + + if session.State().IsWatching { + session.GetWebRTCPeer().Destroy() + } + + return nil +} + func (manager *SessionManagerCtx) Get(id string) (types.Session, bool) { manager.sessionsMu.Lock() defer manager.sessionsMu.Unlock() @@ -193,18 +216,29 @@ func (manager *SessionManagerCtx) List() []types.Session { return sessions } +func (manager *SessionManagerCtx) Range(f func(session types.Session) bool) { + manager.sessionsMu.Lock() + defer manager.sessionsMu.Unlock() + + for _, session := range manager.sessions { + if !f(session) { + return + } + } +} + // --- // host // --- -func (manager *SessionManagerCtx) SetHost(host types.Session) { +func (manager *SessionManagerCtx) setHost(session, host types.Session) { var hostId string if host != nil { hostId = host.ID() } manager.hostId.Store(hostId) - manager.emmiter.Emit("host_changed", host) + manager.emmiter.Emit("host_changed", session, host) } func (manager *SessionManagerCtx) GetHost() (types.Session, bool) { @@ -216,10 +250,6 @@ func (manager *SessionManagerCtx) GetHost() (types.Session, bool) { return manager.Get(hostId) } -func (manager *SessionManagerCtx) ClearHost() { - manager.SetHost(nil) -} - func (manager *SessionManagerCtx) isHost(host types.Session) bool { hostId, ok := manager.hostId.Load().(string) return ok && hostId == host.ID() @@ -332,9 +362,9 @@ func (manager *SessionManagerCtx) OnDisconnected(listener func(session types.Ses }) } -func (manager *SessionManagerCtx) OnProfileChanged(listener func(session types.Session)) { +func (manager *SessionManagerCtx) OnProfileChanged(listener func(session types.Session, new, old types.MemberProfile)) { manager.emmiter.On("profile_changed", func(payload ...any) { - listener(payload[0].(*SessionCtx)) + listener(payload[0].(*SessionCtx), payload[1].(types.MemberProfile), payload[2].(types.MemberProfile)) }) } @@ -344,19 +374,19 @@ func (manager *SessionManagerCtx) OnStateChanged(listener func(session types.Ses }) } -func (manager *SessionManagerCtx) OnHostChanged(listener func(session types.Session)) { +func (manager *SessionManagerCtx) OnHostChanged(listener func(session, host types.Session)) { manager.emmiter.On("host_changed", func(payload ...any) { - if payload[0] == nil { - listener(nil) + if payload[1] == nil { + listener(payload[0].(*SessionCtx), nil) } else { - listener(payload[0].(*SessionCtx)) + listener(payload[0].(*SessionCtx), payload[1].(*SessionCtx)) } }) } -func (manager *SessionManagerCtx) OnSettingsChanged(listener func(new types.Settings, old types.Settings)) { +func (manager *SessionManagerCtx) OnSettingsChanged(listener func(session types.Session, new, old types.Settings)) { manager.emmiter.On("settings_changed", func(payload ...any) { - listener(payload[0].(types.Settings), payload[1].(types.Settings)) + listener(payload[0].(types.Session), payload[1].(types.Settings), payload[2].(types.Settings)) }) } @@ -364,40 +394,68 @@ func (manager *SessionManagerCtx) OnSettingsChanged(listener func(new types.Sett // settings // --- -func (manager *SessionManagerCtx) UpdateSettings(new types.Settings) { +func (manager *SessionManagerCtx) UpdateSettingsFunc(session types.Session, f func(settings *types.Settings) bool) { manager.settingsMu.Lock() - old := manager.settings - manager.settings = new + new := manager.settings + if f(&new) { + old := manager.settings + manager.settings = new + manager.settingsMu.Unlock() + manager.updateSettings(session, new, old) + return + } manager.settingsMu.Unlock() +} +func (manager *SessionManagerCtx) updateSettings(session types.Session, new, old types.Settings) { // if private mode changed if old.PrivateMode != new.PrivateMode { // update webrtc paused state for all sessions - for _, session := range manager.List() { - enabled := session.PrivateModeEnabled() + for _, s := range manager.List() { + enabled := s.PrivateModeEnabled() // if session had control, it must release it - if enabled && session.IsHost() { - manager.ClearHost() + if enabled && s.IsHost() { + session.ClearHost() } // its webrtc connection will be paused or unpaused - if webrtcPeer := session.GetWebRTCPeer(); webrtcPeer != nil { + if webrtcPeer := s.GetWebRTCPeer(); webrtcPeer != nil { webrtcPeer.SetPaused(enabled) } } } + // if control protection changed and controls are not locked + if old.ControlProtection != new.ControlProtection && new.ControlProtection && !new.LockedControls { + // if there is no admin, lock controls + hasAdmin := false + manager.Range(func(session types.Session) bool { + if session.Profile().IsAdmin && session.State().IsConnected { + hasAdmin = true + return false + } + return true + }) + + if !hasAdmin { + manager.settingsMu.Lock() + manager.settings.LockedControls = true + new.LockedControls = true + manager.settingsMu.Unlock() + } + } + // if contols have been locked if old.LockedControls != new.LockedControls && new.LockedControls { // if the host is not admin, it must release controls host, hasHost := manager.GetHost() if hasHost && !host.Profile().IsAdmin { - manager.ClearHost() + session.ClearHost() } } - manager.emmiter.Emit("settings_changed", new, old) + manager.emmiter.Emit("settings_changed", session, new, old) } func (manager *SessionManagerCtx) Settings() types.Settings { diff --git a/internal/session/session.go b/internal/session/session.go index 716f111..2a858f6 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -43,7 +43,7 @@ func (session *SessionCtx) Profile() types.MemberProfile { func (session *SessionCtx) profileChanged() { if !session.profile.CanHost && session.IsHost() { - session.manager.ClearHost() + session.ClearHost() } if (!session.profile.CanConnect || !session.profile.CanLogin || !session.profile.CanWatch) && session.state.IsWatching { @@ -68,6 +68,18 @@ func (session *SessionCtx) IsHost() bool { return session.manager.isHost(session) } +func (session *SessionCtx) SetAsHost() { + session.manager.setHost(session, session) +} + +func (session *SessionCtx) SetAsHostBy(host types.Session) { + session.manager.setHost(session, host) +} + +func (session *SessionCtx) ClearHost() { + session.manager.setHost(session, nil) +} + func (session *SessionCtx) PrivateModeEnabled() bool { return session.manager.Settings().PrivateMode && !session.profile.IsAdmin } @@ -82,10 +94,8 @@ func (session *SessionCtx) SetCursor(cursor types.Cursor) { // websocket // --- -// // Connect WebSocket peer sets current peer and emits connected event. It also destroys the // previous peer, if there was one. If the peer is already set, it will be ignored. -// func (session *SessionCtx) ConnectWebSocketPeer(websocketPeer types.WebSocketPeer) { session.websocketMu.Lock() isCurrentPeer := websocketPeer == session.websocketPeer @@ -113,14 +123,12 @@ func (session *SessionCtx) ConnectWebSocketPeer(websocketPeer types.WebSocketPee } } -// // Disconnect WebSocket peer sets current peer to nil and emits disconnected event. It also // allows for a delayed disconnect. That means, the peer will not be disconnected immediately, // but after a delay. If the peer is connected again before the delay, the disconnect will be // cancelled. // // If the peer is not the current peer or the peer is nil, it will be ignored. -// func (session *SessionCtx) DisconnectWebSocketPeer(websocketPeer types.WebSocketPeer, delayed bool) { session.websocketMu.Lock() isCurrentPeer := websocketPeer == session.websocketPeer && websocketPeer != nil @@ -175,10 +183,8 @@ func (session *SessionCtx) DisconnectWebSocketPeer(websocketPeer types.WebSocket session.websocketMu.Unlock() } -// // Destroy WebSocket peer disconnects the peer and destroys it. It ensures that the peer is // disconnected immediately even though normal flow would be to disconnect it delayed. -// func (session *SessionCtx) DestroyWebSocketPeer(reason string) { session.websocketMu.Lock() peer := session.websocketPeer @@ -195,9 +201,7 @@ func (session *SessionCtx) DestroyWebSocketPeer(reason string) { peer.Destroy(reason) } -// // Send event to websocket peer. -// func (session *SessionCtx) Send(event string, payload any) { session.websocketMu.Lock() peer := session.websocketPeer @@ -212,9 +216,7 @@ func (session *SessionCtx) Send(event string, payload any) { // webrtc // --- -// // Set webrtc peer and destroy the old one, if there is old one. -// func (session *SessionCtx) SetWebRTCPeer(webrtcPeer types.WebRTCPeer) { session.webrtcMu.Lock() session.webrtcPeer, webrtcPeer = webrtcPeer, session.webrtcPeer @@ -225,14 +227,12 @@ func (session *SessionCtx) SetWebRTCPeer(webrtcPeer types.WebRTCPeer) { } } -// // Set if current webrtc peer is connected or not. Since there might be lefover calls from // webrtc peer, that are not used anymore, we need to check if the webrtc peer is still the // same as the one we are setting the connected state for. // // If webrtc peer is disconnected, we don't expect it to be reconnected, so we set it to nil // and send a signal close to the client. New connection is expected to use a new webrtc peer. -// func (session *SessionCtx) SetWebRTCConnected(webrtcPeer types.WebRTCPeer, connected bool) { session.webrtcMu.Lock() isCurrentPeer := webrtcPeer == session.webrtcPeer @@ -274,9 +274,7 @@ func (session *SessionCtx) SetWebRTCConnected(webrtcPeer types.WebRTCPeer, conne } } -// // Get current WebRTC peer. Nil if not connected. -// func (session *SessionCtx) GetWebRTCPeer() types.WebRTCPeer { session.webrtcMu.Lock() defer session.webrtcMu.Unlock() diff --git a/internal/websocket/handler/control.go b/internal/websocket/handler/control.go index e92add1..80e90c6 100644 --- a/internal/websocket/handler/control.go +++ b/internal/websocket/handler/control.go @@ -26,7 +26,7 @@ func (h *MessageHandlerCtx) controlRelease(session types.Session) error { } h.desktop.ResetKeys() - h.sessions.ClearHost() + session.ClearHost() return nil } @@ -44,23 +44,29 @@ func (h *MessageHandlerCtx) controlRequest(session types.Session) error { return ErrIsNotAllowedToHost } - if !h.sessions.Settings().ImplicitHosting { - // tell session if there is a host - if host, hasHost := h.sessions.GetHost(); hasHost { - session.Send( - event.CONTROL_HOST, - message.ControlHost{ - HasHost: true, - HostID: host.ID(), - }) + // if implicit hosting is enabled, set session as host without asking + if h.sessions.Settings().ImplicitHosting { + session.SetAsHost() + return nil + } - return ErrIsAlreadyHosted - } + // if there is no host, set session as host + host, hasHost := h.sessions.GetHost() + if !hasHost { + session.SetAsHost() + return nil } - h.sessions.SetHost(session) + // TODO: Some throttling mechanism to prevent spamming. - return nil + // let host know that someone wants to take control + host.Send( + event.CONTROL_REQUEST, + message.SessionID{ + ID: session.ID(), + }) + + return ErrIsAlreadyHosted } func (h *MessageHandlerCtx) controlMove(session types.Session, payload *message.ControlPos) error { diff --git a/internal/websocket/handler/screen.go b/internal/websocket/handler/screen.go index 8d88ac9..4419cd9 100644 --- a/internal/websocket/handler/screen.go +++ b/internal/websocket/handler/screen.go @@ -13,20 +13,14 @@ func (h *MessageHandlerCtx) screenSet(session types.Session, payload *message.Sc return errors.New("is not the admin") } - size, err := h.desktop.SetScreenSize(types.ScreenSize{ - Width: payload.Width, - Height: payload.Height, - Rate: payload.Rate, - }) - + size, err := h.desktop.SetScreenSize(payload.ScreenSize) if err != nil { return err } - h.sessions.Broadcast(event.SCREEN_UPDATED, message.ScreenSize{ - Width: size.Width, - Height: size.Height, - Rate: size.Rate, + h.sessions.Broadcast(event.SCREEN_UPDATED, message.ScreenSizeUpdate{ + ID: session.ID(), + ScreenSize: size, }) return nil } diff --git a/internal/websocket/handler/session.go b/internal/websocket/handler/session.go index 8d62102..321f634 100644 --- a/internal/websocket/handler/session.go +++ b/internal/websocket/handler/session.go @@ -37,6 +37,16 @@ func (h *MessageHandlerCtx) SessionConnected(session types.Session) error { if err := h.systemAdmin(session); err != nil { return err } + + // update settings in atomic way + h.sessions.UpdateSettingsFunc(session, func(settings *types.Settings) bool { + // if control protection & locked controls: unlock controls + if settings.LockedControls && settings.ControlProtection { + settings.LockedControls = false + return true // update settings + } + return false // do not update settings + }) } return h.SessionStateChanged(session) @@ -46,18 +56,39 @@ func (h *MessageHandlerCtx) SessionDisconnected(session types.Session) error { // clear host if exists if session.IsHost() { h.desktop.ResetKeys() - h.sessions.ClearHost() + session.ClearHost() + } + + if session.Profile().IsAdmin { + hasAdmin := false + h.sessions.Range(func(s types.Session) bool { + if s.Profile().IsAdmin && s.ID() != session.ID() && s.State().IsConnected { + hasAdmin = true + return false + } + return true + }) + + // update settings in atomic way + h.sessions.UpdateSettingsFunc(session, func(settings *types.Settings) bool { + // if control protection & not locked controls & no admin: lock controls + if !settings.LockedControls && settings.ControlProtection && !hasAdmin { + settings.LockedControls = true + return true // update settings + } + return false // do not update settings + }) } return h.SessionStateChanged(session) } -func (h *MessageHandlerCtx) SessionProfileChanged(session types.Session) error { +func (h *MessageHandlerCtx) SessionProfileChanged(session types.Session, new, old types.MemberProfile) error { h.sessions.Broadcast( event.SESSION_PROFILE, message.MemberProfile{ ID: session.ID(), - MemberProfile: session.Profile(), + MemberProfile: new, }) return nil diff --git a/internal/websocket/handler/system.go b/internal/websocket/handler/system.go index adf375d..3ec6048 100644 --- a/internal/websocket/handler/system.go +++ b/internal/websocket/handler/system.go @@ -22,13 +22,6 @@ func (h *MessageHandlerCtx) systemInit(session types.Session) error { HostID: hostID, } - size := h.desktop.GetScreenSize() - screenSize := message.ScreenSize{ - Width: size.Width, - Height: size.Height, - Rate: size.Rate, - } - sessions := map[string]message.SessionData{} for _, session := range h.sessions.List() { sessionId := session.ID() @@ -44,7 +37,7 @@ func (h *MessageHandlerCtx) systemInit(session types.Session) error { message.SystemInit{ SessionId: session.ID(), ControlHost: controlHost, - ScreenSize: screenSize, + ScreenSize: h.desktop.GetScreenSize(), Sessions: sessions, Settings: h.sessions.Settings(), TouchEvents: h.desktop.HasTouchSupport(), @@ -60,9 +53,9 @@ func (h *MessageHandlerCtx) systemInit(session types.Session) error { func (h *MessageHandlerCtx) systemAdmin(session types.Session) error { configurations := h.desktop.ScreenConfigurations() - list := make([]message.ScreenSize, 0, len(configurations)) + list := make([]types.ScreenSize, 0, len(configurations)) for _, conf := range configurations { - list = append(list, message.ScreenSize{ + list = append(list, types.ScreenSize{ Width: conf.Width, Height: conf.Height, Rate: conf.Rate, diff --git a/internal/websocket/manager.go b/internal/websocket/manager.go index 9443aa6..d272172 100644 --- a/internal/websocket/manager.go +++ b/internal/websocket/manager.go @@ -96,10 +96,12 @@ func (manager *WebSocketManagerCtx) Start() { Msg("session disconnected") }) - manager.sessions.OnProfileChanged(func(session types.Session) { - err := manager.handler.SessionProfileChanged(session) + manager.sessions.OnProfileChanged(func(session types.Session, new, old types.MemberProfile) { + err := manager.handler.SessionProfileChanged(session, new, old) manager.logger.Err(err). Str("session_id", session.ID()). + Interface("new", new). + Interface("old", old). Msg("session profile changed") }) @@ -110,24 +112,26 @@ func (manager *WebSocketManagerCtx) Start() { Msg("session state changed") }) - manager.sessions.OnHostChanged(func(session types.Session) { + manager.sessions.OnHostChanged(func(session, host types.Session) { payload := message.ControlHost{ - HasHost: session != nil, + ID: session.ID(), + HasHost: host != nil, } if payload.HasHost { - payload.HostID = session.ID() + payload.HostID = host.ID() } manager.sessions.Broadcast(event.CONTROL_HOST, payload) manager.logger.Info(). + Str("session_id", session.ID()). Bool("has_host", payload.HasHost). Str("host_id", payload.HostID). Msg("session host changed") }) - manager.sessions.OnSettingsChanged(func(new types.Settings, old types.Settings) { + manager.sessions.OnSettingsChanged(func(session types.Session, new, old types.Settings) { // start inactive cursors if new.InactiveCursors && !old.InactiveCursors { manager.startInactiveCursors() @@ -138,8 +142,13 @@ func (manager *WebSocketManagerCtx) Start() { manager.stopInactiveCursors() } - manager.sessions.Broadcast(event.SYSTEM_SETTINGS, new) + manager.sessions.Broadcast(event.SYSTEM_SETTINGS, message.SystemSettingsUpdate{ + ID: session.ID(), + Settings: new, + }) + manager.logger.Info(). + Str("session_id", session.ID()). Interface("new", new). Interface("old", old). Msg("settings changed") diff --git a/openapi.yaml b/openapi.yaml index 2fa9e78..b7ba442 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -13,8 +13,8 @@ servers: url: http://localhost:3000 tags: - - name: session - description: Session management. + - name: sessions + description: Sessions management. - name: room description: Room releated operations. - name: members @@ -61,13 +61,11 @@ paths: required: true # - # session + # current session # /api/login: post: - tags: - - session summary: login operationId: login security: [] @@ -90,8 +88,6 @@ paths: required: true /api/logout: post: - tags: - - session summary: logout operationId: logout responses: @@ -101,8 +97,6 @@ paths: $ref: '#/components/responses/Unauthorized' /api/whoami: get: - tags: - - session summary: whoami operationId: whoami responses: @@ -116,10 +110,15 @@ paths: $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' + + # + # sessions + # + /api/sessions: get: tags: - - session + - sessions summary: get sessions operationId: sessionsGet responses: @@ -135,6 +134,75 @@ paths: $ref: '#/components/responses/Unauthorized' '403': $ref: '#/components/responses/Forbidden' + /api/sessions/{sessionId}: + get: + tags: + - sessions + summary: get session + operationId: sessionGet + parameters: + - in: path + name: sessionId + description: session identifier + required: true + schema: + type: string + responses: + '200': + description: OK + content: + application/json: + schema: + $ref: '#/components/schemas/SessionData' + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + delete: + tags: + - sessions + summary: remove session + operationId: sessionRemove + parameters: + - in: path + name: sessionId + description: session identifier + required: true + schema: + type: string + responses: + '204': + description: OK + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' + /api/sessions/{sessionId}/disconnect: + post: + tags: + - sessions + summary: disconnect session + operationId: sessionDisconnect + parameters: + - in: path + name: sessionId + description: session identifier + required: true + schema: + type: string + responses: + '204': + description: OK + '401': + $ref: '#/components/responses/Unauthorized' + '403': + $ref: '#/components/responses/Forbidden' + '404': + $ref: '#/components/responses/NotFound' # # room @@ -1023,7 +1091,7 @@ components: type: integer # - # session + # sessions # SessionLogin: diff --git a/pkg/auth/auth_test.go b/pkg/auth/auth_test.go index fa9b89e..dd22d47 100644 --- a/pkg/auth/auth_test.go +++ b/pkg/auth/auth_test.go @@ -97,7 +97,7 @@ func TestHostsOnly(t *testing.T) { } // r2 is host - sessionManager.SetHost(session) + session.SetAsHost() r3, _, err := rWithSession(types.MemberProfile{CanHost: false}) if err != nil { @@ -224,9 +224,11 @@ func TestCanHostOnly(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - settings := sessionManager.Settings() - settings.PrivateMode = tt.privateMode - sessionManager.UpdateSettings(settings) + session, _ := GetSession(tt.r) + sessionManager.UpdateSettingsFunc(session, func(s *types.Settings) bool { + s.PrivateMode = tt.privateMode + return true + }) _, err := CanHostOnly(nil, tt.r) if (err != nil) != tt.wantErr { diff --git a/pkg/types/desktop.go b/pkg/types/desktop.go index b769909..0cc4cae 100644 --- a/pkg/types/desktop.go +++ b/pkg/types/desktop.go @@ -15,9 +15,9 @@ type CursorImage struct { } type ScreenSize struct { - Width int - Height int - Rate int16 + Width int `json:"width"` + Height int `json:"height"` + Rate int16 `json:"rate"` } func (s ScreenSize) String() string { diff --git a/pkg/types/member.go b/pkg/types/member.go index 4463f4d..e355806 100644 --- a/pkg/types/member.go +++ b/pkg/types/member.go @@ -23,7 +23,7 @@ type MemberProfile struct { CanSeeInactiveCursors bool `json:"can_see_inactive_cursors" mapstructure:"can_see_inactive_cursors"` // plugin scope - Plugins map[string]any `json:"plugins"` + Plugins PluginSettings `json:"plugins"` } type MemberProvider interface { diff --git a/pkg/types/message/messages.go b/pkg/types/message/messages.go index c2121de..a35c07d 100644 --- a/pkg/types/message/messages.go +++ b/pkg/types/message/messages.go @@ -17,7 +17,7 @@ type SystemWebRTC struct { type SystemInit struct { SessionId string `json:"session_id"` ControlHost ControlHost `json:"control_host"` - ScreenSize ScreenSize `json:"screen_size"` + ScreenSize types.ScreenSize `json:"screen_size"` Sessions map[string]SessionData `json:"sessions"` Settings types.Settings `json:"settings"` TouchEvents bool `json:"touch_events"` @@ -26,8 +26,8 @@ type SystemInit struct { } type SystemAdmin struct { - ScreenSizesList []ScreenSize `json:"screen_sizes_list"` - BroadcastStatus BroadcastStatus `json:"broadcast_status"` + ScreenSizesList []types.ScreenSize `json:"screen_sizes_list"` + BroadcastStatus BroadcastStatus `json:"broadcast_status"` } type SystemLogs = []SystemLog @@ -42,6 +42,11 @@ type SystemDisconnect struct { Message string `json:"message"` } +type SystemSettingsUpdate struct { + ID string `json:"id"` + types.Settings +} + ///////////////////////////// // Signal ///////////////////////////// @@ -111,6 +116,7 @@ type SessionCursors struct { ///////////////////////////// type ControlHost struct { + ID string `json:"id"` HasHost bool `json:"has_host"` HostID string `json:"host_id,omitempty"` } @@ -151,9 +157,12 @@ type ControlTouch struct { ///////////////////////////// type ScreenSize struct { - Width int `json:"width"` - Height int `json:"height"` - Rate int16 `json:"rate"` + types.ScreenSize +} + +type ScreenSizeUpdate struct { + ID string `json:"id"` + types.ScreenSize } ///////////////////////////// diff --git a/pkg/types/plugins.go b/pkg/types/plugins.go index fa57b8f..c0f271f 100644 --- a/pkg/types/plugins.go +++ b/pkg/types/plugins.go @@ -2,10 +2,17 @@ package types import ( "errors" + "fmt" + "strings" + "github.com/demodesk/neko/pkg/utils" "github.com/spf13/cobra" ) +var ( + ErrPluginSettingsNotFound = errors.New("plugin settings not found") +) + type Plugin interface { Name() string Config() PluginConfig @@ -61,3 +68,24 @@ func (p *PluginManagers) Validate() error { return nil } + +type PluginSettings map[string]any + +func (p PluginSettings) Unmarshal(name string, def any) error { + if p == nil { + return fmt.Errorf("%w: %s", ErrPluginSettingsNotFound, name) + } + // loop through the plugin settings and take only the one that starts with the name + // because the settings are stored in a map["plugin_name.setting_name"] = value + newMap := make(map[string]any) + for k, v := range p { + if strings.HasPrefix(k, name+".") { + newMap[strings.TrimPrefix(k, name+".")] = v + } + } + fmt.Printf("newMap: %+v\n", newMap) + if len(newMap) == 0 { + return fmt.Errorf("%w: %s", ErrPluginSettingsNotFound, name) + } + return utils.Decode(newMap, def) +} diff --git a/pkg/types/session.go b/pkg/types/session.go index e901ad7..4ec4085 100644 --- a/pkg/types/session.go +++ b/pkg/types/session.go @@ -11,6 +11,7 @@ var ( ErrSessionAlreadyExists = errors.New("session already exists") ErrSessionAlreadyConnected = errors.New("session is already connected") ErrSessionLoginDisabled = errors.New("session login disabled") + ErrSessionLoginsLocked = errors.New("session logins locked") ) type Cursor struct { @@ -40,13 +41,15 @@ type SessionState struct { type Settings struct { PrivateMode bool `json:"private_mode"` + LockedLogins bool `json:"locked_logins"` LockedControls bool `json:"locked_controls"` + ControlProtection bool `json:"control_protection"` ImplicitHosting bool `json:"implicit_hosting"` InactiveCursors bool `json:"inactive_cursors"` MercifulReconnect bool `json:"merciful_reconnect"` // plugin scope - Plugins map[string]any `json:"plugins"` + Plugins PluginSettings `json:"plugins"` } type Session interface { @@ -54,6 +57,9 @@ type Session interface { Profile() MemberProfile State() SessionState IsHost() bool + SetAsHost() + SetAsHostBy(session Session) + ClearHost() PrivateModeEnabled() bool // cursor @@ -75,13 +81,13 @@ type SessionManager interface { Create(id string, profile MemberProfile) (Session, string, error) Update(id string, profile MemberProfile) error Delete(id string) error + Disconnect(id string) error Get(id string) (Session, bool) GetByToken(token string) (Session, bool) List() []Session + Range(func(Session) bool) - SetHost(host Session) GetHost() (Session, bool) - ClearHost() SetCursor(cursor Cursor, session Session) PopCursors() map[Session][]Cursor @@ -94,12 +100,12 @@ type SessionManager interface { OnDeleted(listener func(session Session)) OnConnected(listener func(session Session)) OnDisconnected(listener func(session Session)) - OnProfileChanged(listener func(session Session)) + OnProfileChanged(listener func(session Session, new, old MemberProfile)) OnStateChanged(listener func(session Session)) - OnHostChanged(listener func(session Session)) - OnSettingsChanged(listener func(new Settings, old Settings)) + OnHostChanged(listener func(session, host Session)) + OnSettingsChanged(listener func(session Session, new, old Settings)) - UpdateSettings(Settings) + UpdateSettingsFunc(session Session, f func(settings *Settings) bool) Settings() Settings CookieEnabled() bool diff --git a/pkg/utils/json.go b/pkg/utils/deocde.go similarity index 81% rename from pkg/utils/json.go rename to pkg/utils/deocde.go index ff4883b..12aeaec 100644 --- a/pkg/utils/json.go +++ b/pkg/utils/deocde.go @@ -3,8 +3,14 @@ package utils import ( "encoding/json" "reflect" + + "github.com/mitchellh/mapstructure" ) +func Decode(input interface{}, output interface{}) error { + return mapstructure.Decode(input, output) +} + func Unmarshal(in any, raw []byte, callback func() error) error { if err := json.Unmarshal(raw, &in); err != nil { return err