In [None]:
#hybrid_mamba_trainning.py

"""
Hybrid Mamba-Transformer Small Language Model Training Script
Optimized for RTX 4060 8GB VRAM - Local Inference Ready
Author: AI Assistant for Soumyaranjan Sahoo
"""

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import transformers
from transformers import AutoTokenizer, TrainingArguments, Trainer
import json
import numpy as np
from typing import Optional, List, Dict, Any
import math
import warnings
warnings.filterwarnings("ignore")

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

class MambaBlock(nn.Module):
    """Simplified Mamba SSM Block for hybrid architecture"""
    def __init__(self, d_model, d_state=16, expand_factor=2):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.expand_factor = expand_factor
        self.d_inner = d_model * expand_factor

        # Linear projections
        self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            kernel_size=4, # conv_kernel
            bias=True,
            padding=2,
            groups=self.d_inner,
        )
        self.x_proj = nn.Linear(self.d_inner, d_state * 2, bias=False)
        self.dt_proj = nn.Linear(self.d_inner, d_state, bias=True)
        self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)

        # State space parameters
        self.A_log = nn.Parameter(torch.log(torch.arange(1, d_state + 1, dtype=torch.float32)))
        self.D = nn.Parameter(torch.ones(self.d_inner))

        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        B, L, D = x.shape

        residual = x
        x = self.norm(x)

        # Linear projection
        xz = self.in_proj(x)  # (B, L, 2*d_inner)
        x, z = xz.chunk(2, dim=-1)  # (B, L, d_inner) each

        # Convolution
        x = x.transpose(1, 2)  # (B, d_inner, L)
        x = self.conv1d(x)[:, :, :L]  # causal conv
        x = x.transpose(1, 2)  # (B, L, d_inner)

        # Activation
        x = nn.functional.silu(x)

        # SSM step (simplified)
        A = -torch.exp(self.A_log.float())  # (d_state,)

        # Selective mechanism
        x_dbl = self.x_proj(x)  # (B, L, 2*d_state)
        delta, B_proj = x_dbl.chunk(2, dim=-1)  # (B, L, d_state) each
        delta = nn.functional.softplus(self.dt_proj(x))  # (B, L, d_state)

        # Simplified SSM computation (for efficiency)
        y = x * self.D + torch.sum(B_proj * delta, dim=-1, keepdim=True)

        # Gate and output
        y = y * nn.functional.silu(z)
        output = self.out_proj(y)

        return output + residual

class HybridAttentionBlock(nn.Module):
    """Efficient Transformer attention block"""
    def __init__(self, d_model, n_heads=8, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.o_proj = nn.Linear(d_model, d_model, bias=False)

        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout)
        )

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, attention_mask=None):
        # x: (batch, seq_len, d_model)
        B, L, D = x.shape

        # Self-attention
        residual = x
        x = self.norm1(x)

        qkv = self.qkv(x).reshape(B, L, 3, self.n_heads, self.head_dim)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)  # (B, n_heads, L, head_dim)

        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)

        if attention_mask is not None:
            scores = scores.masked_fill(attention_mask == 0, -1e9)

        attn_weights = torch.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)

        attn_output = torch.matmul(attn_weights, v)
        attn_output = attn_output.transpose(1, 2).reshape(B, L, D)
        attn_output = self.o_proj(attn_output)

        x = residual + attn_output

        # MLP
        residual = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = residual + x

        return x

