<a href="https://colab.research.google.com/github/mazen200555/curriculum/blob/master/Small_llm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler

# Configuration optimized for 15GB VRAM
class Config:
    vocab_size = 50257  # GPT-2 vocabulary size
    hidden_size = 768
    num_hidden_layers = 10
    num_attention_heads = 12
    intermediate_size = 3072
    hidden_dropout_prob = 0.1
    attention_probs_dropout_prob = 0.1
    max_position_embeddings = 1024
    layer_norm_eps = 1e-12
    gradient_accumulation_steps = 8
    batch_size = 1
    learning_rate = 5e-5
    epochs = 3
    warmup_steps = 1000
    max_grad_norm = 1.0
    use_mixed_precision = True

# Dataset class for text processing
class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=1024):
        self.tokenizer = tokenizer
        self.input_ids = []
        self.attention_masks = []

        for text in texts:
            encodings_dict = tokenizer(text, truncation=True, max_length=max_length, padding="max_length")
            self.input_ids.append(torch.tensor(encodings_dict['input_ids']))
            self.attention_masks.append(torch.tensor(encodings_dict['attention_mask']))

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_masks[idx],
            'labels': self.input_ids[idx].clone()  # For autoregressive language modeling
        }

# 1. Dynamic Attention Mechanism
class DynamicAttention(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, dropout_prob=0.1):
        super(DynamicAttention, self).__init__()

        self.num_attention_heads = num_attention_heads
        self.attention_head_size = hidden_size // num_attention_heads
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # Regular attention components
        self.query = nn.Linear(hidden_size, self.all_head_size)
        self.key = nn.Linear(hidden_size, self.all_head_size)
        self.value = nn.Linear(hidden_size, self.all_head_size)

        # Complexity estimator determines optimal attention mode
        self.complexity_estimator = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 4),
            nn.ReLU(),
            nn.Linear(hidden_size // 4, 3)  # 3 modes: local, global, specialized
        )

        # Different attention modes for different input types
        self.local_projector = nn.Linear(hidden_size, hidden_size)
        self.global_projector = nn.Linear(hidden_size, hidden_size)
        self.specialized_projector = nn.Linear(hidden_size, hidden_size)

        self.dropout = nn.Dropout(dropout_prob)
        self.output_linear = nn.Linear(hidden_size, hidden_size)

    def forward(self, hidden_states, attention_mask=None):
        batch_size, seq_length, _ = hidden_states.size()

        # Determine complexity and attention mode
        avg_token_repr = hidden_states.mean(dim=1)
        complexity_scores = F.softmax(self.complexity_estimator(avg_token_repr), dim=-1)

        # Compute regular attention components
        mixed_query = self.query(hidden_states)
        mixed_key = self.key(hidden_states)
        mixed_value = self.value(hidden_states)

        # Reshape for multi-head attention
        query = mixed_query.view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
        key = mixed_key.view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
        value = mixed_value.view(batch_size, seq_length, self.num_attention_heads, self.attention_head_size).transpose(1, 2)

        # Compute attention scores
        attention_scores = torch.matmul(query, key.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # Apply attention mask if provided
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            attention_mask = (1.0 - attention_mask) * -10000.0
            attention_scores = attention_scores + attention_mask

        # Apply softmax and dropout
        attention_probs = F.softmax(attention_scores, dim=-1)
        attention_probs = self.dropout(attention_probs)

        # Compute context vectors
        context = torch.matmul(attention_probs, value)
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_length, self.all_head_size)

        # Apply different attention modes based on complexity
        local_context = self.local_projector(context) * complexity_scores[:, 0].unsqueeze(1).unsqueeze(2)
        global_context = self.global_projector(context) * complexity_scores[:, 1].unsqueeze(1).unsqueeze(2)
        specialized_context = self.specialized_projector(context) * complexity_scores[:, 2].unsqueeze(1).unsqueeze(2)

        # Combine different attention modes
        output = local_context + global_context + specialized_context
        output = self.output_linear(output)

        return output

