In [1]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
from typing import List, Tuple
import os
from tqdm import tqdm
import re
import json
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

class ArabicGPTTPU:
    def __init__(self, model_name: str = "aubmindlab/aragpt2-base"):
        """
        Initialize the Arabic GPT model with TPU support
        Args:
            model_name: Name of the pre-trained model to use
        """
        # Initialize TPU device
        self.device = xm.xla_device()
        print(f"Using TPU device: {self.device}")

        # Initialize tokenizer and model
        print("Loading tokenizer and model...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name)

        # Add special tokens for Arabic
        special_tokens = {
            'pad_token': '[PAD]',
            'bos_token': '[BOS]',
            'eos_token': '[EOS]',
            'unk_token': '[UNK]'
        }
        self.tokenizer.add_special_tokens(special_tokens)
        self.model.resize_token_embeddings(len(self.tokenizer))

        # Move model to TPU
        self.model = self.model.to(self.device)

    def save_model(self, save_dir: str = 'arabic_gpt_tpu_model'):
        """
        Save the model and tokenizer to disk
        Args:
            save_dir: Directory to save the model and tokenizer
        """
        print(f"Saving model to {save_dir}...")

        # Create directory if it doesn't exist
        os.makedirs(save_dir, exist_ok=True)

        # Move model to CPU before saving
        self.model = self.model.to('cpu')

        # Save model
        model_path = os.path.join(save_dir, 'model')
        self.model.save_pretrained(model_path)

        # Save tokenizer
        tokenizer_path = os.path.join(save_dir, 'tokenizer')
        self.tokenizer.save_pretrained(tokenizer_path)

        # Save model configuration
        config = {
            'model_name': self.model.config.model_type,
            'vocab_size': self.model.config.vocab_size,
            'max_position_embeddings': self.model.config.max_position_embeddings,
            'num_attention_heads': self.model.config.num_attention_heads,
            'num_hidden_layers': self.model.config.num_hidden_layers,
        }

        with open(os.path.join(save_dir, 'config.json'), 'w', encoding='utf-8') as f:
            json.dump(config, f, indent=4)

        print("Model saved successfully!")

        # Move model back to TPU
        self.model = self.model.to(self.device)

    @classmethod
    def load_model(cls, load_dir: str = 'arabic_gpt_tpu_model'):
        """
        Load a saved model and tokenizer
        Args:
            load_dir: Directory containing the saved model and tokenizer
        Returns:
            ArabicGPTTPU instance with loaded model
        """
        print(f"Loading model from {load_dir}...")

        # Create instance
        instance = cls()

        # Load model
        model_path = os.path.join(load_dir, 'model')
        instance.model = AutoModelForCausalLM.from_pretrained(model_path)
        instance.model = instance.model.to(instance.device)

        # Load tokenizer
        tokenizer_path = os.path.join(load_dir, 'tokenizer')
        instance.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

        print("Model loaded successfully!")
        return instance

    def clean_text(self, text: str) -> str:
        """
        Clean and normalize Arabic text
        """
        # Remove URLs
        text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
        # Remove special characters and numbers
        text = re.sub(r'[^\u0600-\u06FF\s]', '', text)
        # Remove extra spaces
        text = re.sub(r'\s+', ' ', text).strip()
        return text

    def preprocess_data(self, file_path: str = '/kaggle/input/arabic-news/Arabic_news.csv') -> List[str]:
        """
        Preprocess the Arabic news dataset
        Args:
            file_path: Path to the dataset file
        Returns:
            List of preprocessed sentences
        """
        print(f"Loading dataset from {file_path}")
        df = pd.read_csv(file_path)

        # Clean and preprocess the text
        print("Preprocessing text data...")
        df=df[:15000]
        df['text'] = df['text'].fillna('')
        df['text'] = df['text'].apply(self.clean_text)
        df = df[df['text'].str.len() > 0]

        # Split into sentences and create training pairs
        sentences = []
        for text in tqdm(df['text'], desc="Processing sentences"):
            # Split by common Arabic sentence endings
            text_sentences = [s.strip() for s in text.split('۔') if len(s.strip()) > 0]
            # Add BOS and EOS tokens
            text_sentences = [f"[BOS] {s} [EOS]" for s in text_sentences]
            sentences.extend(text_sentences)

        print(f"Total number of sentences: {len(sentences)}")
        return sentences

    def prepare_batch(self, sentences: List[str], batch_size: int) -> torch.Tensor:
        """
        Prepare a batch of sentences for training
        """
        encodings = self.tokenizer(sentences,
                                 padding=True,
                                 truncation=True,
                                 max_length=128,  # Reduced from 512 to save memory
                                 return_tensors="pt")
        return encodings

    def train(self, sentences: List[str], epochs: int = 3, batch_size: int = 8, gradient_accumulation_steps: int = 4):
        """
        Fine-tune the model on the Arabic news dataset using TPU with memory-efficient training
        Args:
            sentences: List of training sentences
            epochs: Number of training epochs
            batch_size: Batch size for training (reduced for TPU memory constraints)
            gradient_accumulation_steps: Number of steps to accumulate gradients
        """
        self.model.train()
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-5)

        # Create dataset
        encodings = self.prepare_batch(sentences, batch_size)
        dataset = torch.utils.data.TensorDataset(
            encodings["input_ids"],
            encodings["attention_mask"]
        )

        # Create TPU-optimized data loader with smaller batch size
        train_loader = pl.ParallelLoader(
            torch.utils.data.DataLoader(
                dataset,
                batch_size=batch_size,
                shuffle=True
            ),
            [self.device]
        ).per_device_loader(self.device)

        # Calculate effective batch size
        effective_batch_size = batch_size * gradient_accumulation_steps

        for epoch in range(epochs):
            total_loss = 0
            optimizer.zero_grad()
            running_loss = 0
            num_batches = 0

            # Reset the data loader for each epoch
            train_loader = pl.ParallelLoader(
                torch.utils.data.DataLoader(
                    dataset,
                    batch_size=batch_size,
                    shuffle=True
                ),
                [self.device]
            ).per_device_loader(self.device)

            # Calculate total number of batches for this epoch
            total_batches = len(train_loader)
            mid_point = total_batches // 2

            for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")):
                input_ids, attention_mask = batch

                # Forward pass
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=input_ids
                )
                loss = outputs.loss / gradient_accumulation_steps  # Normalize loss

                # Backward pass
                loss.backward()

                # Update weights if we've accumulated enough gradients
                if (batch_idx + 1) % gradient_accumulation_steps == 0:
                    # TPU optimizer step
                    xm.optimizer_step(optimizer)
                    optimizer.zero_grad()

                    # Mark step for TPU
                    xm.mark_step()

                # Update running loss
                running_loss += loss.item() * gradient_accumulation_steps
                num_batches += 1

                # Print progress at middle of epoch
                if batch_idx == mid_point:
                    avg_loss = running_loss / num_batches
                    xm.master_print(f"\nEpoch {epoch + 1}/{epochs} - Midpoint, Average Loss: {avg_loss:.4f}\n")

            # Calculate and print final epoch average loss
            epoch_avg_loss = running_loss / num_batches
            xm.master_print(f"\nEpoch {epoch + 1}/{epochs} completed, Final Average Loss: {epoch_avg_loss:.4f}\n")
            if (epoch+1)%5==0:
            # Save checkpoint after each epoch
                self.save_model(f'arabic_gpt_tpu_model_epoch_{epoch + 1}')

    def predict_next_word(self, sentence: str, num_predictions: int = 5) -> List[Tuple[str, float]]:
        """
        Predict the next word given a sentence
        Args:
            sentence: Input sentence
            num_predictions: Number of predictions to return
        Returns:
            List of tuples containing (predicted_word, probability)
        """
        # Set model to eval mode once
        self.model.eval()

        with torch.no_grad():
            # Add BOS token to input
            input_text = f"[BOS] {sentence}"
            inputs = self.tokenizer(input_text, return_tensors="pt")
            input_ids = inputs["input_ids"].to(self.device)
            attention_mask = inputs["attention_mask"].to(self.device)

            # Get logits for the last position only
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=False,
                output_attentions=False
            )
            last_token_logits = outputs.logits[0, -1, :]

            # Get top k predictions
            top_k_logits, top_k_indices = torch.topk(last_token_logits, num_predictions)
            probs = torch.softmax(top_k_logits, dim=-1)

            # Convert to list of predictions
            predictions = []
            for idx, (token_id, prob) in enumerate(zip(top_k_indices, probs)):
                word = self.tokenizer.decode([token_id])
                predictions.append((word, prob.item()))

            return predictions