class HybridMambaTransformer(nn.Module):
    """Hybrid Mamba-Transformer Model optimized for efficiency"""
    def __init__(self,
                 vocab_size=32000,
                 d_model=768,
                 n_layers=12,
                 n_heads=12,
                 d_state=16,
                 expand_factor=2,
                 dropout=0.1,
                 max_seq_length=2048,
                 layer_pattern=None):
        super().__init__()

        self.vocab_size = vocab_size
        self.d_model = d_model
        self.n_layers = n_layers
        self.max_seq_length = max_seq_length

        # Default layer pattern: Mamba for early layers, Transformer for later layers
        if layer_pattern is None:
            # 70% Mamba, 30% Transformer - optimal for efficiency
            mamba_layers = int(n_layers * 0.7)
            self.layer_pattern = ['mamba'] * mamba_layers + ['transformer'] * (n_layers - mamba_layers)
        else:
            self.layer_pattern = layer_pattern

        # Token embeddings
        self.embed_tokens = nn.Embedding(vocab_size, d_model)
        self.embed_positions = nn.Embedding(max_seq_length, d_model)

        # Hybrid layers
        self.layers = nn.ModuleList()
        for layer_type in self.layer_pattern:
            if layer_type == 'mamba':
                self.layers.append(MambaBlock(d_model, d_state, expand_factor))
            else:  # transformer
                self.layers.append(HybridAttentionBlock(d_model, n_heads, dropout))

        # Output
        self.norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, input_ids, attention_mask=None, labels=None):
        B, L = input_ids.shape

        # Embeddings
        positions = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1)
        x = self.embed_tokens(input_ids) + self.embed_positions(positions)

        # Process through hybrid layers
        for i, (layer, layer_type) in enumerate(zip(self.layers, self.layer_pattern)):
            if layer_type == 'mamba':
                x = layer(x)
            else:  # transformer
                # Create causal mask for transformer layers
                if attention_mask is None:
                    causal_mask = torch.tril(torch.ones(L, L, device=x.device))
                    causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)  # (1, 1, L, L)
                else:
                    causal_mask = attention_mask
                x = layer(x, causal_mask)

        x = self.norm(x)
        logits = self.lm_head(x)

        loss = None
        if labels is not None:
            # Shift labels for causal LM
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, self.vocab_size), shift_labels.view(-1))

        return {'loss': loss, 'logits': logits}

    def generate(self, input_ids, max_length=100, temperature=0.8, top_p=0.9):
        """Simple generation function"""
        self.eval()
        with torch.no_grad():
            for _ in range(max_length - input_ids.size(1)):
                outputs = self.forward(input_ids)
                logits = outputs['logits'][:, -1, :] / temperature

                # Top-p sampling
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                logits[indices_to_remove] = -float('inf')

                probs = torch.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, 1)
                input_ids = torch.cat([input_ids, next_token], dim=-1)

                if next_token.item() == tokenizer.eos_token_id:
                    break

        return input_ids

# Training configuration optimized for RTX 4060 8GB
class ModelConfig:
    # Model architecture - optimized for 8GB VRAM
    vocab_size = 32000
    d_model = 512  # Reduced from 768 for memory efficiency
    n_layers = 8   # Balanced for performance vs memory
    n_heads = 8
    d_state = 16
    expand_factor = 2
    dropout = 0.1
    max_seq_length = 1024  # Reduced for memory efficiency

    # Training hyperparameters
    batch_size = 2          # Small batch for 8GB VRAM
    gradient_accumulation_steps = 8  # Effective batch size = 16
    learning_rate = 5e-4
    weight_decay = 0.01
    max_steps = 5000       # Adjust based on dataset size
    warmup_steps = 500
    save_steps = 500
    eval_steps = 500

    # Memory optimization
    fp16 = True            # Use mixed precision
    gradient_checkpointing = True
    dataloader_num_workers = 2

def create_training_dataset():
    """Create or load your training dataset"""
    # Replace this with your actual dataset
    # For demonstration, creating a simple dataset

    sample_texts = [
        "The quick brown fox jumps over the lazy dog.",
        "Artificial intelligence is transforming the world.",
        "Machine learning models require large datasets for training.",
        "Python is a versatile programming language.",
        "Deep learning networks can solve complex problems.",
        # Add more training texts here
    ]

    return sample_texts