# 2. Multi-Resolution Token Processing
class MultiResolutionTokenProcessor(nn.Module):
    def __init__(self, hidden_size, num_resolutions=3):
        super(MultiResolutionTokenProcessor, self).__init__()
        self.num_resolutions = num_resolutions
        self.hidden_size = hidden_size

        # Different "resolution" processors
        self.processors = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.LayerNorm(hidden_size),
                nn.GELU()
            ) for _ in range(num_resolutions)
        ])

        # Resolution pooling/unpooling layers
        self.pooling_layers = nn.ModuleList([
            nn.Conv1d(hidden_size, hidden_size, kernel_size=2**i, stride=2**i, padding=0)
            for i in range(1, num_resolutions)
        ])

        self.unpooling_layers = nn.ModuleList([
            nn.ConvTranspose1d(hidden_size, hidden_size, kernel_size=2**i, stride=2**i, padding=0)
            for i in range(1, num_resolutions)
        ])

    def forward(self, hidden_states):
        batch_size, seq_length, hidden_size = hidden_states.size()

        # Process at original resolution
        base_output = self.processors[0](hidden_states)
        outputs = [base_output]

        # Process at lower resolutions
        current_hidden = hidden_states.transpose(1, 2)  # [batch, hidden, seq_len]

        for i in range(1, self.num_resolutions):
            # Pool to lower resolution
            if seq_length >= 2**i:  # Only pool if sequence is long enough
                pooled = self.pooling_layers[i-1](current_hidden)

                # Process at this resolution
                pooled = pooled.transpose(1, 2)  # [batch, reduced_seq, hidden]
                processed = self.processors[i](pooled)

                # Unpool back to original resolution
                processed = processed.transpose(1, 2)  # [batch, hidden, reduced_seq]
                # Ensure the unpooling can restore to the right size
                padding_needed = seq_length - processed.size(2) * (2**i)
                if padding_needed > 0:
                    processed = F.pad(processed, (0, padding_needed))

                upsampled = self.unpooling_layers[i-1](processed)

                # Trim to original sequence length if needed
                if upsampled.size(2) > seq_length:
                    upsampled = upsampled[:, :, :seq_length]

                # Add this resolution's contribution
                outputs.append(upsampled.transpose(1, 2))  # [batch, seq, hidden]

        # Combine outputs from all resolutions
        combined_output = torch.stack(outputs, dim=0).sum(dim=0)
        return combined_output

# 3. Sparse-Dense Hybrid Layer
class SparseDenseHybridLayer(nn.Module):
    def __init__(self, hidden_size, intermediate_size, sparsity_ratio=0.8):
        super(SparseDenseHybridLayer, self).__init__()
        self.sparsity_ratio = sparsity_ratio

        # Dense processing path
        self.dense_ff = nn.Sequential(
            nn.Linear(hidden_size, intermediate_size),
            nn.GELU(),
            nn.Linear(intermediate_size, hidden_size)
        )

        # Sparse processing path (using grouped convolutions)
        self.sparse_groups = 8
        self.sparse_ff = nn.Sequential(
            nn.Conv1d(hidden_size, intermediate_size, kernel_size=1, groups=self.sparse_groups),
            nn.GELU(),
            nn.Conv1d(intermediate_size, hidden_size, kernel_size=1, groups=self.sparse_groups)
        )

    def forward(self, hidden_states):
        # Dense processing path
        dense_output = self.dense_ff(hidden_states)

        # Sparse processing path (using grouped convolutions)
        batch_size, seq_length, hidden_size = hidden_states.size()
        sparse_input = hidden_states.transpose(1, 2)  # [batch, hidden, seq]
        sparse_output = self.sparse_ff(sparse_input)
        sparse_output = sparse_output.transpose(1, 2)  # [batch, seq, hidden]

        # Combine with sparsity ratio
        output = (1 - self.sparsity_ratio) * dense_output + self.sparsity_ratio * sparse_output
        return output

