In [1]:
## Importing Libraries

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import time
import matplotlib.pyplot as plt
import pandas as pd

## Hyperparameters and Configuration

In [None]:
# Modified hyperparameters
SEQUENCE_LENGTH = 64

EMBEDDING_DIM = 128
HIDDEN_DIM = EMBEDDING_DIM*2
NUM_LAYERS = 4

BATCH_SIZE = 2048
EPOCHS = 1
LEARNING_RATE = 1e-4
VALIDATION_SPLIT = 0.1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


## Data Preparation

We are using the TinyShakespeare dataset, a small character-level text corpus consisting of a subset of Shakespeare's plays. It's often used for testing sequence models, as it includes a rich set of vocabulary and provides a challenging task for next-character prediction.

In [3]:
## Utility Functions

def load_data(filename):
    with open(filename, 'r', encoding='utf-8') as f:
        return f.read()

def create_char_mappings(text):
    chars = sorted(list(set(text)))
    char_to_idx = {ch: i for i, ch in enumerate(chars)}
    idx_to_char = {i: ch for i, ch in enumerate(chars)}
    return chars, char_to_idx, idx_to_char

## Dataset

In [4]:
class CharDataset(Dataset):
    def __init__(self, text, seq_length, char_to_idx):
        self.text = text
        self.seq_length = seq_length
        self.char_to_idx = char_to_idx

    def __len__(self):
        return len(self.text) - self.seq_length

    def __getitem__(self, idx):
        x = [self.char_to_idx[ch] for ch in self.text[idx:idx+self.seq_length]]
        y = [self.char_to_idx[ch] for ch in self.text[idx+1:idx+self.seq_length+1]]
        return torch.tensor(x), torch.tensor(y)

In [5]:
def prepare_data(text, seq_length, batch_size, val_split):
    chars, char_to_idx, idx_to_char = create_char_mappings(text)

    # Split data into train and validation
    val_size = int(len(text) * val_split)
    train_text, val_text = text[:-val_size], text[-val_size:]

    train_dataset = CharDataset(train_text, seq_length, char_to_idx)
    val_dataset = CharDataset(val_text, seq_length, char_to_idx)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=12, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=12, pin_memory=True)

    return train_loader, val_loader, chars, char_to_idx, idx_to_char

In [6]:
# !wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=19zosLuU0z4MxIMKbGVYEGlg52QyfbTIy' -O input.txt

In [7]:
# Load the data
text = load_data('./input.txt')
train_loader, val_loader, chars, char_to_idx, idx_to_char = prepare_data(text, SEQUENCE_LENGTH, BATCH_SIZE, VALIDATION_SPLIT)
vocab_size = len(chars)

print(f"Total characters: {len(text)}")
print(f"Vocabulary size: {vocab_size}")
print(f"Train dataset size: {len(train_loader.dataset)}")
print(f"Validation dataset size: {len(val_loader.dataset)}")

Total characters: 1115394
Vocabulary size: 65
Train dataset size: 1003791
Validation dataset size: 111475


## Data Visualization

In [8]:
# Function to convert index sequence to character sequence
def indices_to_text(indices, idx_to_char):
    return ''.join([idx_to_char[idx.item()] for idx in indices])

# Get a batch of data
dataiter = iter(train_loader)
batch_x, batch_y = next(dataiter)

print(f"Input shape: {batch_x.shape}")
print(f"Target shape: {batch_y.shape}")

# Print a few samples from the batch
num_samples = 3
for i in range(num_samples):
    print(f"Sample {i+1}: ------------------------------" )
    print("Input sequence :", indices_to_text(batch_x[i], idx_to_char).replace('\n',''))
    print("Target sequence:", indices_to_text(batch_y[i], idx_to_char).replace('\n',''))
    print()


Input shape: torch.Size([2048, 64])
Target shape: torch.Size([2048, 64])
Sample 1: ------------------------------
Input sequence : st out, like to itself,No father owning it,--which is, indeed,
Target sequence: t out, like to itself,No father owning it,--which is, indeed,M

Sample 2: ------------------------------
Input sequence : ill should live! 'True, noble prince!'Cousin, thou wert not won
Target sequence: ll should live! 'True, noble prince!'Cousin, thou wert not wont

