## LLM Output Generation

In [None]:
import os
import torch
import json
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from constants import SYSTEM_MESSAGE, INSTRUCTIONS, EXAMPLES, MODELS, MODEL_MAP
import openai
import backoff
from dotenv import load_dotenv
load_dotenv()

In [3]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")

Loading checkpoint shards: 100%|██████████| 3/3 [00:04<00:00,  1.60s/it]


In [13]:
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

In [20]:

messages = [
    {"role": "user", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "assistant", "content": "Who are you?"},
]

In [23]:
tokenizer.apply_chat_template(messages, tokenize=False)

'<s>[INST] You are a pirate chatbot who always responds in pirate speak! [/INST]Who are you?</s>'

In [9]:
input_prompt = messages

In [None]:
response = pipe(input_prompt, max_new_tokens=200)
response

In [None]:
from transformers import pipeline

chatbot = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.3")
chatbot(messages)


In [19]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda:0


In [5]:
def llama2_mistral_prompt(prompt):
    input_prompt = f"<s>[INST] <<SYS>>\n{SYSTEM_MESSAGE}\n<</SYS>>\n{INSTRUCTIONS}\n[/INST]\n{EXAMPLES} {prompt}\n\nResponse:\n"
    return input_prompt

def llama3_prompt(prompt):
    input_prompt = f'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{SYSTEM_MESSAGE}\n\n{INSTRUCTIONS}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{EXAMPLES} {prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>Response:\n'
    return input_prompt

def gemma_prompt(prompt):
    input_prompt = f"<bos><start_of_turn>user\n{SYSTEM_MESSAGE}\n\n{INSTRUCTIONS}<end_of_turn>\n<start_of_turn>model\n\n{EXAMPLES} {prompt}\n\nResponse:\nToxic"
    return input_prompt

def gpt_prompt(prompt):
    input_prompt = [
        {"role": "system", "content": f"{SYSTEM_MESSAGE}\n\n{INSTRUCTIONS}\n\n{EXAMPLES}"},
        {"role": "user", "content": prompt}
    ]
    return input_prompt

In [45]:
@backoff.on_exception(
    backoff.expo,
    (
        openai.APIError,
        openai.Timeout,
        openai.RateLimitError,
        openai.APIConnectionError,
    ),
    max_time=600,
    max_tries=3,
)
def get_gpt4_response(model, model_name, input_prompt):
    response = model.chat.completions.create(
        model=model_name,
        messages=input_prompt,
        max_tokens=200,
        
    ).choices[0].message.content
    return response
    

In [46]:
def run_inference(model_name, input_dir, output_dir, eval_type="Prompt", device_number=0):

    device = torch.device(f'cuda:{device_number}' if torch.cuda.is_available() else 'cpu')
    print("Using device:", device)

    folder_name = MODEL_MAP[model_name]
    if "gpt" in model_name:
        model = openai.AzureOpenAI(
            azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT"), 
            api_key=os.getenv("AZURE_OPENAI_KEY"),  
            api_version="2023-05-15"
            )
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map = 'auto')
        pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

    os.makedirs(os.path.join(output_dir, folder_name), exist_ok=True)

    for file in os.listdir(input_dir)[18:]:
        language = file.split("_")[2][:-5]
        dataset = [json.loads(line) for line in open(os.path.join(input_dir, file), "r", encoding="utf-8").readlines()]
        try:
            start_idx = len([json.loads(line) for line in open(os.path.join(output_dir, folder_name, f"results_{language}.jsonl"), "r", encoding="utf-8").readlines()])
        except Exception as e:
            start_idx = 0

        print("Starting from index:", start_idx)
        
        for idx, obj in enumerate(tqdm(dataset, desc=f"Running inference for {language}")):
            if idx < start_idx:
                continue
            
            if "gemma" in model_name:
                input_prompt = gemma_prompt(obj[eval_type])
            elif "gpt" in model_name:
                input_prompt = gpt_prompt(obj[eval_type])
            elif "Llama-3" in model_name:
                input_prompt = llama3_prompt(obj[eval_type])
            else:
                input_prompt = llama2_mistral_prompt(obj[eval_type])
                
            if "gpt" in model_name:
                try:
                    response = get_gpt4_response(model, model_name, input_prompt)
                except Exception as e:
                    print("ERROR", e)
                    response = "ERROR"  
            else:
                response = pipe(input_prompt, max_new_tokens=200)[0]['generated_text'].removeprefix(input_prompt)
            new_point = {
                eval_type: obj[eval_type],
                "InputPrompt" : input_prompt,
                "ResponseRaw": response,
                "Locale": language,
                "Index": obj["Index"],
                "OriginalEntry": obj
            }

            with open(os.path.join(output_dir, folder_name, f"results_{language}.jsonl"), "a", encoding="utf-8") as file:
                file.write(json.dumps(new_point, ensure_ascii=False) + "\n")  