# 4. Adaptive Parameter Efficiency
class AdaptiveParameterModule(nn.Module):
    def __init__(self, hidden_size, num_experts=4, expert_size=None):
        super(AdaptiveParameterModule, self).__init__()
        if expert_size is None:
            expert_size = hidden_size * 2

        # Create multiple expert networks
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, expert_size),
                nn.GELU(),
                nn.Linear(expert_size, hidden_size)
            ) for _ in range(num_experts)
        ])

        # Router network decides which experts to use for each input
        self.router = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, num_experts)
        )

    def forward(self, hidden_states):
        # Get routing probabilities for each token
        batch_size, seq_length, hidden_size = hidden_states.size()

        # Average the token representations for routing
        avg_repr = hidden_states.mean(dim=1)  # [batch, hidden]

        # Compute routing weights
        routing_weights = F.softmax(self.router(avg_repr), dim=-1)  # [batch, num_experts]

        # Apply each expert
        expert_outputs = []
        for i, expert in enumerate(self.experts):
            expert_output = expert(hidden_states)
            # Weight this expert's output by its routing probability
            weighted_output = expert_output * routing_weights[:, i].unsqueeze(1).unsqueeze(2)
            expert_outputs.append(weighted_output)

        # Combine all expert outputs
        combined_output = torch.stack(expert_outputs).sum(dim=0)
        return combined_output

# 5. Advanced Transformer Layer
class AdvancedTransformerLayer(nn.Module):
    def __init__(self, config):
        super(AdvancedTransformerLayer, self).__init__()
        self.attention = DynamicAttention(
            config.hidden_size,
            config.num_attention_heads,
            config.attention_probs_dropout_prob
        )
        self.multi_resolution = MultiResolutionTokenProcessor(config.hidden_size)
        self.sparse_dense = SparseDenseHybridLayer(config.hidden_size, config.intermediate_size)
        self.adaptive_params = AdaptiveParameterModule(config.hidden_size)

        self.attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.output_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

    def forward(self, hidden_states, attention_mask=None):
        # Self attention with residual connection
        attention_output = self.attention(self.attention_layernorm(hidden_states), attention_mask)
        hidden_states = hidden_states + attention_output

        # Process tokens at multiple resolutions
        multi_res_output = self.multi_resolution(hidden_states)
        hidden_states = hidden_states + multi_res_output

        # Apply sparse-dense processing
        intermediate_output = self.sparse_dense(self.output_layernorm(hidden_states))
        hidden_states = hidden_states + intermediate_output

        # Add adaptive parameter computation
        adaptive_output = self.adaptive_params(hidden_states)
        hidden_states = hidden_states + adaptive_output

        return hidden_states, None

# 6. Main Creative Language Model
class CreativeLanguageModel(nn.Module):
    def __init__(self, config):
        super(CreativeLanguageModel, self).__init__()
        self.config = config

        # Token embeddings
        self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)

        # Position embeddings
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)

        # Layer stack with our advanced transformer layers
        self.layers = nn.ModuleList([
            AdvancedTransformerLayer(config) for _ in range(config.num_hidden_layers)
        ])

        # Output normalization and projection
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Tie weights between input embeddings and output projection
        self.output_projection.weight = self.token_embeddings.weight

        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, input_ids, attention_mask=None, labels=None):
        batch_size, seq_length = input_ids.size()

        # Create position IDs
        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
        position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)

        # Get embeddings
        token_embeddings = self.token_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)

        # Combine embeddings
        hidden_states = token_embeddings + position_embeddings
        hidden_states = self.dropout(hidden_states)

        # Apply transformer layers
        for layer in self.layers:
            hidden_states, _ = layer(hidden_states, attention_mask)

        # Final layer norm and output projection
        hidden_states = self.layer_norm(hidden_states)
        logits = self.output_projection(hidden_states)

        # Calculate loss if labels provided
        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        return {'logits': logits, 'loss': loss}

# 7. Data loading and preparation
def load_and_prepare_data(config):
    # Use the Hugging Face datasets library to load a small dataset
    dataset = load_dataset("wikitext", "wikitext-2-v1")

    # Convert dataset to list of texts (just take a portion for faster training)
    train_texts = dataset["train"]["text"][:5000]
    val_texts = dataset["validation"]["text"][:500]

    # Remove empty strings
    train_texts = [text for text in train_texts if text.strip()]
    val_texts = [text for text in val_texts if text.strip()]

    # Use the GPT-2 tokenizer
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    # Create datasets and dataloaders
    train_dataset = TextDataset(train_texts, tokenizer, max_length=config.max_position_embeddings)
    val_dataset = TextDataset(val_texts, tokenizer, max_length=config.max_position_embeddings)

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)

    return train_loader, val_loader, tokenizer

