In [1]:
#!/usr/bin/env python
# coding: utf-8

import torch
import random
from transformers import HfArgumentParser, Seq2SeqTrainingArguments,EarlyStoppingCallback

import logging

from dataclasses import dataclass, field
from typing import Callable, Dict, Optional
from datasets import load_dataset, concatenate_datasets,Value
import numpy as np
from typing import Union, Optional
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GlueDataset, AutoModel
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import (
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    #glue_compute_metrics,
    glue_output_modes,
    glue_tasks_num_labels,
    set_seed,GenerationConfig
)
from arguments import ModelArguments, DataArguments
import wandb
from nltk.tokenize import sent_tokenize
import nltk
from evaluate import load

nltk.download("punkt")
logger = logging.getLogger(__name__)

import pathlib
from datasets import load_dataset
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer
from evaluate import load
# from utils import get_dataset
import numpy as np
from peft import PeftModel    
import logging
import os
from torch.utils.data import DataLoader 
from tqdm import tqdm
from typing import List


from llama import Dialog, Llama
import fire

2024-04-26 17:50:58.649229: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-26 17:50:58.649632: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-26 17:50:58.715261: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-04-26 17:50:58.861551: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
[nltk_data] Downloading package punkt to
[nltk_data] 

In [2]:
import os
import re
from bisect import bisect_left
from collections import defaultdict
import pandas as pd
import numpy as np
import math
from datasets import load_dataset, load_from_disk
import json
import random
import re
from evaluate import load

  



## Few-shots + explicit_errors
few_shots_explicit = '''You are given a few examples of erroneous sentences and their corrections. For every example, you are given the explicit errors and their corrections. You are also given an erroneous input sentence and its correction, followed by the explicit errors in this sentence. Output a simple explanation for each error in the erroneous sentence.'''

## Few shots + no explicit errors
few_shots_no_explicit = '''You are given a few examples of erroneous sentences and their corrections. You are also given an erroneous input sentence and its correction. Output a simple explanation for each error in the erroneous sentence.'''

## No few-shots + explicit
no_few_shots_explicit = '''You are given an erroneous input sentence and its correction, followed by the explicit errors in this sentence. Output a simple explanation for each error in the erroneous sentence.'''

## No few-shots + no explicit
no_few_shots_no_explicit = '''You are given an erroneous input sentence and its correction. Output a simple explanation for each error in the erroneous sentence.'''



def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

def put_explicit_error(example):
    ret = ''
    ret += f"Errors:\n"

    for i,correction in enumerate(example['corrections']):
        correct_word = ' ' if correction['correct'] == '-' else correction['correct']
        error_word = ' ' if correction['error'] == '-' else correction['error']

        ret += f"{i+1}. Error: {error_word}, Correction: {correct_word}\n"
    
    return ret


def generate_prompt(example, prompt_head = None,num_shots=0, explicit_errors = 0,train = 1, dataset = None, field = None):
    full_prompt  = prompt_head

    # print(example.keys())

    if num_shots:
        idx= np.random.randint(0,len(dataset),num_shots)
        samples = dataset.select(idx)

        for sample in samples:
            full_prompt += f"Erroneous sentence: {sample['incorrect_sentence']}\n"
            full_prompt += f"Correct sentence: {sample['correct_sentence']}\n"
            if explicit_errors:
                full_prompt += put_explicit_error(sample)

            full_prompt += f"Explanations:\n"
            for i,correction in enumerate(sample['corrections']):
                full_prompt += f"{i+1}. {correction['explanation']}\n"
            
    full_prompt += f"Erroneous sentence: {example['incorrect_sentence']}\n"
    full_prompt += f"Correct sentence: {example['correct_sentence']}\n"

    if explicit_errors:
        full_prompt += put_explicit_error(example)
        
    full_prompt += f"Explanations:\n"

    labels = ''
    for i,correction in enumerate(example['corrections']):
        labels += f"{i+1}. {correction['explanation']}\n"
    if train:
            full_prompt += labels


    # example['text'] = 'sd'
    example[field] = full_prompt.strip()
    example['label'] = labels
    # print(example.keys())

    return example





