Skip to content

Commit

Permalink
"move MPEG4-audio decoding into streamTrack"
Browse files Browse the repository at this point in the history
  • Loading branch information
aler9 committed Nov 2, 2022
1 parent d4779f2 commit bbe6494
Show file tree
Hide file tree
Showing 12 changed files with 114 additions and 126 deletions.
30 changes: 13 additions & 17 deletions internal/core/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,24 @@ import (
)

// data is the data unit routed across the server.
// it must contain one or more of the following:
// - a single RTP packet
// - a group of H264 NALUs (grouped by timestamp)
// - a single AAC AU
type data interface {
getTrackID() int
getRTPPacket() *rtp.Packet
getRTPPackets() []*rtp.Packet
getPTSEqualsDTS() bool
}

type dataGeneric struct {
trackID int
rtpPacket *rtp.Packet
rtpPackets []*rtp.Packet
ptsEqualsDTS bool
}

func (d *dataGeneric) getTrackID() int {
return d.trackID
}

func (d *dataGeneric) getRTPPacket() *rtp.Packet {
return d.rtpPacket
func (d *dataGeneric) getRTPPackets() []*rtp.Packet {
return d.rtpPackets
}

func (d *dataGeneric) getPTSEqualsDTS() bool {
Expand All @@ -37,7 +33,7 @@ func (d *dataGeneric) getPTSEqualsDTS() bool {

type dataH264 struct {
trackID int
rtpPacket *rtp.Packet
rtpPackets []*rtp.Packet
ptsEqualsDTS bool
pts time.Duration
nalus [][]byte
Expand All @@ -47,27 +43,27 @@ func (d *dataH264) getTrackID() int {
return d.trackID
}

func (d *dataH264) getRTPPacket() *rtp.Packet {
return d.rtpPacket
func (d *dataH264) getRTPPackets() []*rtp.Packet {
return d.rtpPackets
}

func (d *dataH264) getPTSEqualsDTS() bool {
return d.ptsEqualsDTS
}

type dataMPEG4Audio struct {
trackID int
rtpPacket *rtp.Packet
pts time.Duration
au []byte
trackID int
rtpPackets []*rtp.Packet
pts time.Duration
aus [][]byte
}

func (d *dataMPEG4Audio) getTrackID() int {
return d.trackID
}

func (d *dataMPEG4Audio) getRTPPacket() *rtp.Packet {
return d.rtpPacket
func (d *dataMPEG4Audio) getRTPPackets() []*rtp.Packet {
return d.rtpPackets
}

func (d *dataMPEG4Audio) getPTSEqualsDTS() bool {
Expand Down
21 changes: 5 additions & 16 deletions internal/core/hls_muxer.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/mpeg4audio"
"github.com/aler9/gortsplib/pkg/ringbuffer"
"github.com/aler9/gortsplib/pkg/rtpmpeg4audio"
"github.com/gin-gonic/gin"

"github.com/aler9/rtsp-simple-server/internal/conf"
Expand Down Expand Up @@ -295,7 +294,6 @@ func (m *hlsMuxer) runInner(innerCtx context.Context, innerReady chan struct{})
videoTrackID := -1
var audioTrack *gortsplib.TrackMPEG4Audio
audioTrackID := -1
var aacDecoder *rtpmpeg4audio.Decoder

for i, track := range res.stream.tracks() {
switch tt := track.(type) {
Expand All @@ -314,13 +312,6 @@ func (m *hlsMuxer) runInner(innerCtx context.Context, innerReady chan struct{})

audioTrack = tt
audioTrackID = i
aacDecoder = &rtpmpeg4audio.Decoder{
SampleRate: tt.Config.SampleRate,
SizeLength: tt.SizeLength,
IndexLength: tt.IndexLength,
IndexDeltaLength: tt.IndexDeltaLength,
}
aacDecoder.Init()
}
}

Expand Down Expand Up @@ -390,18 +381,16 @@ func (m *hlsMuxer) runInner(innerCtx context.Context, innerReady chan struct{})
return fmt.Errorf("muxer error: %v", err)
}
} else if audioTrack != nil && data.getTrackID() == audioTrackID {
aus, pts, err := aacDecoder.Decode(data.getRTPPacket())
if err != nil {
if err != rtpmpeg4audio.ErrMorePacketsNeeded {
m.log(logger.Warn, "unable to decode audio track: %v", err)
}
tdata := data.(*dataMPEG4Audio)

if tdata.aus == nil {
continue
}

for i, au := range aus {
for i, au := range tdata.aus {
err = m.muxer.WriteAAC(
time.Now(),
pts+time.Duration(i)*mpeg4audio.SamplesPerAccessUnit*
tdata.pts+time.Duration(i)*mpeg4audio.SamplesPerAccessUnit*
time.Second/time.Duration(audioTrack.ClockRate()),
au)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/core/hls_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func (s *hlsSource) run(ctx context.Context) error {
stream.writeData(&dataMPEG4Audio{
trackID: audioTrackID,
pts: pts,
au: au,
aus: [][]byte{au},
})
}

Expand Down
22 changes: 6 additions & 16 deletions internal/core/rtmp_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"github.com/aler9/gortsplib/pkg/h264"
"github.com/aler9/gortsplib/pkg/mpeg4audio"
"github.com/aler9/gortsplib/pkg/ringbuffer"
"github.com/aler9/gortsplib/pkg/rtpmpeg4audio"
"github.com/notedit/rtmp/format/flv/flvio"

"github.com/aler9/rtsp-simple-server/internal/conf"
Expand Down Expand Up @@ -258,7 +257,6 @@ func (c *rtmpConn) runRead(ctx context.Context, u *url.URL) error {
videoTrackID := -1
var audioTrack *gortsplib.TrackMPEG4Audio
audioTrackID := -1
var aacDecoder *rtpmpeg4audio.Decoder

for i, track := range res.stream.tracks() {
switch tt := track.(type) {
Expand All @@ -277,13 +275,6 @@ func (c *rtmpConn) runRead(ctx context.Context, u *url.URL) error {

audioTrack = tt
audioTrackID = i
aacDecoder = &rtpmpeg4audio.Decoder{
SampleRate: tt.Config.SampleRate,
SizeLength: tt.SizeLength,
IndexLength: tt.IndexLength,
IndexDeltaLength: tt.IndexDeltaLength,
}
aacDecoder.Init()
}
}

Expand Down Expand Up @@ -433,24 +424,23 @@ func (c *rtmpConn) runRead(ctx context.Context, u *url.URL) error {
return err
}
} else if audioTrack != nil && data.getTrackID() == audioTrackID {
aus, pts, err := aacDecoder.Decode(data.getRTPPacket())
if err != nil {
if err != rtpmpeg4audio.ErrMorePacketsNeeded {
c.log(logger.Warn, "unable to decode audio track: %v", err)
}
tdata := data.(*dataMPEG4Audio)

if tdata.aus == nil {
continue
}

if videoTrack != nil && !videoFirstIDRFound {
continue
}

pts := tdata.pts
pts -= videoStartDTS
if pts < 0 {
continue
}

for i, au := range aus {
for i, au := range tdata.aus {
c.nconn.SetWriteDeadline(time.Now().Add(time.Duration(c.writeTimeout)))
err := c.conn.WriteMessage(&message.MsgAudio{
ChunkStreamID: message.MsgAudioChunkStreamID,
Expand Down Expand Up @@ -614,7 +604,7 @@ func (c *rtmpConn) runPublish(ctx context.Context, u *url.URL) error {
rres.stream.writeData(&dataMPEG4Audio{
trackID: audioTrackID,
pts: tmsg.DTS,
au: tmsg.Payload,
aus: [][]byte{tmsg.Payload},
})
}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/core/rtmp_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func (s *rtmpSource) run(ctx context.Context) error {
res.stream.writeData(&dataMPEG4Audio{
trackID: audioTrackID,
pts: tmsg.DTS,
au: tmsg.Payload,
aus: [][]byte{tmsg.Payload},
})
}
}
Expand Down
17 changes: 13 additions & 4 deletions internal/core/rtsp_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/base"
"github.com/pion/rtp"

"github.com/aler9/rtsp-simple-server/internal/conf"
"github.com/aler9/rtsp-simple-server/internal/externalcmd"
Expand Down Expand Up @@ -378,18 +379,26 @@ func (s *rtspSession) apiSourceDescribe() interface{} {

// onPacketRTP is called by rtspServer.
func (s *rtspSession) onPacketRTP(ctx *gortsplib.ServerHandlerOnPacketRTPCtx) {
if ctx.H264NALUs != nil {
switch s.announcedTracks[ctx.TrackID].(type) {
case *gortsplib.TrackH264:
s.stream.writeData(&dataH264{
trackID: ctx.TrackID,
rtpPacket: ctx.Packet,
rtpPackets: []*rtp.Packet{ctx.Packet},
ptsEqualsDTS: ctx.PTSEqualsDTS,
pts: ctx.H264PTS,
nalus: ctx.H264NALUs,
})
} else {

case *gortsplib.TrackMPEG4Audio:
s.stream.writeData(&dataMPEG4Audio{
trackID: ctx.TrackID,
rtpPackets: []*rtp.Packet{ctx.Packet},
})

default:
s.stream.writeData(&dataGeneric{
trackID: ctx.TrackID,
rtpPacket: ctx.Packet,
rtpPackets: []*rtp.Packet{ctx.Packet},
ptsEqualsDTS: ctx.PTSEqualsDTS,
})
}
Expand Down
17 changes: 13 additions & 4 deletions internal/core/rtsp_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (

"github.com/aler9/gortsplib"
"github.com/aler9/gortsplib/pkg/base"
"github.com/pion/rtp"

"github.com/aler9/gortsplib/pkg/url"
"github.com/aler9/rtsp-simple-server/internal/conf"
Expand Down Expand Up @@ -143,18 +144,26 @@ func (s *rtspSource) run(ctx context.Context) error {
}()

c.OnPacketRTP = func(ctx *gortsplib.ClientOnPacketRTPCtx) {
if ctx.H264NALUs != nil {
switch tracks[ctx.TrackID].(type) {
case *gortsplib.TrackH264:
res.stream.writeData(&dataH264{
trackID: ctx.TrackID,
rtpPacket: ctx.Packet,
rtpPackets: []*rtp.Packet{ctx.Packet},
ptsEqualsDTS: ctx.PTSEqualsDTS,
pts: ctx.H264PTS,
nalus: ctx.H264NALUs,
})
} else {

case *gortsplib.TrackMPEG4Audio:
res.stream.writeData(&dataMPEG4Audio{
trackID: ctx.TrackID,
rtpPackets: []*rtp.Packet{ctx.Packet},
})

default:
res.stream.writeData(&dataGeneric{
trackID: ctx.TrackID,
rtpPacket: ctx.Packet,
rtpPackets: []*rtp.Packet{ctx.Packet},
ptsEqualsDTS: ctx.PTSEqualsDTS,
})
}
Expand Down
20 changes: 13 additions & 7 deletions internal/core/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ func (m *streamNonRTSPReadersMap) writeData(data data) {
}
}

func (m *streamNonRTSPReadersMap) hasReaders() bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
return len(m.ma) > 0
}

type stream struct {
nonRTSPReaders *streamNonRTSPReadersMap
rtspStream *gortsplib.ServerStream
Expand Down Expand Up @@ -91,13 +97,13 @@ func (s *stream) readerRemove(r reader) {
}

func (s *stream) writeData(data data) {
datas := s.streamTracks[data.getTrackID()].process(data)

for _, data := range datas {
// forward to RTSP readers
s.rtspStream.WritePacketRTP(data.getTrackID(), data.getRTPPacket(), data.getPTSEqualsDTS())
s.streamTracks[data.getTrackID()].onData(data, s.nonRTSPReaders.hasReaders())

// forward to non-RTSP readers
s.nonRTSPReaders.writeData(data)
// forward RTP packets to RTSP readers
for _, pkt := range data.getRTPPackets() {
s.rtspStream.WritePacketRTP(data.getTrackID(), pkt, data.getPTSEqualsDTS())
}

// forward data to non-RTSP readers
s.nonRTSPReaders.writeData(data)
}
2 changes: 1 addition & 1 deletion internal/core/streamtrack.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
)

type streamTrack interface {
process(data) []data
onData(data, bool)
}

func newStreamTrack(track gortsplib.Track, generateRTPPackets bool) (streamTrack, error) {
Expand Down
3 changes: 1 addition & 2 deletions internal/core/streamtrack_generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,5 @@ func newStreamTrackGeneric() *streamTrackGeneric {
return &streamTrackGeneric{}
}

func (t *streamTrackGeneric) process(dat data) []data {
return []data{dat}
func (t *streamTrackGeneric) onData(dat data, hasNonRTSPReaders bool) {
}

0 comments on commit bbe6494

Please sign in to comment.