In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional as F
from einops import rearrange
from tqdm import tqdm

import math
import os
import urllib.request
from zipfile import ZipFile

from transformers import AutoTokenizer

torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f9ed9f03070>

In [2]:
# Configuration flags and hyperparameters
USE_MAMBA = 1
DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM = 0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# User-defined hyperparameters
d_model = 8
state_size = 128  # Example state size
seq_len = 100  # Example sequence length
batch_size = 256  # Example batch size
last_batch_size = 81  # only for the very last batch of the dataset
current_batch_size = batch_size
different_batch_size = False
h_new = None
temp_buffer = None

In [4]:
class S6(nn.Module):
    def __init__(self, seq_len, d_model, state_size, device='cuda'):
        super(S6, self).__init__()

        self.fc_delta = nn.Linear(d_model, d_model, device=device)
        self.fc_B = nn.Linear(d_model, state_size, device=device)
        self.fc_C = nn.Linear(d_model, state_size, device=device)

        self.seq_len = seq_len
        self.d_model = d_model
        self.state_size = state_size

        self.A = nn.Parameter(torch.empty(d_model, state_size, device=device))
        nn.init.xavier_uniform_(self.A)

    def discretization(self, delta, B):
        dB = torch.einsum("bld,bln->bldn", delta, B)
        dA = torch.exp(torch.clamp(torch.einsum("bld,dn->bldn", delta, self.A), max=10.0))
        return dA, dB

    def forward(self, x):
        B = self.fc_B(x)
        C = self.fc_C(x)
        delta = F.softplus(self.fc_delta(x))

        dA, dB = self.discretization(delta, B)

        h = torch.zeros(x.size(0), self.seq_len, self.d_model, self.state_size, device=x.device)
        h = torch.einsum('bldn,bldn->bldn', dA, h) + rearrange(x, "b l d -> b l d 1") * dB

        y = torch.einsum('bln,bldn->bld', C, h)

        return y


In [5]:
class MambaBlock(nn.Module):
    def __init__(self, seq_len, d_model, state_size, device='cuda'):
        super(MambaBlock, self).__init__()

        self.inp_proj = nn.Linear(d_model, 2*d_model, device=device)
        self.out_proj = nn.Linear(2*d_model, d_model, device=device)
        self.D = nn.Linear(d_model, 2*d_model, device=device)
        nn.init.constant_(self.out_proj.bias, 1.0)

        self.S6 = S6(seq_len, 2*d_model, state_size, device)

        self.conv = nn.Conv1d(2*d_model, 2*d_model, kernel_size=3, padding=1, device=device)
        nn.init.xavier_uniform_(self.conv.weight)
        if self.conv.bias is not None:
            nn.init.constant_(self.conv.bias, 0.0)
        self.norm = RMSNorm(d_model, device=device)

    def forward(self, x):
        residual = x
        x = self.norm(x)

        x_proj = self.inp_proj(x)
        x_proj = rearrange(x_proj, 'b l d -> b d l')  # (batch, channels, seq_len)
        x_conv = F.silu(self.conv(x_proj))
        x_conv = rearrange(x_conv, 'b d l -> b l d')

        x_ssm = self.S6(x_conv)
        x_act = F.silu(x_ssm)

        x_residual = F.silu(self.D(residual))

        x_combined = x_act * x_residual
        return self.out_proj(x_combined)


In [6]:
class Mamba(nn.Module):
    def __init__(self, seq_len, d_model, state_size, num_layers=3, device='cuda'):
        super(Mamba, self).__init__()
        self.layers = nn.ModuleList([
            MambaBlock(seq_len, d_model, state_size, device)
            for _ in range(num_layers)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


In [7]:
class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-5, device='cuda'):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model, device=device))

    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight.view(1, 1, -1)

In [8]:
x = torch.rand(batch_size, seq_len, d_model, device=device)
# Create the Mamba model
num_layers = 3
mamba = Mamba(seq_len, d_model, state_size, num_layers, device)

# rmsnorm
norm = RMSNorm(d_model)
x = norm(x)

# Forward pass
test_output = mamba(x)
print(f"test_output.shape = {test_output.shape}")  # Should be [batch_size, seq_len, d_model]

test_output.shape = torch.Size([256, 100, 8])


In [9]:
class Enwiki8Dataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data['input_ids'])

    def __getitem__(self, idx):
        item = {key: val[idx].clone().detach() for key, val in self.data.items()}
        return item

In [10]:
# Define a function for padding
def pad_sequences_3d(sequences, max_len=None, pad_value=0):
    # Assuming sequences is a tensor of shape (batch_size, seq_len, feature_size)
    batch_size, seq_len, feature_size = sequences.shape

    if max_len is None:
        max_len = seq_len + 1


    # Initialize padded_sequences with the pad_value
    padded_sequences = torch.full((batch_size, max_len, feature_size), fill_value=pad_value, dtype=sequences.dtype, device=sequences.device)
    # Pad each sequence to the max_len
    padded_sequences[:, :seq_len, :] = sequences

    return padded_sequences