def get_dataset(dataset_path, split='train', field='prompt', num_shots=0,explicit_errors = 0):
    

    # Load the dataset
    dataset = load_dataset(dataset_path)[split]

    ########### Choose the prompt head based on the options ###########
    if num_shots and explicit_errors:
        prompt_head = few_shots_explicit
    elif num_shots and not explicit_errors:
        prompt_head = few_shots_no_explicit
    elif not num_shots and explicit_errors:
        prompt_head = no_few_shots_explicit
    else:
        prompt_head = no_few_shots_no_explicit

    dataset = dataset.map(generate_prompt, fn_kwargs={"field": field, \
        "prompt_head": prompt_head, "train": split == 'train', \
        'num_shots': num_shots,'dataset':dataset , 'explicit_errors': explicit_errors})
        
    return dataset



In [3]:
# configs = [
#     { 'model_name': 'cognitivecomputations/dolphin-2.8-mistral-7b-v02',
#      'chkpt': "/l/users/abdelrahman.sadallah/UWFE-Mixtral/cognitivecomputations/dolphin-2.8-mistral-7b-v02/best",
#      'shots': 0,
#     'explicit_errors': 0
#     },
#     { 'model_name': 'mistralai/Mistral-7B-v0.1',
#      'chkpt': "/l/users/abdelrahman.sadallah/UWFE-mistral-explicit-errors/mistralai/Mistral-7B-v0.1/best",
#      'shots': 0,
#     'explicit_errors': 1
#     },
#     { 'model_name': 'mistralai/Mistral-7B-v0.1',
#      'chkpt': "/l/users/abdelrahman.sadallah/UWFE-Mixtral/mistralai/Mistral-7B-v0.1/best/",
#      'shots': 0,
#     'explicit_errors': 0
#     },
#     { 'model_name': 'mistralai/Mixtral-8x7B-v0.1',
#      'chkpt': "/l/users/abdelrahman.sadallah/UWFE-mixtral-explicit-errors/mistralai/Mixtral-8x7B-v0.1/checkpoint-18500/",
#      'shots': 0,
#     'explicit_errors': 1
#     },
#     { 'model_name': 'mistralai/Mistral-7B-v0.1',
#         'chkpt': "",
#         'shots': 5,
#         'explicit_errors': 0,
#         'base': True
#     },
#     { 'model_name': 'mistralai/Mistral-7B-v0.1',
#         'chkpt': "",
#         'shots': 5,
#         'explicit_errors': 1,
#         'base': True
#     },
#     { 'model_name': 'mistralai/Mistral-7B-v0.1',
#         'chkpt': "",
#         'shots': 0,
#         'explicit_errors': 0,
#         'base': True
#     },
#     { 'model_name': 'mistralai/Mistral-7B-v0.1',
#         'chkpt': "",
#         'shots': 0,
#         'explicit_errors': 1,
#         'base': True
#     },
#     { 'model_name': 'Meta-Llama-3-8B',
#         'chkpt': '/l/users/abdelrahman.sadallah/llama3/Meta-Llama-3-8B',
#         'shots': 0,
#         'explicit_errors': 0,
#     },
#     { 'model_name': 'Meta-Llama-3-8B',
#         'chkpt': '/l/users/abdelrahman.sadallah/llama3/Meta-Llama-3-8B',
#         'shots': 0,
#         'explicit_errors': 1,
#     },
#     { 'model_name': 'Meta-Llama-3-8B',
#         'chkpt': '/l/users/abdelrahman.sadallah/llama3/Meta-Llama-3-8B',
#         'shots': 5,
#         'explicit_errors': 0,
#     },

#     { 'model_name': 'Meta-Llama-3-8B',
#         'chkpt': '/l/users/abdelrahman.sadallah/llama3/Meta-Llama-3-8B',
#         'shots': 5,
#         'explicit_errors': 1,
#     },
#     { 'model_name': 'Meta-Llama-3-8B-Instruct',
#         'chkpt': '/l/users/abdelrahman.sadallah/llama3/Meta-Llama-3-8B-Instruct',
#         'shots': 0,
#         'explicit_errors': 0,
#     },
#     { 'model_name': 'Meta-Llama-3-8B-Instruct',
#         'chkpt': '/l/users/abdelrahman.sadallah/llama3/Meta-Llama-3-8B-Instruct',
#         'shots': 0,
#         'explicit_errors': 1,
#     },
#     { 'model_name': 'Meta-Llama-3-8B-Instruct',
#         'chkpt': '/l/users/abdelrahman.sadallah/llama3/Meta-Llama-3-8B-Instruct',
#         'shots': 5,
#         'explicit_errors': 0,
#     },

