In [1]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
from typing import List, Tuple
import os
from tqdm import tqdm
import re
import json
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

class ArabicGPTTPU:
    def __init__(self, model_name: str = "aubmindlab/aragpt2-base"):
        """
        Initialize the Arabic GPT model with TPU support
        Args:
            model_name: Name of the pre-trained model to use
        """
        # Initialize TPU device
        self.device = xm.xla_device()
        print(f"Using TPU device: {self.device}")

        # Initialize tokenizer and model
        print("Loading tokenizer and model...")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name)

        # Add special tokens for Arabic
        special_tokens = {
            'pad_token': '[PAD]',
            'bos_token': '[BOS]',
            'eos_token': '[EOS]',
            'unk_token': '[UNK]'
        }
        self.tokenizer.add_special_tokens(special_tokens)
        self.model.resize_token_embeddings(len(self.tokenizer))

        # Move model to TPU
        self.model = self.model.to(self.device)

    def save_model(self, save_dir: str = 'arabic_gpt_tpu_model'):
        """
        Save the model and tokenizer to disk
        Args:
            save_dir: Directory to save the model and tokenizer
        """
        print(f"Saving model to {save_dir}...")

        # Create directory if it doesn't exist
        os.makedirs(save_dir, exist_ok=True)

        # Move model to CPU before saving
        self.model = self.model.to('cpu')

        # Save model
        model_path = os.path.join(save_dir, 'model')
        self.model.save_pretrained(model_path)

        # Save tokenizer
        tokenizer_path = os.path.join(save_dir, 'tokenizer')
        self.tokenizer.save_pretrained(tokenizer_path)

        # Save model configuration
        config = {
            'model_name': self.model.config.model_type,
            'vocab_size': self.model.config.vocab_size,
            'max_position_embeddings': self.model.config.max_position_embeddings,
            'num_attention_heads': self.model.config.num_attention_heads,
            'num_hidden_layers': self.model.config.num_hidden_layers,
        }

        with open(os.path.join(save_dir, 'config.json'), 'w', encoding='utf-8') as f:
            json.dump(config, f, indent=4)

        print("Model saved successfully!")

        # Move model back to TPU
        self.model = self.model.to(self.device)

    @classmethod
    def load_model(cls, load_dir: str = 'arabic_gpt_tpu_model'):
        """
        Load a saved model and tokenizer
        Args:
            load_dir: Directory containing the saved model and tokenizer
        Returns:
            ArabicGPTTPU instance with loaded model
        """
        print(f"Loading model from {load_dir}...")

        # Create instance
        instance = cls()

        # Load model
        model_path = os.path.join(load_dir, 'model')
        instance.model = AutoModelForCausalLM.from_pretrained(model_path)
        instance.model = instance.model.to(instance.device)

        # Load tokenizer
        tokenizer_path = os.path.join(load_dir, 'tokenizer')
        instance.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)

        print("Model loaded successfully!")
        return instance

    def clean_text(self, text: str) -> str:
        """
        Clean and normalize Arabic text
        """
        # Remove URLs
        text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
        # Remove special characters and numbers
        text = re.sub(r'[^\u0600-\u06FF\s]', '', text)
        # Remove extra spaces
        text = re.sub(r'\s+', ' ', text).strip()
        return text

    def preprocess_data(self, file_path: str = '/kaggle/input/arabic-news/Arabic_news.csv') -> List[str]:
        """
        Preprocess the Arabic news dataset
        Args:
            file_path: Path to the dataset file
        Returns:
            List of preprocessed sentences
        """
        print(f"Loading dataset from {file_path}")
        df = pd.read_csv(file_path)

        # Clean and preprocess the text
        print("Preprocessing text data...")
        df=df[:30000]
        df['text'] = df['text'].fillna('')
        df['text'] = df['text'].apply(self.clean_text)
        df = df[df['text'].str.len() > 0]

        # Split into sentences and create training pairs
        sentences = []
        for text in tqdm(df['text'], desc="Processing sentences"):
            # Split by common Arabic sentence endings
            text_sentences = [s.strip() for s in text.split('۔') if len(s.strip()) > 0]
            # Add BOS and EOS tokens
            text_sentences = [f"[BOS] {s} [EOS]" for s in text_sentences]
            sentences.extend(text_sentences)

        print(f"Total number of sentences: {len(sentences)}")
        return sentences

    def prepare_batch(self, sentences: List[str], batch_size: int) -> torch.Tensor:
        """
        Prepare a batch of sentences for training
        """
        encodings = self.tokenizer(sentences,
                                 padding=True,
                                 truncation=True,
                                 max_length=128,  # Reduced from 512 to save memory
                                 return_tensors="pt")
        return encodings

    def train(self, sentences: List[str], epochs: int = 3, batch_size: int = 8, gradient_accumulation_steps: int = 4):
        """
        Fine-tune the model on the Arabic news dataset using TPU with memory-efficient training
        Args:
            sentences: List of training sentences
            epochs: Number of training epochs
            batch_size: Batch size for training (reduced for TPU memory constraints)
            gradient_accumulation_steps: Number of steps to accumulate gradients
        """
        self.model.train()
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=5e-5)

        # Create dataset
        encodings = self.prepare_batch(sentences, batch_size)
        dataset = torch.utils.data.TensorDataset(
            encodings["input_ids"],
            encodings["attention_mask"]
        )

        # Create TPU-optimized data loader with smaller batch size
        train_loader = pl.ParallelLoader(
            torch.utils.data.DataLoader(
                dataset,
                batch_size=batch_size,
                shuffle=True
            ),
            [self.device]
        ).per_device_loader(self.device)

        # Calculate effective batch size
        effective_batch_size = batch_size * gradient_accumulation_steps

        for epoch in range(epochs):
            total_loss = 0
            optimizer.zero_grad()
            running_loss = 0
            num_batches = 0

            # Reset the data loader for each epoch
            train_loader = pl.ParallelLoader(
                torch.utils.data.DataLoader(
                    dataset,
                    batch_size=batch_size,
                    shuffle=True
                ),
                [self.device]
            ).per_device_loader(self.device)

            # Calculate total number of batches for this epoch
            total_batches = len(train_loader)
            mid_point = total_batches // 2

            for batch_idx, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")):
                input_ids, attention_mask = batch

                # Forward pass
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=input_ids
                )
                loss = outputs.loss / gradient_accumulation_steps  # Normalize loss

                # Backward pass
                loss.backward()

                # Update weights if we've accumulated enough gradients
                if (batch_idx + 1) % gradient_accumulation_steps == 0:
                    # TPU optimizer step
                    xm.optimizer_step(optimizer)
                    optimizer.zero_grad()

                    # Mark step for TPU
                    xm.mark_step()

                # Update running loss
                running_loss += loss.item() * gradient_accumulation_steps
                num_batches += 1

                # Print progress at middle of epoch
                if batch_idx == mid_point:
                    avg_loss = running_loss / num_batches
                    xm.master_print(f"\nEpoch {epoch + 1}/{epochs} - Midpoint, Average Loss: {avg_loss:.4f}\n")

            # Calculate and print final epoch average loss
            epoch_avg_loss = running_loss / num_batches
            xm.master_print(f"\nEpoch {epoch + 1}/{epochs} completed, Final Average Loss: {epoch_avg_loss:.4f}\n")

            # Save checkpoint after each epoch
            self.save_model(f'arabic_gpt_tpu_model_epoch_{epoch + 1}')

    def predict_next_word(self, sentence: str, num_predictions: int = 5) -> List[Tuple[str, float]]:
        """
        Predict the next word given a sentence
        Args:
            sentence: Input sentence
            num_predictions: Number of predictions to return
        Returns:
            List of tuples containing (predicted_word, probability)
        """
        # Set model to eval mode once
        self.model.eval()

        with torch.no_grad():
            # Add BOS token to input
            input_text = f"[BOS] {sentence}"
            inputs = self.tokenizer(input_text, return_tensors="pt")
            input_ids = inputs["input_ids"].to(self.device)
            attention_mask = inputs["attention_mask"].to(self.device)

            # Get logits for the last position only
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=False,
                output_attentions=False
            )
            last_token_logits = outputs.logits[0, -1, :]

            # Get top k predictions
            top_k_logits, top_k_indices = torch.topk(last_token_logits, num_predictions)
            probs = torch.softmax(top_k_logits, dim=-1)

            # Convert to list of predictions
            predictions = []
            for idx, (token_id, prob) in enumerate(zip(top_k_indices, probs)):
                word = self.tokenizer.decode([token_id])
                predictions.append((word, prob.item()))

            return predictions

