In [1]:
from datasets import load_dataset, get_dataset_split_names


In [2]:
def load_huggingface_dataset(dataset_name,*args,**kwargs):
    dataset = load_dataset(dataset_name,**kwargs)
    return dataset

In [3]:
dataset = load_huggingface_dataset("mwitiderrick/swahili",split="train")


In [4]:
def generate_dataset_splits(dataset):
  # Split the dataset into train, test and val

  train_dataset = dataset.train_test_split(test_size=0.1, shuffle=True, seed=42)
  test_val = train_dataset["test"].train_test_split(
      test_size=0.5, shuffle=True, seed=42
  )
  train_dataset = train_dataset["train"]
  test_dataset = test_val["test"]
  val_dataset = test_val["train"]
  return train_dataset,test_dataset,val_dataset

In [5]:
train_dataset, test_dataset, val_dataset = generate_dataset_splits(dataset)

In [6]:
import re
def remove_non_text_symbols(text):
  # remove texts that are not within that cannot be processed, e.g emojis, non-ascii symbols

  # remove html tags
  # soup = BeautifulSoup(text, "html.parser")
  # text = soup.get_text()

  # remove non-ascii symbols
  text = re.sub(r'[^\x00-\x7F]+', '', text)

  return text

In [7]:
def clean_dataset(dataset):

  # clean the dataset object
  dataset = dataset.map(lambda example: {"text": remove_non_text_symbols(example["text"])})
  dataset = dataset.filter(lambda example: len(example["text"]) > 0)
  return dataset

In [8]:
# train_dataset = clean_dataset(train_dataset)
test_dataset = clean_dataset(test_dataset)
val_dataset = clean_dataset(val_dataset)

In [9]:
from transformers import PreTrainedTokenizerFast
from tokenizers import Tokenizer

In [10]:
def tokenize(tokenizer, dataset):
  # tokenize the dataset
  dataset = dataset.map(lambda example: tokenizer(example["text"],padding=True,max_length=256))
  return dataset


In [11]:
class KiswahiliSilabiTokenizer(PreTrainedTokenizerFast):
    def __init__(self, tokenizer,unk_token="[UNK]",sos_token="[SOS]",eos_token="[EOS]",space_token="[SPACE]",pad_token="[PAD]", **kwargs):
        super().__init__(tokenizer_object=tokenizer, **kwargs)
        self._vocab = tokenizer.get_vocab()
        self.unk_token = unk_token
        self.sos_token = sos_token
        self.eos_token = eos_token
        self.space_token = space_token
        self.pad_token = pad_token

      # Add special tokens to vocab if they are not already present
        if self.sos_token not in self._vocab:
            self._vocab[self.sos_token] = len(self._vocab)
        if self.eos_token not in self._vocab:
            self._vocab[self.eos_token] = len(self._vocab)
        if self.unk_token not in self._vocab:
            self._vocab[self.unk_token] = len(self._vocab)
        if self.space_token not in self._vocab:
            self._vocab[self.space_token] = len(self._vocab)
        if self.pad_token not in self._vocab:
            self._vocab[self.pad_token] = len(self._vocab)

    def __call__(self, text,**kwargs):
        ids = self.convert_tokens_to_ids(self.tokenize(text,**kwargs))

        return {"input_ids": ids}

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
        tokenizer = Tokenizer.from_file(f"{pretrained_model_name_or_path}/tokenizer.json")
        return cls(tokenizer, **kwargs)

    def _encode_with_byte_fallback(self, text):
        tokens = []
        i = 0
        while i < len(text):
            matched = False
            # Try to match the longest syllable first
            for j in range(len(text), i, -1):
                syllable_candidate = text[i:j]
                if syllable_candidate in self._vocab:
                    tokens.append(syllable_candidate)
                    i = j
                    matched = True
                    break
            # If no syllable matched, fallback to byte encoding
            if not matched:
                if text[i] == " ":
                  tokens.append(self.space_token)
                  i += 1
                else:
                  tokens.extend(self.unk_token)
                  i += 1
        return tokens

    def tokenize(self, text,**kwargs):
        handle_whitespace = kwargs.get("handle_whitespace", True)
        tokens = [self.sos_token]  # Start of sentence token
        for word in text.split(" "):
            tokens.extend(self._encode_with_byte_fallback(word))
            if handle_whitespace:
              tokens.extend(self._encode_with_byte_fallback(" "))
        tokens.append(self.eos_token)  # End of sentence token

        padding = kwargs.get("padding", False)
        if padding:
            max_length = kwargs.get("max_length", None)
            if max_length is not None:
                tokens = tokens[:max_length]
                tokens.extend([self.pad_token] * (max_length - len(tokens)))
            else:
                raise ValueError("max_length must be specified if padding is True")
        return tokens

    def tokens_to_sentence(self,tokens):
      for token in tokens:
        token = token.replace(" ", "")
      sentence = "".join(tokens)
      sentence = sentence.replace(self.eos_token, "")
      sentence = sentence.replace(self.sos_token, "")
      sentence = sentence.replace(self.space_token," ")
      return sentence

