In [1]:
import torchaudio
from audio_augmentations import *
from IPython.display import Audio, display

In [2]:
def play_audio(waveform, sample_rate):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    if num_channels == 1:
        display(Audio(waveform[0], rate=sample_rate))
    elif num_channels == 2:
        display(Audio((waveform[0], waveform[1]), rate=sample_rate))
    else:
        raise ValueError("Waveform with more than 2 channels are not supported.")

num_augmented_samples = 1
audio_file = "/home/server08/hdd0/changhun_workspace/LibriSpeech/test-other/367/130732/367-130732-0000.flac"

# transforms = [
#     RandomApply([PolarityInversion()], p=0.8),
#     RandomApply([Noise(min_snr=0.001, max_snr=0.005)], p=0.3),
#     RandomApply([Gain()], p=0.2),
#     HighLowPass(sample_rate=sr),
#     RandomApply([Delay(sample_rate=sr)], p=0.5),
#     RandomApply([PitchShift(
#         n_samples=sr*5,
#         sample_rate=sr
#     )], p=0.4),
#     RandomApply([Reverb(sample_rate=sr)], p=0.3)
# ]


transforms = [
    PolarityInversion(),
    Noise(min_snr=0.01, max_snr=0.05),
#     Gain(),
#     HighLowPass(sample_rate=sr),
#     Delay(sample_rate=sr),
    PitchShift(n_samples=16000*5, sample_rate=16000, pitch_cents_min=-7.0, pitch_cents_max=-3.0),
    Reverb(sample_rate=16000),
]
augmentation_list = [ComposeMany(transforms=[transform], num_augmented_samples=num_augmented_samples) for transform in transforms]

In [3]:
wav, sr = torchaudio.load(audio_file)
play_audio(wav, sr)

In [4]:
for augmentation in augmentation_list:
    aug_wavs = augmentation(wav)
    for aug_wav in aug_wavs:
        play_audio(aug_wav.unsqueeze(0), sr)