#     { 'model_name': 'Meta-Llama-3-8B-Instruct',
#         'chkpt': '/l/users/abdelrahman.sadallah/llama3/Meta-Llama-3-8B-Instruct',
#         'shots': 5,
#         'explicit_errors': 1,
#     },
#     { 'model_name': 'chatgpt',
#         'shots': 0,
#         'explicit_errors': 0,
#      },
#     { 'model_name': 'chatgpt',
#         'shots': 0,
#         'explicit_errors': 1,
#      },
#     { 'model_name': 'chatgpt',
#         'shots': 5,
#         'explicit_errors': 0,
#      },
#     { 'model_name': 'chatgpt',
#         'shots': 5,
#         'explicit_errors': 1,
#      }


    
# ]


# import random

# ids = random.sample(range(0, 100), 20)
# for i,c in enumerate(configs):
#     configs[i]['id'] = str(ids[i])

# random.shuffle(configs)

# import json
# with open('evaluation_sheet.json','w') as f:
#      json.dump(configs, f)



In [4]:
no_shot_no_explicit_val_dataset = get_dataset(
    dataset_path = "boda/kaneko_data",
    split='test',
    field='prompt',
    num_shots=0,
    explicit_errors = 0).select([0,7,77,54,15,48,55,100,97,31])

no_shot_explicit_val_dataset = get_dataset(
    dataset_path = "boda/kaneko_data",
    split='test',
    field='prompt',
    num_shots=0,
    explicit_errors = 1).select([0,7,77,54,15,48,55,100,97,31])

five_shot_no_explicit_val_dataset = get_dataset(
    dataset_path = "boda/kaneko_data",
    split='test',
    field='prompt',
    num_shots=5,
    explicit_errors = 0).select([0,7,77,54,15,48,55,100,97,31])

five_shot_explicit_val_dataset = get_dataset(
    dataset_path = "boda/kaneko_data",
    split='test',
    field='prompt',
    num_shots=5,
    explicit_errors = 1).select([0,7,77,54,15,48,55,100,97,31])




In [35]:
import json
configs = []
with open('evaluation_sheet.json') as json_file:
   configs = json.load(json_file)

from os import walk

done = []
for (dirpath, dirnames, filenames) in walk('human_eval'):
    for fn in filenames:
        done.append(fn.split('.')[0])

In [6]:
def write_output(filename,predicitons, labels, Errors, corrects ):
    with open(filename, 'w') as f:
        for i in range(len(predicitons)):
            f.write(f"Error_sentence: {Errors[i]}\n\n")
            f.write(f"Correct_sentence: {corrects[i]}\n\n")
            f.write(f"Prediction: {predicitons[i]}\n\n")
            f.write(f"Label: {labels[i]}\n\n")
            f.write(f'-' * 30)
            f.write('\n\n\n')
            
        

In [18]:
from openai import OpenAI, AsyncOpenAI
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
from tqdm import tqdm   

def chatgpt_inf(val_dataset,c):

    id = c['id']
    model_name = "gpt-3.5-turbo-0125"	
    outputs = []
    labels = []
    error_sentences = []
    correct_sentences = []
    for example in tqdm(val_dataset):
        
        prompt = example['prompt']
        response = None
        try:
            clue_message = {"role": "user", "content":prompt}
            completion = client.chat.completions.create(
            model=model_name,
            messages=[
                clue_message])
            response = completion.choices[0].message.content.lower()
            outputs.append(response)
            labels.append(example["label"])
            error_sentences.append(example["incorrect_sentence"])
            correct_sentences.append(example["correct_sentence"])

        except:
            print("Error")
    
    filename = f"human_eval/{id}.txt"
    assert len(outputs) == len(labels) == len(error_sentences) == len(correct_sentences)

    write_output(filename,outputs, labels, error_sentences, correct_sentences)

