In [None]:
import gc
import json
import os
import pickle
import pprint
from functools import lru_cache
from pathlib import Path

import jiwer
import torchaudio
import torchaudio.transforms as T
from datasets import load_dataset
from dsets.dset_config.dset_config import DatasetConfig
from dsets.dsets import ENGLISH_DATASETS, get_datasets
from dsets.helpers.helpers import apply_parallel
from fastai.callback.tensorboard import TensorBoardCallback
from fastai.data.all import *
from fastai.callback.all import SaveModelCallback
from fastai.text.all import *
from fastai.vision.all import *
from itpsaudio.aug_transforms import AddNoise, RandomReverbration
from itpsaudio.core import *
from itpsaudio.transforms import *
from transformers import AutoModelForCTC, Wav2Vec2CTCTokenizer, Wav2Vec2ForCTC

TEST_RUN=False
NUM_EPOCHS=1

datadir = Path("../../data/")
modeldir = datadir / "models" / "audio_en"

# pretrained_model_name = "facebook/wav2vec2-xls-r-300m"
# pretrained_model_name = "facebook/wav2vec2-base"
pretrained_model_name = "OthmaneJ/distil-wav2vec2"
pretrained_model_save_name = pretrained_model_name.replace("/", "_")

pretrained_save_path = modeldir / pretrained_model_save_name
logdir = datadir / "logs" /"audio_en" / pretrained_model_save_name

en_vocab = "../../notebooks/assets/en_vocab.json"


In [None]:
datasets = [
    # DatasetConfig(name='itps', split='train', lang='en', kind=None),
    DatasetConfig(name='librispeech', split='dev', lang=None, kind='clean'),
    # DatasetConfig(name='librispeech', split='train', lang=None, kind='clean'),
    # DatasetConfig(name='librispeech', split='train', lang=None, kind='other'),
    DatasetConfig(name='ljl', split='train', lang=None, kind=None),
    DatasetConfig(name='nict_spreds', split='train', lang='en', kind=None)
 ]

if TEST_RUN:
  if datasets == []:
    p, df = get_datasets([ENGLISH_DATASETS[0]])
  else:
    p, df = get_datasets([datasets[0]])
else:
  if datasets == []:
    df = load_dataset("common_voice", "en",split="train")
  else:
    p, df = get_datasets(datasets)
if os.path.exists("df.pkl"):
  df = pd.read_pickle("df.pkl")



 # Model Training

In [None]:


@lru_cache(maxsize=None)
def get_audio_length(s):
  t, sr = torchaudio.load(s)
  return len(t[0])/sr


In [None]:

wav2vec2tok = Wav2Vec2CTCTokenizer(en_vocab,bos_token="[BOS]",
                                   eos_token="[EOS]",
                                   unk_token="[UNK]",
                                   pad_token="[PAD]",
                                   )
tok = ENTransformersTokenizer(tok=wav2vec2tok)


In [None]:
if not "audio_length" in df.columns:
  df["audio_length"] = df["filename"].apply(get_audio_length).copy()


In [None]:
df["audio_length"].max()


In [None]:
df["audio_length"].plot.hist()


In [None]:
df["audio_length"].sum() / 60 / 60


In [None]:
MAX_AUDIO_LENGTH=15
df = df[df["audio_length"]<MAX_AUDIO_LENGTH].reset_index(drop=True)
df = df[~df["text"].isna()].reset_index(drop=True)


In [None]:
df["audio_length"].sum() / 60 / 60


In [None]:
splits=RandomSplitter(valid_pct=0.2)(df)


In [None]:
tfms = TfmdLists(df, AudioBatchTransform(), splits=splits)


In [None]:


def load_t_model(mod_path,
        attention_dropout=0.08,
        hidden_dropout=0.08,
        feat_proj_dropout=0.08,
        mask_time_prob=0.05,
        mask_feature_prob=0.05,
        layerdrop=0.08,
        ctc_zero_infinity=True,
        pad_token_id=tok.tokenizer.pad_token_id,
        vocab_size=len(tok.tokenizer),
        **kwargs
):
    return Wav2Vec2ForCTC.from_pretrained(
        mod_path,
        attention_dropout=attention_dropout,
        hidden_dropout=hidden_dropout,
        feat_proj_dropout=feat_proj_dropout,
        mask_time_prob=mask_time_prob,
        mask_feature_prob=mask_feature_prob,
        layerdrop=layerdrop,
        ctc_zero_infinity=ctc_zero_infinity,
        pad_token_id=pad_token_id,
        vocab_size=vocab_size,
        **kwargs,
    )