S2SRNNBeamSearchLM(
  (emb): Embedding(
    (Embedding): Embedding(1000, 128)
  )
  (dec): AttentionalRNNDecoder(
    (proj): Linear(in_features=2048, out_features=1024, bias=True)
    (attn): LocationAwareAttention(
      (mlp_enc): Linear(in_features=512, out_features=1024, bias=True)
      (mlp_dec): Linear(in_features=1024, out_features=1024, bias=True)
      (mlp_attn): Linear(in_features=1024, out_features=1, bias=False)
      (conv_loc): Conv1d(1, 10, kernel_size=(201,), stride=(1,), padding=(100,), bias=False)
      (mlp_loc): Linear(in_features=10, out_features=1024, bias=True)
      (mlp_out): Linear(in_features=512, out_features=1024, bias=True)
      (softmax): Softmax(dim=-1)
    )
    (drop): Dropout(p=0.15, inplace=False)
    (rnn): GRUCell(
      (rnn_cells): ModuleList(
        (0): GRUCell(1152, 1024)
      )
      (dropout_layers): ModuleList()
    )
  )
  (fc): Linear(
    (w): Linear(in_features=1024, out_features=1000, bias=True)
  )
  (softmax): LogSoftmax(dim=-1)
  (lm): RNNLM(
    (embedding): Embedding(
      (Embedding): Embedding(1000, 128)
    )
    (dropout): Dropout(p=0.0, inplace=False)
    (rnn): LSTM(
      (rnn): LSTM(128, 2048, num_layers=2, batch_first=True)
    )
    (dnn): Sequential(
      (linear): Linear(
        (w): Linear(in_features=2048, out_features=512, bias=True)
      )
      (norm): LayerNorm(
        (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      )
      (act): LeakyReLU(negative_slope=0.01)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (out): Linear(
      (w): Linear(in_features=512, out_features=1000, bias=True)
    )
  )
  (log_softmax): Softmax(
    (act): LogSoftmax(dim=-1)
  )
)

In [2]:
import torch
import torch.nn as nn

from jiwer import wer

from data import load_dataset
from main import collect_params, setup_optimizer

split = ["test-other"]
dataset_name = "chime"
dataset_dir = "/home/server08/hdd0/changhun_workspace/CHiME3"
batch_size=1
extra_noise=0
steps = 10
lr = 2e-6

dataset = load_dataset(split, dataset_name, dataset_dir, batch_size, extra_noise)

In [3]:
dataset = load_dataset(split, dataset_name, dataset_dir, batch_size, extra_noise)

Read text: 100%|██████████| 2310/2310 [00:00<00:00, 23265.43it/s]

[INFO]    There are 2310 samples.





In [4]:
from speechbrain.pretrained import EncoderDecoderASR
import torchaudio
import torch
from audio_augmentations import *
# from espnet2.bin.asr_inference import Speech2Text

# model = Speech2Text.from_pretrained(
#   "espnet/chai_librispeech_asr_train_conformer-rnn_transducer_raw_en_bpe5000_sp"
#   # "espnet/chai_librispeech_asr_train_rnnt_conformer_raw_en_bpe5000_sp"
# )
# audio = "/home/server08/hdd0/changhun_workspace/LibriSpeech/test-other/367/130732/367-130732-0000.flac"

# import nemo.collections.asr as nemo_asr
# asr_model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained("nvidia/stt_en_conformer_transducer_xlarge")
# print(asr_model.transcribe([audio]))


# from speechbrain.decoders.seq2seq import S2SRNNBeamSearchLM
# from speechbrain.decoders.seq2seq import S2SBeamSearcher

# collect_params 확인해봐야 함 (parameter 개수 비교)

original_model = EncoderDecoderASR.from_hparams("speechbrain/asr-crdnn-rnnlm-librispeech", run_opts={"device" : "cuda"})
model = EncoderDecoderASR.from_hparams("speechbrain/asr-crdnn-rnnlm-librispeech", run_opts={"device" : "cuda"})
params, _ = collect_params(model, train_all=True)
optim, _ = setup_optimizer(params, lr=1e-6)
mse = torch.nn.MSELoss()

weak_transforms = [
    RandomApply([PolarityInversion()], p=0.5),
    RandomApply([Noise(min_snr=0.01, max_snr=0.05)], p=0.5),
    # RandomApply([Gain()], p=0.7),
    # RandomApply([HighLowPass(sample_rate=16000)], p=0.7),
    # RandomApply([PitchShift(n_samples=16000*5, sample_rate=16000)], p=0.5),
    RandomApply([Reverb(sample_rate=16000)], p=0.5),
]
weak_augmentation = ComposeMany(transforms=weak_transforms, num_augmented_samples=1)

transcriptions_1 = []
transcriptions_3 = []
transcriptions_5 = []
transcriptions_10 = []
transcriptions_20 = []
transcriptions_40 = []
gt_texts = []
ori_transcriptions = []
durations = []
werrs = []

for batch in dataset:
    lens, wavs, texts, files = batch
    wavs = torch.tensor(wavs)

    model.eval()
    with torch.no_grad():
        ori_transcription, _ = original_model.transcribe_batch(wavs, wav_lens=torch.tensor([1.0]))
    ori_transcriptions += ori_transcription
    ori_wer = wer(list(texts), list(ori_transcription))
    print("\noriginal WER: ", ori_wer)

    for i in range(steps):
        model.train()
        weak_wavs = weak_augmentation(wavs.detach().cpu())
        original_rep = model.encode_batch(wavs, wav_lens=torch.tensor([1.0]))
        weak_rep = model.encode_batch(weak_wavs, wav_lens=torch.tensor([1.0]))

        loss = mse(weak_rep, original_rep.detach())
        optim.zero_grad()
        loss.backward()
        optim.step()

        if i == 0:
            model.eval()
            transcription, _ = model.transcribe_batch(wavs, wav_lens=torch.tensor([1.0]))
            ada_wer = wer(list(texts), list(transcription))
            print("adapt-1 WER: ", ada_wer)
            transcriptions_1 += transcription
        
        if i == 2:
            model.eval()
            transcription, _ = model.transcribe_batch(wavs, wav_lens=torch.tensor([1.0]))
            ada_wer = wer(list(texts), list(transcription))
            print("adapt-3 WER: ", ada_wer)
            transcriptions_3 += transcription
        
        if i == 4:
            model.eval()
            transcription, _ = model.transcribe_batch(wavs, wav_lens=torch.tensor([1.0]))
            ada_wer = wer(list(texts), list(transcription))
            print("adapt-5 WER: ", ada_wer)
            transcriptions_5 += transcription

        if i == 9:
            model.eval()
            transcription, _ = model.transcribe_batch(wavs, wav_lens=torch.tensor([1.0]))
            ada_wer = wer(list(texts), list(transcription))
            print("adapt-10 WER: ", ada_wer)
            transcriptions_10 += transcription

print("original WER:", wer(gt_texts, ori_transcriptions))
if steps >= 10: 
    print("TTA-1 WER:", wer(gt_texts, transcriptions_1))
    print("TTA-3 WER:", wer(gt_texts, transcriptions_3))
    print("TTA-5 WER:", wer(gt_texts, transcriptions_5))
    print("TTA-10 WER:", wer(gt_texts, transcriptions_10))

  super(AdamW, self).__init__(params, defaults)



original WER:  1.0
adapt-1 WER:  0.7931034482758621
adapt-3 WER:  1.0
adapt-5 WER:  1.0
adapt-10 WER:  0.9655172413793104

original WER:  0.13793103448275862
adapt-1 WER:  0.10344827586206896
adapt-3 WER:  0.10344827586206896
adapt-5 WER:  0.10344827586206896
adapt-10 WER:  0.10344827586206896

original WER:  0.5517241379310345
adapt-1 WER:  0.7931034482758621
adapt-3 WER:  0.8275862068965517
adapt-5 WER:  0.7586206896551724
adapt-10 WER:  0.8275862068965517

original WER:  0.20689655172413793
adapt-1 WER:  0.1724137931034483
adapt-3 WER:  0.1724137931034483
adapt-5 WER:  0.1724137931034483
adapt-10 WER:  0.20689655172413793

original WER:  0.20689655172413793
adapt-1 WER:  0.20689655172413793
adapt-3 WER:  0.20689655172413793
adapt-5 WER:  0.20689655172413793
adapt-10 WER:  0.20689655172413793

original WER:  0.13793103448275862
adapt-1 WER:  0.13793103448275862
adapt-3 WER:  0.13793103448275862
adapt-5 WER:  0.1724137931034483
adapt-10 WER:  0.1724137931034483

original WER:  0.1724