In [None]:
from dataclasses import dataclass, field
from typing import Optional
from datasets import load_dataset
from transformers import HfArgumentParser

import pandas as pd
import os,pickle,time, random
from tqdm import tqdm
import evaluate

from google.generativeai.types import HarmCategory, HarmBlockThreshold
import google.generativeai as genai

In [3]:
def flip_coin():
    coin = random.randint(0,1)
    #https://github.com/globien/easy-python/blob/master/102_%E6%8A%9B%E7%A1%AC%E5%B8%81%E8%BF%9E%E7%BB%AD%E6%AD%A3%E9%9D%A2%E9%97%AE%E9%A2%98.py
    if coin== 1:                       # 扔到了一次正面
        return True
    else:
    #     print('tail')                    # 不是正面，重置计数器
        return False

In [None]:
def load_pkl_data(path):
    with open(path, 'rb') as f:
        data = pickle.load(f)
    return data

def preprocessing(data):
    
    new = [] # datapoints has 2 or more golden translation
    gold = []
    for item in tqdm(data):
        tmp={}
        text= item.strip('\n')
        tmp['English'] = text

        tmp_hinglish=[]
        for item in data[item]:
            tmp_hinglish.append(item.strip('\n'))

        tmp['Gold_Hinglish'] = tmp_hinglish

        new.append(tmp)
        gold.append(tmp_hinglish)
    return new,gold

def transform_unified_golden_label(data_valid):
    gold_label=[]

    for item in data_valid:
        tmp=[]
        for i in data_valid[item]:
            tmp.append(i.strip('\n'))
        while len(tmp)<7:
            tmp.append('')
        gold_label.append(tmp)
        
    return gold_label

def post_processing(data):
    res = []
    for sent in data:

        if "[Hinglish]:" in sent:
            sent = sent.split("[Hinglish]:")[-1]
        if "Hindi-English]:" in sent:
            sent = sent.split("Hindi-English]:")[-1]
        if "Hindi-English:" in sent:
            sent = sent.split("Hindi-English:")[-1]
        if "Hinglish translation:" in sent:
            sent = sent.split("Hinglish translation:")[-1]

        res.append(sent.replace("[","").replace("]","").replace("(","").replace(")","").replace("\n","").strip())

    return res

def post_processing_v2(data):
    '''
    created by Aditya
    '''
    res = []
    for sent in data:
        if "[Hinglish]:" in sent:
            sent = sent.split("[Hinglish]:")[-1]
        if "Hindi-English]:" in sent:
            sent = sent.split("Hindi-English]:")[-1]
        if "Hindi-English:" in sent:
            sent = sent.split("Hindi-English:")[-1]
        if "Hinglish translation:" in sent:
            sent = sent.split("Hinglish translation:")[-1]
        if "[English] se [Hinglish] mein:" in sent:
            sent = sent.split("[English] se [Hinglish] mein:")[-1]
        if "Hinglish:" in sent:
            sent = sent.split("Hinglish:")[-1]
        if "[English]:" in sent:
            sent = sent.split("[English]:")[-1]
        if "Hinglish Translation:" in sent: #similar to line 13 we can change this to lowercase 
            sent = sent.split("Hinglish Translation:")[-1]
        if "hum is prakaar keh sakte hain:" in sent: #from line 690
            sent = sent.split("hum is prakaar keh sakte hain:")[-1]
        if "[Translation in Hinglish]:" in sent: #from line 746
            sent = sent.split("[Translation in Hinglish]:")[-1]
        if "Note:" in sent: #keep everything before "(Note:" #line 494
            sent = sent.split("Note:")[0]
        if "(Translation:" in sent: #line 998
            sent = sent.split("(Translation:")[0]
        if "Hinglish mein hai:" in sent: #from line 1054
            sent = sent.split("Hinglish mein hai:")[-1]
        if "Explanation:" in sent: #line 1131
            sent = sent.split("Explanation:")[0]
        if "Translation Breakdown:" in sent: #line 1516
            sent = sent.split("Translation Breakdown:")[0]
        if "[English] to [Hinglish] translation:" in sent: #line 1705
            sent = sent.split("[English] to [Hinglish] translation:")[-1]
        if "[English to Hinglish translation]" in sent: #line 1789
            sent = sent.split("[English to Hinglish translation]")[-1]
        if "[English to Hinglish]:" in sent: #line 2111
            sent = sent.split("[English to Hinglish translation]")[-1]
        if "(Hinglish translation)" in sent: #line 2461
            sent = sent.split("(Hinglish translation)")[-1]
        if "(In Hinglish" in sent: #line 1131
            sent = sent.split("(In Hinglish")[0]
        if "[Hinglish] mein yeh anuvad hai:" in sent: #line 648
            sent = sent.split("[Hinglish] mein yeh anuvad hai:")[-1]
        res.append(sent.replace("[","").replace("]","").replace("(","").replace(")","").replace("\n","").strip())
    return res


