In [None]:
!pip install transformers torch wandb datasets tqdm huggingface_hub bitsandbytes

In [None]:
from huggingface_hub import login

# Login with your token
login("your token")

import wandb
wandb.login(key="your token")

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import LlamaForCausalLM, PreTrainedTokenizerFast
import wandb
from tqdm import tqdm
from datasets import load_dataset
from torch.utils.data import DataLoader
import re

import bitsandbytes as bnb

class LatentThoughtLLM(nn.Module):
    def __init__(self, model_name="meta-llama/Llama-3.2-1B-Instruct"):
        super().__init__()

        self.model = LlamaForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
        )

        self.hidden_size = self.model.config.hidden_size
        self.tokenizer = PreTrainedTokenizerFast.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Normalization layers
        self.rms_norm = nn.LayerNorm(self.hidden_size, elementwise_affine=False)

    def forward(self, input_ids, attention_mask=None, labels=None, prompt_mask=None, alpha_schedule=0.0):
        # First pass to get hidden states
        with torch.no_grad():
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
                output_hidden_states=True
            )
        
        # Get base embeddings and hidden states
        token_embeddings = self.model.get_input_embeddings()(input_ids)
        last_hidden = outputs.hidden_states[-1]
        
        # Normalize hidden states
        normalized_hidden = self.rms_norm(last_hidden)
        scale_factor = token_embeddings.norm(dim=-1).mean() / normalized_hidden.norm(dim=-1).mean()
        normalized_hidden = normalized_hidden * scale_factor
        
        # Shift hidden states left by one position
        normalized_hidden = torch.cat([normalized_hidden[:, 1:], normalized_hidden[:, -1:]], dim=1)

        mixed_embeddings = token_embeddings.clone()
        is_reasoning_token = prompt_mask.bool() & attention_mask.bool() 
        
        # Mix embeddings only for reasoning tokens
        mixed_embeddings[is_reasoning_token] = (
            alpha_schedule * normalized_hidden[is_reasoning_token] + 
            (1 - alpha_schedule) * token_embeddings[is_reasoning_token]
        )

        # Debug prints
        print(f"\nBefore mixing:")
        print(f"First hidden norm: {token_embeddings.norm(dim=-1).mean():.3f} std: {token_embeddings.std(dim=-1).mean():.3f}")
        print(f"Normalized hidden states norm: {normalized_hidden.norm(dim=-1).mean():.3f} std: {normalized_hidden.std(dim=-1).mean():.3f}")
        print(f"Alpha: {alpha_schedule:.3f}")
        print(f"Number of tokens to mix: {is_reasoning_token.bool().sum().item()}")
        print(f"Mixed embeddings norm: {mixed_embeddings[is_reasoning_token].norm(dim=-1).mean():.3f} std: {mixed_embeddings[is_reasoning_token].std(dim=-1).mean():.3f}")

        # Forward through model again with mixed embeddings
        final_outputs = self.model(
            inputs_embeds=mixed_embeddings,
            attention_mask=attention_mask,
            labels=labels,
            output_hidden_states=True
        )

        final_hidden = final_outputs.hidden_states[-1]
        
        return final_outputs, last_hidden, final_hidden, is_reasoning_token



def LatentThoughtLoss(outputs, hidden_states, target_embeddings, is_reasoning_token):
    # Standard LM loss from outputs, only for reasoning tokens
    ce_loss = outputs.loss
    
    # Embedding matching loss only for reasoning tokens
    mse_loss = F.mse_loss(
        hidden_states[is_reasoning_token],
        target_embeddings[is_reasoning_token]
    )
    
    cos_loss = 1 - F.cosine_similarity(
        hidden_states[is_reasoning_token],
        target_embeddings[is_reasoning_token],
        dim=-1
    ).mean()

    clipped_mse = min(mse_loss, 2)
    embed_loss = 0.5 * clipped_mse + 1 * cos_loss
    
    # Combine losses
    total_loss = ce_loss + embed_loss
    
    return {
        'total_loss': total_loss,
        'ce_loss': ce_loss,
        'mse_loss': mse_loss,
        'cos_loss': cos_loss
    }

def extract_boxed_answer(text):
    """Extract answer from \boxed{} format"""
    match = re.search(r'boxed{(.*?)}', text)
    return match.group(1) if match else None

