In [None]:
!pip install transformers datasets accelerate evaluate jiwer

## **Evaluation**

In [None]:
import os
import csv
import json
import torch
import evaluate
import numpy as np
from string import punctuation
from datasets import load_dataset, Dataset
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers.models.whisper.english_normalizer import BasicTextNormalizer

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
base_path = "../"

In [None]:
model_card = "kasunw/sinhala-transliterator"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    model_card
)
model = AutoModelForSeq2SeqLM.from_pretrained(model_card).to(device)
src_lang_id, tgt_lang_id = "en", "si"

In [None]:
def translate(input_text, model, tokenizer, src_lang="en", tgt_lang="si"):
  tokenizer.src_lang = src_lang
  inputs = tokenizer(input_text.lower(), return_tensors="pt").to(model.device)
  in_len = inputs.input_ids.shape[-1]

  translated_tokens = model.generate(
    **inputs, forced_bos_token_id=tokenizer.get_lang_id(tgt_lang), min_length=None,  max_length=3 * in_len, streamer=None, pad_token_id=tokenizer.eos_token_id,
  )
  translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]

  return translation

## **Evaluation Metrics**

In [None]:
metric = evaluate.load("wer")
cer_metric = evaluate.load("cer")
bleu_metric = evaluate.load("bleu")

do_normalize_text = False
normalizer = BasicTextNormalizer()

In [None]:
def compute_metrics(ref_str, pred_str, find_wer=True, find_cer=True, find_bleu=True):

    if do_normalize_text:
        pred_str = normalizer(pred_str).strip().strip(punctuation).strip()
        ref_str = normalizer(ref_str).strip().strip(punctuation).strip()
    else:
        pred_str = pred_str.strip().strip(punctuation).strip()
        ref_str = ref_str.strip().strip(punctuation).strip()

    if ref_str and pred_str:
      if find_wer:
        wer = metric.compute(predictions=[pred_str], references=[ref_str])
      if find_cer:
        cer = cer_metric.compute(predictions=[pred_str], references=[ref_str])
      if find_bleu:
        bleu = bleu_metric.compute(predictions=[pred_str], references=[ref_str])
        bleu = bleu["bleu"]
    else:
      wer, cer, bleu = 1.0, 1.0, 0.0

    return wer if find_wer else None, cer if find_cer else None, bleu if find_bleu else None


## **Testset Evaluation**

In [None]:
dataset_1_path = os.path.join(base_path, "test-sets", "Sinhala-Test-set-1.csv")
dataset_2_path = os.path.join(base_path, "test-sets", "Sinhala-Test-set-2.csv")

In [None]:
dataset_1 = load_dataset("csv", data_files=dataset_1_path, encoding="utf-8")["train"]
dataset_2 = load_dataset("csv", data_files=dataset_2_path, encoding="utf-8")["train"]

In [None]:
def results_to_csv(file_name, rows, fields=["source", "target", "prediction", "wer", "cer", "bleu"]):
  with open(file_name, 'w') as f:

    # using csv.writer method from CSV package
    write = csv.writer(f)

    write.writerow(fields)
    write.writerows(rows)
  print(f"[INFO] Wrote results to {file_name}")

In [None]:
def eval_dataset(test_dataset, output_filename, max_n=None, find_wer=True, find_cer=True, find_bleu=True):
  wer_list = []
  cer_list = []
  bleu_list = []
  output_data = []

  for i, ele in enumerate(test_dataset):
      if max_n is not None and i > max_n:
        break
      sing_txt, si_txt = ele["Column1"], ele["Column2"]
      pred_txt = translate(sing_txt, model=model, tokenizer=tokenizer, src_lang=src_lang_id, tgt_lang=tgt_lang_id)
      wer, cer, bleu = compute_metrics(si_txt, pred_txt, find_wer=find_wer, find_cer=find_cer)

      output_data.append([sing_txt, si_txt, pred_txt, wer])

      if find_wer:
        wer_list.append(wer)
      if find_cer:
        cer_list.append(cer)
      if find_bleu:
        bleu_list.append(bleu)

  print("============================================")
  if find_wer:
    print(f"Avg wer: {np.mean(wer_list)}")
  if find_cer:
    print(f"Avg cer: {np.mean(cer_list)}")
  if find_bleu:
    print(f"Avg bleu: {np.mean(bleu_list)}")

  results_to_csv(file_name=output_filename, rows=output_data)

  if device == "cuda":
    torch.cuda.empty_cache()

In [None]:
eval_dataset(dataset_1, output_filename="test_set_1_results.csv", max_n=5)

In [None]:
eval_dataset(dataset_2, output_filename="test_set_2_results.csv", max_n=5)

## **General Inference**

In [None]:
while True:
  sing_txt = input("Type the Singlish query (q for quit): ")
  if sing_txt.lower() == "q":
    break
  res = translate(sing_txt, model=model, tokenizer=tokenizer, src_lang=src_lang_id, tgt_lang=tgt_lang_id)
  print(f"{sing_txt} --> {res}")
  print("-----------------------------------")