def request_Gemini(dataset,prompt,value_temp=0.1,model_id="gemini-1.5-flash"):
    res=[]
    res_flip=[]
    # ref:https://ai.google.dev/gemini-api/docs/quickstart?lang=python
    genai.configure(api_key="")

    #safety check:https://ai.google.dev/gemini-api/docs/safety-settings
    safety_settings = {HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
                        HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
                        HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
                        HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
                        }
#     system_role="You are a translation expert, and I need your help in impartially judging the quality of two translations. "
    model = genai.GenerativeModel(model_name=model_id,
#                                   system_instruction=system_role,
                                  generation_config = genai.GenerationConfig(temperature=value_temp,),
                                  safety_settings=safety_settings,
                                 )
    
    for item in tqdm(dataset):
       
        time.sleep(4) # wait for 4 seconds,and then move to next step
        # form the prompt
        result_flip_coin = flip_coin()
        if result_flip_coin:
            final_prompt = prompt.format(original_sent=item['English'],
                                         first_sent=item['translation_reference'],
                                         second_sent=item['translation_model'])
        else:
            final_prompt = prompt.format(original_sent=item['English'],
                                         first_sent=item['translation_model'],
                                         second_sent=item['translation_reference'])
    
        response = model.generate_content(final_prompt)
        final_response = response.text
        res.append(final_response)
        res_flip.append(result_flip_coin)
    
    return res,res_flip

def transform_digital_result_old(data):
    res = []
    for item in tqdm(data):
        if "Translation_2 \n" in item:
            res.append(1)
        elif "Translation_1 \n" in item:
            res.append(0)

    return res

def transform_digital_result(data,res_flip):
    '''
    if original_translation is better, annotate as 0
    if PPO_translation is better, annotate as 1
    '''
    res = []
    for item in tqdm(zip(data,res_flip)):
        if item[1]:
            if "Translation_2 \n" in item[0]:
                res.append(1)
            elif "Translation_1 \n" in item[0]:
                res.append(0)
        else:
            if "Translation_2 \n" in item[0]:
                res.append(0)
            elif "Translation_1 \n" in item[0]:
                res.append(1)
            

    return res

def calculate_new_model_win_rate(result):
    result=result.to_list()
    win_rate=sum(result)/len(result)
    
    print("The win rate of our model is: {:.4f}".format(win_rate))

In [23]:
def evaluate_chrf(prediction,reference):
    '''
    evaluate translation quality using  chrF
    '''
    
    chrf = evaluate.load("chrf") # chrF
    results = chrf.compute(predictions=prediction, references=reference)
    return results['score']

def evaluate_chrf_plus_plus(prediction,reference):
    '''
    evaluate translation quality using  chrF++
    
    chrF+ (where word_order=1) and chrF++ (where word_order=2) produce scores that correlate better 
    with human judgements than chrF (where word_order=0) does.
    https://huggingface.co/spaces/evaluate-metric/chrf
    '''
    chrf = evaluate.load("chrf") # chrF++
    results = chrf.compute(predictions=prediction,references=reference,word_order=2)
    return results['score']