# 8. Training loop with memory optimization
def train_model(model, train_loader, val_loader, config):
    # Set up optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=config.learning_rate,
        steps_per_epoch=len(train_loader) // config.gradient_accumulation_steps,
        epochs=config.epochs
    )

    # Mixed precision training for memory efficiency
    scaler = GradScaler() if config.use_mixed_precision else None

    # Move model to device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # Track best validation loss
    best_val_loss = float('inf')

    for epoch in range(config.epochs):
        # Training loop
        model.train()
        train_loss = 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.epochs} [Train]")
        optimizer.zero_grad()

        for i, batch in enumerate(progress_bar):
            # Get the inputs
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass with mixed precision
            with autocast() if config.use_mixed_precision else nullcontext():
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs['loss']
                loss = loss / config.gradient_accumulation_steps

            # Backward pass with mixed precision
            if config.use_mixed_precision:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            if (i + 1) % config.gradient_accumulation_steps == 0:
                # Clip gradients
                if config.use_mixed_precision:
                    scaler.unscale_(optimizer)

                torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)

                # Update weights
                if config.use_mixed_precision:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()

                optimizer.zero_grad()
                scheduler.step()

            # Update progress bar
            train_loss += loss.item() * config.gradient_accumulation_steps
            progress_bar.set_postfix({'loss': train_loss / (i + 1)})

            # Explicitly delete variables and garbage collect
            del input_ids, attention_mask, labels, outputs, loss
            torch.cuda.empty_cache()
            gc.collect()

        # Validation loop
        model.eval()
        val_loss = 0

        with torch.no_grad():
            progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.epochs} [Val]")

            for i, batch in enumerate(progress_bar):
                # Get the inputs
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                # Forward pass
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs['loss']

                # Update validation loss
                val_loss += loss.item()
                progress_bar.set_postfix({'loss': val_loss / (i + 1)})

                # Explicitly delete variables and garbage collect
                del input_ids, attention_mask, labels, outputs, loss
                torch.cuda.empty_cache()
                gc.collect()

        # Calculate average validation loss
        avg_val_loss = val_loss / len(val_loader)

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_val_loss,
            }, 'best_creative_llm_model.pt')

        print(f"Epoch {epoch+1}/{config.epochs} - Train Loss: {train_loss / len(train_loader):.4f} - Val Loss: {avg_val_loss:.4f}")

    return model


# 9. Text generation function
def generate_text(model, tokenizer, prompt, max_length=100, temperature=1.0, top_k=50, top_p=0.95):
    """Generate text given a prompt"""
    # Set model to evaluation mode
    model.eval()
    device = next(model.parameters()).device

    # Tokenize prompt
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

    # Generate tokens
    generated_tokens = []
    past_tokens = input_ids.clone()

    with torch.no_grad():
        for _ in range(max_length):
            # Create attention mask
            attention_mask = torch.ones_like(past_tokens)

            # Forward pass
            outputs = model(past_tokens, attention_mask=attention_mask)
            logits = outputs['logits']

            # Get the next token logits from the last position
            next_token_logits = logits[:, -1, :] / temperature

            # Apply top-k filtering
            if top_k > 0:
                indices_to_remove = torch.topk(next_token_logits, k=top_k)[0][:, -1].unsqueeze(-1) <= next_token_logits
                next_token_logits[indices_to_remove] = float('-inf')

            # Convert logits to probabilities and sample
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            # Add to generated tokens
            generated_tokens.append(next_token.item())

            # Update past_tokens for next iteration
            past_tokens = torch.cat((past_tokens, next_token), dim=1)

    # Convert back to text
    generated_ids = torch.cat([input_ids[0], torch.tensor(generated_tokens, device=device)], dim=0)
    generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)

    return generated_text

# 10. Main function to run everything
def main():
    # Set up configurations
    config = Config()

    # Apply memory optimizations
    torch.cuda.empty_cache()

    # Load and prepare data
    train_loader, val_loader, tokenizer = load_and_prepare_data(config)

    # Create model
    model = CreativeLanguageModel(config)

    # Print model summary
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Model created with {total_params:,} total parameters")

    # Train model
    model = train_model(model, train_loader, val_loader, config)

    # Generate sample text
    prompts = [
        "In a world where AI has become",
        "The solution to climate change is",
        "Once upon a time in a distant galaxy"
    ]

    print("\n=== Generated Text Samples ===")
    for prompt in prompts:
        generated_text = generate_text(model, tokenizer, prompt, max_length=50)
        print(f"\nPrompt: {prompt}")
        print(f"Generated: {generated_text}")

    # Save the final model
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': config
    }, 'creative_llm_model.pt')