def main():
    # Initialize the model
    print("Initializing model...")
    model = ArabicGPTTPU()

    # Load and preprocess data
    print("Loading and preprocessing data...")
    sentences = model.preprocess_data()

    # Train the model with memory-efficient settings
    print("Training the model...")
    model.train(
        sentences,
        epochs=10,
        batch_size=8,  # Reduced batch size
        gradient_accumulation_steps=4  # Added gradient accumulation
    )

    # Example usage with the trained model
    print("\nTesting predictions with the trained model:")
    test_sentences = [
        "مرحبا كيف حالك",
        "انا اريد ان اتعلم",
        "شكرا جزيلا على",
        "السلام عليكم ورحمة",
        "اهلا وسهلا بكم في",
        "انا احب القراءة",
        "هذا الكتاب جميل"
    ]

    for test_sentence in test_sentences:
        predictions = model.predict_next_word(test_sentence)
        print(f"\nInput sentence: {test_sentence}")
        print("Predictions:")
        for word, prob in predictions:
            print(f"{word}: {prob:.4f}")

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm


Initializing model...


E0000 00:00:1746982925.701381     299 common_lib.cc:621] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: === 
learning/45eac/tfrc/runtime/common_lib.cc:232


Using TPU device: xla:0
Loading tokenizer and model...