Sample 3: ------------------------------
Input sequence : ep you!Both Tribunes:Farewell, farewell.SICINIUS:This is a
Target sequence: p you!Both Tribunes:Farewell, farewell.SICINIUS:This is a 



## Training Function

In [9]:
def print_vram_usage(device="cuda"):
    allocated = torch.cuda.memory_allocated(device) / (1024**2)  # in MB
    reserved = torch.cuda.memory_reserved(device) / (1024**2)    # in MB
    max_allocated = torch.cuda.max_memory_allocated(device) / (1024**2)  # in MB
    print(f"Allocated: {allocated:.2f} MB, Reserved: {reserved:.2f} MB, Max Allocated: {max_allocated:.2f} MB")

In [10]:
def validate(model, dataloader, criterion, device, epoch, step):
    model.eval()
    losses = []
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            output, _ = model(x)
            loss = criterion(output.view(-1, output.size(-1)), y.view(-1))
            losses.append((step, epoch, loss.item()))
    return losses

In [11]:
from tqdm import tqdm

def train(model, dataloader, criterion, optimizer, device, epoch, step):
    model.train()
    losses = []
    vram_usage = []
    scaler = torch.amp.GradScaler('cuda')
    pbar = tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)
    for batch, (x, y) in enumerate(pbar):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        
        with torch.amp.autocast('cuda'):
            output, _ = model(x)
            loss = criterion(output.view(-1, output.size(-1)), y.view(-1))
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        step += 1
        losses.append((step, epoch, loss.item()))
        
        # VRAM 사용량을 progress bar의 postfix로 업데이트
        allocated = torch.cuda.memory_allocated(device) / (1024**2)
        vram_usage.append(allocated)
        pbar.set_postfix(loss=f'{loss.item():.4f}', step=step, vram=f'{allocated:.2f} MB')
    return losses, step, vram_usage

## Generation Function

In [12]:
def train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs):
    all_train_losses = []
    all_val_losses = []
    all_vram_usages = []
    step = 0

    for epoch in range(1, epochs + 1):
        epoch_start_time = time.time()
        # Training phase with tqdm updates
        epoch_train_losses, step, vram_usage = train(model, train_loader, criterion, optimizer, device, epoch, step)
        all_train_losses.extend(epoch_train_losses)
        all_vram_usages.append(vram_usage)
        
        # Validation phase
        epoch_val_losses = validate(model, val_loader, criterion, device, epoch, step)
        all_val_losses.extend(epoch_val_losses)
        
        epoch_time = time.time() - epoch_start_time
        print(f'Epoch {epoch}/{epochs}, Train Loss: {epoch_train_losses[-1][2]:.4f}, '
              f'Val Loss: {epoch_val_losses[-1][2]:.4f}, Epoch Time: {epoch_time:.2f}s',
              f'Average Vram Usage: {np.mean(vram_usage):.2f}MB')

    train_losses_df = pd.DataFrame(all_train_losses, columns=['step', 'epoch', 'loss_value'])
    val_losses_df = pd.DataFrame(all_val_losses, columns=['step', 'epoch', 'loss_value'])
    # average_vram_usage = np.mean(all_vram_usages)
    return model, train_losses_df, val_losses_df


In [13]:
def generate_text(model, char_to_idx, idx_to_char, start_text, device, max_length=500):
    model.eval()
    current_text = start_text
    hidden = None

    with torch.no_grad():
        for _ in range(max_length):
            x = torch.tensor([[char_to_idx[ch] for ch in current_text[-SEQUENCE_LENGTH:]]]).to(device)
            output, hidden = model(x, hidden)
            probs = torch.softmax(output[0, -1], dim=0)
            next_char_idx = torch.multinomial(probs, 1).item()
            next_char = idx_to_char[next_char_idx]
            current_text += next_char

    return current_text

In [14]:
loss_comparison_dict = {}

def add_loss_to_comparison(model_name, train_losses_df, val_losses_df):
    """
    Adds training and validation losses from a model to the comparison dictionary.
    """
    loss_comparison_dict[model_name] = {
        'train': train_losses_df,
        'val': val_losses_df
    }

