In [None]:
#|default_exp lsd
#|export

from perturbative_llm_cognition import core

import torch
from torch import nn

In [None]:
#|export

class LSDPerturbedLLM:

    def __init__(
        self, 
        layer_start: int = 21,                              # Start layer to perturb    
        layer_end: int = 30,                                # Final layer to perturb
        attention_temperature_target: float = 1.20,         # Temperature > 1 flattens attention (scores /= temperature)
        attention_diagonal_penalty_target: float = 0.40,    # Subtract at most-recent key at decode (q_len==1)
        swiglu_skew_target: float = 0.15,                   # Magnitude skew exponent
        swiglu_noise_target: float = 0.12,                  # Structured noise scale (zero-mean) <- better name required
        js_tolerance: float = 0.12,                         # Tolerated JS divergence before clamping gets strong
        js_softness: float = 0.02,                          # Sigmoid softness around the tolerance <- better name required
        divergence_smoothing: float = 0.8,                              # EMA smoothing for JS
        teacher_blend_min: float = 0.05,                    # Min teacher blend weight
        teacher_blend_max: float = 0.60,                    # Max teacher blend weight when drift is high
        ):

        self.layer_start = max(layer_start, 0)
        self.layer_end = max(layer_end, 31)
        self.attention_temperature_target = max(1.0, attention_temperature_target)
        self.attention_diagonal_penalty_target = max(0.0, attention_diagonal_penalty_target)
        self.swiglu_skew_target = max(0.0, swiglu_skew_target)
        self.swiglu_noise_target = max(0.0, swiglu_noise_target)
        self.js_tolerance = max(0.0, js_tolerance)
        self.js_softness = max(0.0, js_softness)
        self.divergence_smoothing = max(0.0, divergence_smoothing)
        self.teacher_blend_min = max(0.0, teacher_blend_min)
        self.teacher_blend_max = min(1.0, teacher_blend_max)

        if self.teacher_blend_min > self.teacher_blend_max:
            raise ValueError('Invalid teacher parameters')

        self.set_perturbation_parameters(
            attention_temperature=attention_temperature_target,
            attention_diagonal_penalty=attention_diagonal_penalty_target,
            teacher_blend=teacher_blend_min,
            #swiglu_skew_target=swiglu_skew_target,
            #swiglu_noise_target=swiglu_noise_target,
        )

        self.tokenizer, self.model = core.load_tokenizer_and_model()
        self.model.config.attn_implementation = 'eager'
        self.model.config._attn_implementation = 'eager'  # Also try this
        self.device = self.model.device

        self._store_original_attention_functions()

    def set_perturbation_parameters(self, **kwargs):
        """
        Clamps all perturbation parameters to valid ranges.
        """

        for parameter, value in kwargs.items():
            if parameter in ['attention_temperature']:
                setattr(self, parameter, max(1.0, value))
            elif parameter in ['teacher_blend']:
                setattr(self, parameter, max(0.0, min(1.0, value)))
            else:
                setattr(self, parameter, max(0.0, value))

    def _store_original_attention_functions(self):
        """Store original attention forward functions for reset capability"""
        self.original_attention_forwards = {}
        for i, block in enumerate(self.model.model.layers):
            if self.is_target_layer(i):
                self.original_attention_forwards[i] = block.self_attn.forward

    def reset_to_base_model(self):
        """Completely reset model to base state by restoring original functions"""
        for i, block in enumerate(self.model.model.layers):
            if self.is_target_layer(i) and i in self.original_attention_forwards:
                block.self_attn.forward = self.original_attention_forwards[i]

    def is_target_layer(self, index: int) -> bool:
        return self.layer_start <= index < self.layer_end

    def apply_perturbation(self):
        """Apply perturbations while preserving original functions"""
        # Always recreate the perturbation functions to get updated parameters
        self.apply_attention_perturbation()


    def apply_attention_perturbation(self):
        """Apply perturbations by directly modifying attention weights"""
        
        for i, block in enumerate(self.model.model.layers):
            if not self.is_target_layer(i):
                continue
                
            print(f"Applying perturbation to layer {i}")
            
            # Store original forward
            original_forward = block.self_attn.forward
            
            def create_perturbed_forward(original_forward, layer_idx, self_ref):
                def perturbed_forward(*args, **kwargs):
                    print(f"DEBUG: Perturbed attention called for layer {layer_idx}")
                    
                    # Call the original forward to get the result
                    result = original_forward(*args, **kwargs)
                    
                    # If we got attention weights, modify them
                    if isinstance(result, tuple) and len(result) >= 2:
                        
                        if len(result) == 2:
                            attn_output, attn_weights = result
                            past_key_value = None
                        else:
                            attn_output, attn_weights, past_key_value = result

                        if attn_weights is not None:
                            print(f"Found attention weights with shape: {attn_weights.shape}")
                            
                            # Apply your modifications to the attention weights
                            modified_weights = attn_weights.clone()
                            
                            # Temperature scaling
                            modified_weights = modified_weights / self.attention_temperature
                            
                            # Diagonal penalty for decode step
                            if modified_weights.size(-2) == 1:
                                modified_weights[..., -1] = modified_weights[..., -1] - self.attention_diagonal_penalty

                            # Apply softmax
                            attn_weights = torch.softmax(attn_weights, dim=-1)
                            
                            # Apply attention to values
                            attn_output = torch.matmul(attn_weights, value_states)
                            
                            # Reshape back
                            attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, num_heads * head_dim)
                            
                            # Apply output projection
                            attn_output = attn_layer.o_proj(attn_output)
                            
                            
                            print(f"Applied modifications: tau={self.attention_temperature}, diag={self_ref.attention_diagonal_penalty}")
                            
                            # Return modified result
                            if past_key_value is not None:
                                return attn_output, modified_weights, past_key_value
                            else:
                                return attn_output, modified_weights
                    
                    return result
                        
                return perturbed_forward
            
            block.self_attn.forward = create_perturbed_forward(original_forward, i, self)

    @staticmethod
    def js_divergence(p, q, eps=1e-8):
        # Add small epsilon to avoid log(0) issues
        p = p + eps
        q = q + eps
        
        # Renormalise
        p = p / p.sum(-1, keepdim=True)
        q = q / q.sum(-1, keepdim=True)
        
        m = 0.5 * (p + q)
        def kl(a, b):
            return (a * (torch.log(a) - torch.log(b))).sum(-1)
        return 0.5 * kl(p, m) + 0.5 * kl(q, m)

    def step(self, pert_logits, base_logits):
        # Debug: check if logits are identical
        are_identical = torch.allclose(pert_logits, base_logits, atol=1e-6)
        if are_identical:
            print("WARNING: Perturbed and base logits are identical!")
        
        # Distributions
        p_base = torch.softmax(base_logits, dim=-1)
        p_pert = torch.softmax(pert_logits, dim=-1)
        # JS divergence (batch-mean scalar)
        d = self.js_divergence(p_pert, p_base).mean().item()
        print(f"JS divergence: {d}")
        # EMA
        beta = self.divergence_smoothing
        self.running_divergence = beta * self.running_divergence + (1 - beta) * d
        # Gain (higher drift → smaller g)
        gain = torch.sigmoid(torch.tensor((self.js_tolerance - self.running_divergence) / self.js_softness)).item()

        # Effective strengths
        attention_temperature  = 1.0 + (self.attention_temperature_target - 1.0) * gain
        attention_diagonal_penalty = self.attention_diagonal_penalty_target * gain
        #skew_eff = self.cfg.swiglu_skew_target * g
        #noise_eff= self.cfg.swiglu_noise_target * g

        # Teacher blend increases as drift grows (g small)
        teacher_blend = self.teacher_blend_min + (1 - gain) * (self.teacher_blend_max - self.teacher_blend_min)
        return {
            "attention_temperature": float(attention_temperature),
            "attention_diagonal_penalty": float(attention_diagonal_penalty),
            #"skew_eff": float(skew_eff),
            #"noise_eff": float(noise_eff),
            "teacher_blend": float(teacher_blend),  # Fixed parameter name
            #"ema_js": float(self.running_divergence),
            #"gain": float(gain),
        }


    @torch.no_grad()
    def generate_with_leash(self,
                            prompt: str,
                            max_new_tokens: int = 128,
                            temperature: float = 0.7,
                            top_p: float = 0.95,
                            repetition_penalty: float = 1.15):


        input_ids = self.tokenizer(prompt, return_tensors="pt").to(self.device)['input_ids']
        generated = input_ids
        past_key_values = None

        self.running_divergence = 0.0

        # We run two forwards per step: base (no perturb), then perturbed (with current state)
        for step in range(max_new_tokens):
            # ---- BASE pass (perturbation off)
            print(f'base temperature: {self.attention_temperature}')
            self.reset_to_base_model()
            base_output = self.model(
                input_ids=generated, 
                use_cache=True, 
                past_key_values=past_key_values, 
                output_hidden_states=True)
            
            base_logits = base_output.logits[:, -1, :]

            self.apply_perturbation()
            print(f'perturbed temperature: {self.attention_temperature}')
            perturbed_output = self.model(
                input_ids=generated, 
                use_cache=True, 
                past_key_values=past_key_values, 
                output_hidden_states=True)

            perturbed_logits = perturbed_output.logits[:, -1, :]

            next_step = self.step(perturbed_logits, base_logits)
            self.set_perturbation_parameters(**next_step)
            print(self.attention_temperature, self.attention_diagonal_penalty, self.teacher_blend)
            past_key_values = base_output.past_key_values

            # ---- Final logits for sampling: teacher blend
            final_logits = (1 - self.teacher_blend) * perturbed_logits + self.teacher_blend * base_logits

            # ---- Sample
            final_logits = final_logits / max(1e-5, temperature)

            if repetition_penalty != 1.0:
                # Get recent tokens (last 50 tokens)
                recent_tokens = generated[0, -50:] if generated.shape[1] > 50 else generated[0]
                for token_id in recent_tokens:
                    if final_logits[0, token_id] < 0:
                        final_logits[0, token_id] *= repetition_penalty
                    else:
                        final_logits[0, token_id] /= repetition_penalty

            
            probs = torch.softmax(final_logits, dim=-1)
            if top_p < 1.0:
                sorted_probs, sorted_idx = torch.sort(probs, descending=True)
                cum = torch.cumsum(sorted_probs, dim=-1)
                mask = cum > top_p
                mask[..., 1:] = mask[..., :-1].clone()
                mask[..., 0] = False
                sorted_probs[mask] = 0
                sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True).clamp_min(1e-9)
                next_idx = torch.multinomial(sorted_probs, num_samples=1)
                next_token = torch.gather(sorted_idx, -1, next_idx)
            else:
                next_token = torch.multinomial(probs, num_samples=1)

            generated = torch.cat([generated, next_token], dim=-1)

        return self.tokenizer.decode(generated[0], skip_special_tokens=True)

    #dtype = torch.bfloat16
    #force_eager_attention: bool = True force either way


In [None]:
#|test
model = LSDPerturbedLLM(
    layer_start=21,                              # Start layer to perturb
    layer_end=30,                               # Final layer to perturb  
    attention_temperature_target=1.40,          # Temperature > 1 flattens attention
    attention_diagonal_penalty_target=0.40,     # Reduced penalty for most recent key
    js_tolerance=0.25,                          # Tolerated JS divergence
    teacher_blend_min=0.10,                     # Min teacher blend weight
    teacher_blend_max=0.60,                     # Max teacher blend weight
)

# Generate text with the leash mechanism
prompt = 'Describe what you see looking out at the sea using a 100 word sentence'
print(f'Prompt: {prompt}')
response = model.generate_with_leash(
    prompt=prompt,
    max_new_tokens=128,
    temperature=0.7,
    top_p=0.9,
    repetition_penalty=1.5
)

print(f'Response: {response}')