def train_model():
    """Main training function"""
    config = ModelConfig()

    # Initialize tokenizer (using LLaMA tokenizer as base)
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
    tokenizer.pad_token = tokenizer.eos_token

    # Create model
    print("Creating hybrid model...")
    model = HybridMambaTransformer(
        vocab_size=config.vocab_size,
        d_model=config.d_model,
        n_layers=config.n_layers,
        n_heads=config.n_heads,
        d_state=config.d_state,
        expand_factor=config.expand_factor,
        dropout=config.dropout,
        max_seq_length=config.max_seq_length
    )

    # Model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Model size: ~{total_params * 4 / 1024**3:.2f} GB (FP32)")

    # Move to device
    model = model.to(device)

    # Enable memory optimizations
    if config.gradient_checkpointing:
        model.gradient_checkpointing_enable()

    # Create dataset
    texts = create_training_dataset()

    # Tokenize dataset
    def tokenize_function(examples):
        return tokenizer(
            examples,
            truncation=True,
            padding=True,
            max_length=config.max_seq_length,
            return_tensors="pt"
        )

    # Simple dataset class
    class TextDataset(Dataset):
        def __init__(self, texts, tokenizer, max_length):
            self.texts = texts
            self.tokenizer = tokenizer
            self.max_length = max_length

        def __len__(self):
            return len(self.texts)

        def __getitem__(self, idx):
            text = self.texts[idx]
            encoding = self.tokenizer(
                text,
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'
            )
            return {
                'input_ids': encoding['input_ids'].flatten(),
                'attention_mask': encoding['attention_mask'].flatten(),
                'labels': encoding['input_ids'].flatten()
            }

    # Create datasets
    train_dataset = TextDataset(texts, tokenizer, config.max_seq_length)

    # Training arguments
    training_args = TrainingArguments(
        output_dir='./hybrid-mamba-model',
        overwrite_output_dir=True,
        num_train_epochs=3,
        per_device_train_batch_size=config.batch_size,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        learning_rate=config.learning_rate,
        weight_decay=config.weight_decay,
        max_steps=config.max_steps,
        warmup_steps=config.warmup_steps,
        logging_steps=100,
        save_steps=config.save_steps,
        eval_steps=config.eval_steps,
        save_total_limit=3,
        prediction_loss_only=True,
        fp16=config.fp16,
        gradient_checkpointing=config.gradient_checkpointing,
        dataloader_num_workers=config.dataloader_num_workers,
        remove_unused_columns=False,
        report_to=None,  # Disable wandb/tensorboard
    )

    # Create trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
    )

    # Start training
    print("Starting training...")
    trainer.train()

    # Save final model
    print("Saving model...")
    trainer.save_model('./hybrid-mamba-final')
    tokenizer.save_pretrained('./hybrid-mamba-final')

    print("Training completed!")
    return model, tokenizer

def test_model(model, tokenizer):
    """Test the trained model"""
    model.eval()

    test_prompts = [
        "The future of artificial intelligence is",
        "In machine learning, we often use",
        "Python programming allows us to"
    ]

    print("\nTesting model generation:")
    print("=" * 50)

    for prompt in test_prompts:
        print(f"Prompt: {prompt}")

        # Tokenize input
        inputs = tokenizer(prompt, return_tensors='pt').to(device)

        # Generate
        with torch.no_grad():
            outputs = model.generate(
                inputs['input_ids'],
                max_length=100,
                temperature=0.8,
                top_p=0.9
            )

        # Decode
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(f"Generated: {generated_text}")
        print("-" * 30)

if __name__ == "__main__":
    print("Hybrid Mamba-Transformer Training Script")
    print("Optimized for RTX 4060 8GB VRAM")
    print("=" * 50)

    # Train the model
    model, tokenizer = train_model()

    # Test the model
    test_model(model, tokenizer)

    print("\nModel files saved to: ./hybrid-mamba-final/")
    print("You can now use this model for local inference!")


In [None]:
#inference.py
"""
Local Inference Script for Hybrid Mamba-Transformer Model
Optimized for RTX 4060 8GB VRAM
"""

import torch
import torch.nn as nn
import json
from transformers import AutoTokenizer
import warnings
warnings.filterwarnings("ignore")

# Import model architecture (assuming it's in the same directory)
from hybrid_mamba_training import HybridMambaTransformer

