# Arabic Small Language Model (SLM) on TinyStories-like Dataset
This notebook adapts the TinyStories architecture to train on an Arabic stories dataset.

In [None]:
!pip install -U datasets tiktoken torch tqdm

# Loading the Arabic Dataset
We will use `arbml/Arabic_Stories_Corpus` or a similar Arabic stories dataset from Hugging Face.

In [None]:
from datasets import load_dataset, DatasetDict

# Load Arabic stories dataset from HF and ensure train/validation splits
raw = load_dataset("arbml/Arabic_Stories_Corpus")
if not isinstance(raw, DatasetDict):
    raw = DatasetDict({"train": raw})

if "validation" in raw:
    dataset = DatasetDict({"train": raw["train"], "validation": raw["validation"]})
elif "val" in raw:
    dataset = DatasetDict({"train": raw["train"], "validation": raw["val"]})
elif "test" in raw:
    dataset = DatasetDict({"train": raw["train"], "validation": raw["test"]})
else:
    split_ds = raw["train"].train_test_split(test_size=0.1, seed=42, shuffle=True)
    dataset = DatasetDict({"train": split_ds["train"], "validation": split_ds["test"]})

print(dataset)


In [None]:
print(df)
# Check for column names and rename if necessary to 'text'
if 'train' in df:
    sample_col = next(iter(df['train'].features.keys()))
    if sample_col != 'text':
        print(f"Renaming column '{sample_col}' to 'text'")
        df = df.rename_column(sample_col, 'text')

In [None]:
df.shape

In [None]:
type(df)

Step 1: Tokenization
We use `cl100k_base` (used in GPT-4) which has better support for multilingual text (including Arabic) compared to `gpt2`.

In [None]:
from tqdm.auto import tqdm
import tiktoken
import os
import numpy as np

In [None]:
# Using cl100k_base for better Arabic support
encoding = tiktoken.get_encoding("cl100k_base")

In [None]:
def processing(sample_text):
    ids = encoding.encode_ordinary(sample_text['text'])
    out = {'ids':ids,'len':len(ids)}
    return out

In [None]:
cols_to_remove = [c for c in dataset['train'].column_names if c != TEXT_FIELD]
if not os.path.exists("train.bin"):
    tokenized = dataset.map(
        processing,
        remove_columns=cols_to_remove,
        desc="tokenizing the splits",
        num_proc=4,
    )
    for split, dset in tokenized.items():
        arr_len = np.sum(dset['len'], dtype=np.uint64)
        filename = f'{split}.bin'
        dtype = np.uint32
        arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
        total_batches = 1 if len(dset) < 1024 else 1024
        idx = 0
        for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
            batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
            arr_batch = np.concatenate(batch['ids'])
            arr[idx : idx + len(arr_batch)] = arr_batch
            idx += len(arr_batch)
        arr.flush()


Now we will have to create input output pairs...

In [None]:
def get_batch(split):
    if split == 'train':
        data = np.memmap('train.bin', dtype=np.uint32, mode='r') # Changed to uint32
    else:
        # Fallback if validation.bin doesn't exist (some datasets only have train)
        if os.path.exists('validation.bin'):
             data = np.memmap('validation.bin', dtype=np.uint32, mode='r')
        else:
             data = np.memmap('train.bin', dtype=np.uint32, mode='r') # Use train for now if no val
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
import numpy as np
from tqdm.auto import tqdm
from contextlib import nullcontext
import os

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
    def forward(self, x):
        return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5)

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.flash = hasattr(F, 'scaled_dot_product_attention')
        if not self.flash:
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                       .view(1, 1, config.block_size, config.block_size))
    def forward(self, x):
        B, T, C = x.size()
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        if self.flash:
            y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.attn_dropout.p if self.training else 0.0, is_causal=True)
        else:
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y

In [None]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = nn.Dropout(config.dropout)
    def forward(self, x):
        return self.dropout(self.c_proj(self.gelu(self.c_fc(x))))