# Define the nullcontext class for Python < 3.7
class nullcontext:
    def __init__(self, enter_result=None):
        self.enter_result = enter_result
    def __enter__(self):
        return self.enter_result
    def __exit__(self, *excinfo):
        pass

if __name__ == "__main__":
    main()

Model created with 321,406,534 total parameters


  scaler = GradScaler() if config.use_mixed_precision else None
  with autocast() if config.use_mixed_precision else nullcontext():
Epoch 1/3 [Train]: 100%|██████████| 3227/3227 [30:57<00:00,  1.74it/s, loss=1.01]
Epoch 1/3 [Val]: 100%|██████████| 337/337 [02:15<00:00,  2.48it/s, loss=0.31]


Epoch 1/3 - Train Loss: 1.0132 - Val Loss: 0.3104


Epoch 2/3 [Train]: 100%|██████████| 3227/3227 [30:31<00:00,  1.76it/s, loss=0.293]
Epoch 2/3 [Val]: 100%|██████████| 337/337 [02:14<00:00,  2.51it/s, loss=0.196]


Epoch 2/3 - Train Loss: 0.2929 - Val Loss: 0.1961


Epoch 3/3 [Train]:  99%|█████████▉| 3189/3227 [30:00<00:22,  1.72it/s, loss=0.207]

In [None]:
# --- extract_weights.py ---
# PURPOSE: Load a PyTorch .pt model, extract weights/config, save as .npz/.json
# REQUIREMENT: Must be run in an environment WITH PyTorch installed.

import torch
import numpy as np
import os
import json
from collections import OrderedDict

# === IMPORTANT: Define Config and Model classes EXACTLY as in your training script ===
# (Copy ALL the class definitions from your original training script here:
# Config, DynamicAttention, MultiResolutionTokenProcessor, SparseDenseHybridLayer,
# AdaptiveParameterModule, AdvancedTransformerLayer, CreativeLanguageModel)
# --- Start of pasted classes ---

# Configuration (Should match the one saved in your .pt file)
class Config:
    vocab_size = 50257
    hidden_size = 768
    num_hidden_layers = 10
    num_attention_heads = 12
    intermediate_size = 3072
    hidden_dropout_prob = 0.1 # Not used in inference, but part of class def
    attention_probs_dropout_prob = 0.1 # Not used in inference, but part of class def
    max_position_embeddings = 1024
    layer_norm_eps = 1e-12
    # --- Add any other params your specific classes might need in __init__ ---
    # Example defaults if they might be missing from a raw state dict:
    sparsity_ratio = 0.8
    num_experts = 4
    sparse_groups = 8
    num_resolutions = 3

import math
import torch.nn as nn
import torch.nn.functional as F

# --- Paste ALL your model class definitions here ---
# (DynamicAttention, MultiResolutionTokenProcessor, SparseDenseHybridLayer,
#  AdaptiveParameterModule, AdvancedTransformerLayer, CreativeLanguageModel)
# --- Make sure they are identical to the training script used ---
# --- to generate creative_llm_model.pt                  ---

# Example (YOU NEED TO PASTE THE FULL DEFINITIONS)
class DynamicAttention(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, dropout_prob=0.1):
        super(DynamicAttention, self).__init__()
        self.num_attention_heads = num_attention_heads
        self.attention_head_size = hidden_size // num_attention_heads
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.query = nn.Linear(hidden_size, self.all_head_size)
        self.key = nn.Linear(hidden_size, self.all_head_size)
        self.value = nn.Linear(hidden_size, self.all_head_size)
        self.complexity_estimator = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 4),
            nn.ReLU(),
            nn.Linear(hidden_size // 4, 3)
        )
        self.local_projector = nn.Linear(hidden_size, hidden_size)
        self.global_projector = nn.Linear(hidden_size, hidden_size)
        self.specialized_projector = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout_prob) # Dropout ignored in eval()
        self.output_linear = nn.Linear(hidden_size, hidden_size)
    # --- forward method NOT needed for weight extraction ---