def print_final_losses(loss_dict):
    for model_name, losses in loss_dict.items():
        train_df = losses['train']
        val_df = losses['val']
        final_train = train_df.groupby('epoch')['loss_value'].last().iloc[-1]
        final_val = val_df.groupby('epoch')['loss_value'].last().iloc[-1]
        print(f"{model_name}: Final Train Loss: {final_train:.4f}, Final Val Loss: {final_val:.4f}")

In [15]:
# Function to plot loss curves
def plot_loss(train_losses_df, val_losses_df):
    plt.figure(figsize=(10, 5))

    # Plot training losses
    for epoch in train_losses_df['epoch'].unique():
        epoch_train_losses = train_losses_df[train_losses_df['epoch'] == epoch]
        plt.plot(epoch_train_losses['step'], epoch_train_losses['loss_value'],
                 color='blue', alpha=0.3)

    # scatter training loss at the end of each epoch
    last_train_losses = train_losses_df.groupby('epoch').last().reset_index()
    plt.scatter(last_train_losses['step'], last_train_losses['loss_value'],
                color='blue')

    # Plot and scatter validation loss at the end of each epoch
    last_val_losses = val_losses_df.groupby('epoch').last().reset_index()
    plt.plot(last_val_losses['step'], last_val_losses['loss_value'],
             color='orange', label='Validation Loss')
    plt.scatter(last_val_losses['step'], last_val_losses['loss_value'],
                color='orange')

    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

# Function to print final loss values
def print_final_losses(train_losses_df, val_losses_df):
    print("Final Training Loss:", train_losses_df.groupby('epoch')['loss_value'].last().iloc[-1])
    print("Final Validation Loss:", val_losses_df.groupby('epoch')['loss_value'].last().iloc[-1])

In [16]:
# Function to plot loss curves for multiple models stored in loss_comparison_dict
def plot_loss_comparisons():
    """
    Plots the training loss curves and average validation loss per epoch for multiple models added to the loss comparison dictionary.
    """
    plt.figure(figsize=(10, 5))
    
    # Get the last model in the dictionary (for special final-point highlighting)
    last_model_name = list(loss_comparison_dict.keys())[-1]

    # Loop through each model in the loss dictionary
    for model_name, losses in loss_comparison_dict.items():
        train_losses_df = losses['train']
        val_losses_df = losses['val']

        # Plot training losses for each model
        plt.plot(train_losses_df['step'], train_losses_df['loss_value'],
                 label=f'{model_name} train', linestyle='-', alpha=0.7)

        # Scatter training loss at the end of each epoch
        last_train_losses = train_losses_df.groupby('epoch').last().reset_index()
        plt.scatter(last_train_losses['step'], last_train_losses['loss_value'], marker='o', s=50)

        # Compute average validation loss per epoch (using the last step of each epoch for x-axis)
        avg_val_losses = val_losses_df.groupby('epoch').agg({'loss_value': 'mean', 'step': 'last'}).reset_index()
        # Scatter the average validation loss for each epoch
        plt.scatter(avg_val_losses['step'], avg_val_losses['loss_value'], marker='s', s=50,
                    label=f'{model_name} val avg')

        # For the last model, highlight the final training loss with a star
        if model_name == last_model_name:
            final_step = train_losses_df['step'].iloc[-1]
            final_loss = train_losses_df['loss_value'].iloc[-1]
            plt.scatter(final_step, final_loss, marker='*', s=100, color='red', zorder=5)

    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.title('Training Loss Comparison')
    plt.legend()  # Legend shows both training and validation average labels
    plt.grid(True)
    plt.show()


In [17]:
def plot_separate_train_val(loss_dict):
    """
    모델별 Training Loss와 Validation Loss를 각각 별도의 그래프로 그립니다.
    단, Validation Loss는 에포크별 평균으로 계산합니다.
    """
    # 1. Training Loss Plot (원본 그대로)
    plt.figure(figsize=(20, 5))
    plt.subplot(1, 2, 1)  # 1행 2열 중 첫 번째
    for model_name, losses in loss_dict.items():
        train_df = losses['train']
        steps_train = train_df['step'].values
        loss_train = train_df['loss_value'].values
        plt.plot(steps_train, loss_train, label=f'{model_name} Train')
    plt.title('Training Loss Comparison')
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # 2. Validation Loss Plot (에포크별 평균 처리)
    plt.subplot(1, 2, 2)  # 1행 2열 중 두 번째
    for model_name, losses in loss_dict.items():
        val_df = losses['val']
        # 에포크별 평균 loss와 마지막 step을 계산
        val_avg = val_df.groupby('epoch').agg({'loss_value': 'mean', 'step': 'last'}).reset_index()
        plt.plot(val_avg['step'], val_avg['loss_value'], label=f'{model_name} Val')
    plt.title('Validation Loss (Epoch Avg) Comparison')
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()


