In [None]:
# Original installations with minor cleanup and ALLaM compatibility
!pip install transformers datasets torch accelerate bitsandbytes wandb arabic-reshaper python-bidi
!pip install git+https://github.com/MagedSaeed/Bohour.git
!pip install -U transformers sentencepiece accelerate datasets evaluate

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting arabic-reshaper
  Downloading arabic_reshaper-3.0.0-py3-none-any.whl.metadata (12 kB)
Collecting python-bidi
  Downloading python_bidi-0.6.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.9 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
 

In [None]:
import os
import re
from collections import Counter

import torch
import arabic_reshaper
from bidi.algorithm import get_display
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
    TrainingArguments, Trainer
)
from peft import (
    LoraConfig, get_peft_model,
    prepare_model_for_kbit_training, TaskType
)



In [None]:
import collections
import re
from datasets import load_dataset

# Constants for tokenization
ST_BAYT_TOKEN = "<|bayt|>"
VERSE_TOKEN = "<|verse|>"
ED_BAYT_TOKEN = "<|endbayt|>"

meter_tokens = {
    "الطويل": "<|meter_1|>",
    "البسيط": "<|meter_2|>",
    "الكامل": "<|meter_3|>",
    "الوافر": "<|meter_4|>",
    "السريع": "<|meter_5|>",
    "الخفيف": "<|meter_6|>"
}

theme_tokens = {
    "مدح": "<|theme_1|>",
    "رثاء": "<|theme_2|>",
    "غزل": "<|theme_3|>",
    "وصف": "<|theme_4|>",
    "حكمة": "<|theme_5|>",
    "فخر": "<|theme_6|>",
    "هجاء": "<|theme_7|>",
    "حماسة": "<|theme_8|>",
    "عتاب": "<|theme_9|>",
    "زهد": "<|theme_10|>",
    None: "<|theme_11|>"  # Default for missing themes
}

DIAC_RE = re.compile(r"[\u0610-\u061A\u064B-\u065F\u0670\u06D6-\u06ED]")
def strip_harakat(t): return DIAC_RE.sub("", t or "")

MAP = str.maketrans({"إ":"ا","أ":"ا","آ":"ا","ٱ":"ا","ى":"ي","ة":"ه"})
def norm(t): return strip_harakat(t).translate(MAP).strip()

def extract_qafiyah(verse, m=2):
    verse = norm(verse)
    verse = "".join(ch if "\u0621" <= ch <= "\u064A" else " " for ch in verse)
    last = verse.split()[-1] if verse.split() else ""
    rh = ""
    for ch in reversed(last):
        rh = ch+rh
        if len(rh) >= m: break
    return rh or None

def filter_by_qafiyah(sample):

    verses = sample['poem verses']
    if len(verses) < 2:
        return False

    qafiyahs = []
    for i in range(1, len(verses), 2):
        if i < len(verses):
            qafiyah = extract_qafiyah(verses[i])
            if qafiyah:
                qafiyahs.append(qafiyah)

    if not qafiyahs or len(qafiyahs) < 2:
        return False

    counter = collections.Counter(qafiyahs)
    most_common = counter.most_common(1)[0]
    consistency_ratio = most_common[1] / len(qafiyahs)

    return consistency_ratio >= 0.3

def get_qafiyah_majority(poem):

    if not poem or ED_BAYT_TOKEN not in poem or VERSE_TOKEN not in poem:
        return None

    qafiyahs = []
    for bayt in poem.split(ED_BAYT_TOKEN):
        if VERSE_TOKEN in bayt:
            parts = bayt.split(VERSE_TOKEN)
            if len(parts) >= 2:
                second_half = parts[1].strip()
                qafiyah = extract_qafiyah(second_half)
                if qafiyah:
                    qafiyahs.append(qafiyah)

    if not qafiyahs:
        return None

    counter = collections.Counter(qafiyahs)
    if not counter:
        return None

    most_common = counter.most_common(1)[0]
    dominant_qafiyah, count = most_common

    if count / len(qafiyahs) >= 0.5:
        return dominant_qafiyah

    return None

