In [1]:
# default_exp core

In [2]:
import nbdev.showdoc as literacy

In [3]:
#export
from speechsep.imports import *
from speechsep.utils import *
from speechsep.plot import *

# Core

This contains most of the basic functions and spectrogram class types. To visualize the spectrograms we will also include a special color map since this makes it easier to notice differences in audio intensities.

The most important things to remember are
- How to create an AudioItem from file, both mono and multi-channel.
- Creating a SpecImage and how the parameters influence the final result.
- Basic SpecImage Visualizer (more indepth explanation here***)

## Loading Data

In [4]:
#export
@delegates(load)
def load_audio(fn, **kwargs):
    return load(fn)

In [5]:
fn = Path("../data/AudioTest1.wav")
sig, sr = load_audio(fn)
display(Audio(sig, rate=sr))

In [6]:
#hide
test_eq(type(sig), np.ndarray)
test_eq(type(sr), int)

## AudioBase
The current base class for audio which is used for mono and multi-channel audio types.

In [7]:
#export
def ResampleSignal(sr_new):
    def _inner(sig, sr):
        '''Resample using faster polyphase technique and avoiding FFT computation. Taken from FastaiAudio by LimeAI'''
        if(sr == sr_new): return sig
        sr_gcd = math.gcd(sr, sr_new)
        resampled = resample_poly(sig, int(sr_new/sr_gcd), int(sr/sr_gcd), axis=-1)
        #resampled = resampled.astype(np.float32)
        return resampled
    return _inner

In [8]:
#export
class AudioBase():
    _show_args={}
    def __init__(self,sig,_sr,fn=None):
        store_attr(self, 'sig,_sr,fn')
        self.data = self.sig
    def __repr__(self): self.listen(); return f'{self.__str__()}'
    def __str__(self): return f'{self.fn}, {self.duration}secs at {self.sr} samples per second'
    def listen(self): display(Audio(self.sig, rate=self.sr))
    @property
    def sr(self): return self._sr
    @sr.setter
    def sr(self, new_sr):
        if self._sr != new_sr: self.sig = ResampleSignal(new_sr)(self.sig, self.sr)
        self._sr = new_sr
    @property
    def duration(self): return len(self.sig)/self.sr

### MonoAudios
Audios with only one channel. For now this is the only Audio type, if the file has more channels they will be averaged out into one.

In [9]:
#export
class AudioMono(AudioBase):
    _show_args={}
    @classmethod
    def create(cls, fn, sr=None):
        audio = cls(*load_audio(fn),fn)
        if sr: audio.sr = sr
        return audio
    load_file = create

In [10]:
aud1 = AudioMono.create(fn) #default file sample rate
aud2 = AudioMono.create(fn, sr=2205) #custom sample rate, could cause loss of quality

In [11]:
#hide
test_eq(type(aud1), AudioMono)
test_eq(aud1.sr, 22050)
test_eq(aud1.fn, fn)

test_eq(aud2.sr, 2205)
test_eq(type(aud2.sig), np.ndarray)
test_eq(type(aud2.sr), int)

In [12]:
#export
@patch_property
def duration(x:AudioMono):
    return len(x.sig)/x.sr

In [13]:
#hide
test_eq(type(aud1.duration), float)
test_eq(round(aud1.duration), 4)

aud1.sr = 48000

test_eq(aud1.sr, 48000)
test_eq(round(aud1.duration), 4)

## Spectrograms

### SpecImage
Gives the template for the rest of the Spectrogram classes. There will be transforms to add mel-bin and decibels

In [14]:
#export
class SpecImage():
    _show_args={}
    def __init__(self, data, sr, fn=None):
        store_attr(self, 'data, sr, fn')
        self._plt_params = {}
    @property
    def plt_params(self): return self._plt_params
    @plt_params.setter
    @delegates(plt.pcolormesh)
    def plt_params(self, **kwargs):
        self._plot = partial(plt.pcolormesh, **kwargs)
        self._plt_params = dict(**kwargs)