In [None]:


@lru_cache(maxsize=None)
def get_sr(x):
    _, sr=torchaudio.load(x)
    return sr

if not "sr" in df.columns:
  df["sr"] = df["filename"].apply(get_sr)
  df.to_pickle("df.pkl")


In [None]:
df["sr"].unique()


In [None]:
SAMPLE_NOISE_URL = "https://pytorch-tutorial-assets.s3.amazonaws.com/VOiCES_devkit/distant-16k/distractors/rm1/babb/Lab41-SRI-VOiCES-rm1-babb-mc01-stu-clo.wav"

with open("noise_sample.wav", "wb") as f:
    r = requests.get(SAMPLE_NOISE_URL)
    f.write(r.content)



In [None]:
train_text_lens = df.loc[splits[0], "audio_length"].to_list()
val_text_lens = df.loc[splits[1], "audio_length"].to_list()
srtd_dl=partial(SortedDL, res = train_text_lens)

dl_kwargs = [{},{'val_res': val_text_lens}]

noise, sr = torchaudio.load("noise_sample.wav")
noise = TensorAudio(noise, sr=sr)
noise_t = AddNoise(range(1, 10), noise, p=1)

resampler = dict()
resampler[48000] = T.Resample(48000, 16000)
resampler[16000] = noop
resampler[22050] = T.Resample(22050, 16000)
for sr in df["sr"].unique():
    resampler[int(sr)] = T.Resample(sr, 16000)

@Transform
def resample(x: TensorAudio):
    sr = x.sr
    if not sr in resampler.keys():
        resampler[sr] = T.Resample(sr, 16000)
    return TensorAudio(resampler[sr](x), sr=16000)

class Range():
  def __init__(self, start, stop):
    self.start,self.stop=start, stop

dls = tfms.dataloaders(bs=2,
                        after_item=[RandomReverbration(p=0.2),
                                    AddNoise(Range(5, 10), noise,power=3, p=0.2),
                                    resample,
                                    tok,
                                    ],
                        # TEXT HAS TO BE PADDED WITH -100 WHEN USING TRANSFORMERS LOSS
                        # AUDIO CAN BE ANYTHING
                        before_batch=[Pad_Audio_Batch(pad_idx_audio=0,
                                                      pad_idx_text=tok.tokenizer.pad_token_id,
                                                      pad_first=True,
                                                      seq_len=1),
                                       squeeze,
                                      ],
                        shuffle=True,
                        n_inp=1,
                        dl_type=srtd_dl,
                        dl_kwargs=dl_kwargs
                       )


In [None]:
dls.one_batch()


In [None]:
dls.show_batch(tok=tok, unique=False)


In [None]:
def wer(pred, labels):
    pred_logits = pred.logits
    pred_ids = np.argmax(pred_logits.detach().cpu().numpy(), axis=-1)
    pred_str = tok.batch_decode(pred_ids)
    label_str = tok.batch_decode(labels)
    wer = jiwer.wer(label_str, pred_str)
    return wer

def cer(pred, labels):
    pred_logits = pred.logits
    pred_ids = np.argmax(pred_logits.detach().cpu().numpy(), axis=-1)
    pred_str = tok.batch_decode(pred_ids)
    label_str = tok.batch_decode(labels)
    cer = jiwer.cer(label_str, pred_str)
    return cer


class TransformersLearner(Learner):
    def _do_one_batch(self):
        self.pred = self.model(self.xb[0], labels=cast(self.yb[0], torch.Tensor))
        self('after_pred')
        self.loss_grad = self.pred["loss"]
        self.loss = self.loss_grad.clone()
        self.smooth_loss = self.loss_grad.clone()
        self('after_loss')
        if not self.training or not len(self.yb): return
        self('before_backward')
        self.loss_grad.backward()
        self._with_events(self.opt.step, 'step', CancelStepException)
        self.opt.zero_grad()


In [None]:
cbs=[TensorBoardCallback(log_dir=logdir,trace_model=False,log_preds=False),
     SaveModelCallback(comp=np.less, monitor="cer", fname=modeldir / "save_model_cb"), 
     ]