def load_dataset_and_preprocess():

    print("Loading the ashaar dataset...")
    ashaar = load_dataset("arbml/ashaar")

    print(f"Dataset loaded with {len(ashaar['train'])} entries")

    selected_meters = ["الخفيف", "الطويل", "الكامل", "البسيط", "السريع", "الوافر"]

    def process_verse(sample):

        chars = 'ابتثجحخدذرزسشصضطظعغفقكلمنهويىئءأؤة ى'
        diacs = 'ْ~ًٌٍَُِّ'
        map_chars = {'ک': 'ك', 'ﺑ': 'ب', 'ٹ': 'ث', 'ی': 'ى'}
        out = []
        for verse in sample['poem verses']:
            proc_verse = ''
            for char in verse:
                if char in chars + diacs:
                    proc_verse += char
                elif char in map_chars:
                    proc_verse += map_chars[char]
            out.append(proc_verse)
        sample['poem verses'] = out
        return sample

    def filter_poems(sample):

        poem = sample['poem verses']
        if len(poem) < 2 or len(poem) % 2 != 0:
            return False
        return all(len(verse) >= 5 for verse in poem)

    def map_meters(sample):

        meter = sample['poem meter']
        if meter:
            if meter == 'بسيط':
                sample['poem meter'] = 'البسيط'
            elif 'خفيف' in meter:
                sample['poem meter'] = 'الخفيف'
            elif 'طويل' in meter:
                sample['poem meter'] = 'الطويل'
            elif 'كامل' in meter:
                sample['poem meter'] = 'الكامل'
            elif 'سريع' in meter:
                sample['poem meter'] = 'السريع'
            elif 'وافر' in meter:
                sample['poem meter'] = 'الوافر'
        return sample

    def filter_meters(sample):

        return sample['poem meter'] in selected_meters

    def join_verses(sample):

        verses = sample['poem verses']
        meter = sample['poem meter']
        theme = sample['poem theme']
        title = sample.get('poem title', 'بدون عنوان')





        poem = ''.join([f'{ST_BAYT_TOKEN} {verses[i]} {VERSE_TOKEN} {verses[i+1]} {ED_BAYT_TOKEN} '
                        for i in range(0, len(verses) - 1, 2)])

        qafiyah = get_qafiyah_majority(poem)
        if not qafiyah:
            return {"prompt": "", "completion": ""}

        prompt = f"""أنشئ قصيدة عربية فصيحة وفقاً للمواصفات التالية:

العنوان: {title}
البحر: {meter}
نوع القصيدة: {theme if theme else 'عامة'}
القافية: {qafiyah}

يجب أن تكون القصيدة:
- ملتزمة بقواعد بحر {meter} وتفعيلاته
- منتهية كل بيت بحرف :{qafiyah}
- متناسبة مع موضوع :{title}

الهيكل المطلوب:
{ST_BAYT_TOKEN} الشطر الأول {VERSE_TOKEN} الشطر الثاني {ED_BAYT_TOKEN}

اكتب القصيدة:"""

        completion = poem.strip()

        return {"prompt": prompt, "completion": completion}

    print("Processing verses...")
    ashaar = ashaar.map(process_verse)
    print("Filtering poems by structure...")
    ashaar = ashaar.filter(filter_poems)
    print("Mapping meters...")
    ashaar = ashaar.map(map_meters)
    print("Filtering by meter...")
    ashaar = ashaar.filter(filter_meters)
    print("Filtering by qafiyah...")
    ashaar = ashaar.filter(filter_by_qafiyah)
    print("Creating prompt-completion pairs...")
    processed_data = ashaar.map(join_verses)
    processed_data = processed_data.filter(lambda x: x["prompt"] != "" and x["completion"] != "")

    print(f"Final dataset size: {len(processed_data['train'])} entries")
    return processed_data

In [None]:
processed_dataset = load_dataset_and_preprocess()


Loading the ashaar dataset...


dataset_infos.json:   0%|          | 0.00/34.7k [00:00<?, ?B/s]

train-00000-of-00002.parquet:   0%|          | 0.00/126M [00:00<?, ?B/s]

train-00001-of-00002.parquet:   0%|          | 0.00/151M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/254630 [00:00<?, ? examples/s]

Dataset loaded with 254630 entries
Processing verses...


Map:   0%|          | 0/254630 [00:00<?, ? examples/s]

Filtering poems by structure...


Filter:   0%|          | 0/254630 [00:00<?, ? examples/s]

Mapping meters...


Map:   0%|          | 0/219946 [00:00<?, ? examples/s]

Filtering by meter...


Filter:   0%|          | 0/219946 [00:00<?, ? examples/s]

Filtering by qafiyah...


Filter:   0%|          | 0/107357 [00:00<?, ? examples/s]

Creating prompt-completion pairs...


Map:   0%|          | 0/86164 [00:00<?, ? examples/s]

Filter:   0%|          | 0/86164 [00:00<?, ? examples/s]

Final dataset size: 77873 entries


In [None]:
from pprint import pprint
pprint(processed_dataset['train']['prompt'][100])
pprint(processed_dataset['train']['completion'][100])

('أنشئ قصيدة عربية فصيحة وفقاً للمواصفات التالية:\n'
 '\n'
 'العنوان: لم تمل بي عن العفاف العقار\n'
 'البحر: الخفيف\n'
 'نوع القصيدة: قصيدة قصيره\n'
 'القافية: ار\n'
 '\n'
 'يجب أن تكون القصيدة:\n'
 '- ملتزمة بقواعد بحر الخفيف وتفعيلاته\n'
 '- منتهية كل بيت بحرف :ار\n'
 '- متناسبة مع موضوع :لم تمل بي عن العفاف العقار\n'
 '\n'
 'الهيكل المطلوب:\n'
 '<|bayt|> الشطر الأول <|verse|> الشطر الثاني <|endbayt|>\n'
 '\n'
 'اكتب القصيدة:')