class MultiResolutionTokenProcessor(nn.Module):
    def __init__(self, hidden_size, num_resolutions=3):
        super(MultiResolutionTokenProcessor, self).__init__()
        self.num_resolutions = num_resolutions
        self.hidden_size = hidden_size
        self.processors = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.LayerNorm(hidden_size),
                nn.GELU()
            ) for _ in range(num_resolutions)
        ])
        self.pooling_layers = nn.ModuleList([
            nn.Conv1d(hidden_size, hidden_size, kernel_size=2**i, stride=2**i, padding=0)
            for i in range(1, num_resolutions)
        ])
        self.unpooling_layers = nn.ModuleList([
            nn.ConvTranspose1d(hidden_size, hidden_size, kernel_size=2**i, stride=2**i, padding=0)
            for i in range(1, num_resolutions)
        ])
     # --- forward method NOT needed for weight extraction ---

class SparseDenseHybridLayer(nn.Module):
    def __init__(self, hidden_size, intermediate_size, sparsity_ratio=0.8, sparse_groups=8): # Added sparse_groups default
        super(SparseDenseHybridLayer, self).__init__()
        self.sparsity_ratio = sparsity_ratio # This value might be needed in numpy impl
        self.dense_ff = nn.Sequential(
            nn.Linear(hidden_size, intermediate_size),
            nn.GELU(),
            nn.Linear(intermediate_size, hidden_size)
        )
        self.sparse_groups = sparse_groups # Use parameter
        self.sparse_ff = nn.Sequential(
            nn.Conv1d(hidden_size, intermediate_size, kernel_size=1, groups=self.sparse_groups),
            nn.GELU(),
            nn.Conv1d(intermediate_size, hidden_size, kernel_size=1, groups=self.sparse_groups)
        )
    # --- forward method NOT needed for weight extraction ---

class AdaptiveParameterModule(nn.Module):
    def __init__(self, hidden_size, num_experts=4, expert_size=None):
        super(AdaptiveParameterModule, self).__init__()
        if expert_size is None: expert_size = hidden_size * 2
        self.num_experts = num_experts # Needed
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, expert_size),
                nn.GELU(),
                nn.Linear(expert_size, hidden_size)
            ) for _ in range(num_experts)
        ])
        self.router = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, num_experts)
        )
    # --- forward method NOT needed for weight extraction ---

class AdvancedTransformerLayer(nn.Module):
    def __init__(self, config):
        super(AdvancedTransformerLayer, self).__init__()
        self.attention = DynamicAttention(
            config.hidden_size, config.num_attention_heads, config.attention_probs_dropout_prob
        )
        # Pass necessary params from config
        self.multi_resolution = MultiResolutionTokenProcessor(config.hidden_size, getattr(config, 'num_resolutions', 3))
        self.sparse_dense = SparseDenseHybridLayer(config.hidden_size, config.intermediate_size, getattr(config, 'sparsity_ratio', 0.8), getattr(config, 'sparse_groups', 8))
        self.adaptive_params = AdaptiveParameterModule(config.hidden_size, getattr(config, 'num_experts', 4))

        self.attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.output_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
     # --- forward method NOT needed for weight extraction ---

class CreativeLanguageModel(nn.Module):
    def __init__(self, config):
        super(CreativeLanguageModel, self).__init__()
        self.config = config # Store config object
        self.token_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.layers = nn.ModuleList([
            AdvancedTransformerLayer(config) for _ in range(config.num_hidden_layers)
        ])
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        # Tie weights between input embeddings and output projection
        self.output_projection.weight = self.token_embeddings.weight # Weight tying happens here
        self.dropout = nn.Dropout(config.hidden_dropout_prob) # Ignored in eval()
        # self.apply(self._init_weights) # Not needed for loading

    # --- _init_weights and forward method NOT needed for weight extraction ---

# --- End of pasted model classes ---

# --- Configuration ---
# *** USE THE PROVIDED PATH FOR THE INPUT MODEL ***
MODEL_PATH = '/storage/emulated/0/Download/creative_llm_model.pt'

# Output files will be saved in the current directory (relative paths)
OUTPUT_NPZ_PATH = 'creative_llm_weights.npz'
CONFIG_PATH = 'creative_llm_config.json'