In [None]:
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = LayerNorm(config.n_embd, config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln2 = LayerNorm(config.n_embd, config.bias)
        self.mlp = MLP(config)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

In [None]:
@dataclass
class GPTConfig:
    block_size: int
    vocab_size: int
    n_layer: int
    n_head: int
    n_embd: int
    dropout: float = 0.0
    bias: bool = True

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict(dict(
            wte=nn.Embedding(config.vocab_size, config.n_embd),
            wpe=nn.Embedding(config.block_size, config.n_embd),
            drop=nn.Dropout(config.dropout),
            h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f=LayerNorm(config.n_embd, config.bias),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.transformer.wte.weight = self.lm_head.weight  # weight tying
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size
        pos = torch.arange(0, t, dtype=torch.long, device=device)
        tok_emb = self.transformer.wte(idx)
        pos_emb = self.transformer.wpe(pos)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        if targets is not None:
            logits = self.lm_head(x)
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
            return logits, loss
        else:
            logits = self.lm_head(x[:, [-1], :])
            return logits, None

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [None]:
# Configuration for Arabic Model
config = GPTConfig(
    vocab_size=100277, # cl100k_base vocab size
    block_size=128,
    n_layer=6,
    n_head=6,
    n_embd=384,
    dropout=0.1,
    bias=True
)
model = GPT(config)

In [None]:
eval_iters = 500
batch_size = 32
block_size = 128
gradient_accumulation_steps = 32
device =  "cuda" if torch.cuda.is_available() else "cpu"
device_type = 'cuda' if 'cuda' in device else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
torch.set_default_device(device)
torch.manual_seed(42)

def estimate_loss(model):
    out = {}
    model.eval()
    with torch.inference_mode():
        for split in ['train', 'validation']:
            if split == 'validation' and not os.path.exists('validation.bin'):
                continue
            losses = torch.zeros(eval_iters)
            for k in range(eval_iters):
                X, Y = get_batch(split)
                with ctx:
                    logits, loss = model(X, Y)
                losses[k] = loss.item()
            out[split] = losses.mean().item()
    model.train()
    return out


def loss_to_perplexity(loss_value):
    return math.exp(loss_value) if math.isfinite(loss_value) else float('inf')


In [None]:
from torch.optim.lr_scheduler import LinearLR,SequentialLR, CosineAnnealingLR
learning_rate = 1e-4
max_iters = 10000
warmup_steps = 1000
min_lr = 5e-4

optimizer =  torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.95), weight_decay=0.1, eps=1e-9)
scheduler_warmup = LinearLR(optimizer, total_iters = warmup_steps)
scheduler_decay = CosineAnnealingLR(optimizer,T_max = max_iters - warmup_steps, eta_min = min_lr)
scheduler = SequentialLR(optimizer, schedulers=[scheduler_warmup, scheduler_decay], milestones=[warmup_steps])
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16'))

In [None]:
best_val_loss = float('inf')
best_model_params_path = "best_model_params.pt"
train_loss_list, validation_loss_list = [], []
train_ppl_list, validation_ppl_list = [], []
model = model.to(device)

for epoch in tqdm(range(max_iters)):
    if epoch % eval_iters == 0 and epoch != 0:
        losses = estimate_loss(model)
        train_loss = losses.get('train')
        val_loss = losses.get('validation')
        train_ppl = loss_to_perplexity(train_loss) if train_loss is not None else float('inf')
        val_ppl = loss_to_perplexity(val_loss) if val_loss is not None else float('inf')
        msg = f"Epoch {epoch}: train loss {train_loss:.4f} (ppl {train_ppl:.2f})"
        if val_loss is not None:
            msg += f", val loss {val_loss:.4f} (ppl {val_ppl:.2f})"
            validation_loss_list.append(val_loss)
            validation_ppl_list.append(val_ppl)
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), best_model_params_path)
        print(msg)
        train_loss_list.append(train_loss)
        train_ppl_list.append(train_ppl)

    X, y = get_batch("train")
    X, y = X.to(device), y.to(device)
    with ctx:
        logits, loss = model(X, y)
        loss = loss / gradient_accumulation_steps
        scaler.scale(loss).backward()

    if ((epoch + 1) % gradient_accumulation_steps == 0) or (epoch + 1 == max_iters):
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)
    scheduler.step()


In [None]:
# Example generation in Arabic
sentence = "كان يا ما كان، كان هناك فتاة صغيرة."
context = (torch.tensor(encoding.encode_ordinary(sentence)).unsqueeze(dim = 0))
context = context.to(device)
y = model.generate(context, 200)
print(encoding.decode(y.squeeze().tolist()))

In [None]:
# Final evaluation with perplexity
if os.path.exists(best_model_params_path):
    model.load_state_dict(torch.load(best_model_params_path, map_location=device))
losses = estimate_loss(model)
for split, val in losses.items():
    print(f"{split} loss: {val:.4f}, ppl: {loss_to_perplexity(val):.2f}")


In [None]:
sentence = "ذهبت فتاة صغيرة إلى الغابة"
context = (torch.tensor(encoding.encode_ordinary(sentence)).unsqueeze(dim = 0))
context = context.to(device)
y = model.generate(context, 200)
print(encoding.decode(y.squeeze().tolist()))