In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional, Union


class CausalConv1D(nn.Module):
    """Causal 1D Depthwise Convolution"""
    
    def __init__(self, dim: int, kernel_size: int):
        super().__init__()
        self.kernel_size = kernel_size
        self.depthwise = nn.Conv1d(dim, dim, kernel_size, groups=dim)
        self.pointwise = nn.Conv1d(dim, dim, kernel_size=1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch, seq_len, dim)
        x = x.transpose(1, 2)  # (batch, dim, seq_len)
        x = F.pad(x, (self.kernel_size - 1, 0))  # causal padding
        x = self.pointwise(self.depthwise(x))
        return x.transpose(1, 2)  # (batch, seq_len, dim)


class MinGRUFunctions:
    """MinGRU core mathematical functions"""
    
    @staticmethod
    def g_activation(x: torch.Tensor) -> torch.Tensor:
        """Custom activation function - optimized with inplace ops"""
        positive_mask = x >= 0
        result = x.clone()  # More efficient than empty_like + assignment
        result[positive_mask] += 0.5
        result[~positive_mask] = torch.sigmoid_(result[~positive_mask])  # inplace sigmoid
        return result
    
    @staticmethod
    def log_g_activation(x: torch.Tensor) -> torch.Tensor:
        """Log version of g activation"""
        positive_mask = x >= 0
        result = torch.empty_like(x)
        result[positive_mask] = torch.log(x[positive_mask] + 0.5).to(x.dtype)
        result[~positive_mask] = -F.softplus(-x[~positive_mask]).to(x.dtype)
        return result
    
    @staticmethod
    def parallel_scan(log_gates: torch.Tensor, log_values: torch.Tensor) -> torch.Tensor:
        """Parallel scan operation - fixed dimension handling"""
        # Cumulative sum of log gates with proper padding
        cumsum_gates = F.pad(torch.cumsum(log_gates, dim=1), [0, 0, 1, 0])  # pad dim=1 (seq_len)
        
        # Log-cumsum-exp operation
        adjusted_values = log_values - cumsum_gates
        cumsum_values = torch.logcumsumexp(adjusted_values, dim=1)
        
        # Combine results
        log_output = cumsum_gates + cumsum_values
        return torch.exp(log_output)
    
    @staticmethod
    def mingru_step(gate: torch.Tensor, hidden: torch.Tensor, prev_state: torch.Tensor) -> torch.Tensor:
        """Single MinGRU computation step - fused operations"""
        eps = 1e-12
        
        # Fused log computation with numerical stability
        log_prev = torch.log(torch.clamp(prev_state, min=eps))
        
        # Fused softplus operations (more numerically stable)
        gate_clamped = torch.clamp(gate, min=-20, max=20)  # Prevent overflow
        log_forget = -F.softplus(gate_clamped)
        log_update = -F.softplus(-gate_clamped)
        
        # Fused activation
        log_candidate = MinGRUFunctions.log_g_activation(hidden)
        
        # Concatenate with pre-allocation
        batch_size, seq_len, hidden_dim = gate.shape
        log_states = torch.empty(batch_size, seq_len + 1, hidden_dim, 
                                dtype=gate.dtype, device=gate.device)
        log_states[:, :1] = log_prev
        log_states[:, 1:] = log_update + log_candidate
        
        sequence_output = MinGRUFunctions.parallel_scan(log_forget, log_states)
        return sequence_output[:, 1:]