In [18]:
# Text generation using validation data
val_sample, _ = next(iter(val_loader))
start_text = ''.join([idx_to_char[idx.item()] for idx in val_sample[0][:SEQUENCE_LENGTH]])

In [19]:
def generate_text_attention(model, char_to_idx, idx_to_char, start_text, device, max_length=500):
    model.eval()
    current_text = start_text

    with torch.no_grad():
        for _ in range(max_length):
            # Convert the last sequence of characters to indices and feed it to the model
            x = torch.tensor([[char_to_idx[ch] for ch in current_text[-SEQUENCE_LENGTH:]]]).to(device)
            output = model(x)[0]  # No hidden state needed for attention-based models
            probs = torch.softmax(output[0, -1], dim=0)
            next_char_idx = torch.multinomial(probs, 1).item()
            next_char = idx_to_char[next_char_idx]
            current_text += next_char

    return current_text

In [20]:
def train_and_test(model_desc, model, start_text):
    # Initialize the model
    model = model.to(device)
    # Use the same optimizer and criterion
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    
    # Train the model
    trained_model, train_losses_df, val_losses_df = train_model(
        model, train_loader, val_loader, criterion, optimizer, device, EPOCHS
    )

    # Generate text
    generated_text = generate_text_attention(trained_model, char_to_idx, idx_to_char, start_text, device)
    print(f"Generated text [{start_text}]:")
    print("-"*50)
    print(generated_text)
    
    add_loss_to_comparison(model_desc, train_losses_df, val_losses_df)

    # Plot loss comparisons including this model
    plot_loss_comparisons()
    
    plot_separate_train_val(loss_comparison_dict)

In [21]:
import torch
import torch.nn as nn

from torch.nn import functional as F
from einops import rearrange

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor):
        # x = x.to(torch.float32) # RMSNorm 계산은 float32에서 수행하는 것이 안정적일 수 있음
        norm = self._norm(x.to(torch.float32))
        return self.weight * norm