# --- Extraction Logic ---
if not os.path.exists(MODEL_PATH):
    # Provide a more specific error if the Android-like path isn't found
    if MODEL_PATH.startswith('/storage/emulated/0/'):
         print(f"Error: Model file not found at '{MODEL_PATH}'.")
         print("Ensure the file exists in your device's 'Download' folder")
         print("and that this script has permission to read it.")
    else:
         print(f"Error: Model file not found: {MODEL_PATH}")
    exit()

print(f"Loading model checkpoint from: {MODEL_PATH}")
# Handle CPU/GPU loading safely
map_location = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using map_location: {map_location}")

try:
    checkpoint = torch.load(MODEL_PATH, map_location=map_location)
except Exception as e:
    print(f"\n--- Error loading checkpoint with torch.load: {e} ---")
    print("This might happen due to:")
    print("  - Corrupted file.")
    print("  - Version mismatch between PyTorch used for saving and loading.")
    print("  - Insufficient RAM/permissions.")
    exit()

# Determine if checkpoint is state_dict itself or a dictionary containing it
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
    model_state_dict = checkpoint['model_state_dict']
    print("Loaded 'model_state_dict' from checkpoint dictionary.")
    # Try to load config from the same dictionary
    config_data = checkpoint.get('config')
    if config_data:
         print("Loaded 'config' from checkpoint dictionary.")
    else:
         print("Warning: 'config' not found in checkpoint dictionary. Using Config class defined above.")
         config_data = Config() # Fallback to class definition
elif isinstance(checkpoint, OrderedDict):
     model_state_dict = checkpoint
     print("Loaded raw state_dict (checkpoint was likely the state_dict itself).")
     print("Warning: 'config' not found in raw state_dict. Using Config class defined above.")
     config_data = Config() # Fallback to class definition
else:
    print("\n--- Error: Unexpected checkpoint format. ---")
    print("Expected a state_dict (OrderedDict) or a dict containing 'model_state_dict'.")
    print(f"Got type: {type(checkpoint)}")
    exit()


# Convert config object/class to a serializable dict
if isinstance(config_data, Config): # If it's the class instance
     config_dict = {k: v for k, v in config_data.__dict__.items() if not k.startswith('__') and not callable(v)}
elif isinstance(config_data, dict): # If it was already a dict
     config_dict = config_data
else:
     print("Warning: Could not determine config format. Saving empty config.")
     config_dict = {}

# Ensure core parameters are present in the final config_dict, using defaults if needed
default_config = Config()
for key, value in default_config.__dict__.items():
    if not key.startswith('__') and not callable(value):
        if key not in config_dict:
            print(f"Config Warning: Adding missing parameter '{key}' with default value '{value}'")
            config_dict[key] = value


# Convert state_dict tensors to NumPy arrays
numpy_weights = OrderedDict()
print("\nConverting weights to NumPy arrays...")
for key, tensor in model_state_dict.items():
    print(f"  Converting: {key} | Shape: {tensor.shape} | Dtype: {tensor.dtype}")
    # Ensure tensor is on CPU before converting to NumPy
    numpy_weights[key] = tensor.cpu().numpy()

print(f"\nSaving {len(numpy_weights)} weight arrays to: {OUTPUT_NPZ_PATH}")
np.savez(OUTPUT_NPZ_PATH, **numpy_weights)

print(f"Saving configuration to: {CONFIG_PATH}")
try:
    with open(CONFIG_PATH, 'w') as f:
        json.dump(config_dict, f, indent=4)
except TypeError as e:
    print(f"\n--- Error saving config to JSON: {e} ---")
    print("Some value in the config might not be JSON serializable.")
    print("Saving simplified config.")
    simplified_config = {k: str(v) for k, v in config_dict.items()} # Convert all to string as fallback
    with open(CONFIG_PATH, 'w') as f:
        json.dump(simplified_config, f, indent=4)


print("\n--- Weight Extraction Complete ---")
print(f"NumPy weights saved to: '{OUTPUT_NPZ_PATH}'")
print(f"Config saved to: '{CONFIG_PATH}'")
print("You can now use these files with the NumPy-based inference script.")
print("Ensure PyTorch is NOT required by the next script.")