metrics = [Perplexity(), wer,cer]
learn = TransformersLearner(dls, load_t_model(pretrained_model_name),
                loss_func=noop, # Loss is calculated in Transformers internally
                metrics=metrics,
                cbs=cbs)


In [None]:

torch.cuda.empty_cache()
gc.collect()


In [None]:
dls.one_batch()


In [None]:
dls.show_batch(tok=tok)


In [None]:
start_lr=1e-7
end_lr=10
r = learn.lr_find(start_lr=start_lr,
                  end_lr=end_lr,
                  num_it=100,
                  stop_div=True,
                  suggest_funcs=())


In [None]:

def save_model(learn, pretrained_save_path):
  pretrained_save_path="/content/drive/MyDrive/data/models/wav2vecaug_pre_300m"
  learn.model.save_pretrained(pretrained_save_path)
  with open(Path(pretrained_save_path) / "en_vocab.json", "w") as f:
      json.dump(tok.tokenizer.get_vocab(), f)



In [None]:

if TEST_RUN:
  learn.fit_one_cycle(1,1e-3)
else:
  learn.model.freeze_feature_extractor()
  learn.fit_one_cycle(NUM_EPOCHS, lr_max=1e-4, cbs=cbs)
  save_model(learn, pretrained_save_path)
#   learn.model = load_model(pretrained_save_path)
#   learn.fit_one_cycle(1, lr_max=1e-4,)
#   save_model(learn, pretrained_save_path)


In [None]:
if TEST_RUN:
  test_datasets = [
      DatasetConfig(name='librispeech', split='test', lang=None, kind='clean'),
      # DatasetConfig(name='librispeech', split='test', lang=None, kind='other'),
      # DatasetConfig(name='ljl', split='test', lang=None, kind=None),
      # DatasetConfig(name='nict_spreds', split='test', lang='en', kind=None¡£)
  ]
else:
  test_datasets = [
      DatasetConfig(name='librispeech', split='test', lang=None, kind='clean'),
      # DatasetConfig(name='librispeech', split='test', lang=None, kind='other'),
      # DatasetConfig(name='ljl', split='test', lang=None, kind=None),
      # DatasetConfig(name='nict_spreds', split='test', lang='en', kind=None)
  ]


In [None]:
tp, tdf = get_datasets(test_datasets)


In [None]:
tdf["audio_length"] = apply_parallel(tdf["filename"], get_audio_length, 16)
tdf = tdf[tdf["audio_length"]<15].reset_index(drop=True)
tdf = tdf[~tdf["text"].isna()].reset_index(drop=True)

abt = AudioBatchTransform()
t_tfms = TfmdLists(tdf, abt)


In [None]:
if TEST_RUN:
  t_tfms = TfmdLists(tdf.iloc[:100], abt)
else:
  t_tfms = TfmdLists(tdf, abt)


In [None]:
t_dl = dls.new(t_tfms)


In [None]:
learn.cbs


In [None]:
learn.remove_cb(learn.cbs[4]).validate(dl=t_dl)


In [None]:
def get_preds(xs):
  preds=learn.model(xs)
  pred_logits=preds.logits
  pred_ids=TensorText(np.argmax(pred_logits.detach().cpu().numpy(), axis=-1))
  pred_str = tok.batch_decode(pred_ids)
  return pred_str
  
for xs, y in iter(t_dl):
  print(wer(learn.model(xs), y))
  print(cer(learn.model(xs), y))
  pprint.pprint(dict(enumerate(list(zip(get_preds(xs), tok.batch_decode(y))))))
  break



In [None]:
comp = [(get_preds(xs), tok.batch_decode(y)) for xs, y in iter(t_dl)]


In [None]:
for i, (x_pair, y_pair) in enumerate(comp):
  print("Pred: ", x_pair[0])
  print("Targ: ", y_pair[0])
  print("Pred: ", x_pair[1])
  print("Targ: ", y_pair[1])
  if (i+1 % 10) == 0:
    break



In [None]:
if not TEST_RUN:
  learn.save("/content/drive/MyDrive/data/models/audio_en/export")
  with open("/content/drive/MyDrive/data/models/audio_en/export_tokenizer.pkl", "wb") as f:
    pickle.dump(tok, f)
  learn.model.save_pretrained("/content/drive/MyDrive/data/models/audio_en/")
  torch.save(learn.model, "/content/drive/MyDrive/data/models/audio_en/export_torch_model.pth")


In [None]:
# neptune.stop()
