## Install Dependencies

In [None]:
!pip install transformers
!pip install sentencepiece
!pip install datasets
!pip install jiwer

## Import Dependencies

In [None]:
from transformers import AutoModelWithLMHead, AutoTokenizer, AutoModelForSeq2SeqLM, BartTokenizer, BartForConditionalGeneration
import datasets

import pandas as pd
import numpy as np

import warnings
warnings.filterwarnings('ignore')

In [None]:
def error_correct(text, model_name):


  # model = AutoModelForSeq2SeqLM.from_pretrained(model_name) #T5
  # tokenizer = AutoTokenizer.from_pretrained(model_name) #T5

  model = BartForConditionalGeneration.from_pretrained(model_name) #BART
  tokenizer = BartTokenizer.from_pretrained(model_name) #BART


  input_ids = tokenizer.encode(text, return_tensors='pt', add_special_tokens=True)

  generated_ids = model.generate(input_ids=input_ids, 
                                 num_return_sequences=5, 
                                 num_beams=5, 
                                 max_length=512, 
                                 no_repeat_ngram_size=2, 
                                 repetition_penalty=3.5, 
                                 length_penalty=1.0, 
                                 early_stopping=True
                                 )

  preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]

  return preds

## Trained Models
#### Please note that all the trained model are currently available in Huggingface_hub and due to the anonimity requirement we will share the trained models in the following Google Drive link: https://drive.google.com/drive/folders/1uowiKAgk3DW48QumeCXEoDzuijN8iqTV?usp=sharing


In [None]:
#base models - from literature
bart_base = 'facebook/bart-base'
bart_large = 'facebook/bart-large'

'''BART fine-tuned models'''

# Standard Objective
bart_clinical = '../models/bart-finetuned-pubmed'
bart_pubmed = '../models/bart-paraphrase-pubmed-1.1'
bart_mlm = '../models/bart-mlm-pubmed'

# Hybrid Objective
bart_mlm_paraphrasing = '../models/bart-mlm-paraphrasing'
bart_paraphrasing_mlm = '../models/bart-paraphrasing-mlm'

# Domain-specific Objective
bart_med_term = '../models/bart-mlm-pubmed-medterm'
bart_cm = '../models/bart-med-term-conditional-masking'
bart_cm_0 = '../models/bart-med-term-conditional-masking-0'


'''T5 fine-tuned models'''

# Standard Objective
t5_clinical = '../models/t5-small-finetuned-pubmed'
t5_pubmed = '../models/t5-small-paraphrase-pubmed'
t5_mlm = '../models/t5-small-mlm-pubmed'

# Hybrid Objective
t5_mlm_paraphrasing = '../models/t5-small-mlm-paraphrasing'
t5_small_paraphrasing_mlm = '../models/t5-small-paraphrasing-mlm'

# Domain-specific Objective
t5_small_med_term_mlm = '../models/t5-small-med-term-mlm'
t5_small_cm = '../models/t5-small-med-term-conditional-masking'
t5_small_cm_0 = '../models/t5-small-med-term-conditional-masking-0'


## Evaluation - WER

In [None]:
metric = datasets.load_metric('wer')

## Evaluation Data

#### This clinical ocnversational dataset is collected in colaboration with NHS Grampian. Therefore the dataset will be available upon request.

In [None]:
test_data_aws =  '../datasets/test/refs_and_trans_aws_gb.csv'
test_data_ms =  '../datasets/refs_and_trans_ms_gb.csv'
test_data_ibm =  '../datasets/refs_and_trans_ibm_gb.csv'
test_data_google =  '../datasets/refs_and_trans_google_gb.csv'

In [None]:
def calculate_wer(test_data, pre_trained_model):

  test_df = pd.read_csv(test_data)

  test_df = test_df.dropna()

  test_df['modified'] = test_df.apply (lambda row: error_correct(row.trans, pre_trained_model), axis=1)
  out = []
  for n, row in test_df.iterrows():
    for item in row['modified']:
      row['flat_modified'] = item
      out += [row.copy()]


  flattened_df = pd.DataFrame(out)

  trans_batch = flattened_df.flat_modified.tolist()
  reference_batch = flattened_df.refs.tolist()

  score = metric.compute(predictions=trans_batch, references=reference_batch)

  return score

In [None]:
score_aws = calculate_wer(test_data_aws, bart_clinical)
score_ms = calculate_wer(test_data_ms, bart_clinical)
score_ibm = calculate_wer(test_data_ibm, bart_clinical)
score_google = calculate_wer(test_data_google, bart_clinical)

print('score_aws: ', score_aws)
print('score_ms: ', score_ms)
print('score_ibm: ', score_ibm)
print('score_google: ', score_google)