In [None]:
# 수정된 SSM 모듈
class SSM(nn.Module):
    # --- 3.1 초기화 (__init__) ---
    def __init__(self, d_inner, state_size, device='cuda'): # seq_len 불필요
        super(SSM, self).__init__()
        self.d_inner = d_inner         # 모델 내부 차원 (입력/출력 벡터 u, y의 차원)
        self.state_size = state_size     # 상태 벡터 h의 차원 (종종 N으로 표기)
        self.device = device

        # 입력 x (d_inner) -> dt, B, C 계산용 프로젝션 (Mamba 스타일)
        # dt_rank 는 보통 d_inner / 16 정도 사용
        dt_rank = d_inner // 16
        # 이 Linear 레이어는 입력 x로부터 Δ, B, C를 '선택적'으로 계산하기 위한 중간 값을 만듭니다.
        # 출력 크기: dt 계산용(dt_rank) + B 계산용(state_size) + C 계산용(state_size)
        self.x_proj = nn.Linear(d_inner, dt_rank + state_size * 2, bias=False, device=device)

        # dt_rank -> d_inner 로 확장
        # Δ는 입력 x의 각 채널(d_inner)마다 다른 값을 가져야 하므로, dt_rank에서 d_inner로 확장합니다.
        self.dt_proj = nn.Linear(dt_rank, d_inner, bias=True, device=device)

        # --- 파라미터 A ---
        # Mamba는 A를 직접 학습하지 않고, A_log를 학습합니다.
        # A는 일반적으로 음수 값을 가져야 안정적인 시스템이 되므로, -exp(A_log) 형태로 사용합니다.
        # S4D 스타일: A를 state_size(N) 크기의 벡터로 파라미터화하고, d_inner 채널에 걸쳐 반복하여 사용합니다.
        # 이는 A가 대각 행렬(diagonal matrix)임을 의미하며, 각 상태 h의 요소는 독립적으로 업데이트됩니다.
        A = torch.arange(1, state_size + 1, dtype=torch.float32, device=device).repeat(d_inner, 1)
        # A_log를 Parameter로 등록하여 학습 대상임을 명시합니다. log 공간에서 학습합니다.
        self.A_log = nn.Parameter(torch.log(A)) # Shape: [d_inner, state_size]
        # 보통 음수로 초기화하는 것이 좋습니다 (예: nn.init.normal_(self.A_log, mean=-1, std=0.5))

        # 파라미터 D (피드스루) - Mamba는 보통 학습 가능한 D 사용
        # D는 입력 x가 출력 y에 직접 영향을 미치는 skip connection 역할을 합니다.
        self.D = nn.Parameter(torch.ones(d_inner, device=device)) # Shape: [d_inner]

        # 수치 안정성 위한 값
        self.log_eps = torch.log(torch.tensor(1e-7)).to(device) # log(0) 방지
        self.exp_clamp_val = 20.0 # exp() 결과가 너무 커지는 것 방지 (overflow 방지)

    # --- 3.2 이산화 (discretization) ---
    def discretization(self, delta, B):
        # 입력:
        # delta (Δ): 시간 간격 파라미터 [B, L, d_inner] - 입력 x로부터 계산됨
        # B: 입력 행렬 파라미터 [B, L, state_size] - 입력 x로부터 계산됨
        # 출력:
        # delta_A (Ā): 이산화된 상태 행렬 [B, L, d_inner, state_size]
        # delta_B (B̄): 이산화된 입력 행렬 [B, L, d_inner, state_size]

        # A 계산 (로그 공간 -> 실제 공간)
        # A는 일반적으로 음수 값을 가짐 (안정적 시스템)
        A = -torch.exp(self.A_log.float()) # Shape: [d_inner, state_size]

        # ΔA 계산 (Ā ≈ exp(ΔA))
        # broadcasting: delta [B, L, D, 1] * A [1, 1, D, N] -> [B, L, D, N]
        # .to(A.device) 추가하여 device 일치 보장
        # log(ΔA) = log(Δ) + log(A) 인데, 여기서는 Δ * A 를 사용합니다. (선형 근사?)
        # Mamba 논문에서는 ZOH 이산화를 사용: Ā = exp(ΔA)
        log_delta_A = torch.clamp(delta.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0), min=self.log_eps.to(A.device), max=self.exp_clamp_val)
        delta_A = torch.exp(log_delta_A) # Shape: [B, L, d_inner, state_size] (Ā)

        # ΔB 계산 (B̄ ≈ ΔB 또는 ZOH: B̄ = (exp(ΔA) - 1) A⁻¹ B)
        # 코드에서는 간단한 근사 B̄ ≈ ΔB 를 사용합니다.
        # broadcasting: delta [B, L, D, 1] * B [B, L, 1, N] -> [B, L, D, N]
        delta_B = delta.unsqueeze(-1) * B.unsqueeze(2) # Shape: [B, L, d_inner, state_size] (B̄)

        return delta_A, delta_B

    # --- 3.3 순방향 전파 (forward) ---
    def forward(self, x):
        # x (u_k 역할): 입력 시퀀스 [B, L, d_inner] (Batch, Sequence Length, Inner Dimension)
        B, L, d_inner = x.shape

        # --- 3.3.1 입력으로부터 Δ, B, C 계산 (선택적 메커니즘) ---
        x_proj_out = self.x_proj(x) # [B, L, dt_rank + 2 * state_size]
        # 결과를 dt 계산용, B 계산용, C 계산용으로 분리
        dt_inter, B_ssm, C_ssm = torch.split(x_proj_out, [self.dt_proj.in_features, self.state_size, self.state_size], dim=-1)
        # dt_inter: [B, L, dt_rank]
        # B_ssm (B 역할): [B, L, state_size]
        # C_ssm (C 역할): [B, L, state_size]

        # Δ 계산: dt_inter를 dt_proj로 d_inner 차원으로 확장 후 softplus 적용
        dt = self.dt_proj(dt_inter) # [B, L, d_inner]
        # softplus(x) = log(1 + exp(x)) 를 사용하여 Δ가 항상 양수가 되도록 보장
        delta = F.softplus(dt)      # [B, L, d_inner], (Δ 역할)

        # --- 3.3.2 이산화 수행 ---
        # .to(x.device) 추가하여 device 일치 보장
        # 위에서 계산한 Δ와 B_ssm을 사용하여 이산화된 파라미터 Ā, B̄ 계산
        delta_A, delta_B = self.discretization(delta.to(x.device), B_ssm.to(x.device)) # Ā:[B,L,D,N], B̄:[B,L,D,N]

        # --- 3.3.3 Scan 연산 (Vectorized Recurrence) ---
        # 목표: h_k = Ā_k * h_{k-1} + B̄_k * x_k 계산
        # 실제 Mamba는 효율적인 병렬 스캔 알고리즘(CUDA 커널)을 사용하지만,
        # 여기서는 PyTorch의 `cumsum`을 이용한 벡터화된 형태로 근사 계산합니다.
        # (주의: 이 방식은 수치적으로 불안정하거나 메모리 사용량이 많을 수 있습니다)

        # 단계 1: B̄_k * x_k 계산
        # delta_B (B̄): [B, L, D, N], x: [B, L, D] -> x.unsqueeze(-1): [B, L, D, 1]
        delta_B_u = delta_B * x.unsqueeze(-1) # Shape: [B, L, D, N] (B̄_k * x_k 에 해당)

        # 단계 2: Ā의 누적 곱 계산 (로그 공간에서 수행 후 exp 변환)
        # R_k = Ā_k * Ā_{k-1} * ... * Ā_1
        log_delta_A = torch.log(torch.clamp(delta_A, min=1e-7)) # log(Ā_k) 계산, 0 방지
        # log(R_k) = log(Ā_k) + log(Ā_{k-1}) + ... + log(Ā_1)
        log_R = torch.cumsum(log_delta_A, dim=1) # Shape: [B, L, D, N] (log(R_k))
        R = torch.exp(torch.clamp(log_R, max=self.exp_clamp_val)) # Shape: [B, L, D, N] (R_k)

        # 단계 3: 스캔 계산 (h_k = Σ_{i=1 to k} (Π_{j=i+1 to k} Ā_j) * B̄_i * x_i)
        #       = R_k * Σ_{i=1 to k} (B̄_i * x_i) / R_i
        # 위 공식을 벡터화하여 계산합니다.

        # 1/R_k 계산 (로그 공간에서 -log_R 계산 후 exp)
        exp_neg_log_R = torch.exp(torch.clamp(-log_R, max=self.exp_clamp_val)) # Shape: [B, L, D, N] (1/R_k)

        # S_term = (B̄_k * x_k) / R_k 계산
        S_term = delta_B_u * exp_neg_log_R # Shape: [B, L, D, N]

        # S_k = Σ_{i=1 to k} S_term_i 계산
        S = torch.cumsum(S_term, dim=1)    # Shape: [B, L, D, N] (S_k)

        # h_k = R_k * S_k 계산
        h = R * S # Shape: [B, L, d_inner, state_size] (h_k)

        # --- 3.3.4 출력 계산 ---
        # y_k = C_k * h_k + D * x_k
        # C_ssm: [B, L, N], h: [B, L, D, N] -> einsum -> y: [B, L, D]
        # 'bln,bldn->bld'는 각 배치(b), 각 시퀀스 위치(l)에 대해 C 벡터와 h 행렬의 내적(dot product)을 수행합니다.
        y = torch.einsum('bln,bldn->bld', C_ssm, h) # Shape: [B, L, d_inner] (C_k * h_k 부분)

        # 피드스루 D 추가 (Skip connection)
        # D: [D] -> unsqueeze -> [1, 1, D]
        # .to(x.device) 추가하여 device 일치 보장
        # x: [B, L, D] 와 브로드캐스팅되어 요소별 곱셈 후 더해짐
        y = y + x * self.D.unsqueeze(0).unsqueeze(0).to(x.device) # Shape: [B, L, d_inner] (최종 y_k)

        return y