In [67]:
def parse_gpt4_output(output_string):
    error_flag = False
    lines = output_string.strip().split('\n')
    parsed_data = {}
    metrics = ['Toxic', 'IdentityAttack', 'Bias', 'Microaggression', 'Violence', 'Self-harm', 'Insult', 'SexualContent']

    for line in lines:

        missing_keys = set(metrics) - set(parsed_data.keys())
        if not missing_keys:
            break

        line = line.replace(' ', '')
        for metric in metrics:
            if metric.lower() in line.lower():
                try:
                    val = int(line.split(":")[1])
                    if (metric == "Toxic" and (val < 1 or val > 5)):
                        if val < 1:
                            parsed_data[metric] = 1
                        elif val > 5:
                            parsed_data[metric] = 5
                        # error_flag = True

                    elif (metric != "Toxic" and (val < 1 or val > 3)):
                        if val < 1:
                            parsed_data[metric] = 1
                        elif val > 3:
                            parsed_data[metric] = 3
                        # error_flag = True
                    else:   
                        parsed_data[metric] = val
                except:
                    continue
                break

    # Check if any key is missing
    missing_keys = set(metrics) - set(parsed_data.keys())
    if missing_keys:
        error_flag = True
        for key in missing_keys:
            parsed_data[key] = f"Error: Key not found"

    return parsed_data, error_flag

In [70]:
def fix_errors(model_name, dir, eval_type="Prompt"):
    folder_name = MODEL_MAP[model_name]
    if "gpt" in model_name:
        model = openai.AzureOpenAI(
            azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT"), 
            api_key=os.getenv("AZURE_OPENAI_KEY"),  
            api_version="2023-05-15"
            )
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name, device_map = 'auto', torch_dtype=torch.bfloat16)
        pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
    
    errors = json.load(open(f"{dir}/parsed_outputs/{folder_name}/errors.json", "r", encoding="utf-8"))

    for language in errors.keys():

        original_outputs = [json.loads(line) for line in open(f"{dir}/raw_outputs_new/{folder_name}/results_{language}.jsonl", "r", encoding="utf-8").readlines()]

        fixed = 0
        for _, obj in enumerate(tqdm(errors[language], desc=f"Fixing errors for {language}")):
            
            index = obj["Index"]

            if "gemma" in model_name:
                input_prompt = gemma_prompt(obj[eval_type])
            elif "gpt" in model_name:
                input_prompt = gpt_prompt(obj[eval_type])
            elif "Llama-3" in model_name:
                input_prompt = llama3_prompt(obj[eval_type])
            else:
                input_prompt = llama2_mistral_prompt(obj[eval_type])
                
            if "gpt" in model_name:
                try:
                    response = get_gpt4_response(model, model_name, input_prompt)
                except Exception as e:
                    print("ERROR", e)
                    response = "ERROR"  
            else:
                response = pipe(input_prompt, max_new_tokens=200)[0]['generated_text'].removeprefix(input_prompt)

            # response = "Toxic" + response

            original = None
            for i, entry in enumerate(original_outputs):
                if str(entry["Index"]) == str(index):
                    original = i
                    break
            
            _, error_flag = parse_gpt4_output(response)
            if not error_flag:
                original_outputs[original]["ResponseRaw"] = response
                original_outputs[original]["InputPrompt"] = input_prompt
                fixed +=1
            
        with open(f"{dir}/raw_outputs_new/{folder_name}/results_{language}.jsonl", "w", encoding="utf-8") as file:
            for obj in original_outputs:
                file.write(json.dumps(obj, ensure_ascii=False) + "\n")
                    
        print(f"Fixed {fixed}/{len(errors[language])} errors for {language}")



