# Interactive train and test with WER in sentence length order

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
import sys
sys.path.append('/home/catskills/Desktop/openasr20/end2end_asr_pytorch')

In [None]:
import os
os.environ['IN_JUPYTER']='True'

In [None]:
%matplotlib inline
import matplotlib.pylab as plt
import random

In [None]:
from utils import constant

In [None]:
import json, logging, math, os, random, time, torch
import numpy as np
import torch.nn as nn
from glob import glob
from torch.autograd import Variable
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

In [None]:
from models.asr.transformer import Transformer, Encoder, Decoder

In [None]:
from utils.data_loader import SpectrogramDataset, AudioDataLoader, BucketingSampler
from utils.functions import save_model, load_model, init_transformer_model, init_optimizer
from utils.lstm_utils import LM
from utils.metrics import calculate_metrics, calculate_cer, calculate_wer, calculate_cer_en_zh
from utils.optimizer import NoamOpt

In [None]:
import time
import numpy as np
from tqdm import tqdm
from utils import constant
from utils.functions import save_model
from utils.optimizer import NoamOpt
from utils.metrics import calculate_metrics, calculate_cer, calculate_wer
from torch.autograd import Variable
from torch.cuda.amp import GradScaler, autocast
import torch
import logging
import sys
from torch.utils.tensorboard import SummaryWriter

In [None]:
language='amharic'
stage='NIST'

In [None]:
chunks = list(sorted(glob(f'{stage}/openasr20_{language}/build/transcription_split/*.txt')))
print(len(chunks), 'chunks')

In [None]:
def chunk_size(fn):
    with open(fn, 'r') as f:
        return (len(f.read().strip().split(' ')), fn)

In [None]:
size_chunks=list(sorted([chunk_size(fn) for fn in chunks]))

In [None]:
size_chunk_files={x:[] for x,y in size_chunks}
for x,y in size_chunks:
    size_chunk_files[x].append(y)

In [None]:
size_chunk_distribution={x:0 for x,y in size_chunks}
for x,y in size_chunks:
    size_chunk_distribution[x] += 1

In [None]:
n_sentences=len(size_chunks)

In [None]:
plt.figure(figsize=(10,4))
plt.hist([np.log2(x) for x,y in size_chunks], bins=max(size_chunk_distribution))
plt.title("Zipf's Law for sentences")
xlabel("$\log_2(|S|)$");

In [None]:
L=[]
for text in size_chunk_files[1]:
    audio=text.replace('transcription','audio').replace('txt', 'wav')
    L.append(f'{audio},{text}')

random.shuffle(L)

In [None]:
manifest_file_path=f'analysis/{language}/size_1.csv'
with open(manifest_file_path,'w') as f:
    f.write('\n'.join(L))

In [None]:
model_dir=f'save/{language}_end2end_asr_pytorch_drop0.1_cnn_batch12_4_vgg_layer4'

In [None]:
args=constant.args

In [None]:
args.continue_from=None
args.cuda = True
args.labels_path = f'analysis/{language}/{language}_characters.json'
args.lr = 1e-4
args.name = f'{language}_end2end_asr_pytorch_drop0.1_cnn_batch12_4_vgg_layer4'
args.save_folder = f'save'
args.epochs = 5
args.save_every = 1
args.feat_extractor = f'vgg_cnn'
args.dropout = 0.1
args.num_layers = 4
args.num_heads = 8
args.dim_model = 512
args.dim_key = 64
args.dim_value = 64
args.dim_input = 161
args.dim_inner = 2048
args.dim_emb = 512
args.shuffle=True
args.min_lr = 1e-6
args.k_lr = 1
args.sample_rate=8000
args.train_manifest_list = [manifest_file_path]

In [None]:
audio_conf = dict(sample_rate=args.sample_rate,
                  window_size=args.window_size,
                  window_stride=args.window_stride,
                  window=args.window,
                  noise_dir=args.noise_dir,
                  noise_prob=args.noise_prob,
                  noise_levels=(args.noise_min, args.noise_max))

In [None]:
audio_conf

In [None]:
with open(args.labels_path, 'r') as label_file:
    labels = str(''.join(json.load(label_file)))

In [None]:
labels

In [None]:
# add PAD_CHAR, SOS_CHAR, EOS_CHAR
labels = constant.PAD_CHAR + constant.SOS_CHAR + constant.EOS_CHAR + labels
label2id, id2label = {}, {}
count = 0
for i in range(len(labels)):
    if labels[i] not in label2id:
        label2id[labels[i]] = count
        id2label[count] = labels[i]
        count += 1
    else:
        print("multiple label: ", labels[i])

In [None]:
constant.args.continue_from=f'{model_dir}/best_model.th'