In [12]:
silabi_tokenizer = KiswahiliSilabiTokenizer.from_pretrained("./silabi_tokenizer")

In [13]:
# train_tokenized_dataset = tokenize(silabi_tokenizer, train_dataset)
test_tokenized_dataset = tokenize(silabi_tokenizer, test_dataset)
# val_tokenized_dataset = tokenize(silabi_tokenizer, val_dataset)

In [14]:
import torch
from torch import nn
import numpy as np

In [15]:
from transformers.models.mamba2.modeling_mamba2 import Mamba2Block
from linear_attention_transformer import LinearAttentionTransformer
from transformers.models.mamba2 import Mamba2Config, Mamba2ForCausalLM


In [16]:
# Transformations

test_value = test_tokenized_dataset['input_ids'][0]

In [17]:
np_array = np.array(test_value)
input_tensor = torch.from_numpy(np_array).to(dtype=torch.long)
input_tensor = input_tensor.unsqueeze(0)  # Shape becomes (1, sequence_length)


In [18]:
mamba_config = Mamba2Config(vocab_size=silabi_tokenizer.vocab_size,hidden_size=512,num_heads=16)
# hidden_size = dimension
mamba_block = Mamba2Block(
    config=mamba_config, layer_idx=0
)

The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d


In [19]:
mamba_embedding_dim = 512  # embedding dimension must match linformer dim
vocab_size = silabi_tokenizer.vocab_size
mamba_embedding_layer = torch.nn.Embedding(num_embeddings=silabi_tokenizer.vocab_size, embedding_dim=mamba_embedding_dim)
mamba_embedded_tensor = mamba_embedding_layer(input_tensor)

In [20]:
mamba_embedded_tensor.shape

torch.Size([1, 256, 512])

In [21]:
mamba_block(mamba_embedded_tensor)

tensor([[[ 3.2783, -1.0339, -0.9427,  ..., -0.5613, -0.8762, -2.3706],
         [-0.6368, -2.1302, -0.2276,  ...,  1.2049, -1.0844, -0.2040],
         [ 1.0469, -2.3309,  1.2093,  ...,  3.6927, -1.4960, -0.9574],
         ...,
         [ 1.2692,  0.0251,  0.2406,  ..., -0.1946,  0.1979,  1.2899],
         [ 1.2692,  0.0251,  0.2406,  ..., -0.1946,  0.1979,  1.2899],
         [ 1.2692,  0.0251,  0.2406,  ..., -0.1946,  0.1979,  1.2899]]],
       grad_fn=<AddBackward0>)

In [22]:
# (batch, sequence_length {256}) -> [embedding] -> (batch, sequence_length {256}. dimension {128}) -> [linformer] -> (batch, sequence_length {256}. dimension {128})

In [23]:
linformer = LinearAttentionTransformer(
    dim = 512,
    heads = 8,
    depth = 1,
    max_seq_len = 256,
    n_local_attn_heads = 4
)

In [24]:
embedding_dim = 512  # embedding dimension must match linformer dim
vocab_size = silabi_tokenizer.vocab_size
embedding_layer = torch.nn.Embedding(num_embeddings=silabi_tokenizer.vocab_size, embedding_dim=embedding_dim)
embedded_tensor = embedding_layer(input_tensor)  # Shape: (batch_size, sequence_length, embedding_dim)


In [25]:
embedded_tensor.shape

torch.Size([1, 256, 512])

In [26]:
linformer(embedded_tensor)