def evaluate_comet(lst_prediction,lst_reference,lst_source):
    '''
    evaluate translation quality using comet
    '''
    
    comet_metric = evaluate.load('comet') 
    results = comet_metric.compute(predictions=lst_prediction, references=lst_reference, sources=lst_source)

    # print('The COMET score of test set is {:.2f}'.format(results['mean_score']))
    return results['mean_score']

def calculate_cosine_similarity(reference_embedding,generated_embedding):
    
    similarity_score = np.dot(generated_embedding, reference_embedding) / (np.linalg.norm(generated_embedding) * np.linalg.norm(reference_embedding))
    # Ensure non-negative score
    return max(similarity_score, 0)

def evaluate_bleu(generated_text: str, reference_text: str, is_japanese: bool = False) -> float:
    """
    Calculates the BLEU score for a generated text compared to a reference truth text. This function supports
    both general text and Japanese-specific evaluation by using the sacrebleu library.

    Parameters:
    - generated_text (str): The generated text to be evaluated.
    - reference_text (str): The reference truth text.
    - is_japanese (bool, optional): Flag to indicate whether the text is in Japanese, requiring special tokenization.

    Returns:
    - float: The BLEU score as a percentage (0 to 1 scale) for the generated text against the reference truth.
    """
    sacrebleu = evaluate.load("sacrebleu")

    # Compute BLEU score with or without Japanese-specific tokenization
    bleu_args = {"predictions": generated_text, "references": reference_text, "lowercase": False}
    
    if is_japanese:
        bleu_args["tokenize"] = "ja-mecab"
        bleu_args["lowercase"] = True
        
    score = sacrebleu.compute(**bleu_args)["score"]

    return score

In [7]:
# load the test set
path_golden = '../../data/MixMT-2022/valid_human_generated.pkl'
data_golden = load_pkl_data(path_golden)
gold_label=transform_unified_golden_label(data_golden)

In [92]:
# Extract the reference completions
path_reference = './Meta-Llama-3.1-8B-Instruct-temp0.1.csv'
dataset_reference = pd.read_csv(path_reference)  
column_name_reference = os.path.split(path_reference)[1].split(".csv")[0]
reference_completions = dataset_reference[column_name_reference].to_list()
# res_reference = post_processing(reference_completions)
res_reference = post_processing_v2(reference_completions)

In [93]:
# Extract the model completions
path_model = './llama3.1_merge-temp0.1.csv'
dataset_model = pd.read_csv(path_model)  
column_name_model = os.path.split(path_model)[1].split(".csv")[0]
model_completions = dataset_model[column_name_model].to_list()
# res_model = post_processing(model_completions)
res_model = post_processing_v2(model_completions)

In [None]:
print("######")
print("The BLEU score of original-model translation is: {:.2f}".format(evaluate_bleu(res_reference, gold_label, is_japanese=False)))
print("The BLEU score of ppo-model translation is: {:.2f}".format(evaluate_bleu(res_model, gold_label, is_japanese=False)))

print("######")
print("The CRF score of original-model translation is: {:.2f}".format(evaluate_chrf(res_reference, gold_label)))
print("The CRF score of ppo-model translation is: {:.2f}".format(evaluate_chrf(res_model, gold_label)))


print("######")
print("The CRF++ score of original-model translation is: {:.2f}".format(evaluate_chrf_plus_plus(res_reference, gold_label)))
print("The CRF++ score of ppo-model translation is: {:.2f}".format(evaluate_chrf_plus_plus(res_model, gold_label)))

In [None]:
import warnings
warnings.filterwarnings("ignore")

lst_reference=[]
for item in gold_label:
    lst_reference.append(item[0])

lst_source = dataset_model['English'].to_list()

print("The COMET score of original-model translation is: {:.2f}".format(evaluate_comet(res_reference,lst_reference,lst_source)))
print("The COMET score of ppo-model translation is: {:.2f}".format(evaluate_comet(res_model,lst_reference,lst_source)))