In [1]:
import numpy as np
import torch
import os
import pandas as pd
import textstat

from pathlib import Path
from huggingface_hub import login

from transformers import AutoTokenizer
from datasets import load_dataset

from vllm import LLM, SamplingParams
from utils.gpu_management import reset_vllm_gpu_environment

from zeus.monitor import ZeusMonitor

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

print(f"CUDA available: {torch.cuda.is_available()}")
print(f"cuDNN version: {torch.backends.cudnn.version()}")

MODELS = [
    "meta-llama/Llama-3.2-3B-Instruct", 
    "meta-llama/Llama-3.1-8B-Instruct",
    "meta-llama/Llama-3.3-70B-Instruct"
]

MAX_SEQ_LEN = 8192
NUM_SAMPLES = 10_000
NUM_BATCHES = 50
SAMPLES_PER_BATCH = NUM_SAMPLES / NUM_BATCHES

# See for reference: https://docs.vllm.ai/en/v0.5.5/dev/sampling_params.html
SAMPLING_PARAMS = SamplingParams(
    temperature=0.8, 
    top_p=0.95,
    min_tokens=1,  # this is key as some models may refuse to generate anything if set to 0.
    max_tokens=128,
)

NUM_GPUS = torch.cuda.device_count()

CSV_FILE_PATH = Path(f"data/simulation_data.csv")

CUDA available: False
cuDNN version: 90100


In [3]:
def add_instruction(sentence_pair, tokenizer: AutoTokenizer = None):

    message = [
        {"role": "system", "content": "You are a helpful chatbot that translates text from German to English. Only provide the translation, nothing else."},
        {"role": "user", "content": {sentence_pair['translation']['de']}}
        # {"role": "user", "content": f"Please translate the following sentence from German to English: \n\n{sentence_pair['translation']['de']}"}
    ]

    sentence_pair["input_formatted"] = tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
    sentence_pair["target"] = sentence_pair["translation"]["en"]
    
    return sentence_pair


def compute_text_metrics(row):

    text = row["input_text"]

    row["flesch_reading_ease"] = textstat.flesch_reading_ease(text)
    row["smog_index"] = textstat.smog_index(text)
    row["automated_readability_index"] = textstat.automated_readability_index(text)
    row["lexical_diversity"] = len(set(text.split())) / len(text.split()) if len(text.split()) > 0 else 0
    row["syllable_count"] = textstat.syllable_count(text)
    row["complex_word_count"] = textstat.difficult_words(text)
    row["avg_word_length"] = sum(len(word) for word in text.split()) / len(text.split()) if len(text.split()) > 0 else 0
    row["sentence_length"] = len(text.split())
    row["flesch_kincaid_grade"] = textstat.flesch_kincaid_grade(text)
    row["coleman_liau_index"] = textstat.coleman_liau_index(text)
    row["dale_chall_readability_score"] = textstat.dale_chall_readability_score(text)
    row["linsear_write_formula"] = textstat.linsear_write_formula(text)
    row["text_standard"] = textstat.text_standard(text)
    row["fernandez_huerta"] = textstat.fernandez_huerta(text)
    row["szigriszt_pazos"] = textstat.szigriszt_pazos(text)
    row["gutierrez_polini"] = textstat.gutierrez_polini(text)
    row["crawford"] = textstat.crawford(text)

    try:
        row["gulpease_index"] = textstat.gulpease_index(text)
    except ZeroDivisionError:
        row["gulpease_index"] = np.nan
        
    try:
        row["osman"] = textstat.osman(text)
    except ZeroDivisionError:
        row["osman"] = np.nan

    return row    

In [4]:
dataset = load_dataset('wmt14', 'de-en', split='train')
dataset = dataset.shuffle().select(range(NUM_SAMPLES))

if CSV_FILE_PATH.exists():
    print("Loaded file")
    df = pd.read_csv(CSV_FILE_PATH)
else:
    df = pd.DataFrame()