In [11]:
def train(model, tokenizer, data_loader, optimizer, scheduler, criterion, device, max_grad_norm=1.0, DEBUGGING_IS_ON=False):
    model.train()
    total_loss = 0

    scaler = torch.amp.GradScaler('cuda')
    pbar = tqdm(data_loader, leave=False)

    for step, batch in enumerate(pbar):
        optimizer.zero_grad()

        input_data = batch['input_ids'].clone().to(device)
        attention_mask = batch['attention_mask'].clone().to(device)

        target = input_data[:, 1:]
        input_data = input_data[:, :-1]

        input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id)
        target = pad_sequences_3d(target, max_len=input_data.size(1), pad_value=tokenizer.pad_token_id)

        with torch.amp.autocast('cuda'):
            output = model(input_data)
            loss = criterion(output, target)

        scaler.scale(loss).backward()

        scaler.unscale_(optimizer)
        parameters_to_clip = [
            param for name, param in model.named_parameters()
            if param.grad is not None and 'out_proj.bias' not in name
        ]
        torch.nn.utils.clip_grad_norm_(parameters_to_clip, max_norm=max_grad_norm)

        scaler.step(optimizer)
        scaler.update()

        scheduler.step()

        if DEBUGGING_IS_ON:
            for name, parameter in model.named_parameters():
                if parameter.grad is None:
                    print(f"{name} has no gradient")

        if USE_MAMBA and DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM:
            model.S6.h[:current_batch_size, ...].copy_(temp_buffer)

        total_loss += loss.item()
        current_lr = scheduler.get_last_lr()[0]  # 현재 learning rate 가져오기
        pbar.set_postfix(loss=f'{loss.item():.4f}', step=step, lr=f'{current_lr:.8f}')


    return total_loss / len(data_loader)

In [12]:
def evaluate(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in data_loader:
            input_data = batch['input_ids'].clone().detach().to(device)
            attention_mask = batch['attention_mask'].clone().detach().to(device)

            # In most sequence modeling tasks, like language modeling, the target should be the next token
            # in the sequence rather than the input token itself.
            # This is because the model's goal is to predict the next word given the previous words.
            # Shift the input data by one position to get the target, so that each target token
            # is the next token following the input token.
            target = input_data[:, 1:]
            input_data = input_data[:, :-1]

            # Pad all the sequences in the batch:
            input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id)
            target = pad_sequences_3d(target, max_len=input_data.size(1), pad_value=tokenizer.pad_token_id)

            if USE_MAMBA:
                output = model(input_data)
                loss = criterion(output, target)
            total_loss += loss.item()
    return total_loss / len(data_loader)

In [13]:
def calculate_perplexity(loss):
    return math.exp(loss)

In [14]:
def load_enwiki8_dataset():
    print(f"Download and extract enwiki8 data")
    url = "http://mattmahoney.net/dc/enwik8.zip"
    urllib.request.urlretrieve(url, "enwik8.zip")

    with ZipFile("enwik8.zip") as f:
        data = f.read("enwik8").decode("utf-8")

    return data

In [15]:
# Tokenize and encode the dataset
def encode_dataset(tokenizer, text_data):
    def batch_encode(tokenizer, text_data, batch_size=1000):
        # Tokenize in batches
        batched_input_ids = []
        for i in range(0, len(text_data), batch_size):
            batch = text_data[i:i+batch_size]
            inputs = tokenizer(batch, add_special_tokens=True, truncation=True,
                               padding='max_length', max_length=seq_len,
                               return_tensors='pt')
            batched_input_ids.append(inputs['input_ids'])
        return torch.cat(batched_input_ids)

    # Assuming enwiki8_data is a list of sentences
    input_ids = batch_encode(tokenizer, enwiki8_data)

    # vocab_size is the number of unique tokens in the tokenizer's vocabulary
    global vocab_size
    vocab_size = len(tokenizer.vocab)  # Note that for some tokenizers, we might access the vocab directly
    print(f"vocab_size = {vocab_size}")

    # Create an embedding layer
    # embedding_dim is the size of the embedding vectors (MAMBA model's D)
    embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)

    # Pass `input_ids` through the embedding layer
    # This will change `input_ids` from shape [B, L] to [B, L, D]
    #encoded_input = embedding_layer(input_ids)   ## this eats memory, so use batched_embedding_calls instead
    def batch_embedding_calls(input_ids, embedding_layer, batch_size=256):
        # Check if input_ids is already a tensor, if not convert it
        if not isinstance(input_ids, torch.Tensor):
            input_ids = torch.tensor(input_ids, dtype=torch.long)

        # Calculate the number of batches needed
        num_batches = math.ceil(input_ids.size(0) / batch_size)

        # List to hold the output embeddings
        output_embeddings = []

        # Process each batch
        for i in range(num_batches):
            # Calculate start and end indices for the current batch
            start_idx = i * batch_size
            end_idx = start_idx + batch_size

            # Get the batch
            input_id_batch = input_ids[start_idx:end_idx]

            # Call the embedding layer
            with torch.no_grad():  # No need gradients for this operation
                batch_embeddings = embedding_layer(input_id_batch)

            # Append the result to the list
            output_embeddings.append(batch_embeddings)

        # Concatenate the embeddings from each batch into a single tensor
        all_embeddings = torch.cat(output_embeddings, dim=0)

        return all_embeddings

    # `input_ids` is a list or tensor of the input IDs and `embedding_layer` is model's embedding layer
    if USE_MAMBA:
        # Set `batch_size` to a value that works for memory constraints
        encoded_inputs = batch_embedding_calls(input_ids, embedding_layer, batch_size=1).float()

    attention_mask = (input_ids != tokenizer.pad_token_id).type(input_ids.dtype)

    return encoded_inputs, attention_mask

