In [1]:
import torch
import torch.nn as nn
import torchaudio
from IPython.display import Audio, display

from speechbrain.pretrained import EncoderDecoderASR
# from espnet2.bin.asr_inference import Speech2Text

from audio_augmentations import *
from speechbrain.lobes.augment import TimeDomainSpecAugment

from jiwer import wer

from data import load_dataset
from main import collect_params

[NeMo W 2022-09-20 22:13:08 optimizers:55] Apex was not found. Using the lamb or fused_adam optimizer will error out.
    


In [4]:
def play_audio(waveform, sample_rate):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    if num_channels == 1:
        display(Audio(waveform[0], rate=sample_rate))
    elif num_channels == 2:
        display(Audio((waveform[0], waveform[1]), rate=sample_rate))
    else:
        raise ValueError("Waveform with more than 2 channels are not supported.")

num_augmented_samples = 1
audio_file = "/home/server08/hdd0/changhun_workspace/LibriSpeech/test-other/367/130732/367-130732-0000.flac"

augmentation_list = [
    PolarityInversion(),
    Noise(min_snr=0.1, max_snr=1),
    Gain(),
    Reverb(sample_rate=16000),
    Delay(sample_rate=16000),
    HighLowPass(sample_rate=16000),
    PitchShift(n_samples=16000*5, sample_rate=16000, pitch_cents_min=-1400, pitch_cents_max=1400),
    TimeDomainSpecAugment(
            perturb_prob=1, drop_freq_prob=1, drop_chunk_prob=1, speeds=[80, 100, 120],
            drop_freq_count_low=10, drop_freq_count_high=20, drop_chunk_count_low=10, drop_chunk_count_high=20,
            drop_chunk_length_low=500, drop_chunk_length_high=1000, drop_chunk_noise_factor=0
    ),
    TimeDomainSpecAugment(
            perturb_prob=1, drop_freq_prob=1, drop_chunk_prob=1, speeds=[80, 100, 120],
            drop_freq_count_low=10, drop_freq_count_high=20, drop_chunk_count_low=10, drop_chunk_count_high=20,
            drop_chunk_length_low=500, drop_chunk_length_high=1000, drop_chunk_noise_factor=0
    ),
]
# augmentation_list = [ComposeMany(transforms=[transform], num_augmented_samples=num_augmented_samples) for transform in transforms]

In [24]:
wav, sr = torchaudio.load(audio_file)
play_audio(wav, sr)

In [25]:
for augmentation in augmentation_list:
    if isinstance(augmentation, TimeDomainSpecAugment):
        aug_wavs = augmentation(wav, lengths=torch.ones(1))
    else:
        aug_wavs = augmentation(wav)
    for aug_wav in aug_wavs:
        play_audio(aug_wav.unsqueeze(0), sr)

