diff --git a/input/all/all.go b/input/all/all.go index f00c86b..8c5eeed 100644 --- a/input/all/all.go +++ b/input/all/all.go @@ -4,4 +4,5 @@ package all import ( _ "github.com/noriah/catnip/input/ffmpeg" _ "github.com/noriah/catnip/input/parec" + _ "github.com/noriah/catnip/input/pipewire" ) diff --git a/input/pipewire/dump.go b/input/pipewire/dump.go new file mode 100644 index 0000000..11b0378 --- /dev/null +++ b/input/pipewire/dump.go @@ -0,0 +1,140 @@ +package pipewire + +import ( + "context" + "encoding/json" + "os" + "os/exec" + + "github.com/pkg/errors" +) + +type pwObjects []pwObject + +func pwDump(ctx context.Context) (pwObjects, error) { + cmd := exec.CommandContext(ctx, "pw-dump") + cmd.Stderr = os.Stderr + + dumpOutput, err := cmd.Output() + if err != nil { + var execErr *exec.ExitError + if errors.As(err, &execErr) { + return nil, errors.Wrapf(err, "failed to run pw-dump: %s", execErr.Stderr) + } + return nil, errors.Wrap(err, "failed to run pw-dump") + } + + var dump pwObjects + if err := json.Unmarshal(dumpOutput, &dump); err != nil { + return nil, errors.Wrap(err, "failed to parse pw-dump output") + } + + return dump, nil +} + +// Filter filters for the devices that satisfies f. +func (d pwObjects) Filter(fns ...func(pwObject) bool) pwObjects { + filtered := make(pwObjects, 0, len(d)) +loop: + for _, device := range d { + for _, f := range fns { + if !f(device) { + continue loop + } + } + filtered = append(filtered, device) + } + return filtered +} + +// Find returns the first object that satisfies f. +func (d pwObjects) Find(f func(pwObject) bool) *pwObject { + for i, device := range d { + if f(device) { + return &d[i] + } + } + return nil +} + +// ResolvePorts returns all PipeWire port objects that belong to the given +// object. +func (d pwObjects) ResolvePorts(object *pwObject, dir pwPortDirection) pwObjects { + return d.Filter( + func(o pwObject) bool { return o.Type == pwInterfacePort }, + func(o pwObject) bool { + return o.Info.Props.NodeID == object.ID && o.Info.Props.PortDirection == dir + }, + ) +} + +type pwObjectID int64 + +type pwObjectType string + +const ( + pwInterfaceDevice pwObjectType = "PipeWire:Interface:Device" + pwInterfaceNode pwObjectType = "PipeWire:Interface:Node" + pwInterfacePort pwObjectType = "PipeWire:Interface:Port" + pwInterfaceLink pwObjectType = "PipeWire:Interface:Link" +) + +type pwObject struct { + ID pwObjectID `json:"id"` + Type pwObjectType `json:"type"` + Info struct { + Props pwInfoProps `json:"props"` + } `json:"info"` +} + +type pwInfoProps struct { + pwDeviceProps + pwNodeProps + pwPortProps + MediaClass string `json:"media.class"` + + JSON json.RawMessage `json:"-"` +} + +func (p *pwInfoProps) UnmarshalJSON(data []byte) error { + type Alias pwInfoProps + if err := json.Unmarshal(data, (*Alias)(p)); err != nil { + return err + } + p.JSON = append([]byte(nil), data...) + return nil +} + +type pwDeviceProps struct { + DeviceName string `json:"device.name"` +} + +// pwNodeProps is for Audio/Sink only. +type pwNodeProps struct { + NodeName string `json:"node.name"` + NodeNick string `json:"node.nick"` + NodeDescription string `json:"node.description"` +} + +// Constants for MediaClass. +const ( + pwAudioDevice string = "Audio/Device" + pwAudioSink string = "Audio/Sink" + pwStreamOutputAudio string = "Stream/Output/Audio" +) + +type pwPortDirection string + +const ( + pwPortIn = "in" + pwPortOut = "out" +) + +type pwPortProps struct { + PortID pwObjectID `json:"port.id"` + PortName string `json:"port.name"` + PortAlias string `json:"port.alias"` + PortDirection pwPortDirection `json:"port.direction"` + NodeID pwObjectID `json:"node.id"` + ObjectPath string `json:"object.path"` +} diff --git a/input/pipewire/link.go b/input/pipewire/link.go new file mode 100644 index 0000000..6848d7d --- /dev/null +++ b/input/pipewire/link.go @@ -0,0 +1,127 @@ +package pipewire + +import ( + "bufio" + "context" + "fmt" + "os" + "os/exec" + "strconv" + "strings" + + "github.com/pkg/errors" +) + +func pwLink(outPortID, inPortID pwObjectID) error { + cmd := exec.Command("pw-link", "-L", fmt.Sprint(outPortID), fmt.Sprint(inPortID)) + if err := cmd.Run(); err != nil { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) && exitErr.Stderr != nil { + return errors.Wrapf(err, "failed to run pw-link: %s", exitErr.Stderr) + } + return err + } + return nil +} + +type pwLinkObject struct { + ID pwObjectID + DeviceName string + PortName string // usually like {input,output}_{FL,FR} +} + +func pwLinkObjectParse(line string) (pwLinkObject, error) { + var obj pwLinkObject + + idStr, portStr, ok := strings.Cut(line, " ") + if !ok { + return obj, fmt.Errorf("failed to parse pw-link object %q", line) + } + + id, err := strconv.Atoi(idStr) + if err != nil { + return obj, errors.Wrapf(err, "failed to parse pw-link object id %q", idStr) + } + + name, port, ok := strings.Cut(portStr, ":") + if !ok { + return obj, fmt.Errorf("failed to parse pw-link port string %q", portStr) + } + + obj = pwLinkObject{ + ID: pwObjectID(id), + DeviceName: name, + PortName: port, + } + + return obj, nil +} + +type pwLinkType string + +const ( + pwLinkInputPorts pwLinkType = "i" + pwLinkOutputPorts pwLinkType = "o" +) + +type pwLinkEvent interface { + pwLinkEvent() +} + +type pwLinkAdd pwLinkObject +type pwLinkRemove pwLinkObject + +func (pwLinkAdd) pwLinkEvent() {} +func (pwLinkRemove) pwLinkEvent() {} + +func pwLinkMonitor(ctx context.Context, typ pwLinkType, ch chan<- pwLinkEvent) error { + cmd := exec.CommandContext(ctx, "pw-link", "-mI"+string(typ)) + cmd.Stderr = os.Stderr + + o, err := cmd.StdoutPipe() + if err != nil { + return errors.Wrap(err, "failed to get stdout pipe") + } + defer o.Close() + + if err := cmd.Start(); err != nil { + return errors.Wrap(err, "pw-link -m") + } + + scanner := bufio.NewScanner(o) + for scanner.Scan() { + line := scanner.Text() + if line == "" { + continue + } + + mark := line[0] + + line = strings.TrimSpace(line[1:]) + + obj, err := pwLinkObjectParse(line) + if err != nil { + continue + } + + var ev pwLinkEvent + switch mark { + case '=': + fallthrough + case '+': + ev = pwLinkAdd(obj) + case '-': + ev = pwLinkRemove(obj) + default: + continue + } + + select { + case <-ctx.Done(): + return ctx.Err() + case ch <- ev: + } + } + + return errors.Wrap(cmd.Wait(), "pw-link exited") +} diff --git a/input/pipewire/pipewire.go b/input/pipewire/pipewire.go new file mode 100644 index 0000000..f980c4c --- /dev/null +++ b/input/pipewire/pipewire.go @@ -0,0 +1,267 @@ +package pipewire + +import ( + "context" + "encoding/base64" + "encoding/binary" + "encoding/json" + "fmt" + "log" + "os" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/noriah/catnip/input" + "github.com/noriah/catnip/input/common/execread" + "github.com/pkg/errors" +) + +func init() { + input.RegisterBackend("pipewire", Backend{}) +} + +type Backend struct{} + +func (p Backend) Init() error { + return nil +} + +func (p Backend) Close() error { + return nil +} + +func (p Backend) Devices() ([]input.Device, error) { + pwObjs, err := pwDump(context.Background()) + if err != nil { + return nil, err + } + + pwSinks := pwObjs.Filter(func(o pwObject) bool { + return o.Type == pwInterfaceNode && + o.Info.Props.MediaClass == pwAudioSink || + o.Info.Props.MediaClass == pwStreamOutputAudio + }) + + devices := make([]input.Device, len(pwSinks)) + for i, device := range pwSinks { + devices[i] = AudioDevice{device.Info.Props.NodeName} + } + + return devices, nil +} + +func (p Backend) DefaultDevice() (input.Device, error) { + return AudioDevice{"auto"}, nil +} + +func (p Backend) Start(cfg input.SessionConfig) (input.Session, error) { + return NewSession(cfg) +} + +type AudioDevice struct { + name string +} + +func (d AudioDevice) String() string { + return d.name +} + +type catnipProps struct { + ApplicationName string `json:"application.name"` + CatnipID string `json:"catnip.id"` +} + +// Session is a PipeWire session. +type Session struct { + session execread.Session + props catnipProps + targetName string +} + +// NewSession creates a new PipeWire session. +func NewSession(cfg input.SessionConfig) (*Session, error) { + currentProps := catnipProps{ + ApplicationName: "catnip", + CatnipID: generateID(), + } + + propsJSON, err := json.Marshal(currentProps) + if err != nil { + return nil, errors.Wrap(err, "failed to marshal props") + } + + dv, ok := cfg.Device.(AudioDevice) + if !ok { + return nil, fmt.Errorf("invalid device type %T", cfg.Device) + } + + target := "0" + if dv.name == "auto" { + target = dv.name + } + + args := []string{ + "pw-cat", + "--record", + "--format", "f32", + "--rate", fmt.Sprint(cfg.SampleRate), + "--latency", fmt.Sprint(cfg.SampleSize), + "--channels", fmt.Sprint(cfg.FrameSize), + "--target", target, // see .relink comment below + "--quality", "0", + "--media-category", "Capture", + "--media-role", "DSP", + "--properties", string(propsJSON), + "-", + } + + return &Session{ + session: *execread.NewSession(args, true, cfg), + props: currentProps, + targetName: dv.name, + }, nil +} + +// Start starts the session. It implements input.Session. +func (s *Session) Start(ctx context.Context, dst [][]input.Sample, kickChan chan bool, mu *sync.Mutex) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + errCh := make(chan error, 1) + setErr := func(err error) { + select { + case errCh <- err: + default: + } + } + + var wg sync.WaitGroup + defer wg.Wait() + + wg.Add(1) + go func() { + defer wg.Done() + setErr(s.session.Start(ctx, dst, kickChan, mu)) + }() + + // No relinking needed if we're not connecting to a specific device. + if s.targetName != "auto" { + wg.Add(1) + go func() { + defer wg.Done() + setErr(s.startRelinker(ctx)) + }() + } + + return <-errCh +} + +// We do a bit of tomfoolery here. Wireplumber actually is pretty incompetent at +// handling target.device, so our --target flag is pretty much useless. We have +// to do the node links ourselves. +// +// Relevant issues: +// +// - https://gitlab.freedesktop.org/pipewire/pipewire/-/issues/2731 +// - https://gitlab.freedesktop.org/pipewire/wireplumber/-/issues/358 +// +func (s *Session) startRelinker(ctx context.Context) error { + var catnipPorts map[string]pwObjectID + var err error + // Employ this awful hack to get the needed port IDs for our session. We + // won't rely on the pwLinkMonitor below, since it may appear out of order. + for i := 0; i < 20; i++ { + catnipPorts, err = findCatnipPorts(ctx, s.props) + if err == nil { + break + } + time.Sleep(100 * time.Millisecond) + } + if err != nil { + return errors.Wrap(err, "failed to find catnip's input ports") + } + + linkEvents := make(chan pwLinkEvent) + linkError := make(chan error, 1) + go func() { linkError <- pwLinkMonitor(ctx, pwLinkOutputPorts, linkEvents) }() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-linkError: + return err + case event := <-linkEvents: + switch event := event.(type) { + case pwLinkAdd: + if event.DeviceName == s.targetName { + catnipPort := "input_" + strings.TrimPrefix(event.PortName, "output_") + catnipPortID := catnipPorts[catnipPort] + targetPortID := event.ID + + // Link the catnip node to the device node. + if err := pwLink(targetPortID, catnipPortID); err != nil { + log.Printf( + "failed to link catnip port %d to device port %d: %v", + catnipPortID, targetPortID, err) + } + } + } + } + } +} + +func findCatnipPorts(ctx context.Context, ourProps catnipProps) (map[string]pwObjectID, error) { + objs, err := pwDump(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get pw-dump") + } + + // Find the catnip node. + nodeObj := objs.Find(func(obj pwObject) bool { + if obj.Type != pwInterfaceNode { + return false + } + var props catnipProps + err := json.Unmarshal(obj.Info.Props.JSON, &props) + return err == nil && props == ourProps + }) + if nodeObj == nil { + return nil, errors.New("failed to find catnip node in PipeWire") + } + + // Find all of catnip's ports. We want catnip's input ports. + portObjs := objs.ResolvePorts(nodeObj, pwPortIn) + if len(portObjs) == 0 { + return nil, errors.New("failed to find any catnip port in PipeWire") + } + + portMap := make(map[string]pwObjectID) + for _, obj := range portObjs { + portMap[obj.Info.Props.PortName] = obj.ID + } + + return portMap, nil +} + +var sessionCounter uint64 + +// generateID generates a unique ID for this session. +func generateID() string { + return fmt.Sprintf( + "%d@%s#%d", + os.Getpid(), + shortEpoch(), + atomic.AddUint64(&sessionCounter, 1), + ) +} + +// shortEpoch generates a small string that is unique to the current epoch. +func shortEpoch() string { + now := time.Now().Unix() + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], uint64(now)) + return base64.RawURLEncoding.EncodeToString(buf[:]) +}