('<|bayt|> لَم تَمل بي عَن العفاف العقارُ <|verse|> أَعشَق الغيد وَالوَقار '
 'وَقارُ <|endbayt|> <|bayt|> أَنظُم الشعر ما حييت وَأَني <|verse|> لابن بَيت '
 'تَهدى لَهُ الأَشعارُ <|endbayt|> <|bayt|> يَتَحلى بِيَ الزَمان تَحلى الغُص '
 '<|verse|> نِ لَما يَزينهُ النوّارُ <|endbayt|> <|bayt|> صَقَلتَني يَد '
 'التَجارُب حَتّى <|verse|> صَحَ عَزمي وَطابَ مِنهُ الغرارُ <|endbayt|> '
 '<|bayt|> وَمَكاني مِن الفَخار مَكان <|verse|> حَسدتهُ الشُموس وَالأَقمارُ '
 '<|endbayt|>')


In [None]:
def load_model():
    model_name = "silma-ai/SILMA-9B-Instruct-v1.0"
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Load model with BF16 precision (ALLaM’s recommendation)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,  # Changed from float16 to match ALLaM’s native precision
        device_map="auto",
        load_in_8bit=True  # Kept for memory efficiency
    )

    # Ensure padding token exists (ALLaM may not define it)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = model.config.eos_token_id

    model = prepare_model_for_kbit_training(model)

    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias="none",
        target_modules=["q_proj", "k_proj", "v_proj", "to_out.0"]
    )

    # Apply LoRA adapters
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    return model, tokenizer

model, tokenizer = load_model()

tokenizer_config.json:   0%|          | 0.00/46.9k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/895 [00:00<?, ?B/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


model.safetensors.index.json:   0%|          | 0.00/39.1k [00:00<?, ?B/s]

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

model-00001-of-00005.safetensors:   0%|          | 0.00/3.91G [00:00<?, ?B/s]

model-00003-of-00005.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]

model-00004-of-00005.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]

model-00005-of-00005.safetensors:   0%|          | 0.00/2.69G [00:00<?, ?B/s]

model-00002-of-00005.safetensors:   0%|          | 0.00/3.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

trainable params: 12,730,368 || all params: 9,254,436,352 || trainable%: 0.1376


In [None]:
class PoetryDataset(torch.utils.data.Dataset):
    def __init__(self, examples, tokenizer, max_length=512):
        self.examples = examples
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example = self.examples[idx]
        prompt = example["prompt"]
        completion = example["completion"]

        # Format as instruction tuning (unchanged)
        full_text = f"<s>[INST] {prompt} [/INST] {completion}</s>"

        # Tokenize
        encodings = self.tokenizer(
            full_text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        input_ids = encodings["input_ids"][0]
        attention_mask = encodings["attention_mask"][0]

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": input_ids.clone()
        }

In [None]:
from google.colab import drive
import os

# Mount Google Drive
drive.mount('/content/drive')

# Define the save path in Drive (updated for ALLaM)
save_path = "/content/drive/MyDrive/Silma_final_v3"

# Create the folder if it doesn’t exist
os.makedirs(save_path, exist_ok=True)

Mounted at /content/drive


In [None]:
def setup_training(processed_data, tokenizer):
    train_dataset = PoetryDataset(
        processed_data["train"],
        tokenizer,
        max_length=512
    )

    training_args = TrainingArguments(
        output_dir=save_path,
        num_train_epochs=2,
        per_device_train_batch_size=8,
        gradient_accumulation_steps=2,
        save_steps=500,
        save_total_limit=2,
        logging_steps=100,
        learning_rate=2e-5,
        warmup_steps=500,
        bf16=True,  # Changed to match ALLaM’s precision
        fp16=False,  # Disabled since bf16 is used
        report_to="wandb",
        logging_dir=os.path.join(save_path, "logs"),
        dataloader_num_workers=2,
    )

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False
    )

    return train_dataset, training_args, data_collator

train_dataset, training_args, data_collator = setup_training(processed_dataset, tokenizer)

In [None]:
def train_model(model, train_dataset, training_args, data_collator):
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=data_collator,
    )

    trainer.train()

    # Updated save path for ALLaM within Google Drive
    save_path_drive = "/content/drive/MyDrive/Silma_final_v3"  # Organize within a folder
    os.makedirs(save_path_drive, exist_ok=True)
    model.save_pretrained(save_path_drive)
    tokenizer.save_pretrained(save_path_drive)

    return trainer

trainer = train_model(model, train_dataset, training_args, data_collator)

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mhatimalhomid[0m ([33mhatimalhomid-education-com[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


It is strongly recommended to train Gemma2 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  return fn(*args, **kwargs)


Step,Training Loss
100,2.9933
200,2.3402
300,1.6547
400,1.5884
500,1.5936
600,1.5542
700,1.5209
800,1.5105
900,1.5014
1000,1.494


  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)
  return fn(*args, **kwargs)


Step,Training Loss
100,2.9933
200,2.3402
300,1.6547
400,1.5884
500,1.5936
600,1.5542
700,1.5209
800,1.5105
900,1.5014
1000,1.494
