In [1]:
import importlib
import cv2
import torch
import torchaudio
import librosa

import torch.utils.benchmark as benchmark
from torch.profiler import profile, record_function, ProfilerActivity
import cProfile
import pstats

import transforms
import transforms_v2
importlib.reload(transforms_v2)
import functional_v2
importlib.reload(functional_v2)
import utils.local_fairseq as local_fairseq

from utils.plots import plot_spectrogram

2023-02-21 09:54:30 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX


# Profiling

In [44]:
torch_data, sr  = torchaudio.load('audio_data/lex_6.wav')
librosa_mel = librosa.feature.melspectrogram(y=torch_data[0].numpy(),
                                             n_fft=2048,
                                             win_length=1024,
                                             sr=sr)

## SpecAugment

SpecAugment first implementation

In [39]:
torch_spec = transforms.SpecAugment(
            time_warp_w = 150,
            freq_mask_n = 2,
            freq_mask_f = 10,
            time_mask_n = 3,
            time_mask_t = 40,
            time_mask_p = 1.0,
)

In [42]:
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("mySpecAugment"):
        augmented = torch_spec(librosa_mel)

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                mySpecAugment        41.42%       2.169ms       100.00%       5.236ms       5.236ms             1  
                  aten::polar         5.31%     278.000us        10.52%     551.000us     137.750us             4  
                    aten::sub         8.23%     431.000us         8.48%     444.000us      34.154us            13  
           aten::index_select         5.83%     305.000us         6.09%     319.000us      79.750us             4  
                  aten::angle         4.98%     261.000us         5.63%     295.000us      49.167us             6  
                    aten::abs         2.73%     143.000us         5.23% 

STAGE:2023-02-18 19:44:48 1707:3895196 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-02-18 19:44:48 1707:3895196 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-02-18 19:44:48 1707:3895196 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


## Torch Built-ins

In [152]:
torch_data, sr  = torchaudio.load('audio_data/lex_30.wav')
librosa_mel = librosa.feature.melspectrogram(y=torch_data[0].numpy(),
                                             n_fft=2048,
                                             win_length=1024,
                                             sr=sr)

tensor_mel = torch.tensor(librosa_mel)

### Time Stretch

In [154]:
n_freq = librosa_mel.shape[-2]
stretch = torchaudio.transforms.TimeStretch(n_freq=n_freq)

with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("torchTimeStretch"):
        augmented = stretch(tensor_mel.unsqueeze(0), 0.8)

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
         torchTimeStretch        38.99%       5.380ms       100.00%      13.797ms      13.797ms             1  
                aten::abs         8.85%       1.221ms        22.50%       3.105ms     776.250us             4  
              aten::polar         5.63%     777.000us        11.21%       1.546ms     773.000us             2  
       aten::index_select        10.19%       1.406ms        10.26%       1.416ms     708.000us             2  
              aten::angle         6.34%     875.000us         6.36%     877.000us     292.333us             3  
              aten::empty         5.36%     739.000us         5.36%     739.000us      73.900us         

STAGE:2023-02-21 09:21:18 43284:4990461 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-02-21 09:21:18 43284:4990461 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-02-21 09:21:18 43284:4990461 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


### Masking

In [147]:
masking = torchaudio.transforms.FrequencyMasking(freq_mask_param=20)

with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("torchFreqMasking"):
        augmented = masking(librosa_mel)

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
             torchFreqMasking        50.44%       4.019ms       100.00%       7.968ms       7.968ms             1  
            aten::masked_fill         7.66%     610.000us        27.84%       2.218ms       2.218ms             1  
                  aten::clone        14.75%       1.175ms        18.31%       1.459ms       1.459ms             1  
                aten::squeeze        11.63%     927.000us        11.73%     935.000us     467.500us             2  
                     aten::ge         3.95%     315.000us         4.04%     322.000us     161.000us             2  
                  aten::copy_         3.88%     309.000us         3.88% 

STAGE:2023-02-21 09:19:46 43284:4990461 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-02-21 09:19:46 43284:4990461 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-02-21 09:19:46 43284:4990461 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


## Fairseq

Fairseq's SpecAugment Implementation

In [25]:
fairseq_spec = \
    local_fairseq.SpecAugmentTransform(
        time_warp_w = 150,
        freq_mask_n = 2,
        freq_mask_f = 40,
        time_mask_n = 3,
        time_mask_t = 10,
        time_mask_p = 1.0,
)

In [26]:
librosa_mel = librosa.feature.melspectrogram(y=torch_data[0].numpy(),
                                             n_fft=2048,
                                             win_length=1024,
                                             sr=sr)

In [27]:
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("fairSpecAugment"):
        fairseq_spec(librosa_mel)