### Decibel Spectrograms

#### Decibelify
Turn spectrogram amplitude to decibel, is automatically called in `Spectify` with `decibel=True`. Decibel is the same as amplitude (intensity of each "pixel") in log-scale.

In [15]:
class Decibelify(Transform):
    def __init__(self): pass
    def encodes(self,spec): pass
    def decodes(self,spec): pass

### Mel-bin Spectrograms

#### Mel-binify
Transforms the frequency to mel-bin. Just like decibels, this transform also resembles human hearing better than linear frequencies do. Sadly making mel-bins also makes it dificult to reconstruct the audio since the phase and data loss is very high. Recommended for classification problems.

In [16]:
class Mel_Binify(Transform):
    def __init__(self): pass
    def encodes(self,spec): pass
    def decodes(self,spec): pass

## Masks

In [17]:
#export
class MaskBase():
    def __init__(self, data):
        store_attr(self, 'data')
    @property
    def shape(self):
        return self.data.shape
    @classmethod
    def create(cls, audios):
        self.adjust(audios)
        joined = join_audios(audios)
        return [cls(self.generate(joined, aud)) for aud in audios]
    def adjust(self, audios):
        pass
    def __mult__(self, spec):
        raise NotImplementedError('This function needs to be implemented before use')
    def generate(self, joined, aud):
        raise NotImplementedError('This function needs to be implemented before use')

### Binary Mask

In [18]:
#export
class MaskBinary(MaskBase):
    def __mult__(self, spec): pass
    def __generate__(self, joined, aud): pass

## Show Functions

In [19]:
#export
@typedispatch
def show_batch(x:(AudioBase, SpecImage, MaskBase), y, samples, ctxs=None, max_n=10, rows=None, cols=None, figsize=None, **kwargs):
    if ctxs is None: ctxs = get_grid(min(len(samples), max_n), rows=rows, cols=cols, figsize=figsize)
    ctxs = show_batch[object](x, y, samples, ctxs=ctxs, max_n=max_n, **kwargs)
    return ctxs

def pre_plot(o, cls, ax=None, pltsize=None, ctx=None):
    ax = ifnone(ax,ctx)
    if ax is None: _,ax = plt.subplots(figsize=pltsize)
    if isinstance(o, cls): o = o.data;
    elif not isinstance(o,np.ndarray): o=array(o)
    return ax, o

def post_plot(ax, title, x_label, y_label, axis=False):
    if title is not None: ax.set_title(title)
    if x_label is not None: ax.set_xlabel(x_label)
    if y_label is not None: ax.set_ylabel(y_label)
    if not axis: ax.axis('off')
    return ax

@patch
@delegates(Line2D)
def show(x:AudioBase, ctx=None, **kwargs): return show_audio(x, ctx=ctx, **merge(x._show_args, kwargs))

@delegates(plt.plot)
def show_audio(aud, ax=None, pltsize=None, title=None, ctx=None, x_label=None, y_label=None, axis=False, **kwargs):
    ax, aud = pre_plot(aud, AudioBase, ax, pltsize, ctx)
    ax.plot(aud, **kwargs)
    return post_plot(ax, title, x_label, y_label, axis)

@patch
@delegates(setup_graph)
def show(x:SpecImage, ctx=None, **kwargs): return show_spec(x, ctx=ctx, **merge(x._show_args, kwargs))