def main():
    # Initialize the model
    print("Initializing model...")
    model = ArabicGPTTPU()

    # Load and preprocess data
    print("Loading and preprocessing data...")
    sentences = model.preprocess_data()

    # Train the model with memory-efficient settings
    print("Training the model...")
    model.train(
        sentences,
        epochs=30,
        batch_size=8,  # Reduced batch size
        gradient_accumulation_steps=4  # Added gradient accumulation
    )

    # Example usage with the trained model
    print("\nTesting predictions with the trained model:")
    test_sentences = [
        "مرحبا كيف حالك",
        "انا اريد ان اتعلم",
        "شكرا جزيلا على",
        "السلام عليكم ورحمة",
        "اهلا وسهلا بكم في",
        "انا احب القراءة",
        "هذا الكتاب جميل"
    ]

    for test_sentence in test_sentences:
        predictions = model.predict_next_word(test_sentence)
        print(f"\nInput sentence: {test_sentence}")
        print("Predictions:")
        for word, prob in predictions:
            print(f"{word}: {prob:.4f}")

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm


Initializing model...


E0000 00:00:1746993159.515732     297 common_lib.cc:621] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:232


Using TPU device: xla:0
Loading tokenizer and model...


E0000 00:00:1746993173.211080     297 common_lib.cc:612] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:230
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Loading and preprocessing data...
Loading dataset from /kaggle/input/arabic-news/Arabic_news.csv
Preprocessing text data...