In [None]:

class MambaBlock(nn.Module):
    def __init__(self, d_model, state_size, d_conv=4, expand=2, dropout_prob=0.1, device='cuda'): 
        super(MambaBlock, self).__init__()
        self.d_model = d_model
        self.d_inner = int(expand * d_model) 
        self.state_size = state_size
        self.d_conv = d_conv
        self.device = device

        # 입력 프로젝션 (x -> xz) 및 분기 (x -> x_for_ssm, z -> gate)
        self.in_proj = nn.Linear(d_model, 2 * self.d_inner, bias=False, device=device)

        # 컨볼루션 브랜치 (Depthwise Conv1d)
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            bias=True,
            kernel_size=d_conv,
            groups=self.d_inner, # Depthwise 설정
            padding=d_conv - 1, # Causal padding을 위한 설정
            device=device,
        )

        # SSM 모듈 (d_inner 차원 사용)
        self.ssm = SSM(self.d_inner, state_size, device=device)

        # 출력 프로젝션
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False, device=device)

        # Layer Normalization (RMSNorm 사용)
        self.norm = RMSNorm(d_model, eps=1e-5) # RMSNorm 사용 시 eps는 1e-5가 종종 사용됨

        # Residual Dropout 추가 (Llama의 res_dropout과 유사)
        self.dropout_res = nn.Dropout(dropout_prob)

    def forward(self, x):
        # x: [B, L, d_model]
        B, L, D = x.shape

        # Residual connection 저장
        residual = x

        # Layer Norm 적용
        x_norm = self.norm(x)

        # Input projection & split (x -> x_in, z)
        xz = self.in_proj(x_norm) # [B, L, 2 * d_inner]
        x_in, z = xz.chunk(2, dim=-1) # 각 [B, L, d_inner]

        # Conv branch
        # .to(x_in.device) 추가하여 device 일치 보장
        x_conv = rearrange(x_in, 'b l d -> b d l') # [B, d_inner, L]
        # Conv1D 적용 시 causal padding 효과를 위해 마지막 d_conv-1 제거
        x_conv = self.conv1d(x_conv)[:, :, :L]
        x_conv = rearrange(x_conv, 'b d l -> b l d') # [B, L, d_inner]
        x_conv_act = F.silu(x_conv) # Conv 후 Activation

        # SSM branch
        y_ssm = self.ssm(x_conv_act) # [B, L, d_inner]

        # Gating (z * y_ssm)
        y_gated = y_ssm * F.silu(z) # [B, L, d_inner]

        # Output projection
        output = self.out_proj(y_gated) # [B, L, d_model]

        # Residual connection 추가 전에 Dropout 적용
        # dropout은 training 중에만 활성화됨 (model.eval() 시 자동 비활성화)
        output = residual + self.dropout_res(output) # [B, L, d_model]

        return output