@delegates(plt.pcolormesh)
def show_spec(spec, ax=None, pltsize=None, title=None, ctx=None, x_label=None, y_label=None, axis=False, **kwargs):
    ax, spec = pre_plot(spec, SpecImage, ax, pltsize, ctx)
    ax.pcolormesh(np.abs(spec.data[:spec.data.shape[0]//2]), **kwargs)
    return post_plot(ax, title, x_label, y_label, axis)

@patch
@delegates(setup_graph)
def show(x:MaskBase, ctx=None, **kwargs): return show_mask(x, ctx=ctx, **merge(x._show_args, kwargs))

@delegates(plt.pcolormesh)
def show_mask(mask, ax=None, pltsize=None, title=None, ctx=None, x_label=None, y_label=None, axis=False, **kwargs):
    ax, mask = pre_plot(mask)
    ax.pcolormesh(maks, **kwargs)
    return post_plot(ax, title, x_label, y_label, axis)

In [20]:
#export
def hear_audio(aud, sr=48000, **kwargs):
    if isinstance(aud, AudioBase):  display(Audio(aud.sig, rate=aud.sr))
    else:                           display(Audio(aud, rate=sr))

## Transforms

### Spectify
Transform that turns AudioItem into a Spectrogram, it can take the parameters for decibel and mel_bin, which are the main transformations that are used. Standard problems will require decibels because it resembles human hearing. Mel-bins also achieve this but it requires us to loose large portion of the phase which reduces the intelligibility of the audio.

In [21]:
#export
class Spectify(Transform):
    def __init__(self, sample_rate=48000, fftsize=512, win_mult=2, overlap=0.5, decibel=False, mel_bin=False):
        store_attr(self, 'sample_rate, fftsize, win_mult, overlap, decibel, mel_bin')
    def encodes(self, audio:AudioMono):
        spec = stft(audio.sig, self.fftsize, self.win_mult, self.overlap)
        if self.decibel: pass #TODO Encode
        if self.mel_bin: pass #TODO Encode
        return SpecImage(spec, audio.sr, audio.fn)
    def decodes(self, spec:SpecImage):
        if self.mel_bin: pass #TODO Decode
        if self.decibel: pass #TODO Decode
        print(f"in decode {type(spec)}")
        if isinstance(spec, ArraySpecBase):
            audio = istft(spec.data, self.fftsize, self.win_mult, self.overlap)
            return AudioMono(audio, spec.sr, spec.fn)
        audio = istft(spec, self.fftsize, self.win_mult, self.overlap)
        return AudioMono(audio, self.sample_rate)

In [22]:
audio = AudioMono.load_file(fn)
Audio2Spec = Spectify()
spec = Audio2Spec(audio)

In [23]:
#hide
test_eq(type(spec), SpecImage)
test_eq(type(spec.data), np.ndarray)
test_eq(spec.fn, fn)
test_eq(spec.sr, 22050)

In [24]:
audio_r = Audio2Spec.decodes(spec)

in decode <class '__main__.SpecImage'>


NameError: name 'ArraySpecBase' is not defined

In [None]:
#hide
test_eq(type(audio_r), AudioMono)
test_eq(type(audio_r.sig), np.ndarray)
test_eq(audio_r.sr, 22050)
test_eq(audio_r.fn, fn)

#### Create Spec

In [None]:
#export
@patch_clsmthd
@delegates(to=Spectify)
def create(cls:SpecImage, fn, sr=None, **kwargs):
    #Open an `Audio` from path `fn`
    if isinstance(fn,(Path,str)): return cls.create(AudioMono.create(fn,sr))
    elif isinstance(fn,AudioMono): return Spectify(**kwargs)(fn)
    raise ValueError('fn must be AudioMono, Path or str')

In [None]:
spec = SpecImage.create(fn)

In [None]:
#hide
test_eq(type(spec), SpecImage)
test_eq(type(spec.data), np.ndarray)
test_eq(spec.sr, 22050)
test_eq(spec.fn, fn)

### BasicTransforms

In [None]:
#export
class Resample(Transform):
    def __init__(self, sr): self.sr = sr
    def encodes(self, x:AudioBase): x.sr = self.sr; return x

In [None]:
audio_resamp = Resample(sr//2)(AudioMono.create(fn))

test_eq(audio_resamp.sr, 11025)
test_eq(audio_resamp._sr, 11025)
audio_resamp

In [None]:
#export
class Clip(Transform):
    def __init__(self, time): self.time = time
    def encodes(self, x:AudioBase):
        new_sig_len = int(self.time*x.sr)
        diff = abs(len(x.sig) - new_sig_len)
        if len(x.sig) <= new_sig_len:
            x.sig = np.pad(x.sig, (0,diff), 'constant', constant_values=(0, 0))
        else:
            x.sig = x.sig[:new_sig_len]
        return x

In [None]:
audio_ext = Clip(5)(AudioMono.create(fn))

fn_long = Path("../data/AudioTest1_full.wav")
audio_clip = Clip(4)(AudioMono.create(fn_long))

In [None]:
#hide
test_eq(audio_ext.duration, 5.0)
test_eq(len(audio_ext.sig), 5.0*audio_ext.sr)
test_eq(audio_clip.duration, 4.0)
test_eq(len(audio_clip.sig), 4.0*audio_clip.sr)

### Convert to Tensor and Array

In [None]:
#export
class ArrayAudioBase(ArrayBase):
    _show_args = {}
    def show(self, **kwargs):
        return show_audio(self, ctx=ctx, **{**self._show_args, **kwargs})

class ArraySpecBase(ArrayBase):
    _show_args = {}
    def show(self, **kwargs):
        return show_spec(self, ctx=ctx, **{**self._show_args, **kwargs})

class ArrayMaskBase(ArrayBase):
    _show_args = {}
    def show(self, **kwargs):
        return show_mask(self, ctx=ctx, **{**self._show_args, **kwargs})

class TensorAudio(TensorBase): 
    _show_args = ArrayAudioBase._show_args
    def show(self, ctx=None, **kwargs):
        return show_audio(self, ctx=ctx, **{**self._show_args, **kwargs})
    
class TensorSpec(TensorBase): 
    _show_args = ArraySpecBase._show_args
    def show(self, ctx=None, **kwargs):
        return show_spec(self, ctx=ctx, **{**self._show_args, **kwargs})
    
class TensorMask(TensorBase): 
    _show_args = ArrayMaskBase._show_args
    def show(self, ctx=None, **kwargs):
        return show_mask(self, ctx=ctx, **{**self._show_args, **kwargs})

In [None]:
#export
AudioMono._tensor_cls = TensorAudio
SpecImage._tensor_cls = TensorSpec
MaskBase._tensor_cls = TensorMask
@ToTensor
def encodes(self, o:AudioBase): return o._tensor_cls(audio2tensor(o))
@ToTensor
def encodes(self, o:SpecImage): return o._tensor_cls(spec2tensor(o))
@ToTensor
def encodes(self, o:MaskBase):  return o._tensor_cls(mask2tensor(o))

def audio2tensor(aud:AudioBase): return TensorAudio(aud.sig)
def spec2tensor(spec:SpecImage): return TensorSpec(spec.data)
def mask2tensor(mask:MaskBase):  return TensorMask(mask.data)

In [None]:
#hide
test_eq(type(ToTensor()(aud1)), TensorAudio)

### Phase and Complex Tensor Managing

In [None]:
#export
class PhaseManager(Transform):
    def __init__(self, mthd="new_dim", cls=SpecImage):
        assert mthd in ['new_dim', 'remove', 'replace'], 'phase method must be either new_dim, remove or replace'
        store_attr(self, 'mthd, cls')
        
    def encodes(self, spec:SpecImage):
        if self.mthd == 'new_dim': return complex2real(spec)
        
    def decodes(self, spec:TensorSpec)->SpecImage:
        if self.mthd == 'new_dim':
            return SpecImage(complex2real_r(spec),48000)

        
def complex2real(spec):
    if np.iscomplexobj(spec.data):
        spec.data = np.concatenate((spec.data.real[..., np.newaxis], spec.data.imag[..., np.newaxis]), axis=-1)
        spec.data = spec.data.T
    return spec

def complex2real_r(data):
    data = data.numpy().T
    return data[..., 0] + data[..., 1]*1j