Processing sentences: 100%|██████████| 15000/15000 [00:00<00:00, 273231.51it/s]


Total number of sentences: 15000
Training the model...


Epoch 1/30:   0%|          | 0/1875 [00:00<?, ?it/s]`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
Epoch 1/30:  50%|████▉     | 936/1875 [03:17<03:04,  5.09it/s] 


Epoch 1/30 - Midpoint, Average Loss: 5.5270



Epoch 1/30: 100%|██████████| 1875/1875 [05:51<00:00,  5.34it/s]



Epoch 1/30 completed, Final Average Loss: 5.0689



Epoch 2/30:  50%|████▉     | 936/1875 [02:33<02:58,  5.26it/s]


Epoch 2/30 - Midpoint, Average Loss: 4.2404



Epoch 2/30: 100%|██████████| 1875/1875 [05:07<00:00,  6.10it/s]



Epoch 2/30 completed, Final Average Loss: 4.1808



Epoch 3/30:  50%|████▉     | 936/1875 [02:33<02:54,  5.37it/s]


Epoch 3/30 - Midpoint, Average Loss: 3.9407



Epoch 3/30: 100%|██████████| 1875/1875 [05:07<00:00,  6.09it/s]



Epoch 3/30 completed, Final Average Loss: 3.9220



Epoch 4/30:  50%|████▉     | 936/1875 [02:33<03:02,  5.15it/s]


Epoch 4/30 - Midpoint, Average Loss: 3.7573



Epoch 4/30: 100%|██████████| 1875/1875 [05:06<00:00,  6.12it/s]



Epoch 4/30 completed, Final Average Loss: 3.7476



Epoch 5/30:  50%|████▉     | 936/1875 [02:34<02:58,  5.25it/s]


Epoch 5/30 - Midpoint, Average Loss: 3.6086



Epoch 5/30: 100%|██████████| 1875/1875 [05:08<00:00,  6.07it/s]



Epoch 5/30 completed, Final Average Loss: 3.6062

Saving model to arabic_gpt_tpu_model_epoch_5...
Model saved successfully!


Epoch 6/30:  50%|████▉     | 936/1875 [02:06<02:04,  7.57it/s]


Epoch 6/30 - Midpoint, Average Loss: 3.4695



Epoch 6/30: 100%|██████████| 1875/1875 [04:05<00:00,  7.64it/s]



Epoch 6/30 completed, Final Average Loss: 3.4724



Epoch 7/30:  50%|████▉     | 936/1875 [02:04<02:06,  7.41it/s]


Epoch 7/30 - Midpoint, Average Loss: 3.4758



Epoch 7/30: 100%|██████████| 1875/1875 [04:08<00:00,  7.54it/s]



Epoch 7/30 completed, Final Average Loss: 3.4723



Epoch 8/30:  50%|████▉     | 936/1875 [02:03<02:08,  7.30it/s]


Epoch 8/30 - Midpoint, Average Loss: 3.4823



Epoch 8/30: 100%|██████████| 1875/1875 [04:10<00:00,  7.50it/s]



Epoch 8/30 completed, Final Average Loss: 3.4725



Epoch 9/30:  50%|████▉     | 937/1875 [02:03<02:06,  7.43it/s]


Epoch 9/30 - Midpoint, Average Loss: 3.4737



Epoch 9/30: 100%|██████████| 1875/1875 [04:06<00:00,  7.60it/s]



Epoch 9/30 completed, Final Average Loss: 3.4704



Epoch 10/30:  50%|████▉     | 936/1875 [02:03<02:08,  7.28it/s]


Epoch 10/30 - Midpoint, Average Loss: 3.4627



Epoch 10/30: 100%|██████████| 1875/1875 [04:07<00:00,  7.59it/s]



Epoch 10/30 completed, Final Average Loss: 3.4708

Saving model to arabic_gpt_tpu_model_epoch_10...
Model saved successfully!


Epoch 11/30:  50%|████▉     | 937/1875 [02:03<02:02,  7.68it/s]


Epoch 11/30 - Midpoint, Average Loss: 3.4726



Epoch 11/30: 100%|██████████| 1875/1875 [04:19<00:00,  7.23it/s]



Epoch 11/30 completed, Final Average Loss: 3.4709



Epoch 12/30:  50%|████▉     | 937/1875 [02:15<02:13,  7.04it/s]


Epoch 12/30 - Midpoint, Average Loss: 3.4750



Epoch 12/30: 100%|██████████| 1875/1875 [04:31<00:00,  6.91it/s]



Epoch 12/30 completed, Final Average Loss: 3.4703



Epoch 13/30:  50%|████▉     | 937/1875 [02:17<02:07,  7.33it/s]


Epoch 13/30 - Midpoint, Average Loss: 3.4789



Epoch 13/30: 100%|██████████| 1875/1875 [04:35<00:00,  6.80it/s]



Epoch 13/30 completed, Final Average Loss: 3.4714



Epoch 14/30:  50%|████▉     | 937/1875 [02:18<02:15,  6.90it/s]


Epoch 14/30 - Midpoint, Average Loss: 3.4735



Epoch 14/30: 100%|██████████| 1875/1875 [04:35<00:00,  6.80it/s]



Epoch 14/30 completed, Final Average Loss: 3.4730



Epoch 15/30:  50%|████▉     | 937/1875 [02:17<02:12,  7.09it/s]


Epoch 15/30 - Midpoint, Average Loss: 3.4592



Epoch 15/30: 100%|██████████| 1875/1875 [04:34<00:00,  6.82it/s]



Epoch 15/30 completed, Final Average Loss: 3.4715

Saving model to arabic_gpt_tpu_model_epoch_15...
Model saved successfully!


Epoch 16/30:  50%|████▉     | 937/1875 [02:15<02:11,  7.13it/s]


Epoch 16/30 - Midpoint, Average Loss: 3.4701



Epoch 16/30: 100%|██████████| 1875/1875 [04:31<00:00,  6.90it/s]



Epoch 16/30 completed, Final Average Loss: 3.4709



Epoch 17/30:  50%|████▉     | 937/1875 [02:16<02:09,  7.25it/s]


Epoch 17/30 - Midpoint, Average Loss: 3.4719



Epoch 17/30: 100%|██████████| 1875/1875 [04:33<00:00,  6.86it/s]



Epoch 17/30 completed, Final Average Loss: 3.4697



Epoch 18/30:  50%|████▉     | 937/1875 [02:16<02:15,  6.95it/s]


Epoch 18/30 - Midpoint, Average Loss: 3.4738



Epoch 18/30: 100%|██████████| 1875/1875 [04:32<00:00,  6.87it/s]



Epoch 18/30 completed, Final Average Loss: 3.4698



Epoch 19/30:  50%|████▉     | 937/1875 [02:16<02:00,  7.81it/s]


Epoch 19/30 - Midpoint, Average Loss: 3.4761



Epoch 19/30: 100%|██████████| 1875/1875 [04:21<00:00,  7.16it/s]



Epoch 19/30 completed, Final Average Loss: 3.4716



Epoch 20/30:  50%|████▉     | 937/1875 [02:06<02:00,  7.78it/s]


Epoch 20/30 - Midpoint, Average Loss: 3.4811



Epoch 20/30: 100%|██████████| 1875/1875 [04:17<00:00,  7.29it/s]



Epoch 20/30 completed, Final Average Loss: 3.4709

Saving model to arabic_gpt_tpu_model_epoch_20...
Model saved successfully!


Epoch 21/30:  50%|████▉     | 937/1875 [02:14<02:08,  7.29it/s]


Epoch 21/30 - Midpoint, Average Loss: 3.4699



Epoch 21/30: 100%|██████████| 1875/1875 [04:32<00:00,  6.88it/s]



Epoch 21/30 completed, Final Average Loss: 3.4704



Epoch 22/30:  50%|████▉     | 937/1875 [02:16<02:10,  7.17it/s]


Epoch 22/30 - Midpoint, Average Loss: 3.4744



Epoch 22/30: 100%|██████████| 1875/1875 [04:33<00:00,  6.85it/s]



Epoch 22/30 completed, Final Average Loss: 3.4711



Epoch 23/30:  50%|████▉     | 937/1875 [02:16<02:15,  6.92it/s]


Epoch 23/30 - Midpoint, Average Loss: 3.4684



Epoch 23/30: 100%|██████████| 1875/1875 [04:32<00:00,  6.87it/s]



Epoch 23/30 completed, Final Average Loss: 3.4714



Epoch 24/30:  50%|████▉     | 937/1875 [02:15<02:01,  7.72it/s]


Epoch 24/30 - Midpoint, Average Loss: 3.4695



Epoch 24/30: 100%|██████████| 1875/1875 [04:29<00:00,  6.95it/s]



Epoch 24/30 completed, Final Average Loss: 3.4716



Epoch 25/30:  50%|████▉     | 937/1875 [02:22<02:24,  6.48it/s]


Epoch 25/30 - Midpoint, Average Loss: 3.4794



Epoch 25/30: 100%|██████████| 1875/1875 [04:44<00:00,  6.59it/s]



Epoch 25/30 completed, Final Average Loss: 3.4705

Saving model to arabic_gpt_tpu_model_epoch_25...
Model saved successfully!


Epoch 26/30:  50%|████▉     | 937/1875 [02:07<02:03,  7.58it/s]


Epoch 26/30 - Midpoint, Average Loss: 3.4818



Epoch 26/30: 100%|██████████| 1875/1875 [04:16<00:00,  7.31it/s]



Epoch 26/30 completed, Final Average Loss: 3.4710



Epoch 27/30:  50%|████▉     | 937/1875 [02:08<02:09,  7.26it/s]


Epoch 27/30 - Midpoint, Average Loss: 3.4802



Epoch 27/30: 100%|██████████| 1875/1875 [04:17<00:00,  7.29it/s]



Epoch 27/30 completed, Final Average Loss: 3.4711



Epoch 28/30:  50%|████▉     | 937/1875 [02:07<02:03,  7.59it/s]


Epoch 28/30 - Midpoint, Average Loss: 3.4610



Epoch 28/30: 100%|██████████| 1875/1875 [04:09<00:00,  7.52it/s]



Epoch 28/30 completed, Final Average Loss: 3.4712



Epoch 29/30:  50%|████▉     | 936/1875 [01:59<02:07,  7.39it/s]


Epoch 29/30 - Midpoint, Average Loss: 3.4726



Epoch 29/30: 100%|██████████| 1875/1875 [03:59<00:00,  7.83it/s]



Epoch 29/30 completed, Final Average Loss: 3.4697



Epoch 30/30:  50%|████▉     | 936/1875 [01:59<02:06,  7.42it/s]


Epoch 30/30 - Midpoint, Average Loss: 3.4722



Epoch 30/30: 100%|██████████| 1875/1875 [03:59<00:00,  7.84it/s]



Epoch 30/30 completed, Final Average Loss: 3.4708

Saving model to arabic_gpt_tpu_model_epoch_30...
Model saved successfully!

Testing predictions with the trained model:

Input sentence: مرحبا كيف حالك
Predictions:
 يا: 0.8551
�: 0.0441
 في: 0.0376
 إذا: 0.0344
 أيها: 0.0288

Input sentence: انا اريد ان اتعلم
Predictions:
 من: 0.2998
 الفرنسية: 0.2593
 شيئا: 0.2137
 كيف: 0.1600
 أن: 0.0671

Input sentence: شكرا جزيلا على
Predictions:
 هذا: 0.2444
 هذه: 0.2081
 حسن: 0.2038
 جهود: 0.1719
 ما: 0.1718

Input sentence: السلام عليكم ورحمة
Predictions:
 الله: 0.9915
 الباري: 0.0028
 الرحمن: 0.0026
 ربي: 0.0018
 لله: 0.0013

Input sentence: اهلا وسهلا بكم في
Predictions:
 موقع: 0.3632
 هذا: 0.2535
 هذه: 0.1682
 منتدى: 0.1108
 حلقة: 0.1042

Input sentence: انا احب القراءة
Predictions:
 كثيرا: 0.3133
 والكتابة: 0.2417
�: 0.2235
 في: 0.1192
 لكن: 0.1023

Input sentence: هذا الكتاب جميل
Predictions:
 جدا: 0.4483
 في: 0.1922
 ومفيد: 0.1303
 للقارئ: 0.1189
 لكنه: 0.1103


In [9]:
import shutil

shutil.make_archive(
    '/kaggle/working/test2',  # Output ZIP file path (no .zip extension)
    'zip',                                            # Archive format
    root_dir='/kaggle/working',                       # Base directory
    base_dir='arabic_gpt_tpu_model_epoch_30'          # Folder to zip (relative to root_dir)
)


'/kaggle/working/test2.zip'