def prepare_dataset(tokenizer, max_length=1024):
    dataset = load_dataset("ant-des/filtered_reasoning_deepseek", split='train')
    
    def process_example(example):
        prompt = example['messages'][0]['content']
        answer = example['answer']
        
        # Tokenize prompt and answer separately
        prompt_tokens = tokenizer(prompt, add_special_tokens=True)  # Include BOS token
        answer_tokens = tokenizer(answer, add_special_tokens=False)  # Don't add extra tokens
        
        # Combine and truncate if needed
        combined_input_ids = prompt_tokens['input_ids'] + answer_tokens['input_ids']
        if len(combined_input_ids) > max_length:
            combined_input_ids = combined_input_ids[:max_length]
        
        # Pad to max_length
        padding_length = max_length - len(combined_input_ids)
        input_ids = combined_input_ids + [tokenizer.pad_token_id] * padding_length
        
        # Create attention mask (1 for real tokens, 0 for padding)
        attention_mask = [1] * len(combined_input_ids) + [0] * padding_length
        
        # Create prompt mask (0 for prompt, 1 for answer)
        prompt_mask = [0] * len(prompt_tokens['input_ids']) + [1] * (len(answer_tokens['input_ids']))
        prompt_mask = prompt_mask + [0] * padding_length  # Add zeros for padding
        
        # Create labels (-100 for prompt and padding)
        labels = [-100] * len(prompt_tokens['input_ids'])  # Mask prompt
        labels.extend(answer_tokens['input_ids'])  # Add answer tokens as labels
        labels.extend([-100] * padding_length)  # Mask padding
        
        # Ensure all tensors are the same length
        input_ids = input_ids[:max_length]
        attention_mask = attention_mask[:max_length]
        prompt_mask = prompt_mask[:max_length]
        labels = labels[:max_length]
        
        return {
            'input_ids': torch.tensor(input_ids),
            'attention_mask': torch.tensor(attention_mask),
            'prompt_mask': torch.tensor(prompt_mask),
            'labels': torch.tensor(labels)
        }

    processed_dataset = dataset.map(
        process_example,
        remove_columns=dataset.column_names
    )

    processed_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'prompt_mask', 'labels'])
    return processed_dataset


def train_step(batch, model, step):
    # Calculate alpha for progressive schedule
    alpha = min(step / 1000, 1.0)  # Progressive schedule over 1000 steps
    #alpha = 0.0

    # Move batch to device
    batch = {k: v.to(model.device) for k, v in batch.items()}
    
    # Forward pass
    outputs, hidden_states, final_hidden, is_reasoning_token = model(
        input_ids=batch['input_ids'],
        attention_mask=batch['attention_mask'],
        labels=batch['labels'],
        prompt_mask=batch['prompt_mask'],
        alpha_schedule=alpha
    )
    # Calculate loss
    loss_dict = LatentThoughtLoss(
        outputs=outputs,
        hidden_states=hidden_states,
        target_embeddings=final_hidden,
        is_reasoning_token=is_reasoning_token
    )

    # Backward pass
    loss_dict['total_loss'].backward()
    
    # Log metrics
    metrics = {
        'train/loss': loss_dict['total_loss'].item(),
        'train/ce_loss': loss_dict['ce_loss'].item(),
        'train/mse_loss': loss_dict['mse_loss'].item(),
        'train/cos_loss': loss_dict['cos_loss'].item(),
        'train/alpha': alpha,
    }
    
    return metrics

def main():
    # Initialize wandb
    wandb.init(project="latent-thought-cot", name="first-experiment")
    
    model = LatentThoughtLLM("meta-llama/Llama-3.2-1B-Instruct")
    
    # Optimizer setup
    #optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    optimizer = bnb.optim.AdamW8bit(
        model.parameters(),
        lr=1e-5,
        betas=(0.9, 0.95)
    )
    
    # Load dataset
    train_dataset = prepare_dataset(model.tokenizer, max_length=1024)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
    
    # Training loop
    num_epochs = 10
    global_step = 0
    max_grad_norm = 1.0
    gradient_accumulation_steps = 8 

    for epoch in range(num_epochs):

        model.train()
        
        # Training loop with tqdm
        progress_bar = tqdm(
            train_loader,
            desc=f"Epoch {epoch+1}/{num_epochs}",
            leave=True
        )
        
        for batch_idx, batch in enumerate(progress_bar):
           
            # Training step
            metrics = train_step(batch, model, global_step)
            torch.cuda.empty_cache()
            
            # Log to wandb
            wandb.log(metrics)
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f"{metrics['train/loss']:.4f}",
                'ce_loss': f"{metrics['train/ce_loss']:.4f}",
                'mse_loss': f"{metrics['train/mse_loss']:.4f}",
                'cos_loss': f"{metrics['train/cos_loss']:.4f}",
                'alpha': f"{metrics['train/alpha']:.2f}"
            })
            
            global_step += 1

            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
                optimizer.step()
                optimizer.zero_grad()

            
            # Debug prints every 100 steps
            if global_step % 50 == 0:
                model.eval()
                with torch.no_grad():
                    # Move sample batch to device
                    sample_batch = {k: v[:1].to(model.device) for k, v in batch.items()}

                    # Get non-padding tokens (move to same device as output_tokens)
                    valid_tokens = sample_batch['attention_mask'][0].bool().to(model.device)
                    
                    print("\nOriginal prompt:")
                    print(model.tokenizer.decode(sample_batch['input_ids'][0][valid_tokens]))
                    
                    sample_output = model(
                        input_ids=sample_batch['input_ids'],
                        attention_mask=sample_batch['attention_mask'],
                        prompt_mask=sample_batch['prompt_mask'],
                        alpha_schedule=metrics['train/alpha']
                    )
                    
                    # Decode and print sample (only non-padding tokens)
                    output_tokens = sample_output[0].logits[0].argmax(dim=-1)
                    decoded = model.tokenizer.decode(
                        output_tokens[valid_tokens].cpu(),
                        skip_special_tokens=True
                    )
                    print("\nSample output:")
                    print(decoded)
                    print("\nExtracted answer:", extract_boxed_answer(decoded))
                
                model.train()
                torch.cuda.empty_cache()

if __name__ == "__main__":
    main()