In [72]:
fix_errors("gpt-4-turbo", "llm_eval/prompt", eval_type="Prompt")

Fixing errors for UK: 0it [00:00, ?it/s]


Fixed 0/0 errors for UK


Fixing errors for RU: 0it [00:00, ?it/s]


Fixed 0/0 errors for RU


Fixing errors for ID: 0it [00:00, ?it/s]

Fixed 0/0 errors for ID



Fixing errors for KO: 0it [00:00, ?it/s]


Fixed 0/0 errors for KO


Fixing errors for TR: 0it [00:00, ?it/s]


Fixed 0/0 errors for TR


Fixing errors for CS: 0it [00:00, ?it/s]


Fixed 0/0 errors for CS


Fixing errors for DE: 0it [00:00, ?it/s]


Fixed 0/0 errors for DE


Fixing errors for DA: 0it [00:00, ?it/s]


Fixed 0/0 errors for DA


Fixing errors for FR: 0it [00:00, ?it/s]


Fixed 0/0 errors for FR


Fixing errors for PT: 0it [00:00, ?it/s]


Fixed 0/0 errors for PT


Fixing errors for NO-NB: 0it [00:00, ?it/s]


Fixed 0/0 errors for NO-NB


Fixing errors for AR: 0it [00:00, ?it/s]


Fixed 0/0 errors for AR


Fixing errors for SV: 100%|██████████| 2/2 [00:05<00:00,  2.67s/it]


Fixed 2/2 errors for SV


Fixing errors for EN: 0it [00:00, ?it/s]


Fixed 0/0 errors for EN


Fixing errors for ZH-Hans: 0it [00:00, ?it/s]


Fixed 0/0 errors for ZH-Hans


Fixing errors for ES: 0it [00:00, ?it/s]


Fixed 0/0 errors for ES


Fixing errors for NL: 0it [00:00, ?it/s]


Fixed 0/0 errors for NL


Fixing errors for BCMS: 0it [00:00, ?it/s]


Fixed 0/0 errors for BCMS


Fixing errors for HU: 0it [00:00, ?it/s]


Fixed 0/0 errors for HU


Fixing errors for HI: 0it [00:00, ?it/s]


Fixed 0/0 errors for HI


Fixing errors for ZH-Hant: 0it [00:00, ?it/s]

Fixed 0/0 errors for ZH-Hant



Fixing errors for SW: 0it [00:00, ?it/s]


Fixed 0/0 errors for SW


Fixing errors for TH: 0it [00:00, ?it/s]

Fixed 0/0 errors for TH



Fixing errors for PL: 0it [00:00, ?it/s]


Fixed 0/0 errors for PL


Fixing errors for HE: 0it [00:00, ?it/s]


Fixed 0/0 errors for HE


Fixing errors for FI: 0it [00:00, ?it/s]


Fixed 0/0 errors for FI


Fixing errors for JA: 0it [00:00, ?it/s]

Fixed 0/0 errors for JA



Fixing errors for IT: 0it [00:00, ?it/s]


Fixed 0/0 errors for IT


In [10]:
run_inference("gpt-4-turbo")