In [16]:
# Load a pretrained tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')



In [17]:
# Assuming encoded_inputs is a preprocessed tensor of shape [num_samples, seq_len, d_model]
encoded_inputs_file = 'encoded_inputs_mamba.pt'

if os.path.exists(encoded_inputs_file):
    print("Loading pre-tokenized data...")
    saved_data = torch.load(encoded_inputs_file)
    encoded_inputs = saved_data['input_ids']
    attention_mask = saved_data['attention_mask']
else:
    print("Tokenizing raw data...")
    enwiki8_data = load_enwiki8_dataset()
    encoded_inputs, attention_mask = encode_dataset(tokenizer, enwiki8_data)
    torch.save({'input_ids': encoded_inputs, 'attention_mask': attention_mask}, encoded_inputs_file)
    print("Finished tokenizing data")

# Combine into a single dictionary
data = {
    'input_ids': encoded_inputs,
    'attention_mask': attention_mask
}



# Combine into a single dictionary
data = {
    'input_ids': encoded_inputs,
    'attention_mask': attention_mask
}

# Split the data into train and validation sets
total_size = len(data['input_ids'])
train_size = int(total_size * 0.8)

train_data = {key: val[:train_size] for key, val in data.items()}
val_data = {key: val[train_size:] for key, val in data.items()}

train_dataset = Enwiki8Dataset(train_data)
val_dataset = Enwiki8Dataset(val_data)


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


# Initialize the model
num_layers = 3
model = Mamba(seq_len, d_model, state_size, num_layers, device).to(device)

# Define the loss function and optimizer
learning_rate = 3e-6
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

# Training loop
num_epochs = 10  

from transformers import get_cosine_schedule_with_warmup
total_steps = len(train_loader) * num_epochs
warmup_steps = int(0.1 * total_steps) 

scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

for epoch in range(num_epochs):
    train_loss = train(model, tokenizer, train_loader, optimizer, scheduler, criterion, device, max_grad_norm=1.0, DEBUGGING_IS_ON=True)
    val_loss = evaluate(model, val_loader, criterion, device)
    val_perplexity = calculate_perplexity(val_loss)
    print(f'Epoch: {epoch+1}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation Perplexity: {val_perplexity:.4f}')

  saved_data = torch.load(encoded_inputs_file)


Loading pre-tokenized data...
[2025-03-25 09:33:01,535] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
2025-03-25 09:33:02.961088: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
                                                                                       

Epoch: 1, Training Loss: 3.8879, Validation Loss: 3.3162, Validation Perplexity: 27.5562


                                                                                        

Epoch: 2, Training Loss: 3.8804, Validation Loss: 3.3121, Validation Perplexity: 27.4414


                                                                                        

Epoch: 3, Training Loss: 3.8749, Validation Loss: 3.2806, Validation Perplexity: 26.5923


                                                                                        

Epoch: 4, Training Loss: 3.8032, Validation Loss: 3.1869, Validation Perplexity: 24.2129


                                                                                        

Epoch: 5, Training Loss: 3.6967, Validation Loss: 3.0699, Validation Perplexity: 21.5404


                                                                                        

Epoch: 6, Training Loss: 3.5888, Validation Loss: 2.9510, Validation Perplexity: 19.1253


                                                                                        

Epoch: 7, Training Loss: 3.4634, Validation Loss: 2.8547, Validation Perplexity: 17.3696


                                                                                        

Epoch: 8, Training Loss: 3.3960, Validation Loss: 2.7959, Validation Perplexity: 16.3778


                                                                                        

Epoch: 9, Training Loss: 3.3531, Validation Loss: 2.7728, Validation Perplexity: 16.0033


                                                                                        

Epoch: 10, Training Loss: 3.3457, Validation Loss: 2.7693, Validation Perplexity: 15.9479
