In [2]:
import torch
import torchaudio
import librosa

import transforms
import transforms_v2
import utils.local_fairseq as local_fairseq

from utils.plots import plot_spectrogram

# Profiling

In [37]:
from torch.profiler import profile, record_function, ProfilerActivity

### SpecAugment

In [38]:
torch_data, sr  = torchaudio.load('audio_data/lex_6.wav')
librosa_mel = librosa.feature.melspectrogram(y=torch_data[0].numpy(),
                                             n_fft=2048,
                                             win_length=1024,
                                             sr=sr)

librosa_mel = torch.tensor(librosa_mel)

In [39]:
torch_spec = transforms.SpecAugment(
            time_warp_w = 150,
            freq_mask_n = 2,
            freq_mask_f = 10,
            time_mask_n = 3,
            time_mask_t = 40,
            time_mask_p = 1.0,
)

In [42]:
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("mySpecAugment"):
        augmented = torch_spec(librosa_mel)

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                mySpecAugment        41.42%       2.169ms       100.00%       5.236ms       5.236ms             1  
                  aten::polar         5.31%     278.000us        10.52%     551.000us     137.750us             4  
                    aten::sub         8.23%     431.000us         8.48%     444.000us      34.154us            13  
           aten::index_select         5.83%     305.000us         6.09%     319.000us      79.750us             4  
                  aten::angle         4.98%     261.000us         5.63%     295.000us      49.167us             6  
                    aten::abs         2.73%     143.000us         5.23% 

STAGE:2023-02-18 19:44:48 1707:3895196 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-02-18 19:44:48 1707:3895196 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-02-18 19:44:48 1707:3895196 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


In [3]:
torch_data, sr  = torchaudio.load('audio_data/lex_6.wav')

mel_transform = torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=2048, win_length=1024)
torch_mel = mel_transform(torch_data[0])

torch_mel = torch_mel.unsqueeze(0)

In [4]:
torch_mel.shape

torch.Size([1, 128, 517])

In [19]:
n_freq = torch_mel.shape[-2]
stretch = torchaudio.transforms.TimeStretch(n_freq=n_freq)

with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("torchTimeStretch"):
        augmented = stretch(torch_mel, 0.8)
        
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
         torchTimeStretch        22.91%       1.231ms       100.00%       5.373ms       5.373ms             1  
                aten::abs        10.68%     574.000us        20.98%       1.127ms     281.750us             4  
              aten::polar         7.15%     384.000us        14.22%     764.000us     382.000us             2  
                aten::pad         0.07%       4.000us        13.74%     738.000us     738.000us             1  
    aten::constant_pad_nd         0.34%      18.000us        13.66%     734.000us     734.000us             1  
                aten::mul         8.38%     450.000us         9.01%     484.000us     161.333us         

STAGE:2023-02-18 19:41:34 1707:3895196 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-02-18 19:41:34 1707:3895196 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-02-18 19:41:34 1707:3895196 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


In [18]:
masking = torchaudio.transforms.FrequencyMasking(freq_mask_param=20)

with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("torchFreqMasking"):
        augmented = masking(torch_mel)

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
             torchFreqMasking        83.49%       1.644ms       100.00%       1.969ms       1.969ms             1  
            aten::masked_fill         0.56%      11.000us         6.96%     137.000us     137.000us             1  
                  aten::copy_         3.15%      62.000us         3.15%      62.000us       7.750us             8  
                  aten::clone         0.15%       3.000us         3.15%      62.000us      62.000us             1  
           aten::masked_fill_         3.10%      61.000us         3.10%      61.000us      61.000us             1  
                aten::__and__         0.10%       2.000us         1.88% 

STAGE:2023-02-18 19:41:27 1707:3895196 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-02-18 19:41:27 1707:3895196 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-02-18 19:41:27 1707:3895196 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


### Torch Built-ins

In [20]:
torch_data, sr  = torchaudio.load('audio_data/lex_6.wav')
librosa_mel = librosa.feature.melspectrogram(y=torch_data[0].numpy(),
                                             n_fft=2048,
                                             win_length=1024,
                                             sr=sr)

librosa_mel = torch.tensor(librosa_mel).unsqueeze(0)

Time Stretch

In [22]:
n_freq = librosa_mel.shape[-2]
stretch = torchaudio.transforms.TimeStretch(n_freq=n_freq)

with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("torchTimeStretch"):
        augmented = stretch(librosa_mel, 0.8)

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
         torchTimeStretch        31.95%       1.136ms       100.00%       3.556ms       3.556ms             1  
              aten::polar        10.69%     380.000us        21.23%     755.000us     377.500us             2  
                aten::add         9.36%     333.000us         9.90%     352.000us     117.333us             3  
       aten::index_select         7.26%     258.000us         8.35%     297.000us     148.500us             2  
                aten::abs         3.74%     133.000us         7.28%     259.000us      64.750us             4  
                aten::sub         6.21%     221.000us         6.58%     234.000us      58.500us         