STAGE:2023-02-18 19:42:41 1707:3895196 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-02-18 19:42:41 1707:3895196 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-02-18 19:42:41 1707:3895196 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


In [28]:
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

-------------------  ------------  ------------  ------------  ------------  ------------  ------------  
               Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------  ------------  ------------  ------------  ------------  ------------  ------------  
    fairSpecAugment       100.00%     656.000us       100.00%     656.000us     656.000us             1  
-------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 656.000us



## WarpAxis

#### Warp Axis Torch Interpolate

In [155]:
tensor_mel = torch.tensor(librosa_mel)
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("warp_axis_interpolate_torch"):
        functional_v2.warp_axis_v2(tensor_mel, 1, 200)

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

-------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                           Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
    warp_axis_interpolate_torch        67.64%       2.260ms       100.00%       3.341ms       3.341ms             1  
      aten::upsample_bilinear2d        26.25%     877.000us        26.70%     892.000us     446.000us             2  
                      aten::cat         3.86%     129.000us         4.22%     141.000us     141.000us             1  
                    aten::slice         0.63%      21.000us         0.90%      30.000us       5.000us             6  
                    aten::empty         0.42%      14.000us         0.42%      14.000us       0.875us            16  
                 aten::squeeze_         0.30%      10.00

STAGE:2023-02-21 09:21:32 43284:4990461 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-02-21 09:21:32 43284:4990461 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-02-21 09:21:32 43284:4990461 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


In [117]:
prof = cProfile.Profile()
prof.enable()
functional_v2.warp_axis_v2(tensor_mel, 1, 200)
prof.disable()