class HybridModelInference:
    def __init__(self, model_path="./hybrid-mamba-final", device=None):
        """Initialize the model for inference"""

        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device(device)

        print(f"Loading model on: {self.device}")

        # Load configuration
        with open(f"{model_path}/config.json", "r") as f:
            config = json.load(f)

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        # Initialize model
        self.model = HybridMambaTransformer(
            vocab_size=config["vocab_size"],
            d_model=config["d_model"],
            n_layers=config["n_layers"],
            n_heads=config["n_heads"],
            d_state=config["d_state"],
            expand_factor=config["expand_factor"],
            dropout=config["dropout"],
            max_seq_length=config["max_seq_length"],
            layer_pattern=config["layer_pattern"]
        )

        # Load trained weights
        try:
            checkpoint = torch.load(f"{model_path}/pytorch_model.bin", map_location=self.device)
            self.model.load_state_dict(checkpoint)
            print("Model weights loaded successfully!")
        except FileNotFoundError:
            print("Warning: No trained weights found. Using randomly initialized model.")

        self.model.to(self.device)
        self.model.eval()

        # Model info
        total_params = sum(p.numel() for p in self.model.parameters())
        print(f"Model parameters: {total_params:,}")
        print(f"Estimated memory usage: ~{total_params * 2 / 1024**3:.2f} GB (FP16)")

    def generate_text(self,
                     prompt,
                     max_length=200,
                     temperature=0.8,
                     top_p=0.9,
                     repetition_penalty=1.1,
                     do_sample=True):
        """Generate text from a prompt"""

        # Tokenize input
        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",
            truncation=True,
            max_length=self.model.max_seq_length - max_length
        ).to(self.device)

        input_length = inputs["input_ids"].shape[1]

        with torch.no_grad():
            # Use the model's generate method
            generated_ids = self.model.generate(
                inputs["input_ids"],
                max_length=input_length + max_length,
                temperature=temperature,
                top_p=top_p
            )

        # Decode only the new tokens
        new_tokens = generated_ids[0][input_length:]
        generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)

        return generated_text

    def chat_interface(self):
        """Interactive chat interface"""
        print("\n" + "="*60)
        print("Hybrid Mamba-Transformer Chat Interface")
        print("Type 'quit' to exit, 'clear' to reset context")
        print("="*60)

        conversation_history = ""

        while True:
            try:
                user_input = input("\nYou: ").strip()

                if user_input.lower() == 'quit':
                    print("Goodbye!")
                    break

                if user_input.lower() == 'clear':
                    conversation_history = ""
                    print("Context cleared!")
                    continue

                if not user_input:
                    continue

                # Build prompt with conversation history
                if conversation_history:
                    prompt = f"{conversation_history}\nHuman: {user_input}\nAssistant: "
                else:
                    prompt = f"Human: {user_input}\nAssistant: "

                # Generate response
                print("Assistant: ", end="", flush=True)
                response = self.generate_text(
                    prompt,
                    max_length=150,
                    temperature=0.7,
                    top_p=0.9
                )

                print(response)

                # Update conversation history (keep it manageable)
                conversation_history = f"{conversation_history}\nHuman: {user_input}\nAssistant: {response}"

                # Truncate history if too long
                if len(conversation_history) > 2000:
                    lines = conversation_history.split('\n')
                    conversation_history = '\n'.join(lines[-10:])  # Keep last 10 exchanges

            except KeyboardInterrupt:
                print("\nGoodbye!")
                break
            except Exception as e:
                print(f"Error: {e}")

    def benchmark_performance(self):
        """Benchmark model performance"""
        print("\nRunning performance benchmark...")

        test_prompts = [
            "The future of artificial intelligence",
            "In the field of machine learning",
            "Python is a programming language that",
            "The benefits of renewable energy include",
            "Space exploration has led to"
        ]

        import time

        total_time = 0
        total_tokens = 0

        for i, prompt in enumerate(test_prompts):
            print(f"\nTest {i+1}/5: {prompt}...")

            start_time = time.time()
            response = self.generate_text(prompt, max_length=100, temperature=0.8)
            end_time = time.time()

            generation_time = end_time - start_time
            tokens_generated = len(self.tokenizer.encode(response))
            tokens_per_second = tokens_generated / generation_time

            print(f"Generated: {response[:100]}...")
            print(f"Time: {generation_time:.2f}s, Tokens: {tokens_generated}, Speed: {tokens_per_second:.1f} tokens/s")

            total_time += generation_time
            total_tokens += tokens_generated

        avg_speed = total_tokens / total_time
        print(f"\nAverage performance: {avg_speed:.1f} tokens/second")
        print(f"Total time: {total_time:.2f}s, Total tokens: {total_tokens}")

def main():
    """Main function for inference"""
    import argparse

    parser = argparse.ArgumentParser(description="Hybrid Mamba-Transformer Inference")
    parser.add_argument("--model_path", type=str, default="./hybrid-mamba-final",
                       help="Path to the trained model")
    parser.add_argument("--prompt", type=str, help="Text prompt for generation")
    parser.add_argument("--max_length", type=int, default=200,
                       help="Maximum generation length")
    parser.add_argument("--temperature", type=float, default=0.8,
                       help="Sampling temperature")
    parser.add_argument("--top_p", type=float, default=0.9,
                       help="Top-p sampling threshold")
    parser.add_argument("--chat", action="store_true",
                       help="Start interactive chat interface")
    parser.add_argument("--benchmark", action="store_true",
                       help="Run performance benchmark")

    args = parser.parse_args()

    # Initialize model
    model = HybridModelInference(args.model_path)

    if args.benchmark:
        model.benchmark_performance()
    elif args.chat:
        model.chat_interface()
    elif args.prompt:
        # Single generation
        response = model.generate_text(
            args.prompt,
            max_length=args.max_length,
            temperature=args.temperature,
            top_p=args.top_p
        )
        print(f"Prompt: {args.prompt}")
        print(f"Response: {response}")
    else:
        print("Use --prompt for single generation, --chat for interactive mode, or --benchmark for testing")

if __name__ == "__main__":
    main()