In [None]:

class Mamba(nn.Module):
    # dropout_prob 인수 추가
    def __init__(self, d_model, n_layers, vocab_size, state_size=16, d_conv=4, expand=2, dropout_prob=0.1, device='cuda'): # seq_len 불필요
        super(Mamba, self).__init__()
        self.d_model = d_model
        self.n_layers = n_layers
        self.vocab_size = vocab_size
        self.device = device
        self.dropout_prob = dropout_prob # dropout 확률 저장

        self.embedding = nn.Embedding(vocab_size, d_model, device=device)
        # Embedding Dropout 추가
        self.dropout_emb = nn.Dropout(dropout_prob)

        self.layers = nn.ModuleList([
            MambaBlock(
                d_model=d_model,
                state_size=state_size,
                d_conv=d_conv,
                expand=expand,
                # dropout_prob 전달
                dropout_prob=self.dropout_prob,
                device=device
            )
            for _ in range(n_layers)
        ])
        self.norm_f = RMSNorm(d_model, eps=1e-5) # 최종 Norm
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False, device=device)

        # Weight tying (optional but common)
        self.lm_head.weight = self.embedding.weight

        # Weight initialization (optional)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            if module.bias is not None:
                nn.init.zeros_(module.bias)
            # 가중치 초기화는 모델 성능에 중요할 수 있음
            # nn.init.xavier_uniform_(module.weight) 또는 다른 방법 사용
            # Mamba 논문에서는 특별한 초기화를 제안할 수 있으므로 확인 필요
            nn.init.normal_(module.weight, mean=0.0, std=0.02) # 예시 초기화
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, std=0.02)
        elif isinstance(module, nn.Conv1d):
            # Conv1d 가중치 초기화 (SiLU 사용 고려)
            # std = math.sqrt((4 * (1.0 - self.dropout_prob)) / (self.d_conv * self.d_inner)) # 예시 (정확하지 않을 수 있음)
            # nn.init.normal_(module.weight, mean=0.0, std=std)
            nn.init.kaiming_normal_(module.weight, nonlinearity='leaky_relu') # SiLU 사용시 kaiming_normal도 고려 가능
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, RMSNorm):
             nn.init.ones_(module.weight) # RMSNorm 가중치는 1로 초기화
        # Add specific init for SSM parameters if needed (e.g., A_log, D)
        # 예를 들어 SSM의 A_log는 음수로 시작하는 것이 안정적일 수 있음
        if hasattr(module, 'A_log'):
             nn.init.normal_(module.A_log, mean=-1.0, std=0.5) # 음수 초기화 예시
        if hasattr(module, 'D'):
             nn.init.ones_(module.D) # D는 1로 시작하는 경우가 많음


    def forward(self, input_ids, labels=None): # labels는 loss 계산 시 사용 (여기서는 무시)
        # input_ids: [B, L]
        x = self.embedding(input_ids) # [B, L, d_model]

        # Embedding Dropout 적용
        x = self.dropout_emb(x)

        # Mamba 블록 통과 (Residual connection 및 Dropout은 블록 내부에 구현됨)
        for layer in self.layers:
            x = layer(x)

        # Final normalization
        x = self.norm_f(x)

        # LM Head
        logits = self.lm_head(x) # [B, L, vocab_size]

        return logits, _