In [None]:
constant.args.continue_from

In [None]:
if constant.args.continue_from:
        model, opt, epoch, metrics, loaded_args, label2id, id2label = load_model(
            constant.args.continue_from)
        start_epoch = epoch  # index starts from zero
        verbose = constant.args.verbose
else:
    model = init_transformer_model(constant.args, label2id, id2label)
    opt = init_optimizer(constant.args, model, "noam")

In [None]:
start_epoch = epoch
metrics = None
loaded_args = None
verbose = True

In [None]:
constant.USE_CUDA=True

In [None]:
train_data = SpectrogramDataset(audio_conf, manifest_filepath_list=args.train_manifest_list, label2id=label2id, normalize=True, augment=args.augment)

In [None]:
class Trainer():
    """
    Trainer class
    """
    def __init__(self):
        logging.info("Trainer is initialized")
        self.writer = SummaryWriter()

    def train(self, model, train_loader, train_sampler, opt, loss_type, start_epoch, num_epochs, label2id, id2label, last_metrics=None):
        """
        Training
        args:
            model: Model object
            train_loader: DataLoader object of the training set
            opt: Optimizer object
            start_epoch: start epoch (> 0 if you resume the process)
            num_epochs: last epoch
            last_metrics: (if resume)
        """
        history = []
        start_time = time.time()
        smoothing = constant.args.label_smoothing

        logging.info("name " +  constant.args.name)

        training_pass = 0
        
        for epoch in range(start_epoch, num_epochs):
            sys.stdout.flush()
            total_loss, total_cer, total_wer, total_char, total_word = 0, 0, 0, 0, 0

            start_iter = 0

            scaler = GradScaler()
            
            logging.info("TRAIN")
            model.train()
            pbar = tqdm(iter(train_loader), leave=True, total=len(train_loader))
            for i, (data) in enumerate(pbar, start=start_iter):
                src, tgt, src_percentages, src_lengths, tgt_lengths = data

                if constant.USE_CUDA:
                    src = src.cuda()
                    tgt = tgt.cuda()

                opt.zero_grad()

                with autocast():
                    pred, gold, hyp_seq, gold_seq = model(src, src_lengths, tgt, verbose=False)

                    try: # handle case for CTC
                        strs_gold, strs_hyps = [], []
                        for ut_gold in gold_seq:
                            str_gold = ""
                            for x in ut_gold:
                                if int(x) == constant.PAD_TOKEN:
                                    break
                                str_gold = str_gold + id2label[int(x)]
                            strs_gold.append(str_gold)
                        for ut_hyp in hyp_seq:
                            str_hyp = ""
                            for x in ut_hyp:
                                if int(x) == constant.PAD_TOKEN:
                                    break
                                str_hyp = str_hyp + id2label[int(x)]
                            strs_hyps.append(str_hyp)
                    except Exception as e:
                        print(e)
                        logging.info("NaN predictions")
                        continue

                    seq_length = pred.size(1)
                    sizes = Variable(src_percentages.mul_(int(seq_length)).int(), requires_grad=False)

                    loss, num_correct = calculate_metrics(
                        pred, gold, input_lengths=sizes, target_lengths=tgt_lengths, smoothing=smoothing, loss_type=loss_type)

                    if loss.item() == float('Inf'):
                        logging.info("Found infinity loss, masking")
                        loss = torch.where(loss != loss, torch.zeros_like(loss), loss) # NaN masking
                        continue

                    if constant.args.verbose:
                         logging.info("GOLD", strs_gold)
                         logging.info("HYP", strs_hyps)

                    for j in range(len(strs_hyps)):
                        strs_hyps[j] = strs_hyps[j].replace(constant.SOS_CHAR, '').replace(constant.EOS_CHAR, '')
                        strs_gold[j] = strs_gold[j].replace(constant.SOS_CHAR, '').replace(constant.EOS_CHAR, '')
                        cer = calculate_cer(strs_hyps[j].replace(' ', ''), strs_gold[j].replace(' ', ''))
                        wer = calculate_wer(strs_hyps[j], strs_gold[j])
                        total_cer += cer
                        total_wer += wer
                        total_char += len(strs_gold[j].replace(' ', ''))
                        total_word += len(strs_gold[j].split(" "))

                scaler.scale(loss).backward()

                constant.args.clip = False
                if constant.args.clip:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), constant.args.max_norm)
                
                scaler.step(opt)

                scaler.update()

                total_loss += loss.item()
                non_pad_mask = gold.ne(constant.PAD_TOKEN)
                num_word = non_pad_mask.sum().item()

                TRAIN_LOSS=total_loss/(i+1)
                CER = total_cer*100/total_char
                pbar.set_description("(Epoch {}) TRAIN LOSS:{:.4f} CER:{:.2f}% LR:{:.7f}".format(
                    (epoch+1), TRAIN_LOSS, CER, opt._rate))
                self.writer.add_scalar("Loss/train", TRAIN_LOSS, training_pass+1)
                self.writer.add_scalar("CER/train", CER, training_pass+1)
                self.writer.add_scalar("LR/train", opt._rate, training_pass+1)
                self.writer.flush()
                training_pass += 1

            logging.info("(Epoch {}) TRAIN LOSS:{:.4f} CER:{:.2f}% LR:{:.7f}".format(
                (epoch+1), total_loss/(len(train_loader)), total_cer*100/total_char, opt._rate))

            metrics = {}
            metrics["train_loss"] = total_loss / len(train_loader)
            metrics["train_cer"] = total_cer
            metrics["train_wer"] = total_wer
            metrics["history"] = history
            history.append(metrics)

            if epoch % constant.args.save_every == 0:
                save_model(model, (epoch+1), opt, metrics,
                        label2id, id2label, best_model=False)

            save_model(model, (epoch+1), opt, metrics,
                        label2id, id2label, best_model=True)

            train_sampler.shuffle(epoch)