STAGE:2023-02-18 19:42:15 1707:3895196 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-02-18 19:42:15 1707:3895196 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-02-18 19:42:15 1707:3895196 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


Masking

In [24]:
masking = torchaudio.transforms.FrequencyMasking(freq_mask_param=20)

with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("torchFreqMasking"):
        augmented = masking(librosa_mel)

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
             torchFreqMasking        78.79%       1.805ms       100.00%       2.291ms       2.291ms             1  
            aten::masked_fill         0.96%      22.000us         8.03%     184.000us     184.000us             1  
           aten::masked_fill_         3.75%      86.000us         3.75%      86.000us      86.000us             1  
                  aten::copy_         3.45%      79.000us         3.45%      79.000us       9.875us             8  
                  aten::clone         0.22%       5.000us         3.01%      69.000us      69.000us             1  
                     aten::to         1.57%      36.000us         2.97% 

STAGE:2023-02-18 19:42:31 1707:3895196 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-02-18 19:42:31 1707:3895196 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-02-18 19:42:31 1707:3895196 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


### Fairseq

In [25]:
fairseq_spec = \
    local_fairseq.SpecAugmentTransform(
        time_warp_w = 150,
        freq_mask_n = 2,
        freq_mask_f = 40,
        time_mask_n = 3,
        time_mask_t = 10,
        time_mask_p = 1.0,
)

In [26]:
librosa_mel = librosa.feature.melspectrogram(y=torch_data[0].numpy(),
                                             n_fft=2048,
                                             win_length=1024,
                                             sr=sr)

In [27]:
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
    with record_function("fairSpecAugment"):
        fairseq_spec(librosa_mel)

STAGE:2023-02-18 19:42:41 1707:3895196 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2023-02-18 19:42:41 1707:3895196 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2023-02-18 19:42:41 1707:3895196 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


In [28]:
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

-------------------  ------------  ------------  ------------  ------------  ------------  ------------  
               Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------  ------------  ------------  ------------  ------------  ------------  ------------  
    fairSpecAugment       100.00%     656.000us       100.00%     656.000us     656.000us             1  
-------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 656.000us



# Benchmark

In [3]:
import torch.utils.benchmark as benchmark

In [4]:
torch_data, sr  = torchaudio.load('audio_data/lex_30.wav')
librosa_mel = librosa.feature.melspectrogram(y=torch_data[0].numpy(),
                                             n_fft=2048,
                                             win_length=1024,
                                             sr=sr)

In [5]:
torch_spec = transforms.SpecAugment(
            time_warp_w = 150,
            freq_mask_n = 2,
            freq_mask_f = 10,
            time_mask_n = 3,
            time_mask_t = 50,
            time_mask_p = 1.0,
)

torch_spec_v2 = transforms_v2.SpecAugment(
            warp_axis=1,
            warp_w = 50,
            freq_mask_num = 0,
            freq_mask_param = 10,
            freq_mask_p = 1.0,
            time_mask_num = 5,
            time_mask_param = 50,
            time_mask_p = 1.0,
)

fairseq_spec = \
    local_fairseq.SpecAugmentTransform(
        time_warp_w = 50,
        freq_mask_n = 2,
        freq_mask_f = 50,
        time_mask_n = 2,
        time_mask_t = 10,
        time_mask_p = 1.0,
)

In [6]:
t0 = benchmark.Timer(
    stmt='augmented = torch_spec(mel)',
    setup='',
    globals={"mel": torch.tensor(librosa_mel), "torch_spec": torch_spec})

t1 = benchmark.Timer(
    stmt='augmented = torch_spec_v2(mel)',
    setup='',
    globals={"mel": librosa_mel, "torch_spec_v2": torch_spec_v2})

t2 = benchmark.Timer(
    stmt='augmented = fairseq_spec(mel)',
    setup='',
    globals={"mel": librosa_mel, "fairseq_spec": fairseq_spec})


res0 = t0.timeit(500)
res1 = t1.timeit(500)
res2 = t2.timeit(500)



In [7]:
print(res0)
print(res1)
print(res2)

<torch.utils.benchmark.utils.common.Measurement object at 0x104456bf0>
augmented = torch_spec(mel)
  4.65 ms
  1 measurement, 500 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x104455090>
augmented = torch_spec_v2(mel)
  277.76 us
  1 measurement, 500 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x104457b80>
augmented = fairseq_spec(mel)
  438.51 us
  1 measurement, 500 runs , 1 thread