tensor([[[ 0.1979, -1.5463,  0.2702,  ...,  1.0276, -0.1961,  0.2674],
         [ 1.5884, -0.2209,  0.0208,  ...,  1.3840, -0.6087,  0.4319],
         [ 0.6247,  0.5986, -0.4655,  ...,  1.4478, -0.2209, -0.9150],
         ...,
         [ 1.6710,  0.5188,  2.0017,  ..., -2.3979, -0.8950, -0.6265],
         [ 1.6710,  0.5188,  2.0017,  ..., -2.3979, -0.8950, -0.6265],
         [ 1.6710,  0.5188,  2.0017,  ..., -2.3979, -0.8950, -0.6265]]],
       grad_fn=<AddBackward0>)

In [27]:
from dataclasses import dataclass

@dataclass
class LinformerConfig:
    dim: int
    heads: int
    depth: int
    max_seq_len: int
    n_local_attn_heads: int

    def to_dict(self):
        return {
            "dim": self.dim,
            "heads": self.heads,
            "max_seq_len": self.max_seq_len,
            "n_local_attn_heads": self.n_local_attn_heads,
            "depth": self.depth
        }
@dataclass
class MambaConfig:
    hidden_size: int # this is also the dimension
    num_heads: int 


    @property
    def dim(self):
        return self.hidden_size

    def mamba2_config(self,vocab_size):
        return Mamba2Config(vocab_size=vocab_size,hidden_size=self.hidden_size,num_heads=self.num_heads)

        


In [103]:
class MambaBlock(nn.Module):
    def __init__(self,vocab_size, config: MambaConfig,layer_idx=0):
        super().__init__()
        self.vocab_size = vocab_size
        #self.embedding = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=config.dim)
        self.normalization = nn.modules.normalization.RMSNorm(config.dim)
        self.mamba2_block = Mamba2Block(config = config.mamba2_config(self.vocab_size),layer_idx=layer_idx)


    def forward(self,x):
        #x = self.embedding(x)
        x = self.normalization(x)
        mamba_output = self.mamba2_block(x)
        x = x + mamba_output
        
        x = self.normalization(x)
        
        
        
        
        return x
        

In [145]:
class LinformerBlock(nn.Module):
    def __init__(self,vocab_size,config:LinformerConfig):
        super().__init__()
        #self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=config.dim)
        self.normalization = nn.modules.normalization.RMSNorm(config.dim)
        self.linformer = LinearAttentionTransformer(
            **config.to_dict()
        )


    def forward(self,x):
        #x = self.embedding(x)
        x = self.normalization(x)
        output = self.linformer(x)
        x = output + x

        x = self.normalization(x)

        return x

In [185]:
class LimbaBlock(nn.Module):
    def __init__(self,linformer_config,mamba_config,layer_idx=0,vocab_size=1000,dropout = 0.1):
        super().__init__()

        self.mamba_block = MambaBlock(vocab_size,mamba_config,layer_idx=layer_idx)
        self.linformer_block = LinformerBlock(vocab_size,linformer_config)
        
        self.linformer_mamba_reshape = nn.Linear(linformer_config.dim, mamba_config.dim)
        self.mamba_linformer_reshape = nn.Linear(mamba_config.dim, linformer_config.dim)
        self.dropout = nn.Dropout(dropout)



    def forward(self,x):
        # x -> (batch,seq_len)
        x = self.mamba_block(x)

        x = self.mamba_linformer_reshape(x)

        x = self.dropout(x)
        
        x = self.linformer_block(x)

        x = self.dropout(x)

        x = self.linformer_mamba_reshape(x)

        


        return x
        
        

        

In [187]:
linformer_config = LinformerConfig(dim = 256,
    heads = 8,
    depth = 1,
    max_seq_len = 256,
    n_local_attn_heads = 4)
mamba_config = MambaConfig(
    hidden_size = 512,
    num_heads = 16
)


In [282]:

class Limba(nn.Module):
    def __init__(self, linformer_config, mamba_config, vocab_size, num_layers=6,dropout=0.1):
        super(Limba, self).__init__()
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=mamba_config.dim)
        self.layers = nn.ModuleList(
            [
                LimbaBlock(linformer_config, mamba_config, vocab_size) for _ in range(num_layers)
            ]
        )

        self.layer_norm = nn.LayerNorm(mamba_config.dim)
        self.dropout = nn.Dropout(dropout)
        

        self.output_layer = nn.Linear(mamba_config.dim, vocab_size)

    
    def forward(self, x):
        # Pass input through embedding layer
        x = self.embedding(x)
        
        for layer in self.layers:
            x = layer(x)

        
        x = self.layer_norm(x)
        logits = self.output_layer(x)

        return logits

        

        

        