In [None]:
loss_type = args.loss
model = model.cuda(0)
num_epochs = start_epoch + constant.args.epochs

In [None]:
trainer = Trainer()

In [None]:
start_epoch, num_epochs

In [None]:
args.batch_size = 10
train_sampler = BucketingSampler(train_data, batch_size=args.batch_size)
train_loader = AudioDataLoader(train_data, num_workers=args.num_workers, batch_sampler=train_sampler)

In [None]:
trainer.train(model, train_loader, train_sampler, opt, loss_type, start_epoch, num_epochs, label2id, id2label, metrics)

In [None]:
args.verbose = True

In [None]:
smoothing = constant.args.label_smoothing

In [None]:
model.eval();

In [None]:
valid_loader = train_loader
total_valid_loss, total_valid_cer, total_valid_wer, total_valid_char, total_valid_word = 0, 0, 0, 0, 0
for i, (data) in enumerate(valid_loader):
    src, tgt, src_percentages, src_lengths, tgt_lengths = data
    src = src.cuda()
    tgt = tgt.cuda()
    with autocast():
        pred, gold, hyp_seq, gold_seq = model(src, src_lengths, tgt, verbose=False)

    seq_length = pred.size(1)
    sizes = Variable(src_percentages.mul_(int(seq_length)).int(), requires_grad=False)

    loss, num_correct = calculate_metrics(
        pred, gold, input_lengths=sizes, target_lengths=tgt_lengths, smoothing=smoothing, loss_type=loss_type)

    if loss.item() == float('Inf'):
        logging.info("Found infinity loss, masking")
        loss = torch.where(loss != loss, torch.zeros_like(loss), loss) # NaN masking
        continue

    try: # handle case for CTC
        strs_gold, strs_hyps = [], []
        for ut_gold in gold_seq:
            str_gold = ""
            for x in ut_gold:
                if int(x) == constant.PAD_TOKEN:
                    break
                str_gold = str_gold + id2label[int(x)]
            strs_gold.append(str_gold)
        for ut_hyp in hyp_seq:
            str_hyp = ""
            for x in ut_hyp:
                if int(x) == constant.PAD_TOKEN:
                    break
                str_hyp = str_hyp + id2label[int(x)]
            strs_hyps.append(str_hyp)
    except Exception as e:
        print(e)
        logging.info("NaN predictions")
        continue

    for j in range(len(strs_hyps)):
        strs_hyps[j] = strs_hyps[j].replace(constant.SOS_CHAR, '').replace(constant.EOS_CHAR, '')
        strs_gold[j] = strs_gold[j].replace(constant.SOS_CHAR, '').replace(constant.EOS_CHAR, '')
        cer = calculate_cer(strs_hyps[j].replace(' ', ''), strs_gold[j].replace(' ', ''))
        wer = calculate_wer(strs_hyps[j], strs_gold[j])
        success = 'SUCCESS' if strs_gold[j] == strs_hyps[j] else ''
        print(f'[{j}] cer {cer} wer {wer} gold {strs_gold[j]}:{len(strs_gold[j])} hyp {strs_hyps[j]}:{len(strs_hyps[j])} {success}')
        total_valid_cer += cer
        total_valid_wer += wer
        total_valid_char += len(strs_gold[j].replace(' ', ''))
        total_valid_word += len(strs_gold[j].split(" "))

    total_valid_loss += loss.item()
    break

In [None]:
strs_gold, strs_hyps