stats = pstats.Stats(prof).strip_dirs().sort_stats("cumtime")
stats.print_stats(10) # top 10 rows

         79 function calls in 0.005 seconds

   Ordered by: cumulative time
   List reduced from 28 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000    0.005    0.003 interactiveshell.py:3406(run_code)
        2    0.000    0.000    0.005    0.003 {built-in method builtins.exec}
        1    0.000    0.000    0.005    0.005 383942620.py:1(<module>)
        1    0.002    0.002    0.005    0.005 functional_v2.py:41(warp_axis_v2)
        2    0.001    0.000    0.002    0.001 functional.py:3772(interpolate)
        2    0.001    0.001    0.001    0.001 {built-in method torch._C._nn.upsample_bilinear2d}
        1    0.001    0.001    0.001    0.001 {built-in method torch.cat}
        2    0.000    0.000    0.000    0.000 {method 'randint' of 'numpy.random.mtrand.RandomState' objects}
        2    0.000    0.000    0.000    0.000 codeop.py:117(__call__)
        2    0.000    0.000    0.000    0.000 {method 'squ

<pstats.Stats at 0x29a12fd90>

#### Warp Axis CV2

In [114]:
prof = cProfile.Profile()
prof.enable()
functional_v2.warp_axis(librosa_mel, 1, 200)
prof.disable()

stats = pstats.Stats(prof).strip_dirs().sort_stats("cumtime")
stats.print_stats(10) # top 10 rows

         47 function calls in 0.001 seconds

   Ordered by: cumulative time
   List reduced from 24 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    0.000    0.000    0.001    0.001 interactiveshell.py:3406(run_code)
        2    0.000    0.000    0.001    0.001 {built-in method builtins.exec}
        1    0.000    0.000    0.001    0.001 functional_v2.py:4(warp_axis)
        2    0.001    0.000    0.001    0.000 {resize}
        1    0.000    0.000    0.000    0.000 <__array_function__ internals>:177(concatenate)
        1    0.000    0.000    0.000    0.000 {built-in method numpy.core._multiarray_umath.implement_array_function}
        2    0.000    0.000    0.000    0.000 codeop.py:117(__call__)
        2    0.000    0.000    0.000    0.000 {built-in method builtins.compile}
        2    0.000    0.000    0.000    0.000 {method 'randint' of 'numpy.random.mtrand.RandomState' objects}
        2    0.000    0.000    0.0

<pstats.Stats at 0x29a145420>

# Benchmark

In [2]:
torch_data, sr  = torchaudio.load('audio_data/lex_30.wav')
librosa_mel = librosa.feature.melspectrogram(y=torch_data[0].numpy(),
                                             n_fft=2048,
                                             win_length=1024,
                                             sr=sr)

### SpecAugment

In [3]:
torch_spec = transforms.SpecAugment(
            time_warp_w = 150,
            freq_mask_n = 2,
            freq_mask_f = 10,
            time_mask_n = 3,
            time_mask_t = 50,
            time_mask_p = 1.0,
)

torch_spec_v2 = transforms_v2.SpecAugment(
            warp_axis=1,
            warp_w = 50,
            freq_mask_num = 0,
            freq_mask_param = 10,
            freq_mask_p = 1.0,
            time_mask_num = 5,
            time_mask_param = 50,
            time_mask_p = 1.0,
)

fairseq_spec = \
    local_fairseq.SpecAugmentTransform(
        time_warp_w = 50,
        freq_mask_n = 2,
        freq_mask_f = 50,
        time_mask_n = 2,
        time_mask_t = 10,
        time_mask_p = 1.0,
)

In [None]:
t0 = benchmark.Timer(
    stmt='augmented = torch_spec(mel)',
    label='* SPECAUGMENT WITH TORCH TIMESTRETCH',
    globals={"mel": torch.tensor(librosa_mel), "torch_spec": torch_spec})

t1 = benchmark.Timer(
    stmt="augmented = torch_spec_v2(mel, 'cv2')",
    label='* SPECAUGMENT WITH CV2 RESIZE',
    globals={"mel": librosa_mel, "torch_spec_v2": torch_spec_v2})

t2 = benchmark.Timer(
    stmt="augmented = torch_spec_v2(mel, 'torch')",
    label='* SPECAUGMENT WITH TORCH INTERPOLATE',
    globals={"mel": torch.tensor(librosa_mel), "torch_spec_v2": torch_spec_v2})

t3 = benchmark.Timer(
    stmt='augmented = fairseq_spec(mel)',
    label='* SPECAUGMENT WITH FAIRSEQ',
    globals={"mel": librosa_mel, "fairseq_spec": fairseq_spec})


res0 = t0.timeit(500)
res1 = t1.timeit(500)
res2 = t2.timeit(500)
res3 = t3.timeit(500)

In [5]:
print(res0)
print(res1)
print(res2)
print(res3)

<torch.utils.benchmark.utils.common.Measurement object at 0x294196b30>
* SPECAUGMENT WITH TORCH TIMESTRETCH
  4.68 ms
  1 measurement, 500 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x294195d80>
* SPECAUGMENT WITH CV2 RESIZE
  292.06 us
  1 measurement, 500 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x294196cb0>
* SPECAUGMENT WITH TORCH INTERPOLATE
  492.15 us
  1 measurement, 500 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x294196500>
* SPECAUGMENT WITH FAIRSEQ
  355.52 us
  1 measurement, 500 runs , 1 thread


### Warp Axis

In [None]:
t0 = benchmark.Timer(
    stmt='augmented = warp_axis_torch(mel, 1, 200)',
    label='* WARP AXIS WITH TORCH INTERPOLATE',
    globals={"mel": torch.tensor(librosa_mel), "warp_axis_torch": functional_v2.warp_axis_torch})

t1 = benchmark.Timer(
    stmt='augmented = warp_axis_cv2(mel, 1, 200)',
    label='* WARP AXIS WITH CV2 RESIZE',
    globals={"mel": librosa_mel, "warp_axis_cv2": functional_v2.warp_axis_cv2})

res0 = t0.timeit(5000)
res1 = t1.timeit(5000)

In [7]:
print(res0)
print(res1)

<torch.utils.benchmark.utils.common.Measurement object at 0x150b03160>
* WARP AXIS WITH TORCH INTERPOLATE
  435.73 us
  1 measurement, 5000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x291f165c0>
* WARP AXIS WITH CV2 RESIZE
  283.40 us
  1 measurement, 5000 runs , 1 thread


### torch.nn.interpolate vs cv2.Resize

In [10]:
tensor_mel.shape

torch.Size([1, 1, 128, 2584])

In [None]:
new_sz = (800, 800)
t0 = benchmark.Timer(
    stmt='resize(mel, dsize=(800,800), interpolation=cv2.INTER_LINEAR)',
    setup="import cv2",
    label="* CV2 RESIZE",
    globals={"mel": librosa_mel, "resize": cv2.resize})


tensor_mel = torch.tensor(librosa_mel)
tensor_mel = tensor_mel[(None,)*2]
t1 = benchmark.Timer(
    stmt = "interpolate(mel, size=(800,800), mode='bilinear')",
    label="* TORCH INTERPOLATE",
    globals={"mel": tensor_mel, "interpolate": torch.nn.functional.interpolate})

res0 = t0.timeit(5000)
res1 = t1.timeit(5000)

In [9]:
print(res0)
print(res1)

<torch.utils.benchmark.utils.common.Measurement object at 0x105f13040>
* CV2 RESIZE
setup: import cv2
  173.95 us
  1 measurement, 5000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x294197700>
* TORCH INTERPOLATE
  373.30 us
  1 measurement, 5000 runs , 1 thread