class MinGRULayer(nn.Module):
    """Single MinGRU layer with optimized initialization"""
    
    def __init__(self, input_dim: int, hidden_dim: int, use_bias: bool = True):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        self.gate_projection = nn.Linear(input_dim, hidden_dim * 2, bias=use_bias)
        self.residual_projection = nn.Linear(input_dim, hidden_dim, bias=False) if input_dim != hidden_dim else nn.Identity()
        
        self._init_weights()
    
    def _init_weights(self):
        """Optimized weight initialization"""
        # Xavier/Glorot initialization for better gradient flow
        nn.init.xavier_uniform_(self.gate_projection.weight)
        if self.gate_projection.bias is not None:
            # Initialize gate bias to favor forgetting (negative bias)
            nn.init.constant_(self.gate_projection.bias[:self.hidden_dim], -1.0)  # forget gate
            nn.init.constant_(self.gate_projection.bias[self.hidden_dim:], 0.0)   # input gate
        
        if not isinstance(self.residual_projection, nn.Identity):
            nn.init.xavier_uniform_(self.residual_projection.weight)
    
    def forward(self, x: torch.Tensor, prev_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Project input to gate and hidden
        gate_hidden = self.gate_projection(x)
        gate, hidden = gate_hidden.chunk(2, dim=-1)
        
        # Apply MinGRU computation
        output = MinGRUFunctions.mingru_step(gate, hidden, prev_state)
        
        # Residual connection
        if not isinstance(self.residual_projection, nn.Identity):
            output = output + self.residual_projection(x)
        
        # Return output and last hidden state
        next_state = output[:, -1:, :]
        return output, next_state


class MinGRU(nn.Module):
    """Multi-layer MinGRU with state caching"""
    
    def __init__(self, input_dim: int, hidden_dims: List[int], dropout: float = 0.0):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.num_layers = len(hidden_dims)
        
        # Build layers
        layer_dims = [input_dim] + hidden_dims
        self.layers = nn.ModuleList([
            MinGRULayer(layer_dims[i], layer_dims[i + 1])
            for i in range(self.num_layers)
        ])
        
        # Dropout (except for last layer)
        self.dropouts = nn.ModuleList([
            nn.Dropout(dropout) if i < self.num_layers - 1 else nn.Identity()
            for i in range(self.num_layers)
        ])
        
        # State cache for inference optimization
        self._cached_states: Optional[List[torch.Tensor]] = None
        self._cache_batch_size: int = 0
    
    def init_states(self, batch_size: int, device: torch.device) -> List[torch.Tensor]:
        """Initialize hidden states with optimized memory layout"""
        states = [
            MinGRUFunctions.g_activation(torch.zeros(batch_size, 1, dim, device=device, dtype=torch.float32))
            for dim in self.hidden_dims
        ]
        
        # Cache states for inference
        if not self.training:
            self._cached_states = [state.clone() for state in states]
            self._cache_batch_size = batch_size
            
        return states
    
    def get_cached_states(self, batch_size: int, device: torch.device) -> Optional[List[torch.Tensor]]:
        """Get cached states if available and valid"""
        if (self._cached_states is not None and 
            self._cache_batch_size == batch_size and 
            not self.training):
            return [state.clone() for state in self._cached_states]
        return None
    
    def forward(self, x: torch.Tensor, states: Optional[List[torch.Tensor]] = None) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        if states is None:
            # Try to use cached states first
            states = self.get_cached_states(x.size(0), x.device)
            if states is None:
                states = self.init_states(x.size(0), x.device)
        
        output = x
        next_states = []
        
        for i, (layer, dropout) in enumerate(zip(self.layers, self.dropouts)):
            output, next_state = layer(output, states[i])
            output = dropout(output)
            next_states.append(next_state)
        
        # Update cache for next inference
        if not self.training:
            self._cached_states = [state.clone() for state in next_states]
        
        return output, next_states


class SwiGLU(nn.Module):
    """SwiGLU Feed-Forward Network with mixed precision support"""
    
    def __init__(self, dim: int, expansion_factor: float = 2.0, dropout: float = 0.0):
        super().__init__()
        hidden_dim = int(dim * expansion_factor * 2/3)
        hidden_dim = ((hidden_dim + 7) // 8) * 8  # Round to multiple of 8 for tensor cores
        
        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.SiLU()
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights for mixed precision training"""
        for module in [self.gate_proj, self.up_proj, self.down_proj]:
            nn.init.xavier_uniform_(module.weight)
            # Scale down weights for numerical stability in fp16
            module.weight.data *= 0.5
    
    @torch.amp.autocast('cuda',)  # Enable automatic mixed precision
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Fused gating operation
        gate = self.activation(self.gate_proj(x))
        up = self.up_proj(x)
        
        # Element-wise multiplication in fp16, but accumulate in fp32
        gated = gate * up
        return self.dropout(self.down_proj(gated))


class MinGRUDecoder(nn.Module):
    """Complete MinGRU-based decoder model with optimizations"""
    
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        hidden_dims: Union[int, List[int]],
        num_layers: Optional[int] = None,
        dropout: float = 0.1,
        use_conv: bool = True,
        conv_kernel: int = 3,
        ffn_expansion: float = 1.0,
        norm_eps: float = 1e-8,
        use_gradient_checkpointing: bool = False,
    ):
        super().__init__()
        
        # Handle hidden dimensions
        if isinstance(hidden_dims, int):
            if num_layers is None:
                raise ValueError("num_layers must be specified when hidden_dims is int")
            hidden_dims = [hidden_dims] * num_layers
        
        self.embed_dim = embed_dim
        self.hidden_dims = hidden_dims
        self.use_gradient_checkpointing = use_gradient_checkpointing
        
        # Input layers
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.input_proj = nn.Linear(embed_dim, hidden_dims[0]) if embed_dim != hidden_dims[0] else nn.Identity()
        
        # Optional convolution
        self.conv = CausalConv1D(hidden_dims[0], conv_kernel) if use_conv else None
        
        # Core MinGRU
        self.pre_gru_norm = nn.RMSNorm(hidden_dims[0], eps=norm_eps)
        self.mingru = MinGRU(hidden_dims[0], hidden_dims, dropout)
        
        # Feed-forward network
        self.post_gru_norm = nn.RMSNorm(hidden_dims[-1], eps=norm_eps)
        self.ffn = SwiGLU(hidden_dims[-1], ffn_expansion, dropout)
        
        # Output layers
        self.final_norm = nn.RMSNorm(hidden_dims[-1], eps=norm_eps)
        self.output_proj = nn.Linear(hidden_dims[-1], vocab_size)
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize all model weights"""
        # Embedding initialization
        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
        
        # Linear layer initialization
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def _forward_block(self, h: torch.Tensor) -> torch.Tensor:
        """Forward pass through MinGRU and FFN blocks"""
        # MinGRU processing
        h_norm = self.pre_gru_norm(h)
        gru_out, _ = self.mingru(h_norm)
        h = h + gru_out  # Residual connection
        
        # Feed-forward with residual
        h_norm = self.post_gru_norm(h)
        ffn_out = self.ffn(h_norm)
        h = h + ffn_out  # Residual connection
        
        return h
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Embedding and input projection
        h = self.embedding(x)
        h = self.input_proj(h)
        
        # Optional convolution with residual
        if self.conv is not None:
            h = h + self.conv(h)
        
        # Main processing with optional gradient checkpointing
        if self.use_gradient_checkpointing and self.training:
            h = torch.utils.checkpoint.checkpoint(self._forward_block, h, use_reentrant=False)
        else:
            h = self._forward_block(h)
        
        # Final output
        h = self.final_norm(h)
        logits = self.output_proj(h)
        
        return logits


# Example usage with optimizations
if __name__ == "__main__":
    # Enable optimizations
    torch.backends.cudnn.benchmark = True  # Optimize for fixed input sizes
    torch.backends.cuda.matmul.allow_tf32 = True  # Enable TF32 on Ampere GPUs
    
    model = MinGRUDecoder(
        vocab_size=10000,
        embed_dim=512,
        hidden_dims=[512, 512],
        dropout=0.1,
        use_conv=True,
        conv_kernel=3,
        ffn_expansion=2.0,
        use_gradient_checkpointing=True  # Enable memory-efficient training
    )
    
    # Compile model for optimization (PyTorch 2.0+)
    model = torch.compile(model, mode="reduce-overhead")
    
    # Test forward pass
    batch_size, seq_len = 4, 128
    input_ids = torch.randint(0, 10000, (batch_size, seq_len))
    
    # Mixed precision training
    with torch.amp.autocast('cuda'):
        output = model(input_ids)
    
    print(f"Input shape: {input_ids.shape}")
    print(f"Output shape: {output.shape}")  # Should be (4, 128, 10000)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Memory usage optimization example
    if torch.cuda.is_available():
        model = model.cuda()
        input_ids = input_ids.cuda()
        
        # Clear cache and measure memory
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        
        with torch.amp.autocast('cuda'):
            output = model(input_ids)
        
        peak_memory = torch.cuda.max_memory_allocated() / 1024**3  # GB
        print(f"Peak GPU memory usage: {peak_memory:.2f} GB")

  @torch.cuda.amp.autocast()  # Enable automatic mixed precision


Input shape: torch.Size([4, 128])
Output shape: torch.Size([4, 128, 10000])
Model parameters: 12,623,632
Peak GPU memory usage: 0.42 GB


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional, Union


class CausalConv1D(nn.Module):
    """Causal 1D Depthwise Convolution"""
    
    def __init__(self, dim: int, kernel_size: int):
        super().__init__()
        self.kernel_size = kernel_size
        self.depthwise = nn.Conv1d(dim, dim, kernel_size, groups=dim)
        self.pointwise = nn.Conv1d(dim, dim, kernel_size=1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch, seq_len, dim)
        x = x.transpose(1, 2)  # (batch, dim, seq_len)
        x = F.pad(x, (self.kernel_size - 1, 0))  # causal padding
        x = self.pointwise(self.depthwise(x))
        return x.transpose(1, 2)  # (batch, seq_len, dim)


class MinGRUFunctions:
    """MinGRU core mathematical functions"""
    
    @staticmethod
    def g_activation(x: torch.Tensor) -> torch.Tensor:
        """Custom activation function - optimized with inplace ops"""
        positive_mask = x >= 0
        result = x.clone()  # More efficient than empty_like + assignment
        result[positive_mask] += 0.5
        result[~positive_mask] = torch.sigmoid_(result[~positive_mask])  # inplace sigmoid
        return result
    
    @staticmethod
    def log_g_activation(x: torch.Tensor) -> torch.Tensor:
        """Log version of g activation"""
        positive_mask = x >= 0
        result = torch.empty_like(x)
        result[positive_mask] = torch.log(x[positive_mask] + 0.5).to(x.dtype)
        result[~positive_mask] = -F.softplus(-x[~positive_mask]).to(x.dtype)
        return result
    
    @staticmethod
    def parallel_scan(log_gates: torch.Tensor, log_values: torch.Tensor) -> torch.Tensor:
        """Parallel scan operation - fixed dimension handling"""
        # Cumulative sum of log gates with proper padding
        cumsum_gates = F.pad(torch.cumsum(log_gates, dim=1), [0, 0, 1, 0])  # pad dim=1 (seq_len)
        
        # Log-cumsum-exp operation
        adjusted_values = log_values - cumsum_gates
        cumsum_values = torch.logcumsumexp(adjusted_values, dim=1)
        
        # Combine results
        log_output = cumsum_gates + cumsum_values
        return torch.exp(log_output)
    
    @staticmethod
    def mingru_step(gate: torch.Tensor, hidden: torch.Tensor, prev_state: torch.Tensor) -> torch.Tensor:
        """Single MinGRU computation step - fixed concatenation"""
        eps = 1e-12
        
        # Safe log computation with numerical stability
        log_prev = torch.log(torch.clamp(prev_state, min=eps))
        
        # Compute log probabilities with clamping for stability
        gate_clamped = torch.clamp(gate, min=-20, max=20)
        log_forget = -F.softplus(gate_clamped)
        log_update = -F.softplus(-gate_clamped)
        
        # Compute log candidate
        log_candidate = MinGRUFunctions.log_g_activation(hidden)
        
        # Concatenate along sequence dimension (dim=1)
        log_states = torch.cat([log_prev, log_update + log_candidate], dim=1)
        
        # Apply parallel scan
        sequence_output = MinGRUFunctions.parallel_scan(log_forget, log_states)
        return sequence_output[:, 1:]  # Remove initial state


class MinGRULayer(nn.Module):
    """Single MinGRU layer with optimized initialization"""
    
    def __init__(self, input_dim: int, hidden_dim: int, use_bias: bool = True):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        self.gate_projection = nn.Linear(input_dim, hidden_dim * 2, bias=use_bias)
        self.residual_projection = nn.Linear(input_dim, hidden_dim, bias=False) if input_dim != hidden_dim else nn.Identity()
        
        self._init_weights()
    
    def _init_weights(self):
        """Optimized weight initialization"""
        # Xavier/Glorot initialization for better gradient flow
        nn.init.xavier_uniform_(self.gate_projection.weight)
        if self.gate_projection.bias is not None:
            # Initialize gate bias to favor forgetting (negative bias)
            nn.init.constant_(self.gate_projection.bias[:self.hidden_dim], -1.0)  # forget gate
            nn.init.constant_(self.gate_projection.bias[self.hidden_dim:], 0.0)   # input gate
        
        if not isinstance(self.residual_projection, nn.Identity):
            nn.init.xavier_uniform_(self.residual_projection.weight)
    
    def forward(self, x: torch.Tensor, prev_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Project input to gate and hidden
        gate_hidden = self.gate_projection(x)
        gate, hidden = gate_hidden.chunk(2, dim=-1)
        
        # Apply MinGRU computation
        output = MinGRUFunctions.mingru_step(gate, hidden, prev_state)
        
        # Residual connection
        if not isinstance(self.residual_projection, nn.Identity):
            output = output + self.residual_projection(x)
        
        # Return output and last hidden state
        next_state = output[:, -1:, :]
        return output, next_state


class MinGRU(nn.Module):
    """Multi-layer MinGRU with state caching"""
    
    def __init__(self, input_dim: int, hidden_dims: List[int], dropout: float = 0.0):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.num_layers = len(hidden_dims)
        
        # Build layers
        layer_dims = [input_dim] + hidden_dims
        self.layers = nn.ModuleList([
            MinGRULayer(layer_dims[i], layer_dims[i + 1])
            for i in range(self.num_layers)
        ])
        
        # Dropout (except for last layer)
        self.dropouts = nn.ModuleList([
            nn.Dropout(dropout) if i < self.num_layers - 1 else nn.Identity()
            for i in range(self.num_layers)
        ])
        
        # State cache for inference optimization
        self._cached_states: Optional[List[torch.Tensor]] = None
        self._cache_batch_size: int = 0
    
    def init_states(self, batch_size: int, device: torch.device) -> List[torch.Tensor]:
        """Initialize hidden states with optimized memory layout"""
        states = [
            MinGRUFunctions.g_activation(torch.zeros(batch_size, 1, dim, device=device, dtype=torch.float32))
            for dim in self.hidden_dims
        ]
        
        # Cache states for inference
        if not self.training:
            self._cached_states = [state.clone() for state in states]
            self._cache_batch_size = batch_size
            
        return states
    
    def get_cached_states(self, batch_size: int, device: torch.device) -> Optional[List[torch.Tensor]]:
        """Get cached states if available and valid"""
        if (self._cached_states is not None and 
            self._cache_batch_size == batch_size and 
            not self.training):
            return [state.clone() for state in self._cached_states]
        return None
    
    def forward(self, x: torch.Tensor, states: Optional[List[torch.Tensor]] = None) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        if states is None:
            # Try to use cached states first
            states = self.get_cached_states(x.size(0), x.device)
            if states is None:
                states = self.init_states(x.size(0), x.device)
        
        output = x
        next_states = []
        
        for i, (layer, dropout) in enumerate(zip(self.layers, self.dropouts)):
            output, next_state = layer(output, states[i])
            output = dropout(output)
            next_states.append(next_state)
        
        # Update cache for next inference
        if not self.training:
            self._cached_states = [state.clone() for state in next_states]
        
        return output, next_states


class SwiGLU(nn.Module):
    """SwiGLU Feed-Forward Network with mixed precision support"""
    
    def __init__(self, dim: int, expansion_factor: float = 2.0, dropout: float = 0.0):
        super().__init__()
        hidden_dim = int(dim * expansion_factor * 2/3)
        hidden_dim = ((hidden_dim + 7) // 8) * 8  # Round to multiple of 8 for tensor cores
        
        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.SiLU()
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights for mixed precision training"""
        for module in [self.gate_proj, self.up_proj, self.down_proj]:
            nn.init.xavier_uniform_(module.weight)
            # Scale down weights for numerical stability in fp16
            module.weight.data *= 0.5
    
    @torch.cuda.amp.autocast()  # Enable automatic mixed precision
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Fused gating operation
        gate = self.activation(self.gate_proj(x))
        up = self.up_proj(x)
        
        # Element-wise multiplication in fp16, but accumulate in fp32
        gated = gate * up
        return self.dropout(self.down_proj(gated))


class MinGRUDecoder(nn.Module):
    """Complete MinGRU-based decoder model with optimizations"""
    
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        hidden_dims: Union[int, List[int]],
        num_layers: Optional[int] = None,
        dropout: float = 0.1,
        use_conv: bool = True,
        conv_kernel: int = 3,
        ffn_expansion: float = 1.0,
        norm_eps: float = 1e-8,
        use_gradient_checkpointing: bool = False,
    ):
        super().__init__()
        
        # Handle hidden dimensions
        if isinstance(hidden_dims, int):
            if num_layers is None:
                raise ValueError("num_layers must be specified when hidden_dims is int")
            hidden_dims = [hidden_dims] * num_layers
        
        self.embed_dim = embed_dim
        self.hidden_dims = hidden_dims
        self.use_gradient_checkpointing = use_gradient_checkpointing
        
        # Input layers
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.input_proj = nn.Linear(embed_dim, hidden_dims[0]) if embed_dim != hidden_dims[0] else nn.Identity()
        
        # Optional convolution
        self.conv = CausalConv1D(hidden_dims[0], conv_kernel) if use_conv else None
        
        # Core MinGRU
        self.pre_gru_norm = nn.RMSNorm(hidden_dims[0], eps=norm_eps)
        self.mingru = MinGRU(hidden_dims[0], hidden_dims, dropout)
        
        # Feed-forward network
        self.post_gru_norm = nn.RMSNorm(hidden_dims[-1], eps=norm_eps)
        self.ffn = SwiGLU(hidden_dims[-1], ffn_expansion, dropout)
        
        # Output layers
        self.final_norm = nn.RMSNorm(hidden_dims[-1], eps=norm_eps)
        self.output_proj = nn.Linear(hidden_dims[-1], vocab_size)
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize all model weights"""
        # Embedding initialization
        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
        
        # Linear layer initialization
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def _forward_block(self, h: torch.Tensor) -> torch.Tensor:
        """Forward pass through MinGRU and FFN blocks"""
        # MinGRU processing
        h_norm = self.pre_gru_norm(h)
        gru_out, _ = self.mingru(h_norm)
        h = h + gru_out  # Residual connection
        
        # Feed-forward with residual
        h_norm = self.post_gru_norm(h)
        ffn_out = self.ffn(h_norm)
        h = h + ffn_out  # Residual connection
        
        return h
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Embedding and input projection
        h = self.embedding(x)
        h = self.input_proj(h)
        
        # Optional convolution with residual
        if self.conv is not None:
            h = h + self.conv(h)
        
        # Main processing with optional gradient checkpointing
        if self.use_gradient_checkpointing and self.training:
            h = torch.utils.checkpoint.checkpoint(self._forward_block, h, use_reentrant=False)
        else:
            h = self._forward_block(h)
        
        # Final output
        h = self.final_norm(h)
        logits = self.output_proj(h)
        
        return logits


# Example usage with optimizations
if __name__ == "__main__":
    # Enable optimizations
    torch.backends.cudnn.benchmark = True  # Optimize for fixed input sizes
    torch.backends.cuda.matmul.allow_tf32 = True  # Enable TF32 on Ampere GPUs
    
    model = MinGRUDecoder(
        vocab_size=10000,
        embed_dim=512,
        hidden_dims=[512, 512],
        dropout=0.1,
        use_conv=True,
        conv_kernel=3,
        ffn_expansion=2.0,
        use_gradient_checkpointing=True  # Enable memory-efficient training
    )
    
    # Compile model for optimization (PyTorch 2.0+)
    model = torch.compile(model, mode="reduce-overhead")
    
    # Test forward pass
    batch_size, seq_len = 4, 128
    input_ids = torch.randint(0, 10000, (batch_size, seq_len))
    
    # Mixed precision training
    with torch.amp.autocast('cuda'):
        output = model(input_ids)
    
    print(f"Input shape: {input_ids.shape}")
    print(f"Output shape: {output.shape}")  # Should be (4, 128, 10000)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Memory usage optimization example
    if torch.cuda.is_available():
        model = model.cuda()
        input_ids = input_ids.cuda()
        
        # Clear cache and measure memory
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        
        with torch.amp.autocast('cuda'):
            output = model(input_ids)
        
        peak_memory = torch.cuda.max_memory_allocated() / 1024**3  # GB
        print(f"Peak GPU memory usage: {peak_memory:.2f} GB")

  @torch.cuda.amp.autocast()  # Enable automatic mixed precision


Input shape: torch.Size([4, 128])
Output shape: torch.Size([4, 128, 10000])
Model parameters: 12,623,632
Peak GPU memory usage: 0.19 GB


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional, Union


class CausalConv1D(nn.Module):
    """Causal 1D Depthwise Convolution"""
    
    def __init__(self, dim: int, kernel_size: int):
        super().__init__()
        self.kernel_size = kernel_size
        self.depthwise = nn.Conv1d(dim, dim, kernel_size, groups=dim)
        self.pointwise = nn.Conv1d(dim, dim, kernel_size=1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch, seq_len, dim)
        x = x.transpose(1, 2)  # (batch, dim, seq_len)
        x = F.pad(x, (self.kernel_size - 1, 0))  # causal padding
        x = self.pointwise(self.depthwise(x))
        return x.transpose(1, 2)  # (batch, seq_len, dim)


class MinGRUFunctions:
    """MinGRU core mathematical functions"""
    
    @staticmethod
    def g_activation(x: torch.Tensor) -> torch.Tensor:
        """Custom activation function - optimized with inplace ops"""
        positive_mask = x >= 0
        result = x.clone()  # More efficient than empty_like + assignment
        result[positive_mask] += 0.5
        result[~positive_mask] = torch.sigmoid_(result[~positive_mask])  # inplace sigmoid
        return result
    
    @staticmethod
    def log_g_activation(x: torch.Tensor) -> torch.Tensor:
        """Log version of g activation - improved numerical stability"""
        positive_mask = x >= 0
        result = torch.empty_like(x)
        
        # For positive values: log(x + 0.5)
        result[positive_mask] = torch.log(x[positive_mask] + 0.5).to(x.dtype)
        
        # For negative values: log(sigmoid(x)) = -softplus(-x)
        result[~positive_mask] = -F.softplus(-x[~positive_mask]).to(x.dtype)
        
        return result
    
    @staticmethod
    def parallel_scan(log_gates: torch.Tensor, log_values: torch.Tensor) -> torch.Tensor:
        """Parallel scan operation - fixed dimension handling"""
        # Cumulative sum of log gates with proper padding
        cumsum_gates = F.pad(torch.cumsum(log_gates, dim=1), [0, 0, 1, 0])  # pad dim=1 (seq_len)
        
        # Log-cumsum-exp operation
        adjusted_values = log_values - cumsum_gates
        cumsum_values = torch.logcumsumexp(adjusted_values, dim=1)
        
        # Combine results
        log_output = cumsum_gates + cumsum_values
        return torch.exp(log_output)
    
    @staticmethod
    def mingru_step(gate: torch.Tensor, hidden: torch.Tensor, prev_state: torch.Tensor) -> torch.Tensor:
        """Single MinGRU computation step - fixed concatenation"""
        eps = 1e-12
        
        # Safe log computation with numerical stability
        log_prev = torch.log(torch.clamp(prev_state, min=eps))
        
        # Compute log probabilities with clamping for stability
        gate_clamped = torch.clamp(gate, min=-20, max=20)
        log_forget = -F.softplus(gate_clamped)
        log_update = -F.softplus(-gate_clamped)
        
        # Compute log candidate
        log_candidate = MinGRUFunctions.log_g_activation(hidden)
        
        # Concatenate along sequence dimension (dim=1)
        log_states = torch.cat([log_prev, log_update + log_candidate], dim=1)
        
        # Apply parallel scan
        sequence_output = MinGRUFunctions.parallel_scan(log_forget, log_states)
        return sequence_output[:, 1:]  # Remove initial state


class MinGRULayer(nn.Module):
    """Single MinGRU layer with optimized initialization"""
    
    def __init__(self, input_dim: int, hidden_dim: int, use_bias: bool = True):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        self.gate_projection = nn.Linear(input_dim, hidden_dim * 2, bias=use_bias)
        self.residual_projection = nn.Linear(input_dim, hidden_dim, bias=False) if input_dim != hidden_dim else nn.Identity()
        
        self._init_weights()
    
    def _init_weights(self):
        """Optimized weight initialization"""
        # Xavier/Glorot initialization for better gradient flow
        nn.init.xavier_uniform_(self.gate_projection.weight)
        if self.gate_projection.bias is not None:
            # Initialize gate bias to favor forgetting (negative bias)
            nn.init.constant_(self.gate_projection.bias[:self.hidden_dim], -1.0)  # forget gate
            nn.init.constant_(self.gate_projection.bias[self.hidden_dim:], 0.0)   # input gate
        
        if not isinstance(self.residual_projection, nn.Identity):
            nn.init.xavier_uniform_(self.residual_projection.weight)
    
    def forward(self, x: torch.Tensor, prev_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Project input to gate and hidden
        gate_hidden = self.gate_projection(x)
        gate, hidden = gate_hidden.chunk(2, dim=-1)
        
        # Apply MinGRU computation
        output = MinGRUFunctions.mingru_step(gate, hidden, prev_state)
        
        # Residual connection
        if not isinstance(self.residual_projection, nn.Identity):
            output = output + self.residual_projection(x)
        
        # Return output and last hidden state
        next_state = output[:, -1:, :]
        return output, next_state


class MinGRU(nn.Module):
    """Multi-layer MinGRU with state caching"""
    
    def __init__(self, input_dim: int, hidden_dims: List[int], dropout: float = 0.0):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.num_layers = len(hidden_dims)
        
        # Build layers
        layer_dims = [input_dim] + hidden_dims
        self.layers = nn.ModuleList([
            MinGRULayer(layer_dims[i], layer_dims[i + 1])
            for i in range(self.num_layers)
        ])
        
        # Dropout (except for last layer)
        self.dropouts = nn.ModuleList([
            nn.Dropout(dropout) if i < self.num_layers - 1 else nn.Identity()
            for i in range(self.num_layers)
        ])
        
        # State cache for inference optimization
        self._cached_states: Optional[List[torch.Tensor]] = None
        self._cache_batch_size: int = 0
    
    def init_states(self, batch_size: int, device: torch.device) -> List[torch.Tensor]:
        """Initialize hidden states with optimized memory layout"""
        states = [
            MinGRUFunctions.g_activation(torch.zeros(batch_size, 1, dim, device=device, dtype=torch.float32))
            for dim in self.hidden_dims
        ]
        
        # Cache states for inference
        if not self.training:
            self._cached_states = [state.clone() for state in states]
            self._cache_batch_size = batch_size
            
        return states
    
    def get_cached_states(self, batch_size: int, device: torch.device) -> Optional[List[torch.Tensor]]:
        """Get cached states if available and valid"""
        if (self._cached_states is not None and 
            self._cache_batch_size == batch_size and 
            not self.training):
            return [state.clone() for state in self._cached_states]
        return None
    
    def forward(self, x: torch.Tensor, states: Optional[List[torch.Tensor]] = None) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        if states is None:
            # Try to use cached states first
            states = self.get_cached_states(x.size(0), x.device)
            if states is None:
                states = self.init_states(x.size(0), x.device)
        
        output = x
        next_states = []
        
        for i, (layer, dropout) in enumerate(zip(self.layers, self.dropouts)):
            output, next_state = layer(output, states[i])
            output = dropout(output)
            next_states.append(next_state)
        
        # Update cache for next inference
        if not self.training:
            self._cached_states = [state.clone() for state in next_states]
        
        return output, next_states


class SwiGLU(nn.Module):
    """SwiGLU Feed-Forward Network with mixed precision support"""
    
    def __init__(self, dim: int, expansion_factor: float = 2.0, dropout: float = 0.0):
        super().__init__()
        hidden_dim = int(dim * expansion_factor * 2/3)
        hidden_dim = ((hidden_dim + 7) // 8) * 8  # Round to multiple of 8 for tensor cores
        
        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.SiLU()
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights for mixed precision training"""
        for module in [self.gate_proj, self.up_proj, self.down_proj]:
            nn.init.xavier_uniform_(module.weight)
            # Scale down weights for numerical stability in fp16
            module.weight.data *= 0.5
    
    @torch.cuda.amp.autocast()  # Enable automatic mixed precision
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Fused gating operation
        gate = self.activation(self.gate_proj(x))
        up = self.up_proj(x)
        
        # Element-wise multiplication in fp16, but accumulate in fp32
        gated = gate * up
        return self.dropout(self.down_proj(gated))


class MinGRUDecoder(nn.Module):
    """Complete MinGRU-based decoder model with optimizations"""
    
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        hidden_dims: Union[int, List[int]],
        num_layers: Optional[int] = None,
        dropout: float = 0.1,
        use_conv: bool = True,
        conv_kernel: int = 3,
        ffn_expansion: float = 1.0,
        norm_eps: float = 1e-8,
        use_gradient_checkpointing: bool = False,
    ):
        super().__init__()
        
        # Handle hidden dimensions
        if isinstance(hidden_dims, int):
            if num_layers is None:
                raise ValueError("num_layers must be specified when hidden_dims is int")
            hidden_dims = [hidden_dims] * num_layers
        
        self.embed_dim = embed_dim
        self.hidden_dims = hidden_dims
        self.use_gradient_checkpointing = use_gradient_checkpointing
        
        # Input layers
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.input_proj = nn.Linear(embed_dim, hidden_dims[0]) if embed_dim != hidden_dims[0] else nn.Identity()
        
        # Optional convolution
        self.conv = CausalConv1D(hidden_dims[0], conv_kernel) if use_conv else None
        
        # Core MinGRU
        self.pre_gru_norm = nn.RMSNorm(hidden_dims[0], eps=norm_eps)
        self.mingru = MinGRU(hidden_dims[0], hidden_dims, dropout)
        
        # Feed-forward network
        self.post_gru_norm = nn.RMSNorm(hidden_dims[-1], eps=norm_eps)
        self.ffn = SwiGLU(hidden_dims[-1], ffn_expansion, dropout)
        
        # Output layers
        self.final_norm = nn.RMSNorm(hidden_dims[-1], eps=norm_eps)
        self.output_proj = nn.Linear(hidden_dims[-1], vocab_size)
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize all model weights"""
        # Embedding initialization
        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
        
        # Linear layer initialization
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def _forward_block(self, h: torch.Tensor) -> torch.Tensor:
        """Forward pass through MinGRU and FFN blocks"""
        # MinGRU processing
        h_norm = self.pre_gru_norm(h)
        gru_out, _ = self.mingru(h_norm)
        h = h + gru_out  # Residual connection
        
        # Feed-forward with residual
        h_norm = self.post_gru_norm(h)
        ffn_out = self.ffn(h_norm)
        h = h + ffn_out  # Residual connection
        
        return h
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Embedding and input projection
        h = self.embedding(x)
        h = self.input_proj(h)
        
        # Optional convolution with residual
        if self.conv is not None:
            h = h + self.conv(h)
        
        # Main processing with optional gradient checkpointing
        if self.use_gradient_checkpointing and self.training:
            h = torch.utils.checkpoint.checkpoint(self._forward_block, h, use_reentrant=False)
        else:
            h = self._forward_block(h)
        
        # Final output
        h = self.final_norm(h)
        logits = self.output_proj(h)
        
        return logits


# Example usage with optimizations
if __name__ == "__main__":
    # Enable optimizations
    torch.backends.cudnn.benchmark = True  # Optimize for fixed input sizes
    torch.backends.cuda.matmul.allow_tf32 = True  # Enable TF32 on Ampere GPUs
    
    model = MinGRUDecoder(
        vocab_size=10000,
        embed_dim=512,
        hidden_dims=[512, 512],
        dropout=0.1,
        use_conv=True,
        conv_kernel=3,
        ffn_expansion=2.0,
        use_gradient_checkpointing=True  # Enable memory-efficient training
    )
    
    # Compile model for optimization (PyTorch 2.0+)
    model = torch.compile(model, mode="reduce-overhead")
    
    # Test forward pass
    batch_size, seq_len = 4, 128
    input_ids = torch.randint(0, 10000, (batch_size, seq_len))
    
    # Mixed precision training
    with torch.amp.autocast('cuda'):
        output = model(input_ids)
    
    print(f"Input shape: {input_ids.shape}")
    print(f"Output shape: {output.shape}")  # Should be (4, 128, 10000)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Memory usage optimization example
    if torch.cuda.is_available():
        model = model.cuda()
        input_ids = input_ids.cuda()
        
        # Clear cache and measure memory
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        
        with torch.amp.autocast('cuda'):
            output = model(input_ids)
        
        peak_memory = torch.cuda.max_memory_allocated() / 1024**3  # GB
        print(f"Peak GPU memory usage: {peak_memory:.2f} GB")

  @torch.cuda.amp.autocast()  # Enable automatic mixed precision


Input shape: torch.Size([4, 128])
Output shape: torch.Size([4, 128, 10000])
Model parameters: 12,623,632
Peak GPU memory usage: 0.19 GB


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional, Union

# Conditional import for mixed precision
try:
    from torch.amp import autocast
    HAS_AMP = True
except ImportError:
    HAS_AMP = False
    # Dummy decorator if AMP is not available
    def autocast():
        def decorator(func):
            return func('cuda')
        return decorator


class CausalConv1D(nn.Module):
    """Causal 1D Depthwise Convolution"""
    
    def __init__(self, dim: int, kernel_size: int):
        super().__init__()
        self.kernel_size = kernel_size
        self.depthwise = nn.Conv1d(dim, dim, kernel_size, groups=dim)
        self.pointwise = nn.Conv1d(dim, dim, kernel_size=1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch, seq_len, dim)
        x = x.transpose(1, 2)  # (batch, dim, seq_len)
        x = F.pad(x, (self.kernel_size - 1, 0))  # causal padding
        x = self.pointwise(self.depthwise(x))
        return x.transpose(1, 2)  # (batch, seq_len, dim)


class MinGRUFunctions:
    """MinGRU core mathematical functions"""
    
    @staticmethod
    def g_activation(x: torch.Tensor) -> torch.Tensor:
        """Custom activation function - optimized with inplace ops"""
        positive_mask = x >= 0
        result = x.clone()  # More efficient than empty_like + assignment
        result[positive_mask] += 0.5
        result[~positive_mask] = torch.sigmoid_(result[~positive_mask])  # inplace sigmoid
        return result
    
    @staticmethod
    def log_g_activation(x: torch.Tensor) -> torch.Tensor:
        """Log version of g activation - improved numerical stability"""
        positive_mask = x >= 0
        result = torch.empty_like(x)
        
        # For positive values: log(x + 0.5)
        result[positive_mask] = torch.log(x[positive_mask] + 0.5).to(x.dtype)
        
        # For negative values: log(sigmoid(x)) = -softplus(-x)
        result[~positive_mask] = -F.softplus(-x[~positive_mask]).to(x.dtype)
        
        return result
    
    @staticmethod
    def parallel_scan(log_gates: torch.Tensor, log_values: torch.Tensor) -> torch.Tensor:
        """Parallel scan operation - fixed dimension handling"""
        # Cumulative sum of log gates with proper padding
        cumsum_gates = F.pad(torch.cumsum(log_gates, dim=1), [0, 0, 1, 0])  # pad dim=1 (seq_len)
        
        # Log-cumsum-exp operation
        adjusted_values = log_values - cumsum_gates
        cumsum_values = torch.logcumsumexp(adjusted_values, dim=1)
        
        # Combine results
        log_output = cumsum_gates + cumsum_values
        return torch.exp(log_output)
    
    @staticmethod
    def mingru_step(gate: torch.Tensor, hidden: torch.Tensor, prev_state: torch.Tensor) -> torch.Tensor:
        """Single MinGRU computation step - fixed concatenation"""
        eps = 1e-12
        
        # Safe log computation with numerical stability
        log_prev = torch.log(torch.clamp(prev_state, min=eps))
        
        # Compute log probabilities with clamping for stability
        gate_clamped = torch.clamp(gate, min=-20, max=20)
        log_forget = -F.softplus(gate_clamped)
        log_update = -F.softplus(-gate_clamped)
        
        # Compute log candidate
        log_candidate = MinGRUFunctions.log_g_activation(hidden)
        
        # Concatenate along sequence dimension (dim=1)
        log_states = torch.cat([log_prev, log_update + log_candidate], dim=1)
        
        # Apply parallel scan
        sequence_output = MinGRUFunctions.parallel_scan(log_forget, log_states)
        return sequence_output[:, 1:]  # Remove initial state


class MinGRULayer(nn.Module):
    """Single MinGRU layer with optimized initialization"""
    
    def __init__(self, input_dim: int, hidden_dim: int, use_bias: bool = True):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        self.gate_projection = nn.Linear(input_dim, hidden_dim * 2, bias=use_bias)
        self.residual_projection = nn.Linear(input_dim, hidden_dim, bias=False) if input_dim != hidden_dim else nn.Identity()
        
        self._init_weights()
    
    def _init_weights(self):
        """Optimized weight initialization"""
        # Xavier/Glorot initialization for better gradient flow
        nn.init.xavier_uniform_(self.gate_projection.weight)
        if self.gate_projection.bias is not None:
            # Initialize gate bias to favor forgetting (negative bias)
            nn.init.constant_(self.gate_projection.bias[:self.hidden_dim], -1.0)  # forget gate
            nn.init.constant_(self.gate_projection.bias[self.hidden_dim:], 0.0)   # input gate
        
        if not isinstance(self.residual_projection, nn.Identity):
            nn.init.xavier_uniform_(self.residual_projection.weight)
    
    def forward(self, x: torch.Tensor, prev_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Project input to gate and hidden
        gate_hidden = self.gate_projection(x)
        gate, hidden = gate_hidden.chunk(2, dim=-1)
        
        # Apply MinGRU computation
        output = MinGRUFunctions.mingru_step(gate, hidden, prev_state)
        
        # Residual connection
        if not isinstance(self.residual_projection, nn.Identity):
            output = output + self.residual_projection(x)
        
        # Return output and last hidden state
        next_state = output[:, -1:, :]
        return output, next_state


class MinGRU(nn.Module):
    """Multi-layer MinGRU with state caching"""
    
    def __init__(self, input_dim: int, hidden_dims: List[int], dropout: float = 0.0):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.num_layers = len(hidden_dims)
        
        # Build layers
        layer_dims = [input_dim] + hidden_dims
        self.layers = nn.ModuleList([
            MinGRULayer(layer_dims[i], layer_dims[i + 1])
            for i in range(self.num_layers)
        ])
        
        # Dropout (except for last layer)
        self.dropouts = nn.ModuleList([
            nn.Dropout(dropout) if i < self.num_layers - 1 else nn.Identity()
            for i in range(self.num_layers)
        ])
        
        # State cache for inference optimization
        self._cached_states: Optional[List[torch.Tensor]] = None
        self._cache_batch_size: int = 0
    
    def init_states(self, batch_size: int, device: torch.device) -> List[torch.Tensor]:
        """Initialize hidden states with optimized memory layout"""
        states = [
            MinGRUFunctions.g_activation(torch.zeros(batch_size, 1, dim, device=device, dtype=torch.float32))
            for dim in self.hidden_dims
        ]
        
        # Cache states for inference
        if not self.training:
            self._cached_states = [state.clone() for state in states]
            self._cache_batch_size = batch_size
            
        return states
    
    def get_cached_states(self, batch_size: int, device: torch.device) -> Optional[List[torch.Tensor]]:
        """Get cached states if available and valid"""
        if (self._cached_states is not None and 
            self._cache_batch_size == batch_size and 
            not self.training):
            return [state.clone() for state in self._cached_states]
        return None
    
    def forward(self, x: torch.Tensor, states: Optional[List[torch.Tensor]] = None) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        if states is None:
            # Try to use cached states first
            states = self.get_cached_states(x.size(0), x.device)
            if states is None:
                states = self.init_states(x.size(0), x.device)
        
        output = x
        next_states = []
        
        for i, (layer, dropout) in enumerate(zip(self.layers, self.dropouts)):
            output, next_state = layer(output, states[i])
            output = dropout(output)
            next_states.append(next_state)
        
        # Update cache for next inference
        if not self.training:
            self._cached_states = [state.clone() for state in next_states]
        
        return output, next_states


class SwiGLU(nn.Module):
    """SwiGLU Feed-Forward Network with mixed precision support"""
    
    def __init__(self, dim: int, expansion_factor: float = 2.0, dropout: float = 0.0):
        super().__init__()
        hidden_dim = int(dim * expansion_factor * 2/3)
        hidden_dim = ((hidden_dim + 7) // 8) * 8  # Round to multiple of 8 for tensor cores
        
        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.SiLU()
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights for mixed precision training"""
        for module in [self.gate_proj, self.up_proj, self.down_proj]:
            nn.init.xavier_uniform_(module.weight)
            # Scale down weights for numerical stability in fp16
            module.weight.data *= 0.5
    
    @torch.cuda.amp.autocast()  # Enable automatic mixed precision
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Fused gating operation
        gate = self.activation(self.gate_proj(x))
        up = self.up_proj(x)
        
        # Element-wise multiplication in fp16, but accumulate in fp32
        gated = gate * up
        return self.dropout(self.down_proj(gated))


class MinGRUDecoder(nn.Module):
    """Complete MinGRU-based decoder model with optimizations"""
    
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        hidden_dims: Union[int, List[int]],
        num_layers: Optional[int] = None,
        dropout: float = 0.1,
        use_conv: bool = True,
        conv_kernel: int = 3,
        ffn_expansion: float = 1.0,
        norm_eps: float = 1e-8,
        use_gradient_checkpointing: bool = False,
    ):
        super().__init__()
        
        # Handle hidden dimensions
        if isinstance(hidden_dims, int):
            if num_layers is None:
                raise ValueError("num_layers must be specified when hidden_dims is int")
            hidden_dims = [hidden_dims] * num_layers
        
        self.embed_dim = embed_dim
        self.hidden_dims = hidden_dims
        self.use_gradient_checkpointing = use_gradient_checkpointing
        
        # Input layers
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.input_proj = nn.Linear(embed_dim, hidden_dims[0]) if embed_dim != hidden_dims[0] else nn.Identity()
        
        # Optional convolution
        self.conv = CausalConv1D(hidden_dims[0], conv_kernel) if use_conv else None
        
        # Core MinGRU
        self.pre_gru_norm = nn.RMSNorm(hidden_dims[0], eps=norm_eps)
        self.mingru = MinGRU(hidden_dims[0], hidden_dims, dropout)
        
        # Feed-forward network
        self.post_gru_norm = nn.RMSNorm(hidden_dims[-1], eps=norm_eps)
        self.ffn = SwiGLU(hidden_dims[-1], ffn_expansion, dropout)
        
        # Output layers
        self.final_norm = nn.RMSNorm(hidden_dims[-1], eps=norm_eps)
        self.output_proj = nn.Linear(hidden_dims[-1], vocab_size)
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize all model weights"""
        # Embedding initialization
        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
        
        # Linear layer initialization
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def _forward_block(self, h: torch.Tensor) -> torch.Tensor:
        """Forward pass through MinGRU and FFN blocks"""
        # MinGRU processing
        h_norm = self.pre_gru_norm(h)
        gru_out, _ = self.mingru(h_norm)
        h = h + gru_out  # Residual connection
        
        # Feed-forward with residual
        h_norm = self.post_gru_norm(h)
        ffn_out = self.ffn(h_norm)
        h = h + ffn_out  # Residual connection
        
        return h
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Embedding and input projection
        h = self.embedding(x)
        h = self.input_proj(h)
        
        # Optional convolution with residual
        if self.conv is not None:
            h = h + self.conv(h)
        
        # Main processing with optional gradient checkpointing
        if self.use_gradient_checkpointing and self.training:
            h = torch.utils.checkpoint.checkpoint(self._forward_block, h, use_reentrant=False)
        else:
            h = self._forward_block(h)
        
        # Final output
        h = self.final_norm(h)
        logits = self.output_proj(h)
        
        return logits


# Example usage with optimizations
if __name__ == "__main__":
    # Enable optimizations
    torch.backends.cudnn.benchmark = True  # Optimize for fixed input sizes
    torch.backends.cuda.matmul.allow_tf32 = True  # Enable TF32 on Ampere GPUs
    
    model = MinGRUDecoder(
        vocab_size=10000,
        embed_dim=512,
        hidden_dims=[512, 512],
        dropout=0.1,
        use_conv=True,
        conv_kernel=3,
        ffn_expansion=2.0,
        use_gradient_checkpointing=True  # Enable memory-efficient training
    )
    
    # Compile model for optimization (PyTorch 2.0+)
    model = torch.compile(model, mode="reduce-overhead")
    
    # Test forward pass
    batch_size, seq_len = 4, 128
    input_ids = torch.randint(0, 10000, (batch_size, seq_len))
    
    # Mixed precision training
    with torch.amp.autocast('cuda'):
        output = model(input_ids)
    
    print(f"Input shape: {input_ids.shape}")
    print(f"Output shape: {output.shape}")  # Should be (4, 128, 10000)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Memory usage optimization example
    if torch.cuda.is_available():
        model = model.cuda()
        input_ids = input_ids.cuda()
        
        # Clear cache and measure memory
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        
        with torch.amp.autocast('cuda'):
            output = model(input_ids)
        
        peak_memory = torch.cuda.max_memory_allocated() / 1024**3  # GB
        print(f"Peak GPU memory usage: {peak_memory:.2f} GB")

  @torch.cuda.amp.autocast()  # Enable automatic mixed precision


Input shape: torch.Size([4, 128])
Output shape: torch.Size([4, 128, 10000])
Model parameters: 12,623,632
Peak GPU memory usage: 0.19 GB


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional, Union

# Conditional import for mixed precision
try:
    from torch.cuda.amp import autocast
    HAS_AMP = True
except ImportError:
    HAS_AMP = False
    # Dummy decorator if AMP is not available
    def autocast():
        def decorator(func):
            return func
        return decorator


class CausalConv1D(nn.Module):
    """Causal 1D Depthwise Convolution"""
    
    def __init__(self, dim: int, kernel_size: int):
        super().__init__()
        self.kernel_size = kernel_size
        self.depthwise = nn.Conv1d(dim, dim, kernel_size, groups=dim)
        self.pointwise = nn.Conv1d(dim, dim, kernel_size=1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch, seq_len, dim)
        x = x.transpose(1, 2)  # (batch, dim, seq_len)
        x = F.pad(x, (self.kernel_size - 1, 0))  # causal padding
        x = self.pointwise(self.depthwise(x))
        return x.transpose(1, 2)  # (batch, seq_len, dim)


class MinGRUFunctions:
    """MinGRU core mathematical functions"""
    
    @staticmethod
    def g_activation(x: torch.Tensor) -> torch.Tensor:
        """Custom activation function - optimized with inplace ops"""
        positive_mask = x >= 0
        result = x.clone()  # More efficient than empty_like + assignment
        result[positive_mask] += 0.5
        result[~positive_mask] = torch.sigmoid_(result[~positive_mask])  # inplace sigmoid
        return result
    
    @staticmethod
    def log_g_activation(x: torch.Tensor) -> torch.Tensor:
        """Log version of g activation - improved numerical stability"""
        positive_mask = x >= 0
        result = torch.empty_like(x)
        
        # For positive values: log(x + 0.5)
        result[positive_mask] = torch.log(x[positive_mask] + 0.5).to(x.dtype)
        
        # For negative values: log(sigmoid(x)) = -softplus(-x)
        result[~positive_mask] = -F.softplus(-x[~positive_mask]).to(x.dtype)
        
        return result
    
    @staticmethod
    def parallel_scan(log_gates: torch.Tensor, log_values: torch.Tensor) -> torch.Tensor:
        """Parallel scan operation - fixed dimension handling"""
        # Cumulative sum of log gates with proper padding
        cumsum_gates = F.pad(torch.cumsum(log_gates, dim=1), [0, 0, 1, 0])  # pad dim=1 (seq_len)
        
        # Log-cumsum-exp operation
        adjusted_values = log_values - cumsum_gates
        cumsum_values = torch.logcumsumexp(adjusted_values, dim=1)
        
        # Combine results
        log_output = cumsum_gates + cumsum_values
        return torch.exp(log_output)
    
    @staticmethod
    def mingru_step(gate: torch.Tensor, hidden: torch.Tensor, prev_state: torch.Tensor) -> torch.Tensor:
        """Single MinGRU computation step - fixed concatenation"""
        eps = 1e-12
        
        # Safe log computation with numerical stability
        log_prev = torch.log(torch.clamp(prev_state, min=eps))
        
        # Compute log probabilities with clamping for stability
        gate_clamped = torch.clamp(gate, min=-20, max=20)
        log_forget = -F.softplus(gate_clamped)
        log_update = -F.softplus(-gate_clamped)
        
        # Compute log candidate
        log_candidate = MinGRUFunctions.log_g_activation(hidden)
        
        # Concatenate along sequence dimension (dim=1)
        log_states = torch.cat([log_prev, log_update + log_candidate], dim=1)
        
        # Apply parallel scan
        sequence_output = MinGRUFunctions.parallel_scan(log_forget, log_states)
        return sequence_output[:, 1:]  # Remove initial state


class MinGRULayer(nn.Module):
    """Single MinGRU layer with optimized initialization"""
    
    def __init__(self, input_dim: int, hidden_dim: int, use_bias: bool = True):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        self.gate_projection = nn.Linear(input_dim, hidden_dim * 2, bias=use_bias)
        self.residual_projection = nn.Linear(input_dim, hidden_dim, bias=False) if input_dim != hidden_dim else nn.Identity()
        
        self._init_weights()
    
    def _init_weights(self):
        """Optimized weight initialization"""
        # Xavier/Glorot initialization for better gradient flow
        nn.init.xavier_uniform_(self.gate_projection.weight)
        if self.gate_projection.bias is not None:
            # Initialize gate bias to favor forgetting (negative bias)
            nn.init.constant_(self.gate_projection.bias[:self.hidden_dim], -1.0)  # forget gate
            nn.init.constant_(self.gate_projection.bias[self.hidden_dim:], 0.0)   # input gate
        
        if not isinstance(self.residual_projection, nn.Identity):
            nn.init.xavier_uniform_(self.residual_projection.weight)
    
    def forward(self, x: torch.Tensor, prev_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # Project input to gate and hidden
        gate_hidden = self.gate_projection(x)
        gate, hidden = gate_hidden.chunk(2, dim=-1)
        
        # Apply MinGRU computation
        output = MinGRUFunctions.mingru_step(gate, hidden, prev_state)
        
        # Residual connection
        if not isinstance(self.residual_projection, nn.Identity):
            output = output + self.residual_projection(x)
        
        # Return output and last hidden state
        next_state = output[:, -1:, :]
        return output, next_state


class MinGRU(nn.Module):
    """Multi-layer MinGRU with state caching"""
    
    def __init__(self, input_dim: int, hidden_dims: List[int], dropout: float = 0.0):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.num_layers = len(hidden_dims)
        
        # Build layers
        layer_dims = [input_dim] + hidden_dims
        self.layers = nn.ModuleList([
            MinGRULayer(layer_dims[i], layer_dims[i + 1])
            for i in range(self.num_layers)
        ])
        
        # Dropout (except for last layer)
        self.dropouts = nn.ModuleList([
            nn.Dropout(dropout) if i < self.num_layers - 1 else nn.Identity()
            for i in range(self.num_layers)
        ])
        
        # State cache for inference optimization
        self._cached_states: Optional[List[torch.Tensor]] = None
        self._cache_batch_size: int = 0
    
    def init_states(self, batch_size: int, device: torch.device) -> List[torch.Tensor]:
        """Initialize hidden states with optimized memory layout"""
        states = [
            MinGRUFunctions.g_activation(torch.zeros(batch_size, 1, dim, device=device, dtype=torch.float32))
            for dim in self.hidden_dims
        ]
        
        # Cache states for inference
        if not self.training:
            self._cached_states = [state.clone() for state in states]
            self._cache_batch_size = batch_size
            
        return states
    
    def get_cached_states(self, batch_size: int, device: torch.device) -> Optional[List[torch.Tensor]]:
        """Get cached states if available and valid"""
        if (self._cached_states is not None and 
            self._cache_batch_size == batch_size and 
            not self.training):
            return [state.clone() for state in self._cached_states]
        return None
    
    def forward(self, x: torch.Tensor, states: Optional[List[torch.Tensor]] = None) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        if states is None:
            # Try to use cached states first
            states = self.get_cached_states(x.size(0), x.device)
            if states is None:
                states = self.init_states(x.size(0), x.device)
        
        output = x
        next_states = []
        
        for i, (layer, dropout) in enumerate(zip(self.layers, self.dropouts)):
            output, next_state = layer(output, states[i])
            output = dropout(output)
            next_states.append(next_state)
        
        # Update cache for next inference
        if not self.training:
            self._cached_states = [state.clone() for state in next_states]
        
        return output, next_states


class SwiGLU(nn.Module):
    """SwiGLU Feed-Forward Network with mixed precision support"""
    
    def __init__(self, dim: int, expansion_factor: float = 2.0, dropout: float = 0.0):
        super().__init__()
        hidden_dim = int(dim * expansion_factor * 2/3)
        hidden_dim = ((hidden_dim + 7) // 8) * 8  # Round to multiple of 8 for tensor cores
        
        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.SiLU()
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights for mixed precision training"""
        for module in [self.gate_proj, self.up_proj, self.down_proj]:
            nn.init.xavier_uniform_(module.weight)
            # Scale down weights for numerical stability in fp16
            module.weight.data *= 0.5
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Fused gating operation
        gate = self.activation(self.gate_proj(x))
        up = self.up_proj(x)
        
        # Element-wise multiplication 
        gated = gate * up
        return self.dropout(self.down_proj(gated))


class MinGRUDecoder(nn.Module):
    """Complete MinGRU-based decoder model with optimizations"""
    
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        hidden_dims: Union[int, List[int]],
        num_layers: Optional[int] = None,
        dropout: float = 0.1,
        use_conv: bool = True,
        conv_kernel: int = 3,
        ffn_expansion: float = 1.0,
        norm_eps: float = 1e-8,
        use_gradient_checkpointing: bool = False,
    ):
        super().__init__()
        
        # Handle hidden dimensions
        if isinstance(hidden_dims, int):
            if num_layers is None:
                raise ValueError("num_layers must be specified when hidden_dims is int")
            hidden_dims = [hidden_dims] * num_layers
        
        self.embed_dim = embed_dim
        self.hidden_dims = hidden_dims
        self.use_gradient_checkpointing = use_gradient_checkpointing
        
        # Input layers
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.input_proj = nn.Linear(embed_dim, hidden_dims[0]) if embed_dim != hidden_dims[0] else nn.Identity()
        
        # Optional convolution
        self.conv = CausalConv1D(hidden_dims[0], conv_kernel) if use_conv else None
        
        # Core MinGRU
        self.pre_gru_norm = nn.RMSNorm(hidden_dims[0], eps=norm_eps)
        self.mingru = MinGRU(hidden_dims[0], hidden_dims, dropout)
        
        # Feed-forward network
        self.post_gru_norm = nn.RMSNorm(hidden_dims[-1], eps=norm_eps)
        self.ffn = SwiGLU(hidden_dims[-1], ffn_expansion, dropout)
        
        # Output layers
        self.final_norm = nn.RMSNorm(hidden_dims[-1], eps=norm_eps)
        self.output_proj = nn.Linear(hidden_dims[-1], vocab_size)
        
        self._init_weights()
    
    def _init_weights(self):
        """Initialize all model weights"""
        # Embedding initialization
        nn.init.normal_(self.embedding.weight, mean=0.0, std=0.02)
        
        # Linear layer initialization
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def _forward_block(self, h: torch.Tensor) -> torch.Tensor:
        """Forward pass through MinGRU and FFN blocks"""
        # MinGRU processing
        h_norm = self.pre_gru_norm(h)
        gru_out, _ = self.mingru(h_norm)
        h = h + gru_out  # Residual connection
        
        # Feed-forward with residual
        h_norm = self.post_gru_norm(h)
        ffn_out = self.ffn(h_norm)
        h = h + ffn_out  # Residual connection
        
        return h
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Embedding and input projection
        h = self.embedding(x)
        h = self.input_proj(h)
        
        # Optional convolution with residual
        if self.conv is not None:
            h = h + self.conv(h)
        
        # Main processing with optional gradient checkpointing
        if self.use_gradient_checkpointing and self.training:
            h = torch.utils.checkpoint.checkpoint(self._forward_block, h, use_reentrant=False)
        else:
            h = self._forward_block(h)
        
        # Final output
        h = self.final_norm(h)
        logits = self.output_proj(h)
        
        return logits


# Example usage with optimizations
if __name__ == "__main__":
    # Enable optimizations
    torch.backends.cudnn.benchmark = True  # Optimize for fixed input sizes
    torch.backends.cuda.matmul.allow_tf32 = True  # Enable TF32 on Ampere GPUs
    
    model = MinGRUDecoder(
        vocab_size=10000,
        embed_dim=512,
        hidden_dims=[512, 512],
        dropout=0.1,
        use_conv=True,
        conv_kernel=3,
        ffn_expansion=2.0,
        use_gradient_checkpointing=True  # Enable memory-efficient training
    )
    
    # Compile model for optimization (PyTorch 2.0+)
    model = torch.compile(model, mode="reduce-overhead")
    
    # Test forward pass
    batch_size, seq_len = 4, 128
    input_ids = torch.randint(0, 10000, (batch_size, seq_len))
    
    # Mixed precision training
    with torch.cuda.amp.autocast():
        output = model(input_ids)
    
    print(f"Input shape: {input_ids.shape}")
    print(f"Output shape: {output.shape}")  # Should be (4, 128, 10000)
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Memory usage optimization example
    if torch.cuda.is_available():
        model = model.cuda()
        input_ids = input_ids.cuda()
        
        # Clear cache and measure memory
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        
        with torch.cuda.amp.autocast():
            output = model(input_ids)
        
        peak_memory = torch.cuda.max_memory_allocated() / 1024**3  # GB
        print(f"Peak GPU memory usage: {peak_memory:.2f} GB")

  with torch.cuda.amp.autocast():


Input shape: torch.Size([4, 128])
Output shape: torch.Size([4, 128, 10000])
Model parameters: 12,623,632


  with torch.cuda.amp.autocast():


Peak GPU memory usage: 0.19 GB


In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Tuple, Optional, Union
import time # 모델 실행 속도 측정을 위해 추가

# Conditional import for mixed precision
try:
    from torch.amp import autocast # 최신 API로 변경
    HAS_AMP = True
except ImportError:
    HAS_AMP = False
    # Dummy autocast for when torch.amp.autocast is not available
    class autocast:
        def __init__(self, device_type: str, dtype: Optional[torch.dtype] = None, enabled: bool = True, cache_enabled: Optional[bool] = None):
            # These arguments are for API compatibility with the real torch.amp.autocast.
            self.device_type = device_type
            self.dtype = dtype
            self.enabled = enabled
            # print(f"Dummy autocast initialized for device: {self.device_type}, enabled: {self.enabled}, dtype: {self.dtype}")

        def __enter__(self):
            # print(f"Dummy autocast entered for device: {self.device_type}")
            return self

        def __exit__(self, exc_type, exc_val, exc_tb):
            # print(f"Dummy autocast exited for device: {self.device_type}")
            pass

        def __call__(self, func): # 데코레이터로 사용될 경우
            import functools
            @functools.wraps(func)
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)
            return wrapper


class CausalConv1D(nn.Module):
    """Causal 1D Depthwise Convolution"""
    
    def __init__(self, dim: int, kernel_size: int):
        super().__init__()
        self.kernel_size = kernel_size
        # Using depthwise separable convolution: depthwise followed by pointwise
        self.depthwise = nn.Conv1d(dim, dim, kernel_size, groups=dim)
        self.pointwise = nn.Conv1d(dim, dim, kernel_size=1) # Pointwise to mix channels
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (batch, seq_len, dim)
        x = x.transpose(1, 2)  # (batch, dim, seq_len) for Conv1d
        # Causal padding: Pad only on the left for the time dimension (seq_len)
        x = F.pad(x, (self.kernel_size - 1, 0))
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x.transpose(1, 2)  # (batch, seq_len, dim)


class MinGRUFunctions:
    """MinGRU core mathematical functions"""
    
    @staticmethod
    def g_activation(x: torch.Tensor) -> torch.Tensor:
        """Custom activation function - rewritten with torch.where"""
        positive_mask = x >= 0
        x_pos = x + 0.5
        x_neg = torch.sigmoid(x) # Use non-inplace sigmoid for torch.where
        
        result = torch.where(positive_mask, x_pos, x_neg)
        return result
    
    @staticmethod
    def log_g_activation(x: torch.Tensor) -> torch.Tensor:
        """Log version of g activation - improved numerical stability, rewritten with torch.where"""
        positive_mask = x >= 0
        log_val_pos = torch.log(x + 0.5) 
        log_val_neg = -F.softplus(-x)
        
        result = torch.where(positive_mask, log_val_pos, log_val_neg)
        return result.to(x.dtype) 
    
    @staticmethod
    def parallel_scan(log_gates: torch.Tensor, log_values: torch.Tensor) -> torch.Tensor:
        """Parallel scan operation"""
        cumsum_log_gates = torch.cumsum(log_gates, dim=1)
        padded_cumsum_log_gates = F.pad(cumsum_log_gates, (0, 0, 1, 0))
        adjusted_values = log_values - padded_cumsum_log_gates
        cumsum_values = torch.logcumsumexp(adjusted_values, dim=1)
        log_output = padded_cumsum_log_gates + cumsum_values
        return torch.exp(log_output)
    
    @staticmethod
    def mingru_step(gate: torch.Tensor, hidden: torch.Tensor, prev_state: torch.Tensor) -> torch.Tensor:
        """Single MinGRU computation step"""
        # Clone prev_state to potentially resolve issues with TensorAlias in torch.compile
        # This is a key change to address the TypeError with FakeTensor and TensorAlias
        prev_state_for_ops = prev_state.clone()

        eps = 1e-12
        # Use the cloned prev_state for subsequent operations
        log_prev = torch.log(torch.clamp(prev_state_for_ops, min=eps))
        
        gate_clamped = torch.clamp(gate, min=-20, max=20)
        log_forget = -F.softplus(gate_clamped)
        log_update = -F.softplus(-gate_clamped)
        log_candidate = MinGRUFunctions.log_g_activation(hidden)
        
        log_states_for_scan = torch.cat([log_prev, log_update + log_candidate], dim=1)
        sequence_output = MinGRUFunctions.parallel_scan(log_forget, log_states_for_scan)
        
        # The returned slice will be used to form the next 'prev_state'.
        # Cloning it here ensures that the state passed to the next step is not a view
        # that might cause issues with torch.compile.
        return sequence_output[:, 1:, :].clone()


class MinGRULayer(nn.Module):
    """Single MinGRU layer with optimized initialization"""
    
    def __init__(self, input_dim: int, hidden_dim: int, use_bias: bool = True):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.gate_projection = nn.Linear(input_dim, hidden_dim * 2, bias=use_bias)
        self.residual_projection = nn.Linear(input_dim, hidden_dim, bias=False) if input_dim != hidden_dim else nn.Identity()
        self._init_weights()
    
    def _init_weights(self):
        nn.init.xavier_uniform_(self.gate_projection.weight)
        if self.gate_projection.bias is not None:
            nn.init.constant_(self.gate_projection.bias[:self.hidden_dim], -1.0)
            nn.init.constant_(self.gate_projection.bias[self.hidden_dim:], 0.0) 
        if not isinstance(self.residual_projection, nn.Identity):
            nn.init.xavier_uniform_(self.residual_projection.weight)
    
    def forward(self, x: torch.Tensor, prev_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        gate_hidden_proj = self.gate_projection(x)
        gate, hidden = gate_hidden_proj.chunk(2, dim=-1)
        # The output from mingru_step is already cloned if the return there is .clone()
        output = MinGRUFunctions.mingru_step(gate, hidden, prev_state)
        residual_input = self.residual_projection(x)
        output = output + residual_input 
        
        # The 'output' here is the result of mingru_step (already cloned) + residual.
        # The slice for next_state should also be cloned to ensure clean states are passed.
        next_state = output[:, -1:, :].clone()
        return output, next_state


class MinGRU(nn.Module):
    """Multi-layer MinGRU with state caching"""
    
    def __init__(self, input_dim: int, hidden_dims: List[int], dropout: float = 0.0):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.num_layers = len(hidden_dims)
        
        current_dim = input_dim
        self.layers = nn.ModuleList()
        for h_dim in hidden_dims:
            self.layers.append(MinGRULayer(current_dim, h_dim))
            current_dim = h_dim
            
        self.dropouts = nn.ModuleList()
        for i in range(self.num_layers):
            self.dropouts.append(nn.Dropout(dropout) if i < self.num_layers - 1 else nn.Identity())
            
        self._cached_states: Optional[List[torch.Tensor]] = None
        self._cache_batch_size: int = 0
    
    def init_states(self, batch_size: int, device: torch.device, dtype: torch.dtype = torch.float32) -> List[torch.Tensor]:
        states = []
        for dim in self.hidden_dims:
            zero_state = torch.zeros(batch_size, 1, dim, device=device, dtype=dtype)
            # g_activation should return a new tensor, not a view.
            states.append(MinGRUFunctions.g_activation(zero_state)) 
        
        if not self.training: # During inference
            # Cached states are cloned, ensuring they are fresh tensors.
            self._cached_states = [state.clone() for state in states]
            self._cache_batch_size = batch_size
        return states
    
    def get_cached_states(self, batch_size: int, device: torch.device) -> Optional[List[torch.Tensor]]:
        if (self._cached_states is not None and 
            self._cache_batch_size == batch_size and 
            not self.training): # Only use cache during inference
            # Return clones of cached states to prevent modification of the cache
            # and ensure they are on the correct device.
            return [state.to(device, non_blocking=True).clone() for state in self._cached_states]
        return None
    
    def forward(self, x: torch.Tensor, states: Optional[List[torch.Tensor]] = None) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        if states is None: # If no initial states are provided
            states = self.get_cached_states(x.size(0), x.device) # Try to get from cache first (for inference)
            if states is None: # If not in cache or not valid, initialize them
                states = self.init_states(x.size(0), x.device, dtype=x.dtype)
        
        output = x
        next_states_list = []
        
        for i in range(self.num_layers):
            layer = self.layers[i]
            dropout_layer = self.dropouts[i]
            # states[i] is used as prev_state. It should be a clean tensor due to cloning at init/cache/prev_layer.
            output, next_s = layer(output, states[i]) # next_s from MinGRULayer is already cloned.
            output = dropout_layer(output)
            next_states_list.append(next_s) # next_s is already a cloned tensor.
        
        if not self.training: # During inference, update the cache
            # Detach states before caching if they won't be used for gradient computation later.
            self._cached_states = [s.detach().clone() for s in next_states_list]
            self._cache_batch_size = x.size(0)
        return output, next_states_list


class SwiGLU(nn.Module):
    """SwiGLU Feed-Forward Network"""
    
    def __init__(self, dim: int, expansion_factor: float = 2.0, dropout: float = 0.0):
        super().__init__()
        hidden_dim = int(dim * expansion_factor)
        hidden_dim = ((hidden_dim + 7) // 8) * 8 
        
        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.up_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.down_proj = nn.Linear(hidden_dim, dim, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.SiLU()
        self._init_weights()
    
    def _init_weights(self):
        for module in [self.gate_proj, self.up_proj, self.down_proj]:
            nn.init.xavier_uniform_(module.weight)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate_val = self.activation(self.gate_proj(x))
        up_val = self.up_proj(x)
        gated_val = gate_val * up_val
        output = self.down_proj(gated_val)
        return self.dropout(output)


class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

class MinGRUDecoder(nn.Module):
    """Complete MinGRU-based decoder model with optimizations"""
    
    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        hidden_dims: Union[int, List[int]],
        num_layers: Optional[int] = None,
        dropout: float = 0.1,
        use_conv: bool = True,
        conv_kernel: int = 3,
        ffn_expansion: float = 2.0, 
        norm_eps: float = 1e-5, 
        use_gradient_checkpointing: bool = False,
    ):
        super().__init__()
        
        if isinstance(hidden_dims, int):
            if num_layers is None:
                raise ValueError("num_layers must be specified when hidden_dims is an int")
            actual_hidden_dims = [hidden_dims] * num_layers
        else:
            actual_hidden_dims = hidden_dims
            
        self.embed_dim = embed_dim
        self.hidden_dims_list = actual_hidden_dims
        self.use_gradient_checkpointing = use_gradient_checkpointing
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        first_gru_dim = actual_hidden_dims[0]
        self.input_proj = nn.Linear(embed_dim, first_gru_dim) if embed_dim != first_gru_dim else nn.Identity()
        self.conv = CausalConv1D(first_gru_dim, conv_kernel) if use_conv else None
        self.pre_gru_norm = RMSNorm(first_gru_dim, eps=norm_eps)
        self.mingru = MinGRU(first_gru_dim, actual_hidden_dims, dropout) # MinGRU will handle its own state cloning internally
        
        last_gru_dim = actual_hidden_dims[-1]
        self.post_gru_norm = RMSNorm(last_gru_dim, eps=norm_eps)
        self.ffn = SwiGLU(last_gru_dim, ffn_expansion, dropout)
        
        self.final_norm = RMSNorm(last_gru_dim, eps=norm_eps)
        self.output_proj = nn.Linear(last_gru_dim, vocab_size)
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.normal_(self.embedding.weight, mean=0.0, std=self.embed_dim ** -0.5)
        for name, module in self.named_modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    # This internal block is what might be checkpointed.
    # It takes the main data tensor `h` and the `initial_states` for the MinGRU component.
    def _forward_mingru_ffn_block(self, h: torch.Tensor, initial_states_for_mingru: Optional[List[torch.Tensor]]=None) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        h_norm_pre_gru = self.pre_gru_norm(h)
        # MinGRU's forward method is called here. It manages its states.
        # The `initial_states_for_mingru` are passed to the MinGRU instance.
        gru_out, next_mingru_states = self.mingru(h_norm_pre_gru, states=initial_states_for_mingru)
        h = h + gru_out # Residual connection
        
        h_norm_post_gru = self.post_gru_norm(h)
        ffn_out = self.ffn(h_norm_post_gru)
        h = h + ffn_out # Residual connection
        
        # This block returns the processed tensor `h` and the `next_mingru_states` from the MinGRU module.
        return h, next_mingru_states
    
    def forward(self, input_ids: torch.Tensor, states: Optional[List[torch.Tensor]] = None) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        # `states` here are the initial states for the entire MinGRU stack for this forward pass.
        h = self.embedding(input_ids)
        h = self.input_proj(h)
        if self.conv is not None:
            conv_out = self.conv(h)
            h = h + conv_out
        
        # `states` are passed to the block that includes the MinGRU module.
        if self.use_gradient_checkpointing and self.training:
            # The checkpointed function is `_forward_mingru_ffn_block`.
            # It receives `h` and `states` (which are initial_states_for_mingru).
            h, next_model_states = torch.utils.checkpoint.checkpoint(
                self._forward_mingru_ffn_block,
                h,                  # main input tensor
                states,             # initial states for MinGRU within the block
                use_reentrant=False, 
                preserve_rng_state=True
            )
        else:
            h, next_model_states = self._forward_mingru_ffn_block(h, initial_states_for_mingru=states)
        
        h = self.final_norm(h)
        logits = self.output_proj(h)
        # The `next_model_states` are the final states from the MinGRU module after processing.
        return logits, next_model_states


# Example usage with optimizations
if __name__ == "__main__":
    if torch.backends.cudnn.is_available():
        torch.backends.cudnn.benchmark = True
    if hasattr(torch.backends.cuda, 'matmul') and hasattr(torch.backends.cuda.matmul, 'allow_tf32'):
        torch.backends.cuda.matmul.allow_tf32 = True
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    model = MinGRUDecoder(
        vocab_size=10000,
        embed_dim=256,
        hidden_dims=[256, 256],
        dropout=0.1,
        use_conv=True,
        conv_kernel=3,
        ffn_expansion=2.0,
        norm_eps=1e-5,
        use_gradient_checkpointing=False 
    ).to(device)
    
    compiled_model = model # Default to original model
    if hasattr(torch, 'compile'):
        print("Attempting to compile model...")
        try:
            # Try with fullgraph=True for potentially better performance, but might be stricter.
            # If "reduce-overhead" fails, "default" might be more robust.
            compiled_model = torch.compile(model, mode="reduce-overhead") 
            print("Model compiled successfully!")
        except Exception as e:
            print(f"Model compilation failed: {e}")
            print("Proceeding without compilation.")
            # compiled_model remains the original model
    else:
        print("torch.compile not available. Proceeding without compilation.")

    compiled_model.eval() # 추론 속도 측정을 위해 평가 모드로 설정

    batch_size, seq_len = 4, 64
    input_ids = torch.randint(0, 10000, (batch_size, seq_len), device=device)
    # For inference, initial_model_states are typically None to let the model handle its cache or init.
    # If you were stepping through inference token by token, you'd pass the states from the previous step.
    initial_model_states = None 

    print(f"\nInput shape: {input_ids.shape}")

    # --- 워밍업 실행 ---
    print("\nPerforming warm-up passes...")
    warm_up_iterations = 10 # Increased warm-up
    # Determine the autocast dtype based on GPU capability or default to float16
    amp_dtype = torch.bfloat16 if (device.type == 'cuda' and torch.cuda.is_bf16_supported()) else torch.float16

    if HAS_AMP and device.type == 'cuda':
        print(f"Using AMP with dtype: {amp_dtype}")
        with autocast(device_type=device.type, dtype=amp_dtype, enabled=True):
            for _ in range(warm_up_iterations):
                _, _ = compiled_model(input_ids, states=initial_model_states)
    else: # CPU or CUDA without AMP
        for _ in range(warm_up_iterations):
            _, _ = compiled_model(input_ids, states=initial_model_states)

    if device.type == 'cuda':
        torch.cuda.synchronize()
    # --- 워밍업 종료 ---

    num_iterations = 50 # 속도 측정 반복 횟수
    
    output_logits = None # Define to ensure it's available for print later

    # --- AMP 사용 시 속도 측정 (CUDA 전용) ---
    if HAS_AMP and device.type == 'cuda':
        print(f"\nMeasuring execution time with AMP ({amp_dtype}) over {num_iterations} iterations...")
        total_time_amp_ms = 0
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        
        for i in range(num_iterations):
            # torch.cuda.empty_cache() # Generally not needed for inference timing unless specific memory issues
            start_event.record()
            with autocast(device_type=device.type, dtype=amp_dtype, enabled=True):
                output_logits, _ = compiled_model(input_ids, states=initial_model_states)
            end_event.record()
            torch.cuda.synchronize() 
            iter_time_ms = start_event.elapsed_time(end_event)
            if i >= warm_up_iterations // 2 : # Start accumulating after some initial timed iterations
                 total_time_amp_ms += iter_time_ms
            # if i == 0: print(f"  First timed iteration (AMP) took: {iter_time_ms:.3f} ms") # First timed one can be slower
        
        # Adjust num_iterations for averaging if skipping some initial ones
        avg_iterations_amp = num_iterations - (warm_up_iterations // 2)
        if avg_iterations_amp <= 0 : avg_iterations_amp = num_iterations # safety
        avg_time_amp_ms = total_time_amp_ms / avg_iterations_amp if avg_iterations_amp > 0 else 0
        print(f"Average inference time (AMP, {amp_dtype} on CUDA): {avg_time_amp_ms:.3f} ms per iteration")
        if output_logits is not None: print(f"Output shape (AMP): {output_logits.shape}")

    # --- 표준 정밀도 (FP32 GPU 또는 CPU) 속도 측정 ---
    # Determine the standard precision dtype from the model
    # model_dtype = next(compiled_model.parameters()).dtype # Get dtype from a parameter
    # For this model, embedding is a good indicator.
    model_dtype = compiled_model.embedding.weight.dtype
    print(f"\nMeasuring execution time with standard precision ({model_dtype}) over {num_iterations} iterations...")
    total_time_std_ms = 0
    
    if device.type == 'cuda':
        start_event_std = torch.cuda.Event(enable_timing=True)
        end_event_std = torch.cuda.Event(enable_timing=True)
        for i in range(num_iterations):
            start_event_std.record()
            output_logits, _ = compiled_model(input_ids, states=initial_model_states)
            end_event_std.record()
            torch.cuda.synchronize()
            iter_time_ms = start_event_std.elapsed_time(end_event_std)
            if i >= warm_up_iterations // 2 :
                total_time_std_ms += iter_time_ms
            # if i == 0: print(f"  First timed iteration (FP32 GPU) took: {iter_time_ms:.3f} ms")

        avg_iterations_std = num_iterations - (warm_up_iterations // 2)
        if avg_iterations_std <= 0 : avg_iterations_std = num_iterations
        avg_time_std_ms = total_time_std_ms / avg_iterations_std if avg_iterations_std > 0 else 0
        print(f"Average inference time ({model_dtype} on GPU): {avg_time_std_ms:.3f} ms per iteration")
    else: # CPU
        cpu_total_time_sec = 0
        for i in range(num_iterations):
            start_time_cpu = time.perf_counter()
            output_logits, _ = compiled_model(input_ids, states=initial_model_states)
            end_time_cpu = time.perf_counter()
            iter_time_sec = end_time_cpu - start_time_cpu
            if i >= warm_up_iterations // 2:
                cpu_total_time_sec += iter_time_sec
            # if i == 0: print(f"  First timed iteration (CPU) took: {iter_time_sec * 1000:.3f} ms")
        
        avg_iterations_cpu = num_iterations - (warm_up_iterations // 2)
        if avg_iterations_cpu <=0 : avg_iterations_cpu = num_iterations
        avg_time_std_ms = (cpu_total_time_sec / avg_iterations_cpu) * 1000 if avg_iterations_cpu > 0 else 0
        print(f"Average inference time (CPU, {model_dtype}): {avg_time_std_ms:.3f} ms per iteration")

    if output_logits is not None: 
        print(f"Final output shape: {output_logits.shape}")
        # Ensure the shape is as expected
        expected_shape = (batch_size, seq_len, 10000) # Use model.vocab_size
        assert output_logits.shape == expected_shape, \
            f"Output shape mismatch! Expected {expected_shape}, got {output_logits.shape}"
    
    print(f"Model parameters: {sum(p.numel() for p in compiled_model.parameters() if p.requires_grad):,}")
    
    # Memory measurement can be done separately if needed, as it might interfere with timing.
    # if device.type == 'cuda':
    #     torch.cuda.empty_cache()
    #     torch.cuda.reset_peak_memory_stats(device)
    #     # Perform a single pass for memory measurement
    #     with autocast(device_type=device.type, dtype=amp_dtype, enabled=HAS_AMP and device.type == 'cuda'):
    #         _, _ = compiled_model(input_ids, states=initial_model_states)
    #     torch.cuda.synchronize()
    #     peak_memory_bytes = torch.cuda.max_memory_allocated(device)
    #     peak_memory_gb = peak_memory_bytes / (1024**3)
    #     print(f"Peak GPU memory usage for one forward pass: {peak_memory_gb:.2f} GB")



Using device: cuda
Attempting to compile model...
Model compiled successfully!

Input shape: torch.Size([4, 64])

Performing warm-up passes...
Using AMP with dtype: torch.bfloat16





Measuring execution time with AMP (torch.bfloat16) over 50 iterations...
Average inference time (AMP, torch.bfloat16 on CUDA): 10.142 ms per iteration
Output shape (AMP): torch.Size([4, 64, 10000])

Measuring execution time with standard precision (torch.float32) over 50 iterations...




Average inference time (torch.float32 on GPU): 12.526 ms per iteration
Final output shape: torch.Size([4, 64, 10000])
Model parameters: 5,853,968