In [8]:
def llama_inf(val_dataloader,c):

    import os
    os.environ['RANK'] = '0'
    os.environ['WORLD_SIZE'] = '1'
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    if c['model_name'] == 'Meta-Llama-3-8B-Instruct':
        tokenizer_path = '/l/users/abdelrahman.sadallah/llama3/Meta-Llama-3-8B-Instruct/tokenizer.model'
    elif c['model_name'] == 'Meta-Llama-3-8B':
        tokenizer_path = '/l/users/abdelrahman.sadallah/llama3/Meta-Llama-3-8B/tokenizer.model'

    generator = Llama.build(
        ckpt_dir=c['chkpt'],
        tokenizer_path=tokenizer_path,
        max_seq_len=2048,
        max_batch_size=2,
        model_parallel_size = 1
        )
    
    predicitons = []
    labels = []
    error_sentences = []
    correct_sentences = []
    for batch in tqdm(val_dataloader):
        prompts = [x['prompt'] for x in batch]
        for i,ex in enumerate(batch):
            # labels.append(f"{i+1}. {ex['explanation']}\n")
            labels.append(ex['label'])
            error_sentences.append(ex['incorrect_sentence'])
            correct_sentences.append(ex['correct_sentence'])

        if c['model_name'] == 'Meta-Llama-3-8B-Instruct':
                dialogs: List[Dialog] = [ ]
                for p in prompts:
                    d = [{"role": "user", "content":p}]
                    dialogs.append(d)

                prompts = dialogs

                model_outputs = generator.chat_completion(
                        prompts,
                        max_gen_len=256,
                        temperature=0.0001,
                        # top_p=0.9,
                        )
        else:
            model_outputs = generator.text_completion(
                    prompts,
                    max_gen_len=256,
                    temperature=0.0001,
                    # top_p=0.9,
                    )

        predicitons.extend([z['generation'] for z in model_outputs] if c['model_name'] == 'Meta-Llama-3-8B' else [z['generation']['content'] for z in model_outputs])

    filename = f"human_eval/{c['id']}.txt"
    assert len(predicitons) == len(labels) == len(error_sentences) == len(correct_sentences)
    write_output(filename,predicitons, labels, error_sentences, correct_sentences)

In [33]:

def inference(prompts, tokenizer, model):
    
   
    encoding = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **encoding,
            max_new_tokens=512,
            do_sample=True,
            temperature=0.1,
            repetition_penalty = 5.0,
            pad_token_id=tokenizer.eos_token_id,
            early_stopping=True,
            num_beams=3,
            # generation_config=generation_config,
        )
    answer_tokens = outputs[:, encoding.input_ids.shape[1] :]
    output_text = tokenizer.batch_decode(answer_tokens, skip_special_tokens=True)

    return output_text
        

def inf(val_dataloader,c):
    
    bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            # bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=False
        )
    model = AutoModelForCausalLM.from_pretrained(
    c['model_name'],
    quantization_config=bnb_config,
    trust_remote_code=True,
    # use_flash_attention_2=model_args.use_flash_attention_2,
    use_flash_attention_2=0,

    )

    if c['chkpt']:
        print(f"Loading model from {c['model_name']}")
        adapter_checkpoint  = c['chkpt']
        model = PeftModel.from_pretrained(model, adapter_checkpoint)

    else:
        print(f"Loading Base Model {c['model_name']}")

    tokenizer = AutoTokenizer.from_pretrained(c['model_name'])
    model = model.eval()
    # Define PAD Token = BOS Token
    tokenizer.pad_token = tokenizer.bos_token
    model.config.pad_token_id = model.config.bos_token_id
    predicitons = []
    labels = []
    error_sentences = []
    correct_sentences = []
    
    for batch in tqdm(val_dataloader):

        prompts = [x['prompt'] for x in batch]
        for i,ex in enumerate(batch):
            # labels.append(f"{i+1}. {ex['explanation']}\n")
            labels.append(ex['label'])
            error_sentences.append(ex['incorrect_sentence'])
            correct_sentences.append(ex['correct_sentence'])

        output_text = inference(prompts=prompts, tokenizer=tokenizer, model=model)

        predicitons.extend(output_text)

    filename = f"human_eval/{c['id']}.txt"

    
    assert len(predicitons) == len(labels) == len(error_sentences) == len(correct_sentences)
    write_output(filename,predicitons, labels, error_sentences, correct_sentences)

