# Interactive train and test with WER all

## Setup

In [None]:
!rm -rf runs

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('/home/catskills/Desktop/openasr20/end2end_asr_pytorch')

import os
os.environ['IN_JUPYTER']='True'

from chunk_size import chunk_size
from glob import glob
from models.asr.transformer import Transformer, Encoder, Decoder
from torch.autograd import Variable
from torch.cuda.amp import GradScaler, autocast
from utils import constant
from utils.data_loader import SpectrogramDataset, AudioDataLoader, BucketingSampler
from utils.functions import save_model, load_model, init_transformer_model, init_optimizer
from utils.lstm_utils import LM
from utils.metrics import calculate_metrics, calculate_cer, calculate_wer, calculate_cer_en_zh
from utils.optimizer import NoamOpt
import json, logging, math, os, random, time, torch, sys, random
import numpy as np
import torch.nn as nn
from Trainer import Trainer

language='amharic'
stage='NIST'
chunks = list(sorted(glob(f'{stage}/openasr20_{language}/build/transcription_split/*.txt')))
print(len(chunks), 'chunks')
size_chunks=list(sorted([chunk_size(fn) for fn in chunks]))
size_chunk_files={x:[] for x,y in size_chunks}
for x,y in size_chunks:
    size_chunk_files[x].append(y)
size_chunk_distribution={x:0 for x,y in size_chunks}
for x,y in size_chunks:
    size_chunk_distribution[x] += 1

sizes=[x for x in size_chunk_files]

size_chunks[-3:]

L=[]
for n, text in reversed(sorted(size_chunks)):
    audio=text.replace('transcription','audio').replace('txt', 'wav')
    L.append(f'{audio},{text}')

L[0]

print(f'{len(L)} samples in training set')

manifest_file_path=f'analysis/{language}/size_1.csv'
with open(manifest_file_path,'w') as f:
    f.write('\n'.join(L))

model_dir=f'save/{language}_end2end_asr_pytorch_drop0.1_cnn_batch12_4_vgg_layer4'

args=constant.args
args.continue_from=None
args.cuda = True
args.labels_path = f'analysis/{language}/{language}_characters.json'
args.lr = 1e-4
args.name = f'{language}_end2end_asr_pytorch_drop0.1_cnn_batch12_4_vgg_layer4'
args.save_folder = f'save'
args.epochs = 5
args.save_every = 1
args.feat_extractor = f'vgg_cnn'
args.dropout = 0.1
args.num_layers = 4
args.num_heads = 8
args.dim_model = 512
args.dim_key = 64
args.dim_value = 64
args.dim_input = 161
args.dim_inner = 2048
args.dim_emb = 512
args.shuffle=True
args.min_lr = 1e-6
args.k_lr = 1
args.sample_rate=8000
args.train_manifest_list = [manifest_file_path]
args.continue_from=f'{model_dir}/best_model.th'

args.augment=True

audio_conf = dict(sample_rate=args.sample_rate,
                  window_size=args.window_size,
                  window_stride=args.window_stride,
                  window=args.window,
                  noise_dir=args.noise_dir,
                  noise_prob=args.noise_prob,
                  noise_levels=(args.noise_min, args.noise_max))

with open(args.labels_path, 'r') as label_file:
    labels = str(''.join(json.load(label_file)))

# add PAD_CHAR, SOS_CHAR, EOS_CHAR
labels = constant.PAD_CHAR + constant.SOS_CHAR + constant.EOS_CHAR + labels
label2id, id2label = {}, {}
count = 0
for i in range(len(labels)):
    if labels[i] not in label2id:
        label2id[labels[i]] = count
        id2label[count] = labels[i]
        count += 1
    else:
        print("multiple label: ", labels[i])

if constant.args.continue_from:
        model, opt, epoch, metrics, loaded_args, label2id, id2label = load_model(
            constant.args.continue_from)
        start_epoch = epoch  # index starts from zero
        verbose = constant.args.verbose
else:
    model = init_transformer_model(constant.args, label2id, id2label)
    opt = init_optimizer(constant.args, model, "noam")

start_epoch = epoch
metrics = None
loaded_args = None
verbose = True

constant.USE_CUDA=True

train_data = SpectrogramDataset(audio_conf, manifest_filepath_list=args.train_manifest_list, 
                                label2id=label2id, normalize=True, augment=args.augment)

