In [None]:
import copy
import torch
import torch.nn as nn
import torchaudio
from IPython.display import Audio, display

from speechbrain.pretrained import EncoderDecoderASR
# from espnet2.bin.asr_inference import Speech2Text

from audio_augmentations import *
from speechbrain.lobes.augment import TimeDomainSpecAugment

from jiwer import wer

from data import load_dataset

In [None]:
def collect_params(model, train_params, bias_only=False):
    params = []
    names = []

    for np, p in model.named_parameters():
        collect = False
        if "all" in train_params:
            collect = True
        if 'enc' in train_params and 'enc' in str(np):
            collect = True
        if 'dec' in train_params and 'dec' in str(np):
            collect = True
        if 'linear' in train_params and 'fc' in str(np):
            collect = True
        if 'LN' in train_params and 'norm' in str(np):
            collect = True

        if collect:
            p.requires_grad = True
            params.append(p)
            names.append(str(np))

    return params, names

split = ["test-other"]
# dataset_name = "chime"
# dataset_dir = "/home/server08/hdd0/changhun_workspace/CHiME3"
dataset_name = 'librispeech'
dataset_dir = '/home/server17/hdd/changhun_workspace/LibriSpeech'

batch_size=1
extra_noise=0.00
steps = 10
lr = 2e-5

dataset = load_dataset(split, dataset_name, dataset_dir, batch_size, extra_noise)

original_model = EncoderDecoderASR.from_hparams("speechbrain/asr-crdnn-rnnlm-librispeech", run_opts={"device" : "cuda"})
model = EncoderDecoderASR.from_hparams("speechbrain/asr-crdnn-rnnlm-librispeech", run_opts={"device" : "cuda"}).requires_grad_(True)
params, _ = collect_params(model, train_params=['all'])
optim = torch.optim.Adam(params, lr=lr)
mse = nn.MSELoss()
l1_loss = nn.L1Loss()
model_mse = nn.MSELoss()

transcriptions_1 = []
transcriptions_3 = []
transcriptions_5 = []
transcriptions_10 = []
transcriptions_20 = []
transcriptions_40 = []
gt_texts = []
ori_transcriptions = []
durations = []
werrs = []

In [None]:
for batch in dataset:
    lens, wavs, texts, files = batch
    wavs = torch.tensor(wavs)

    model.eval()
    noise = (0.01 * torch.randn_like(wavs)).requires_grad_(True)
    with torch.no_grad():
        ori_transcription, _ = model.transcribe_batch(wavs, wav_lens=torch.ones(len(wavs)))
    ori_wer = wer(list(texts), list(ori_transcription))
    print("\noriginal WER: ", ori_wer)

    for step_idx in range(5):
        model.train()
        clean_enc_outputs = model.encode_batch(wavs, wav_lens=torch.ones(len(wavs)))
        noisy_enc_outputs = model.encode_batch(wavs + noise, wav_lens=torch.ones(len(wavs)))

        noise.grad = torch.zeros_like(noise)
        loss = mse(clean_enc_outputs.detach(), noisy_enc_outputs)
        loss.backward(retain_graph=True)
        noise = noise + 0.3 * noise.grad

        clean_enc_outputs = model.encode_batch(wavs, wav_lens=torch.ones(len(wavs)))
        ada_noisy_enc_outputs = model.encode_batch(wavs + noise, wav_lens=torch.ones(len(wavs)))
        model_loss = model_mse(clean_enc_outputs.detach(), ada_noisy_enc_outputs)
        optim.zero_grad()
        model_loss.backward()
        optim.step()

        model.eval()
        adapt_transcription, _ = model.transcribe_batch(wavs, wav_lens=torch.ones(len(wavs)))

        adapt_wer = wer(list(texts), list(ori_transcription))
        print(f"{step_idx}-th adapt WER: ", adapt_wer)

    # ada_noisy_transcription, _ = model.transcribe_batch(wavs + noise, wav_lens=torch.ones(len(wavs)))
    # ori_wer = wer(list(texts), list(ada_noisy_transcription))
    # print(f"adapt noisy WER: {ori_wer}")

    # print("\n\n\n\n\n\n\n")

    # model_loss = mse(clean_enc_outputs, ada_noisy_enc_outputs)

    # optim.zero_grad()
    # model_loss.backward()
    # optim.step()

    # for i in range(steps):
    #     model.train()
    #     weak_wavs = weak_augmentation(wavs.detach().cpu())
    #     original_rep = model.encode_batch(wavs, wav_lens=torch.tensor([1.0]))
    #     weak_rep = model.encode_batch(weak_wavs, wav_lens=torch.tensor([1.0]))

    #     loss = mse(weak_rep, original_rep.detach())
    #     optim.zero_grad()
    #     loss.backward()
    #     optim.step()

    #     if i == 0:
    #         model.eval()
    #         transcription, _ = model.transcribe_batch(wavs, wav_lens=torch.tensor([1.0]))
    #         ada_wer = wer(list(texts), list(transcription))
    #         print("adapt-1 WER: ", ada_wer)
    #         transcriptions_1 += transcription
        
    #     if i == 2:
    #         model.eval()
    #         transcription, _ = model.transcribe_batch(wavs, wav_lens=torch.tensor([1.0]))
    #         ada_wer = wer(list(texts), list(transcription))
    #         print("adapt-3 WER: ", ada_wer)
    #         transcriptions_3 += transcription
        
    #     if i == 4:
    #         model.eval()
    #         transcription, _ = model.transcribe_batch(wavs, wav_lens=torch.tensor([1.0]))
    #         ada_wer = wer(list(texts), list(transcription))
    #         print("adapt-5 WER: ", ada_wer)
    #         transcriptions_5 += transcription

    #     if i == 9:
    #         model.eval()
    #         transcription, _ = model.transcribe_batch(wavs, wav_lens=torch.tensor([1.0]))
    #         ada_wer = wer(list(texts), list(transcription))
    #         print("adapt-10 WER: ", ada_wer)
    #         transcriptions_10 += transcription

# print("original WER:", wer(gt_texts, ori_transcriptions))
# if steps >= 10: 
#     print("TTA-1 WER:", wer(gt_texts, transcriptions_1))
#     print("TTA-3 WER:", wer(gt_texts, transcriptions_3))
#     print("TTA-5 WER:", wer(gt_texts, transcriptions_5))
#     print("TTA-10 WER:", wer(gt_texts, transcriptions_10))