In [None]:
import torch

from transformers import HfArgumentParser, Seq2SeqTrainingArguments,EarlyStoppingCallback

import logging

from dataclasses import dataclass, field
from typing import Callable, Dict, Optional
from datasets import load_dataset, concatenate_datasets,Value
import numpy as np
from typing import Union, Optional
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GlueDataset, AutoModel
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import (
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    #glue_compute_metrics,
    glue_output_modes,
    glue_tasks_num_labels,
    set_seed,
)
from transformers import (
    TrainingArguments,
    Trainer
)
import evaluate
from peft import get_peft_model
from arguments import ModelArguments, DataArguments
import wandb
from nltk.tokenize import sent_tokenize
import nltk
from evaluate import load

nltk.download("punkt")
logger = logging.getLogger(__name__)
from transformers import (RobertaForMultipleChoice, RobertaTokenizer, Trainer,
                          TrainingArguments, XLMRobertaForMultipleChoice,
                          XLMRobertaTokenizer)

import pathlib
from datasets import load_dataset
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer
from transformers import TrainingArguments
from trl import SFTTrainer
from evaluate import load
from peft import LoraConfig, prepare_model_for_kbit_training

import re
from pathlib import Path
from utils_seq import *
import numpy as np
from peft import PeftModel    
import logging
import os
from huggingface_hub import login


In [None]:
import re
from datasets import load_dataset
import pyarabic.araby as araby
from transformers import AutoTokenizer


# DEFAULT_ARABIC_SYSTEM_PROMPT = '''
# The following is a sentence in {dialect} Arabic dialect. Please translate it to Modern Standard Arabic (MSA).
# '''.strip()


def clean_text(text):
    '''
    Cleans text from unnecessary characters.
    '''
    text = re.sub(r'http\S+', '', text)
    text = re.sub(r'@[^\s]+', '', text)
    text = re.sub(r'\s+', ' ', text)

    return re.sub(r'\^[^ ]+', '', text)


def print_trainable_parameters(model):
    '''
    Prints the number of trainable parameters in the model.
    '''
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f'trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}'
    )



def generate_arabic_training_prompt(example, tokenizer, field='prompt'):

    source = example['source']
    target = example['target']
    dialect = example['dialect']

    DEFAULT_ARABIC_SYSTEM_PROMPT = '''
    The following is a sentence in {dialect} Arabic dialect. Please translate it to Modern Standard Arabic (MSA).
    '''.strip()

    DEFAULT_ARABIC_SYSTEM_PROMPT = DEFAULT_ARABIC_SYSTEM_PROMPT.format(dialect=dialect)

    prompt = f'''
### Instruction: {DEFAULT_ARABIC_SYSTEM_PROMPT}

### Input:
{source}

### Response:
'''.strip()
    
    # if train:

    #     prompt = prompt + f'{target}'
    MAX_LENGTH = 512
    example[field] =  prompt
    model_inputs = tokenizer(
        example[field],
        max_length=MAX_LENGTH,
        truncation=True,
        padding='max_length'
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            target,
            max_length=MAX_LENGTH,
            truncation=True,
            padding='max_length'
        )
 
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


def get_dataset(dataset_name='boda/nadi2024',split = 'train', field='prompt'):
    '''
    Returns train, validation and test datasets for arabic, in the format described in
    generate_arabic_training_prompt().
    '''
    tokenizer = AutoTokenizer.from_pretrained("UBC-NLP/AraT5v2-base-1024")
    tokenizer.pad_token = tokenizer.eos_token

    dataset = load_dataset(dataset_name,split=split)

    dataset = dataset.map(generate_arabic_training_prompt, fn_kwargs={'field': field, 'tokenizer':tokenizer})
    
    return dataset


In [None]:

login(token="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")

bleu = evaluate.load("sacrebleu")

def preprocess_logits_for_metrics(logits, labels):
    """
    Original Trainer may have a memory leak.
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    pred_ids = torch.argmax(logits[0], dim=-1)
    return pred_ids, labels
    
def compute_metrics(eval_pred):
    predictions, labels = eval_pred.predictions[0], eval_pred.label_ids
 
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
 
    result = bleu.compute(
        predictions=decoded_preds,
        references=decoded_labels,
    )
 
    return {k: v for k, v in result.items()}

In [None]:

parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments))

model_args, data_args, training_args = parser.parse_args_into_dataclasses()

for arg in vars(model_args):
    print(arg, getattr(model_args, arg))
for arg in vars(data_args):
    print(arg, getattr(data_args, arg))
for arg in vars(training_args):
    print(arg, getattr(training_args, arg))


wandb.init(project=model_args.wandb_project,name=model_args.wandb_run_name)

## load tokenizer
global tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = 'right'
# tokenizer.padding_side  = 'left'
# model = AutoModel.from_pretrained(model_args.model_name_or_path)



print("Loading the datasets")
train_dataset = get_dataset(
    dataset_name = data_args.dataset,
    split='train',
    field=data_args.prompt_key)
    
val_dataset = get_dataset(
    dataset_name = data_args.dataset,
    split='dev',
    field=data_args.prompt_key,)

model = AutoModelForSeq2SeqLM.from_pretrained(
model_args.model_name_or_path,
# quantization_config=bnb_config,
trust_remote_code=True,
use_flash_attention_2=model_args.use_flash_attention_2,
cache_dir = "/scratch/afz225/.cache",
)


save_path = f'{training_args.output_dir}/{model_args.model_name_or_path}'
training_args.output_dir = save_path


lora_alpha = 16
lora_dropout = 0.1
lora_r = 64
# lora_target_modules = [
#                             'q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj'
#                         ]
lora_target_modules = [
    'q', 'v', 'k', 'o', 'wi_0', 'wi_1', 'wo'
]
max_seq_length = 256

peft_config = LoraConfig(
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        r=lora_r,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules = lora_target_modules
    )
model = get_peft_model(model, peft_config)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    compute_metrics=compute_metrics,
    callbacks = [EarlyStoppingCallback(early_stopping_patience = 3)]
)
    
history = trainer.train()

# Save the fine-tuned model
trainer.save_model(f"{save_path}/best")  # Adjust save directory

print("Training completed. Model saved. at ", save_path)

# eval_results = trainer.evaluate(val_dataset)

# print("Evaluation Results:", eval_results)
# wandb.log(eval_results)