loss_type = args.loss
model = model.cuda(0)
num_epochs = start_epoch + 1000

## Training

In [None]:
args.batch_size = 4
train_sampler = BucketingSampler(train_data, batch_size=args.batch_size)
train_loader = AudioDataLoader(train_data, num_workers=args.num_workers, batch_sampler=train_sampler)

In [None]:
trainer = Trainer()

In [None]:
trainer.train(model, train_loader, train_sampler, opt, loss_type, start_epoch, num_epochs, label2id, id2label, metrics)

## Eval

In [None]:
args.batch_size = 4
train_sampler = BucketingSampler(train_data, batch_size=args.batch_size)
train_loader = AudioDataLoader(train_data, num_workers=args.num_workers, batch_sampler=train_sampler)

In [None]:
model.eval();

In [None]:
R = []

valid_loader = train_loader
total_valid_loss, total_valid_cer, total_valid_wer, total_valid_char, total_valid_word = 0, 0, 0, 0, 0
for i, (data) in enumerate(valid_loader):
    src, tgt, src_percentages, src_lengths, tgt_lengths = data
    src = src.cuda()
    tgt = tgt.cuda()
    with autocast():
        pred, gold, hyp_seq, gold_seq = model(src, src_lengths, tgt, verbose=False)

    seq_length = pred.size(1)
    sizes = Variable(src_percentages.mul_(int(seq_length)).int(), requires_grad=False)

    loss, num_correct = calculate_metrics(
        pred, gold, input_lengths=sizes, target_lengths=tgt_lengths, smoothing=smoothing, loss_type=loss_type)

    if loss.item() == float('Inf'):
        logging.info("Found infinity loss, masking")
        loss = torch.where(loss != loss, torch.zeros_like(loss), loss) # NaN masking
        continue

    try: # handle case for CTC
        strs_gold, strs_hyps = [], []
        for ut_gold in gold_seq:
            str_gold = ""
            for x in ut_gold:
                if int(x) == constant.PAD_TOKEN:
                    break
                str_gold = str_gold + id2label[int(x)]
            strs_gold.append(str_gold)
        for ut_hyp in hyp_seq:
            str_hyp = ""
            for x in ut_hyp:
                if int(x) == constant.PAD_TOKEN:
                    break
                str_hyp = str_hyp + id2label[int(x)]
            strs_hyps.append(str_hyp)
    except Exception as e:
        print(e)
        logging.info("NaN predictions")
        continue

    for j in range(len(strs_hyps)):
        strs_hyps[j] = strs_hyps[j].replace(constant.SOS_CHAR, '').replace(constant.EOS_CHAR, '')
        strs_gold[j] = strs_gold[j].replace(constant.SOS_CHAR, '').replace(constant.EOS_CHAR, '')
        cer = calculate_cer(strs_hyps[j].replace(' ', ''), strs_gold[j].replace(' ', ''))
        wer = calculate_wer(strs_hyps[j], strs_gold[j])
        success = 'SUCCESS' if strs_gold[j] == strs_hyps[j] else ''
        print(f'[{j}] cer {cer} wer {wer} gold {strs_gold[j]}:{len(strs_gold[j])} hyp {strs_hyps[j]}:{len(strs_hyps[j])} {success}')
        R.append((cer, wer, strs_gold[j], strs_hyps[j], success))
        total_valid_cer += cer
        total_valid_wer += wer
        total_valid_char += len(strs_gold[j].replace(' ', ''))
        total_valid_word += len(strs_gold[j].split(" "))

    total_valid_loss += loss.item()

In [None]:
len(R)

In [None]:
R[0]

In [None]:
RL = [(x[1],y) for x,y in zip(R,L)]

In [None]:
CER=[x[0] for x in R]
WER=[x[1] for x in R]

In [None]:
%matplotlib inline
import matplotlib.pylab as plt

In [None]:
plt.figure(figsize=(10,10))
plt.scatter(CER, WER, s=1, color='red');

In [None]:
RL[0:3]

In [None]:
RL.sort()

In [None]:
RL[0:3]

In [None]:
RL=list(reversed(RL))

In [None]:
RL[0:3]

In [None]:
plt.plot([x for x,y in RL])
plt.xlim(0,1000)

In [None]:
list_by_wl=f'analysis/{language}/RL.json'
with open(list_by_wl, 'w') as f:
    json.dump(RL,f)