S2SRNNBeamSearchLM(
  (emb): Embedding(
    (Embedding): Embedding(1000, 128)
  )
  (dec): AttentionalRNNDecoder(
    (proj): Linear(in_features=2048, out_features=1024, bias=True)
    (attn): LocationAwareAttention(
      (mlp_enc): Linear(in_features=512, out_features=1024, bias=True)
      (mlp_dec): Linear(in_features=1024, out_features=1024, bias=True)
      (mlp_attn): Linear(in_features=1024, out_features=1, bias=False)
      (conv_loc): Conv1d(1, 10, kernel_size=(201,), stride=(1,), padding=(100,), bias=False)
      (mlp_loc): Linear(in_features=10, out_features=1024, bias=True)
      (mlp_out): Linear(in_features=512, out_features=1024, bias=True)
      (softmax): Softmax(dim=-1)
    )
    (drop): Dropout(p=0.15, inplace=False)
    (rnn): GRUCell(
      (rnn_cells): ModuleList(
        (0): GRUCell(1152, 1024)
      )
      (dropout_layers): ModuleList()
    )
  )
  (fc): Linear(
    (w): Linear(in_features=1024, out_features=1000, bias=True)
  )
  (softmax): LogSoftmax(dim=-1)
  (lm): RNNLM(
    (embedding): Embedding(
      (Embedding): Embedding(1000, 128)
    )
    (dropout): Dropout(p=0.0, inplace=False)
    (rnn): LSTM(
      (rnn): LSTM(128, 2048, num_layers=2, batch_first=True)
    )
    (dnn): Sequential(
      (linear): Linear(
        (w): Linear(in_features=2048, out_features=512, bias=True)
      )
      (norm): LayerNorm(
        (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      )
      (act): LeakyReLU(negative_slope=0.01)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (out): Linear(
      (w): Linear(in_features=512, out_features=1000, bias=True)
    )
  )
  (log_softmax): Softmax(
    (act): LogSoftmax(dim=-1)
  )
)

In [5]:
split = ["test-other"]
dataset_name = "chime"
dataset_dir = "/home/server08/hdd0/changhun_workspace/CHiME3"
batch_size=1
extra_noise=0
steps = 10
lr = 2e-6

dataset = load_dataset(split, dataset_name, dataset_dir, batch_size, extra_noise)

Read text: 100%|██████████| 2310/2310 [00:00<00:00, 22480.96it/s]

[INFO]    There are 2310 samples.





In [29]:
# augmentation test
model = EncoderDecoderASR.from_hparams("speechbrain/asr-crdnn-rnnlm-librispeech", run_opts={"device" : "cuda"})

gt_texts = []
transcriptions_list = [[] for _ in range(1 + len(augmentation_list))]

for batch_idx, batch in enumerate(dataset):
    lens, wavs, texts, files = batch
    wavs = torch.tensor(wavs)
    gt_texts += texts
    # play_audio(wavs, 16000)
    print(f"ground truth text: {texts}")

    model.eval()
    ori_transcription, _ = model.transcribe_batch(wavs, wav_lens=torch.FloatTensor([wavs.shape[1]]))
    transcriptions_list[0] += ori_transcription
    ori_wer = wer(list(texts), list(ori_transcription))
    print(f"original text : {ori_transcription}")
    print(f"original WER: {ori_wer}")

    for i, augmentation in enumerate(augmentation_list):
        if isinstance(augmentation, TimeDomainSpecAugment):
            aug_wavs = augmentation(wavs, lengths=torch.ones(1))
        else:
            aug_wavs = augmentation(wavs)
        # play_audio(aug_wavs, 16000)
        print(augmentation)
        print(f"aug_wavs.shape : {aug_wavs.shape}")
        aug_transcription, _ = model.transcribe_batch(aug_wavs, wav_lens=torch.FloatTensor([wavs.shape[1]]))
        transcriptions_list[i + 1] += aug_transcription
        aug_wer = wer(list(texts), list(aug_transcription))
        print(f"augmentaiton text : {aug_transcription}")
        print(f"aug WER: {aug_wer}")
    print("\n\n\n")

for transcriptions in transcriptions_list:
    print(wer(list(gt_texts), list(transcriptions)))

ground truth text: ('THE LABOR DEPARTMENT SAID NON FARM PAYROLL EMPLOYMENT INCREASED A ROBUST THREE HUNDRED THIRTY SEVEN THOUSAND LAST MONTH AFTER A REVISED THREE HUNDRED NINETEEN THOUSAND GAIN THE MONTH BEFORE',)
original text : ['HE LABOURED FRANKLY AND ON THE BROAD NARROW SYMPATHY']
original WER: 1.0
PolarityInversion()
aug_wavs.shape : torch.Size([1, 154242])
augmentaiton text : ['HE LABOURED FRANKLY AND ON THE BROAD NARROW SYMPATHY']
aug WER: 1.0
Noise()
aug_wavs.shape : torch.Size([1, 154242])
augmentaiton text : ['THE FAVOURITE SYMPATHY HAD BEEN DROWNED IN THE INTELLECTUAL SYMPATHY AND INTELLECTUAL SENSIBILITY']
aug WER: 0.9655172413793104
Gain()
aug_wavs.shape : torch.Size([1, 154242])
augmentaiton text : ['THE LABOURED ACCOMPLISHED THREE HUNDRED THIRTY SEVEN ACCUSTOMED AS MONSIEUR ADVISED RUINED NINETEEN YEARS AGO ONLY']
aug WER: 0.7931034482758621
Reverb()
aug_wavs.shape : torch.Size([1, 154242])
augmentaiton text : ["HE'S LABORED AWHILE LYNDE AND I'M ON THE BLIND NARROWLY BL

In [None]:
# model = Speech2Text.from_pretrained(
#   "espnet/chai_librispeech_asr_train_conformer-rnn_transducer_raw_en_bpe5000_sp"
#   # "espnet/chai_librispeech_asr_train_rnnt_conformer_raw_en_bpe5000_sp"
# )
# audio = "/home/server08/hdd0/changhun_workspace/LibriSpeech/test-other/367/130732/367-130732-0000.flac"

# import nemo.collections.asr as nemo_asr
# asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained("nvidia/stt_en_conformer_transducer_xlarge")
# print(asr_model.transcribe([audio]))


# from speechbrain.decoders.seq2seq import S2SRNNBeamSearchLM
# from speechbrain.decoders.seq2seq import S2SBeamSearcher

# collect_params 확인해봐야 함 (parameter 개수 비교)

original_model = EncoderDecoderASR.from_hparams("speechbrain/asr-crdnn-rnnlm-librispeech", run_opts={"device" : "cuda"})
model = EncoderDecoderASR.from_hparams("speechbrain/asr-crdnn-rnnlm-librispeech", run_opts={"device" : "cuda"})
params, _ = collect_params(model, train_all=True)
optim = torch.optim.Adam(params, lr=1e-6)
mse = torch.nn.MSELoss()

weak_transforms = [
    RandomApply([PolarityInversion()], p=0.5),
    RandomApply([Noise(min_snr=0.01, max_snr=0.05)], p=0.5),
    # RandomApply([Gain()], p=0.7),
    # RandomApply([HighLowPass(sample_rate=16000)], p=0.7),
    # RandomApply([PitchShift(n_samples=16000*5, sample_rate=16000)], p=0.5),
    RandomApply([Reverb(sample_rate=16000)], p=0.5),
]
weak_augmentation = ComposeMany(transforms=weak_transforms, num_augmented_samples=1)

transcriptions_1 = []
transcriptions_3 = []
transcriptions_5 = []
transcriptions_10 = []
transcriptions_20 = []
transcriptions_40 = []
gt_texts = []
ori_transcriptions = []
durations = []
werrs = []

for batch in dataset:
    lens, wavs, texts, files = batch
    wavs = torch.tensor(wavs)

    model.eval()
    with torch.no_grad():
        ori_transcription, _ = original_model.transcribe_batch(wavs, wav_lens=torch.tensor([1.0]))
    ori_transcriptions += ori_transcription
    ori_wer = wer(list(texts), list(ori_transcription))
    print("\noriginal WER: ", ori_wer)

    for i in range(steps):
        model.train()
        weak_wavs = weak_augmentation(wavs.detach().cpu())
        original_rep = model.encode_batch(wavs, wav_lens=torch.tensor([1.0]))
        weak_rep = model.encode_batch(weak_wavs, wav_lens=torch.tensor([1.0]))

        loss = mse(weak_rep, original_rep.detach())
        optim.zero_grad()
        loss.backward()
        optim.step()

        if i == 0:
            model.eval()
            transcription, _ = model.transcribe_batch(wavs, wav_lens=torch.tensor([1.0]))
            ada_wer = wer(list(texts), list(transcription))
            print("adapt-1 WER: ", ada_wer)
            transcriptions_1 += transcription
        
        if i == 2:
            model.eval()
            transcription, _ = model.transcribe_batch(wavs, wav_lens=torch.tensor([1.0]))
            ada_wer = wer(list(texts), list(transcription))
            print("adapt-3 WER: ", ada_wer)
            transcriptions_3 += transcription
        
        if i == 4:
            model.eval()
            transcription, _ = model.transcribe_batch(wavs, wav_lens=torch.tensor([1.0]))
            ada_wer = wer(list(texts), list(transcription))
            print("adapt-5 WER: ", ada_wer)
            transcriptions_5 += transcription

        if i == 9:
            model.eval()
            transcription, _ = model.transcribe_batch(wavs, wav_lens=torch.tensor([1.0]))
            ada_wer = wer(list(texts), list(transcription))
            print("adapt-10 WER: ", ada_wer)
            transcriptions_10 += transcription

print("original WER:", wer(gt_texts, ori_transcriptions))
if steps >= 10: 
    print("TTA-1 WER:", wer(gt_texts, transcriptions_1))
    print("TTA-3 WER:", wer(gt_texts, transcriptions_3))
    print("TTA-5 WER:", wer(gt_texts, transcriptions_5))
    print("TTA-10 WER:", wer(gt_texts, transcriptions_10))

In [None]:
from speechbrain.lobes.augment import TimeDomainSpecAugment, SpecAugment, EnvCorrupt

weak_augmentation = TimeDomainSpecAugment(
    perturb_prob=0, drop_freq_prob=1, drop_chunk_prob=1, speeds=[100],
    drop_freq_count_low=1, drop_freq_count_high=3, drop_chunk_count_low=1, drop_chunk_count_high=3,
).cuda()

strong_augmentation = TimeDomainSpecAugment(
    perturb_prob=0, drop_freq_prob=1, drop_chunk_prob=1, speeds=[100],
    drop_freq_count_low=5, drop_freq_count_high=10, drop_chunk_count_low=5, drop_chunk_count_high=10,
    drop_chunk_length_low=500, drop_chunk_length_high=1000, drop_chunk_noise_factor=0
).cuda()

weak_wavs = weak_augmentation(wav.cuda(), lengths=torch.ones(1).cuda())
for weak_wav in weak_wavs:
    play_audio(weak_wav.cpu().unsqueeze(0), sr)

strong_wavs = strong_augmentation(wav.cuda(), lengths=torch.ones(1).cuda())
for strong_wav in strong_wavs:
    play_audio(strong_wav.cpu().unsqueeze(0), sr)

In [10]:
from speechbrain.pretrained import SpectralMaskEnhancement
model = SpectralMaskEnhancement.from_hparams(
  "speechbrain/metricgan-plus-voicebank"
)

for batch in dataset:
    lens, wavs, texts, files = batch
    wavs = torch.tensor(wavs)
    play_audio(wavs, 16000)

    enhanced_wavs = model.enhance_batch(wavs, lengths=torch.ones(1))
    play_audio(enhanced_wavs, 16000)

In [23]:
def softmax_entropy(x, dim=-1):
    return -(x.softmax(dim) * x.log_softmax(dim)).sum(dim)

example = torch.randn(3, 10).requires_grad_(True)
loss = softmax_entropy(example).mean()
loss.backward()
print(example.softmax(-1))

lr = 100
example = example - lr * example.grad
print(example.softmax(-1))

tensor([[0.0504, 0.0179, 0.0354, 0.2088, 0.1452, 0.0404, 0.0036, 0.2185, 0.2515,
         0.0283],
        [0.0299, 0.0581, 0.1207, 0.0303, 0.0414, 0.1477, 0.2327, 0.2001, 0.0532,
         0.0861],
        [0.0265, 0.0504, 0.1683, 0.0301, 0.1908, 0.1733, 0.1566, 0.1049, 0.0401,
         0.0589]], grad_fn=<SoftmaxBackward0>)
tensor([[3.6757e-04, 2.3421e-04, 2.9656e-04, 8.6147e-02, 5.3341e-03, 3.1675e-04,
         1.0701e-04, 1.3849e-01, 7.6843e-01, 2.7130e-04],
        [2.1952e-04, 4.0012e-04, 3.1449e-03, 2.2131e-04, 2.7442e-04, 1.0023e-02,
         8.4803e-01, 1.3644e-01, 3.5548e-04, 8.9204e-04],
        [1.1399e-03, 1.8810e-03, 1.6294e-01, 1.2242e-03, 5.1799e-01, 2.0898e-01,
         9.2027e-02, 1.0004e-02, 1.4972e-03, 2.3170e-03]],
       grad_fn=<SoftmaxBackward0>)
