In [None]:
from google.colab import drive
drive.mount('/content/drive/')
%cd drive/MyDrive/IW06-07/

In [None]:
!pip3 install speechbrain
!pip3 install deepspeech-gpu
!pip3 install jiwer
%pip install torchaudio==0.10.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html

In [None]:
!curl -LO https://github.com/mozilla/DeepSpeech/releases/download/v0.9.3/deepspeech-0.9.3-models.pbmm
!curl -LO https://github.com/mozilla/DeepSpeech/releases/download/v0.9.3/deepspeech-0.9.3-models.scorer

In [None]:
import deepspeech
import wave
import numpy as np
import torch
import torchaudio
from speechbrain.pretrained import SpectralMaskEnhancement
from jiwer import compute_measures

Deepspeech model

In [None]:
model_file_path = 'deepspeech-0.9.3-models.pbmm'
model = deepspeech.Model(model_file_path)
scorer_file_path = 'deepspeech-0.9.3-models.scorer'
model.enableExternalScorer(scorer_file_path)
lm_alpha = 0.931289039105002
lm_beta = 1.1834137581510284
model.setScorerAlphaBeta(lm_alpha, lm_beta)

In [None]:
def convert(model, audio):
  w = wave.open(audio, 'r')
  assert int(w.getframerate()) == 16000
  data = np.frombuffer(w.readframes(w.getnframes()), dtype=np.int16)
  return model.stt(data)
def diff(o, a, e):
  return compute_measures(o, a), compute_measures(o, e)

Speech Enhancement Model

In [None]:
enhance_model = SpectralMaskEnhancement.from_hparams(
    source="speechbrain/metricgan-plus-voicebank",
    savedir="pretrained_models/metricgan-plus-voicebank",
    run_opts={"device":"cuda"},
)

Adversarial Dataset A Parsing

In [None]:
def call_enhance_wer(audio_files, original_signal, adv_signals, enhanced_path):
  # long, medium, short
  for audio in audio_files:
    o_file = original_signal + audio
    ori = convert(model, o_file)
    for i, post in enumerate(['long', 'medium', 'short']):
      a_file = adv_signals[i] + audio
      noisy = enhance_model.load_audio(a_file).unsqueeze(0)
      enhanced = enhance_model.enhance_batch(noisy, lengths=torch.tensor([1.]))
      enh_name = enhanced_path + post + '/enhanced-' + audio
      torchaudio.save(enh_name, enhanced.cpu(), sample_rate=16000, bits_per_sample=16)

      enh = convert(model, enh_name)
      adv = convert(model, a_file)
      editors[i].write(ori + ', ' + adv + ', ' + enh + '\n')
      
      stats_adv, stats_enh = diff(ori, adv, enh)
      edit_stats_adv[i].append(stats_adv)
      edit_stats_enh[i].append(stats_enh)
      average_wer_adv[i] += stats_adv['wer']
      average_wer_enh[i] += stats_enh['wer']

In [None]:
long_signal = './adversarial_dataset-A/Adversarial-Examples/long-signals'
long_signals_ori = long_signal + '/Original-examples/sample-'
long_signals_adv = [long_signal + '/adv-long-target/adv-long2long-', long_signal + '/adv-medium-target/adv-long2medium-', long_signal + '/adv-short-target/adv-long2short-']

medium_signal = './adversarial_dataset-A/Adversarial-Examples/medium-signals'
medium_signals_ori = medium_signal + '/Original-examples/sample-'
medium_signals_adv = [medium_signal + '/adv-long-target/adv-medium2long-', medium_signal + '/adv-medium-target/adv-medium2medium-', medium_signal + '/adv-short-target/adv-medium2short-']

short_signal = './adversarial_dataset-A/Adversarial-Examples/short-signals'
short_signals_ori = short_signal + '/Original-examples/sample-'
short_signals_adv = [short_signal + '/adv-long-target/adv-short2long-', short_signal + '/adv-medium-target/adv-short2medium-', short_signal + '/adv-short-target/adv-short2short-']

## Define short constants

In [None]:
enhanced_path = './enhanced/baseline/short/'
shortf = open("short")
shortfs = shortf.read().split()
shortf.close()

## Define long constants

In [None]:
enhanced_path = './enhanced/baseline/long/'
longf = open("long")
longfs = longf.read().split()
longf.close()

## Define medium constants

In [None]:
enhanced_path = './enhanced/baseline/medium/'
mediumf = open("medium")
mediumfs = mediumf.read().split()
mediumf.close()

## Call models

In [None]:
edit_stats_adv = [[], [], []]
edit_stats_enh = [[], [], []]
average_wer_adv = [0, 0, 0]
average_wer_enh = [0, 0, 0]

In [None]:
editors = [open(enhanced_path + 'long/translation', 'w+'), open(enhanced_path + 'medium/translation', 'w+'), open(enhanced_path + 'short/translation', 'w+')]

In [None]:
call_enhance_wer(mediumfs, medium_signals_ori, medium_signals_adv, enhanced_path)

In [None]:
print(enhanced_path)

In [None]:
for editor in editors:
  editor.close()

In [None]:
short_stats = open(enhanced_path + 'stats', 'w+')
short_stats.write(str(edit_stats_adv)+'\n')
short_stats.write(str(edit_stats_enh)+'\n')
short_stats.write(str(len(shortfs))+'\n')
short_stats.write(str(average_wer_adv)+'\n')
short_stats.write(str(average_wer_enh)+'\n')

In [None]:
short_stats.close()

In [None]:
a = [0, 0, 0]
e = [0, 0, 0]
for i in range(len(adv)):
  for j in range(len(adv[i])):
    a[i] += adv[i][j]['wer']
    e[i] += enh[i][j]['wer']

In [None]:
print(a)
print(e)