E0000 00:00:1746982939.930728     299 common_lib.cc:612] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/common_lib.cc:230
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Loading and preprocessing data...
Loading dataset from /kaggle/input/arabic-news/Arabic_news.csv
Preprocessing text data...


Processing sentences: 100%|██████████| 30000/30000 [00:00<00:00, 265329.61it/s]


Total number of sentences: 30000
Training the model...


Epoch 1/10:   0%|          | 0/3750 [00:00<?, ?it/s]`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
Epoch 1/10:  50%|█████     | 1875/3750 [06:27<04:51,  6.43it/s]


Epoch 1/10 - Midpoint, Average Loss: 5.0612



Epoch 1/10: 100%|██████████| 3750/3750 [11:51<00:00,  5.27it/s]



Epoch 1/10 completed, Final Average Loss: 4.6538

Saving model to arabic_gpt_tpu_model_epoch_1...
Model saved successfully!


Epoch 2/10:  50%|█████     | 1875/3750 [04:09<03:45,  8.33it/s]


Epoch 2/10 - Midpoint, Average Loss: 4.0368



Epoch 2/10: 100%|██████████| 3750/3750 [08:02<00:00,  7.77it/s]



Epoch 2/10 completed, Final Average Loss: 4.0393

Saving model to arabic_gpt_tpu_model_epoch_2...
Model saved successfully!


Epoch 3/10:  50%|█████     | 1875/3750 [03:53<03:40,  8.49it/s]


Epoch 3/10 - Midpoint, Average Loss: 4.0512



Epoch 3/10: 100%|██████████| 3750/3750 [07:46<00:00,  8.04it/s]



Epoch 3/10 completed, Final Average Loss: 4.0400

Saving model to arabic_gpt_tpu_model_epoch_3...
Model saved successfully!


Epoch 4/10:  50%|█████     | 1875/3750 [03:54<03:52,  8.08it/s]


Epoch 4/10 - Midpoint, Average Loss: 4.0423



Epoch 4/10: 100%|██████████| 3750/3750 [08:03<00:00,  7.75it/s]



Epoch 4/10 completed, Final Average Loss: 4.0396

Saving model to arabic_gpt_tpu_model_epoch_4...
Model saved successfully!


Epoch 5/10:  50%|█████     | 1875/3750 [04:33<04:37,  6.77it/s]


Epoch 5/10 - Midpoint, Average Loss: 4.0357



Epoch 5/10: 100%|██████████| 3750/3750 [09:14<00:00,  6.77it/s]



Epoch 5/10 completed, Final Average Loss: 4.0392

Saving model to arabic_gpt_tpu_model_epoch_5...
Model saved successfully!


Epoch 6/10:  50%|█████     | 1875/3750 [04:28<03:49,  8.17it/s]


Epoch 6/10 - Midpoint, Average Loss: 4.0402



Epoch 6/10: 100%|██████████| 3750/3750 [08:19<00:00,  7.51it/s]



Epoch 6/10 completed, Final Average Loss: 4.0398

Saving model to arabic_gpt_tpu_model_epoch_6...
Model saved successfully!


Epoch 7/10:  50%|█████     | 1875/3750 [03:56<03:50,  8.13it/s]


Epoch 7/10 - Midpoint, Average Loss: 4.0370



Epoch 7/10: 100%|██████████| 3750/3750 [07:48<00:00,  8.00it/s]



Epoch 7/10 completed, Final Average Loss: 4.0395

Saving model to arabic_gpt_tpu_model_epoch_7...
Model saved successfully!


Epoch 8/10:  50%|█████     | 1875/3750 [03:53<03:42,  8.41it/s]


Epoch 8/10 - Midpoint, Average Loss: 4.0425



Epoch 8/10: 100%|██████████| 3750/3750 [07:46<00:00,  8.03it/s]



Epoch 8/10 completed, Final Average Loss: 4.0399

Saving model to arabic_gpt_tpu_model_epoch_8...
Model saved successfully!


Epoch 9/10:  50%|█████     | 1875/3750 [03:51<03:43,  8.40it/s]


Epoch 9/10 - Midpoint, Average Loss: 4.0356



Epoch 9/10: 100%|██████████| 3750/3750 [07:43<00:00,  8.09it/s]



Epoch 9/10 completed, Final Average Loss: 4.0399

Saving model to arabic_gpt_tpu_model_epoch_9...
Model saved successfully!


Epoch 10/10:  50%|█████     | 1875/3750 [03:52<03:40,  8.51it/s]


Epoch 10/10 - Midpoint, Average Loss: 4.0407



Epoch 10/10: 100%|██████████| 3750/3750 [07:44<00:00,  8.08it/s]



Epoch 10/10 completed, Final Average Loss: 4.0397

Saving model to arabic_gpt_tpu_model_epoch_10...
Model saved successfully!

Testing predictions with the trained model:

Input sentence: مرحبا كيف حالك
Predictions:
 يا: 0.7270
 في: 0.0914
 إذا: 0.0724
 ؟: 0.0597
 مع: 0.0495

Input sentence: انا اريد ان اتعلم
Predictions:
 من: 0.3163
 شيئا: 0.2006
 كيف: 0.1968
 الفرنسية: 0.1552
 اللغة: 0.1312

Input sentence: شكرا جزيلا على
Predictions:
 هذا: 0.3313
 جهود: 0.2204
 هذه: 0.2106
 ما: 0.1247
 كل: 0.1130

Input sentence: السلام عليكم ورحمة
Predictions:
 الله: 0.9881
 الرحمن: 0.0054
 الباري: 0.0026
 لله: 0.0023
 ربي: 0.0017

Input sentence: اهلا وسهلا بكم في
Predictions:
 هذا: 0.3324
 موقع: 0.2732
 منتدى: 0.1655
 هذه: 0.1581
 منتديات: 0.0707

Input sentence: انا احب القراءة
Predictions:
 كثيرا: 0.3755
�: 0.2477
 في: 0.1619
 والكتابة: 0.1165
 منذ: 0.0984

Input sentence: هذا الكتاب جميل
Predictions:
 جدا: 0.5786
 في: 0.2053
 من: 0.1052
 للغاية: 0.0615
 ومفيد: 0.0494
