In [1]:
import torch

from transformers import HfArgumentParser, Seq2SeqTrainingArguments,EarlyStoppingCallback

import logging

from dataclasses import dataclass, field
from typing import Callable, Dict, Optional
from datasets import load_dataset, concatenate_datasets,Value
import numpy as np
from typing import Union, Optional
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GlueDataset, AutoModel
from transformers import GlueDataTrainingArguments as DataTrainingArguments

from arguments import ModelArguments, DataArguments
import wandb
from nltk.tokenize import sent_tokenize
import nltk

nltk.download("punkt")
logger = logging.getLogger(__name__)
from transformers import (RobertaForMultipleChoice, RobertaTokenizer, Trainer,
                          TrainingArguments, XLMRobertaForMultipleChoice,
                          XLMRobertaTokenizer)

import pathlib
from datasets import load_dataset
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer
from transformers import TrainingArguments
from trl import SFTTrainer
from peft import LoraConfig, prepare_model_for_kbit_training


from utils import *
import numpy as np
from peft import PeftModel    
import logging
import os

# import evaluate 
from evaluate import load 
from torch.utils.data import DataLoader
from tqdm import tqdm


def split_sequence(sequence, chunk_size):
    chunks=[]
    for i in range(0, len(sequence), chunk_size):
        chunks.append(sequence[i: i + chunk_size])
    return chunks
		

def calc_results(prediction, truth, save_file, chunk_size=100):
    

    global bleu_score
    
    if (len(truth) != len(prediction)):
        print ("both files must have same number of instances")
        exit()


    truth_chunks= split_sequence(truth, chunk_size)

    truth_Egyptain=truth_chunks[0]
    truth_Emirati=truth_chunks[1]
    truth_Jordanian=truth_chunks[2]
    truth_Palestinian=truth_chunks[3]

    prediction_chunks= split_sequence(prediction, chunk_size)

    prediction_Egyptain=prediction_chunks[0]
    prediction_Emirati=prediction_chunks[1]
    prediction_Jordanian=prediction_chunks[2]
    prediction_Palestinian=prediction_chunks[3]

    ### get scores
    results_Egyptain = bleu_score.compute(predictions=prediction_Egyptain, references=truth_Egyptain)
    results_Emirati = bleu_score.compute(predictions=prediction_Emirati, references=truth_Emirati)
    results_Jordanian = bleu_score.compute(predictions=prediction_Jordanian, references=truth_Jordanian)
    results_Palestinian = bleu_score.compute(predictions=prediction_Palestinian, references=truth_Palestinian)
    overall_results = bleu_score.compute(predictions=prediction, references=truth)

    #write to a text file
    print('Scores:')
    scores = {
            'Overall': overall_results['bleu']*100,
            'Egyptain': results_Egyptain['bleu']*100,
            'Emirati': results_Emirati['bleu']*100,
            'Jordanian': results_Jordanian['bleu']*100,
            'Palestinian': results_Palestinian['bleu']*100, 
            }
    print(scores)

    with open(save_file, 'w') as score_file:
        score_file.write("Overall: %0.12f\n" % scores["Overall"])
        score_file.write("Egyptain: %0.12f\n" % scores["Egyptain"])
        score_file.write("Emirati: %0.12f\n" % scores["Emirati"])
        score_file.write("Jordanian: %0.12f\n" % scores["Jordanian"])
        score_file.write("Palestinian: %0.12f\n" % scores["Palestinian"])







2024-05-02 23:30:48.319980: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-02 23:30:48.320138: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-02 23:30:48.321361: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-02 23:30:48.328213: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
[nltk_data] Downloading package punkt to
[nltk_data] 

In [22]:

# model_name_or_path='core42/jais-13b'
model_name_or_path='core42/jais-13b-chat'


dataset = 'boda/nadi2024'
prompt_key="prompt"
chunk_size =500 
split="test"
per_device_eval_batch_size=4
save_file='outputs/jais_val'
# checkpoint_path='/l/users/abdelrahman.sadallah/nadi/core42/jais-13b/best/'


In [23]:






bleu_score = load("bleu")


print(f"Loading the   {split} datasets")
dataset = get_dataset(
    dataset_name = dataset,
    split=split,
    field=prompt_key)


save_file = save_file


val_dataloader = DataLoader(dataset, batch_size=per_device_eval_batch_size, shuffle=False)  



tokenizer = AutoTokenizer.from_pretrained(model_name_or_path,padding_side='left')
tokenizer.pad_token = tokenizer.eos_token