In [10]:
##3 move Mixtral to last
for c in configs:
    if c['model_name'] == "mistralai/Mixtral-8x7B-v0.1":
        configs.remove(c)
        configs.append(c)


In [11]:
for c in configs:
    print(c['model_name'])

cognitivecomputations/dolphin-2.8-mistral-7b-v02
Meta-Llama-3-8B
mistralai/Mistral-7B-v0.1
Meta-Llama-3-8B
chatgpt
chatgpt
mistralai/Mistral-7B-v0.1
Meta-Llama-3-8B-Instruct
Meta-Llama-3-8B-Instruct
mistralai/Mistral-7B-v0.1
mistralai/Mistral-7B-v0.1
Meta-Llama-3-8B-Instruct
chatgpt
mistralai/Mistral-7B-v0.1
Meta-Llama-3-8B-Instruct
chatgpt
Meta-Llama-3-8B
Meta-Llama-3-8B
mistralai/Mistral-7B-v0.1
mistralai/Mixtral-8x7B-v0.1


In [36]:
import os
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'




for c in configs:

    if str(c['id']) in done:
        continue

    # if c['model_name'] == 'cognitivecomputations/dolphin-2.8-mistral-7b-v02':
    #     continue
    # torch.cuda.empty_cache()

    print(f'----- Generating for {c["model_name"]} -----')
    
    if c['shots'] == 0 and c['explicit_errors'] == 0:
        val_dataset = no_shot_no_explicit_val_dataset
    elif c['shots'] == 0 and c['explicit_errors'] == 1:
        val_dataset = no_shot_explicit_val_dataset
    elif c['shots'] == 5 and c['explicit_errors'] == 0:
        val_dataset = five_shot_no_explicit_val_dataset
    elif c['shots'] == 5 and c['explicit_errors'] == 1:
        val_dataset = five_shot_explicit_val_dataset

    val_dataloader = DataLoader(val_dataset, batch_size=2,collate_fn=lambda x: x )

    ## chatgpt inference
    if c['model_name'] == 'chatgpt':
        chatgpt_inf(val_dataset,c)
    elif 'Meta' in c['model_name']:
        llama_inf(val_dataloader,c)
    ## Mistral and mixtral inference
    else:
        inf(val_dataloader,c)

    done.append(str(c['id']))

----- Generating for mistralai/Mistral-7B-v0.1 -----


`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading Base Model mistralai/Mistral-7B-v0.1


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [02:21<00:00, 28.37s/it]


----- Generating for mistralai/Mistral-7B-v0.1 -----


`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading Base Model mistralai/Mistral-7B-v0.1


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [01:14<00:00, 14.97s/it]


In [13]:
# for c in configs:
#  ## chatgpt inference
#     if c['model_name'] == 'chatgpt':
#         chatgpt_inf(val_dataset,c)

In [14]:
# import os
# print(os.environ['WORLD_SIZE'])

In [15]:
done

['2',
 '42',
 '68',
 '53',
 '3',
 '69',
 '89',
 '58',
 '24',
 '87',
 '18',
 '11',
 '72',
 '62',
 '90',
 '74',
 '8',
 '7',
 '76',
 '92']

In [25]:
no_shot_explicit_val_dataset[1]['prompt']

'You are given an erroneous input sentence and its correction, followed by the explicit errors in this sentence. Output a simple explanation for each error in the erroneous sentence.Erroneous sentence: Nevertheless , there are cyber polices who roam around these social networking platforms to prevent these hoaxes spreading around .\nCorrect sentence: Nevertheless , there are cyber police who roam around these social networking platforms to prevent these hoaxes spreading around .\nErrors:\n1. Error: polices, Correction: police\nExplanations:'