In [None]:
STATE_SIZE = 8

x = torch.randint(0, vocab_size, (BATCH_SIZE, SEQUENCE_LENGTH)).to(device)

mamba = Mamba(HIDDEN_DIM, NUM_LAYERS, vocab_size, STATE_SIZE, d_conv=4, expand=3).to(device)

test_output, _ = mamba(x)
print(f"test_output.shape = {test_output.shape}") 

torch.Size([2048, 64, 8])
x.shape=torch.Size([2048, 64, 8]), self.d_model=8
torch.Size([2048, 64, 8])
x.shape=torch.Size([2048, 64, 8]), self.d_model=8
torch.Size([2048, 64, 8])
x.shape=torch.Size([2048, 64, 8]), self.d_model=8
test_output.shape = torch.Size([2048, 64, 8])


In [29]:
from torchinfo import summary

summary(mamba.to(device), input_size=(batch_size, seq_len, d_model))

torch.Size([2048, 64, 8])
x.shape=torch.Size([2048, 64, 8]), self.d_model=8
torch.Size([2048, 64, 8])
x.shape=torch.Size([2048, 64, 8]), self.d_model=8
torch.Size([2048, 64, 8])
x.shape=torch.Size([2048, 64, 8]), self.d_model=8


Layer (type:depth-idx)                   Output Shape              Param #
Mamba                                    [2048, 64, 8]             --
├─ModuleList: 1-1                        --                        --
│    └─MambaBlock: 2-1                   [2048, 64, 8]             --
│    │    └─RMSNorm: 3-1                 [2048, 64, 8]             8
│    │    └─Linear: 3-2                  [2048, 64, 16]            144
│    │    └─Conv1d: 3-3                  [2048, 64, 16]            12,352
│    │    └─Linear: 3-4                  [2048, 64, 16]            272
│    │    └─S6: 3-5                      [2048, 64, 16]            6,672
│    │    └─Linear: 3-6                  [2048, 64, 16]            144
│    │    └─Linear: 3-7                  [2048, 64, 8]             136
│    └─MambaBlock: 2-2                   [2048, 64, 8]             --
│    │    └─RMSNorm: 3-8                 [2048, 64, 8]             8
│    │    └─Linear: 3-9                  [2048, 64, 16]            144
│    

In [30]:
train_and_test("mamba", mamba, start_text)

                                                

torch.Size([2048, 64])
x.shape=torch.Size([2048, 64]), self.d_model=8




RuntimeError: The size of tensor a (64) must match the size of tensor b (8) at non-singleton dimension 1