Loading the   test datasets


In [4]:

model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    trust_remote_code=True,
    return_dict=True,
    load_in_4bit=True,
    device_map="auto",
)



The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

In [5]:
checkpoint_path= '/l/users/abdelrahman.sadallah/nadi/core42/jais-13b-chat/best/'


bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    # bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=False
)
if checkpoint_path:
    print(f'Loading model from {checkpoint_path}')
    adapter_checkpoint  = checkpoint_path
    model = PeftModel.from_pretrained(model, adapter_checkpoint,quantization_config=bnb_config)

else:
    print(f'Loading Base Model {model_name_or_path}')


model = model.eval()

# Define PAD Token = BOS Token
model.config.pad_token_id = model.config.bos_token_id


Loading model from /l/users/abdelrahman.sadallah/nadi/core42/jais-13b-chat/best/


In [17]:
invalid = 0

In [18]:
def inference(prompts, tokenizer, model):
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    outs = []

    with torch.no_grad():
        for p in prompts:
            
            encoding = tokenizer(p, return_tensors="pt", padding=True).to(model.device)
    
            try:
                outputs = model.generate(
                    **encoding,
                    max_new_tokens=256,
                    do_sample=False,
                    # top_p = 0.9,
                    repetition_penalty=1.4,
                    # temperature=0.9,
                    pad_token_id=tokenizer.eos_token_id,
                )
                answer_tokens = outputs[:, encoding.input_ids.shape[1] :]
                output_text = tokenizer.batch_decode(answer_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                # print(output_text)
            except:
                invalid += 1
                output_text = ['']
        
            outs.append(output_text[0])


    return outs
        

In [19]:
dataset

Dataset({
    features: ['source', 'target', 'dialect', 'prompt'],
    num_rows: 400
})

In [None]:



predictions = []
labels = []


torch.cuda.empty_cache()


for batch in tqdm(val_dataloader):

    prompts = batch['prompt']
    ans = []

    labels.extend(batch['target'])

    output_text = inference(prompts=prompts, tokenizer=tokenizer, model=model)

    predictions.extend(output_text)

    print(len(predictions))
    # break
assert (len(predictions) == len(labels))

save_file =   save_file + '_results.txt'

preds_file = save_file + '_predictions.txt'

with open(preds_file, 'w') as f:
    for item in predictions:
        f.write("%s\n" % item)

calc_results(predictions, labels, save_file,chunk_size)
    

  0%|▎                                                                                                                                                      | 1/500 [00:15<2:06:05, 15.16s/it]

4


  0%|▌                                                                                                                                                      | 2/500 [00:23<1:32:56, 11.20s/it]

8


  1%|▉                                                                                                                                                      | 3/500 [00:32<1:24:35, 10.21s/it]

12


  1%|█▏                                                                                                                                                     | 4/500 [00:41<1:21:06,  9.81s/it]

16


  1%|█▌                                                                                                                                                     | 5/500 [00:50<1:17:26,  9.39s/it]

20


  1%|█▊                                                                                                                                                     | 6/500 [00:58<1:14:29,  9.05s/it]

24


  1%|██                                                                                                                                                     | 7/500 [01:06<1:09:39,  8.48s/it]

28


  2%|██▍                                                                                                                                                    | 8/500 [01:14<1:07:52,  8.28s/it]

32


  2%|██▋                                                                                                                                                    | 9/500 [01:25<1:14:47,  9.14s/it]

36


  2%|███                                                                                                                                                   | 10/500 [01:35<1:18:00,  9.55s/it]

40


  2%|███▎                                                                                                                                                  | 11/500 [01:48<1:25:51, 10.53s/it]

44


  2%|███▌                                                                                                                                                  | 12/500 [01:58<1:25:03, 10.46s/it]

48


  3%|███▉                                                                                                                                                  | 13/500 [02:06<1:19:34,  9.80s/it]

52


  3%|████▏                                                                                                                                                 | 14/500 [02:16<1:17:55,  9.62s/it]

56


  3%|████▌                                                                                                                                                 | 15/500 [02:24<1:14:50,  9.26s/it]

60


  3%|████▊                                                                                                                                                 | 16/500 [02:32<1:12:15,  8.96s/it]

64


  3%|█████                                                                                                                                                 | 17/500 [02:42<1:13:32,  9.14s/it]

68


  4%|█████▍                                                                                                                                                | 18/500 [02:52<1:15:08,  9.35s/it]

72


  4%|█████▋                                                                                                                                                | 19/500 [02:59<1:10:13,  8.76s/it]

76


  4%|██████                                                                                                                                                | 20/500 [03:06<1:06:43,  8.34s/it]

80


  4%|██████▎                                                                                                                                               | 21/500 [03:19<1:15:53,  9.51s/it]

84


  4%|██████▌                                                                                                                                               | 22/500 [03:29<1:18:15,  9.82s/it]

88


  5%|██████▉                                                                                                                                               | 23/500 [03:41<1:23:46, 10.54s/it]

92


  5%|███████▏                                                                                                                                              | 24/500 [03:48<1:15:11,  9.48s/it]

96


  5%|███████▌                                                                                                                                              | 25/500 [03:56<1:10:35,  8.92s/it]

100


  5%|███████▊                                                                                                                                              | 26/500 [04:04<1:07:54,  8.60s/it]

104


  5%|████████                                                                                                                                              | 27/500 [04:19<1:22:29, 10.46s/it]

108


  6%|████████▍                                                                                                                                             | 28/500 [04:29<1:23:10, 10.57s/it]

112


  6%|████████▋                                                                                                                                             | 29/500 [04:41<1:24:16, 10.74s/it]

116


  6%|█████████                                                                                                                                             | 30/500 [04:49<1:18:44, 10.05s/it]

120


  6%|█████████▎                                                                                                                                            | 31/500 [04:59<1:19:00, 10.11s/it]

124


  6%|█████████▌                                                                                                                                            | 32/500 [05:07<1:14:22,  9.54s/it]

128


  7%|█████████▉                                                                                                                                            | 33/500 [05:17<1:14:28,  9.57s/it]

132


  7%|██████████▏                                                                                                                                           | 34/500 [05:27<1:14:41,  9.62s/it]

136


  7%|██████████▌                                                                                                                                           | 35/500 [05:39<1:19:17, 10.23s/it]

140


  7%|██████████▊                                                                                                                                           | 36/500 [05:51<1:23:44, 10.83s/it]

144


  7%|███████████                                                                                                                                           | 37/500 [06:05<1:31:13, 11.82s/it]

148


  8%|███████████▍                                                                                                                                          | 38/500 [06:14<1:25:20, 11.08s/it]

152


  8%|███████████▋                                                                                                                                          | 39/500 [06:25<1:23:53, 10.92s/it]

156


  8%|████████████                                                                                                                                          | 40/500 [06:32<1:14:44,  9.75s/it]

160


  8%|████████████▎                                                                                                                                         | 41/500 [06:48<1:29:43, 11.73s/it]

164


  8%|████████████▌                                                                                                                                         | 42/500 [06:55<1:18:57, 10.34s/it]

168


  9%|████████████▉                                                                                                                                         | 43/500 [07:03<1:13:15,  9.62s/it]

172


  9%|█████████████▏                                                                                                                                        | 44/500 [07:12<1:11:15,  9.38s/it]

176


  9%|█████████████▌                                                                                                                                        | 45/500 [07:25<1:18:26, 10.34s/it]

180


In [None]:
len(predictions) , invalid

In [None]:
device = "cuda"

model.to(device)
def get_response(text,tokenizer=tokenizer,model=model):
    

    
    input_ids = tokenizer(text, return_tensors="pt") # .input_ids
    inputs = input_ids.to(device)
    # input_len = inputs.shape[-1]
    generate_ids = model.generate(
        **inputs,
        top_p=0.5,
        temperature=0.5,
         max_new_tokens=32,
        repetition_penalty=1.2,
        do_sample=True,
    )
    
    answer_tokens = generate_ids[:, inputs.input_ids.shape[1] :]
    response = tokenizer.batch_decode(
        answer_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True
    )[0]
    return response


# text = '''.فيما يلي جملة باللهجة العربية المصرية. يرجى ترجمتها إلى اللغة العربية الفصحى الحديثة
# طّلع الرجالة اللي انت عايزهم من الجراج و قطاع النقل يمشّطوا المناطق دي لحد ما يلاقوهم
# '''

text='''The following is a sentence in Egyptain Arabic dialect. Please translate it to Modern Standard Arabic (MSA).
طّلع الرجالة اللي انت عايزهم من الجراج و قطاع النقل يمشّطوا المناطق دي لحد ما يلاقوهم.
'''
text = dataset['prompt'][0]
print(get_response(text))

# text = "The capital of UAE is"
# print(get_response(text))

In [None]:
next(model.parameters()).is_cuda

In [None]:
dataset['prompt']