In [284]:
model = Limba(linformer_config,mamba_config, silabi_tokenizer.vocab_size)

In [286]:
print(model(input_tensor))

tensor([[[ 0.8918,  0.1215, -1.0221,  ..., -0.1375,  0.6453,  0.2408],
         [ 0.6800,  0.0238,  0.5443,  ...,  0.3798,  0.3497,  0.3220],
         [-0.3338,  0.0462,  0.9040,  ...,  0.1092,  0.4037,  0.3869],
         ...,
         [-0.8182,  0.2585,  0.8561,  ...,  1.0838,  0.3018, -0.9353],
         [-0.3339, -0.7124,  0.0153,  ..., -0.9043,  1.2537, -0.0372],
         [-0.2685, -1.0179, -0.1490,  ...,  0.3481,  0.0739, -0.3264]]],
       grad_fn=<ViewBackward0>)


In [200]:
from torchinfo import summary
summary(limba, input_size=(1, 256),dtypes=[torch.long],device="cpu")

Layer (type:depth-idx)                                                      Output Shape              Param #
Limba                                                                       [1, 256, 512]             --
├─Embedding: 1-1                                                            [1, 256, 512]             338,944
├─ModuleList: 1-2                                                           --                        --
│    └─LimbaBlock: 2-1                                                      [1, 256, 512]             --
│    │    └─MambaBlock: 3-1                                                 [1, 256, 512]             2,647,088
│    │    └─Linear: 3-2                                                     [1, 256, 256]             131,328
│    │    └─Dropout: 3-3                                                    [1, 256, 256]             --
│    │    └─LinformerBlock: 3-4                                             [1, 256, 256]             789,248
│    │    └─Dropout: 3-5    

In [202]:
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

class TokenizedDataset(Dataset):
    def __init__(self, tokenized_dataset):
        self.tokenized_dataset = tokenized_dataset
        self.pad_token_id = silabi_tokenizer.pad_token_id
    def __len__(self):
        return len(self.tokenized_dataset)
    def __getitem__(self, idx):
        return torch.tensor(self.tokenized_dataset[idx]['input_ids'])


In [204]:
#train_tokenized = TokenizedDataset(train_tokenized_dataset)
test_tokenized = TokenizedDataset(test_tokenized_dataset)
#val_tokenized = TokenizedDataset(val_tokenized_dataset)

In [288]:
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm  # Optional, for progress bar

def train_model(model, train_loader, num_epochs=10, learning_rate=1e-4, device='cuda'):
    # Move model to the specified device (GPU or CPU)
    model.to(device)
    
    # Initialize the optimizer
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Initialize the loss function
    criterion = torch.nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding token (assumed to be 0)

    for epoch in range(num_epochs):
        model.train()  # Set the model to training mode
        running_loss = 0.0

        # Loop through the training data
        for batch_idx, input_ids in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            input_ids = input_ids.to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass: get model output
            output = model(input_ids)

            # The output shape should be [batch_size, seq_len, vocab_size]
            # For language modeling, the target is the input shifted by 1 position
            target = input_ids[:, 1:].contiguous()  # Shift input for the target

            # Pad the target to the same length as the output (if needed)
            if target.size(1) < output.size(1):
                target = F.pad(target, (0, output.size(1) - target.size(1)), value=0)

            # Flatten the output and target tensors for CrossEntropyLoss
            output = output.view(-1, output.size(-1))  # Flatten the output tensor
            target = target.view(-1)  # Flatten the target tensor

            # Compute the loss
            loss = criterion(output, target)

            # Backpropagate the loss
            loss.backward()

            # Update the weights
            optimizer.step()

            # Track loss
            running_loss += loss.item()

        # Print the statistics for the current epoch
        avg_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")



In [292]:
mps_device = torch.device("mps")


In [294]:
mps_device

device(type='mps')

In [236]:
test_loader = DataLoader(test_tokenized,batch_size=2)

In [300]:
# Example usage:
# Assuming train_loader and val_loader are your data loaders
train_model(model, test_loader,num_epochs=3, device="cpu")

Epoch 1/3:   0%|                         | 2/288338 [00:21<871:54:06, 10.89s/it]


KeyboardInterrupt: 