df["input_text"] = [dataset[idx]['translation']['de'] for idx in range(len(dataset["translation"]))]

for model_name in MODELS: 

    tokenizer = AutoTokenizer.from_pretrained(model_name, padding=True, truncation=True, max_length=512)
    tokenizer.pad_token = tokenizer.eos_token
    # print(tokenizer.chat_template)
    dataset_formatted = dataset.map(lambda sentence_pair: add_instruction(sentence_pair, tokenizer))

    for batch in range(NUM_BATCHES):

        subset = dataset_formatted.select(range(
            int(SAMPLES_PER_BATCH * batch), 
            int(SAMPLES_PER_BATCH * (batch + 1))
        ))
        
        outputs = "None"

        for idx, output in enumerate(outputs): 
            df.loc[df["input_text"] == subset[idx]['translation']['de'], f"output_{model_name.replace('/', '_')}"] = output
    
        df.to_csv(CSV_FILE_PATH)

df = df.apply(compute_text_metrics, axis=1)
df.to_csv(CSV_FILE_PATH)

Loaded file


Map: 100%|███████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:01<00:00, 6671.98 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:01<00:00, 7131.41 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:01<00:00, 7200.27 examples/s]


In [5]:
df

Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,input_text,output_meta-llama_Llama-3.2-3B-Instruct,output_meta-llama_Llama-3.1-8B-Instruct,output_meta-llama_Llama-3.3-70B-Instruct,flesch_reading_ease,smog_index,automated_readability_index,lexical_diversity,...,coleman_liau_index,dale_chall_readability_score,linsear_write_formula,text_standard,fernandez_huerta,szigriszt_pazos,gutierrez_polini,crawford,gulpease_index,osman
0,0,0,Nach dem 11. September wurde der Sicherheit vo...,N,N,N,63.36,0.0,15.5,0.937500,...,20.26,16.86,4.5,7th and 8th grade,102.68,101.49,25.71,1.7,56.5,18.87
1,1,1,Die Kommission hat eine umfangreiche Untersuch...,o,o,o,4.47,0.0,21.6,0.875000,...,25.23,16.27,14.0,16th and 17th grade,58.52,54.55,18.06,6.0,33.4,-1.35
2,2,2,Und tatsächlich ist Werner Heuser die Ausnahme...,n,n,n,62.68,0.0,14.0,1.000000,...,14.85,19.34,10.5,8th and 9th grade,99.50,94.55,35.04,2.9,49.6,39.87
3,3,3,In einem neuen Fenster öffnet sich dann ein Ad...,e,e,e,45.76,0.0,15.9,1.000000,...,17.57,19.34,11.5,11th and 12th grade,87.50,83.56,30.48,3.7,45.5,28.49
4,4,4,Zusammensetzung der Ausschüsse: siehe Protokoll,,,,49.48,0.0,21.6,1.000000,...,27.00,19.67,2.5,7th and 8th grade,93.74,89.70,11.97,1.3,63.0,-17.09
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9995,9995,9995,"Bei einer Verhandlungspause brechen Nog , Jake...",,,,78.59,0.0,11.9,0.894737,...,11.89,17.69,10.5,11th and 12th grade,110.48,110.96,39.59,1.6,54.0,54.90
9996,9996,9996,Zudem müssen die Mitgliedstaaten unter Aufsich...,,,,49.82,0.0,18.4,1.000000,...,22.43,18.86,7.5,18th and 19th grade,91.58,88.40,22.75,3.3,41.3,9.03
9997,9997,9997,Ich möchte ihnen meine Hochachtung aussprechen...,,,,81.63,0.0,13.7,0.933333,...,15.36,18.07,6.5,5th and 6th grade,113.54,112.92,34.34,1.4,50.3,40.48
9998,9998,9998,Deshalb stellt das Paket der Änderungsanträge ...,,,,75.54,0.0,14.0,0.904762,...,12.82,17.46,11.5,7th and 8th grade,107.42,102.77,37.50,2.2,50.4,50.51
