In [None]:
#|default_exp lsd
#|export

from perturbative_llm_cognition import core

import torch
import numpy as np
import math
import types

from torch import nn
from transformers.models.mistral.modeling_mistral import (
    apply_rotary_pos_emb,
    repeat_kv,
)

import nltk
from nltk.corpus import stopwords

In [14]:
#|export

from typing import Any


class LSDPerturbedLLM:

    def __init__(
        self, 
        layer_start: int = 21,                              # Start layer to perturb    
        layer_end: int = 27,      
        attention_scaling_factor: float = 1.4,
        attention_noise: float = 0.3,
        attention_diagonal_penalty:float = 0.2, 
        attention_probability_smoothing_factor: float = 0.5,
        js_tolerance: float = 0.2,                         # Tolerated JS divergence before clamping gets strong
        teacher_blend_min: float = 0.10,                    # Min teacher blend weight
        teacher_blend_max: float = 0.50,  
        debug: bool = False                  # Max teacher blend weight when drift is high
        ):

        self.layer_start = max(layer_start, 0)
        self.layer_end = max(layer_end, 31)
        self.js_tolerance = max(0.0, js_tolerance)
        #self.js_softness = max(0.0, js_softness)
        self.teacher_blend_min = max(0.0, teacher_blend_min)
        self.teacher_blend_max = min(1.0, teacher_blend_max)

        if self.teacher_blend_min > self.teacher_blend_max:
            raise ValueError('Invalid teacher parameters')


        self.attention_scaling_factor = attention_scaling_factor
        self.attention_noise = attention_noise
        self.attention_diagonal_penalty = attention_diagonal_penalty
        self.attention_probability_smoothing_factor = attention_probability_smoothing_factor


        self.tokenizer, self.model = core.load_tokenizer_and_model()
        self.model.config.attn_implementation = 'eager'
        self.model.config._attn_implementation = 'eager'  # Also try this
        self.device = self.model.device

        self._store_original_attention_functions()

        try:
            self.stop_words = self.tokenizer(' '.join(stopwords.words('english')), return_tensors='pt'
            ).to(self.device)['input_ids']
        except LookupError:
            nltk.download('stopwords')
            self.stop_words = self.tokenizer(' '.join(stopwords.words('english')), return_tensors='pt'
            ).to(self.device)['input_ids']

        self.debug = debug

    def _store_original_attention_functions(self):
        """Store original attention forward functions for reset capability"""
        self.original_attention_forwards = {}
        for i, block in enumerate(self.model.model.layers):
            if self.is_target_layer(i):
                self.original_attention_forwards[i] = block.self_attn.forward

    def reset_to_base_model(self):
        """Completely reset model to base state by restoring original functions"""
        for i, block in enumerate(self.model.model.layers):
            if self.is_target_layer(i) and i in self.original_attention_forwards:
                block.self_attn.forward = self.original_attention_forwards[i]

    def is_target_layer(self, index: int) -> bool:
        return self.layer_start <= index < self.layer_end

    def apply_perturbation(self):
        """Apply perturbations while preserving original functions"""
        # Always recreate the perturbation functions to get updated parameters
        for i, block in enumerate(self.model.model.layers):
            if self.is_target_layer(i):
                layer = block.self_attn
                self.create_perturbed_forward(layer, i)

    def create_perturbed_forward(self, layer, layer_idx):
        """Create a perturbed forward function following the same logic as the provided code"""
        attention_scaling_factor = self.attention_scaling_factor
        attention_noise = self.attention_noise
        diag_penalty = self.attention_diagonal_penalty
        attention_probability_smoothing_factor = self.attention_probability_smoothing_factor
        debug = self.debug
        
        def forward(
            self,
            hidden_states,
            attention_mask=None,
            position_ids=None,
            past_key_value=None,
            output_attentions=False,
            use_cache=False,
            cache_position=None,
        ):
            b, q_len, _ = hidden_states.size()
            
            # Q K V projections
            q = self.q_proj(hidden_states).view(b, q_len, self.num_heads, self.head_dim).transpose(1, 2)
            k = self.k_proj(hidden_states).view(b, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
            v = self.v_proj(hidden_states).view(b, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
            
            # RoPE (rotary positional embeddings)
            seq_len = v.shape[-2]  # or q_len, whichever matches your attention shape
            cos, sin = self.rotary_emb(
                v,
                position_ids if position_ids is not None else torch.arange(seq_len, device=v.device).unsqueeze(0))
            
            q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids=None)
            
            # Cache update
            if past_key_value is not None:
                k, v = past_key_value.update(k, v, layer_idx, {'sin': sin, 'cos': cos, 'cache_position': cache_position})
            
            # Expand KV
            k = repeat_kv(k, self.num_key_value_groups)
            v = repeat_kv(v, self.num_key_value_groups)
            
            # Attention logits
            attention_scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
            if attention_mask is not None:
                attention_scores = attention_scores + attention_mask[:, :, :, : k.size(-2)]
            
            # === Perturbation modifications ===
            attention_scores = attention_scores / attention_scaling_factor

            # Add noise for exploration
            noise = torch.randn_like(attention_scores) * attention_noise
            attention_scores = attention_scores + noise

            # Penalise diagonal/self-attention to look elsewhere
            if attention_scores.shape[-2] == 1:
                attention_scores[..., -1] = attention_scores[..., -1] - diag_penalty

            # Softmax
            attention_probabilities = nn.functional.softmax(attention_scores, dim=-1).to(q.dtype)

            # Flatten distribution for more uniform attention
            attention_probabilities = attention_probabilities ** attention_probability_smoothing_factor
            attention_probabilities = attention_probabilities / attention_probabilities.sum(dim=-1, keepdim=True)

            attention_probabilities = nn.functional.dropout(attention_probabilities, p=self.attention_dropout, training=self.training)
            
            # Weighted sum
            attention_output = torch.matmul(attention_probabilities, v).transpose(1, 2).contiguous().view(b, q_len, -1)
            attention_output = self.o_proj(attention_output)
            
            if debug:
                print(f'[L{layer_idx}] attention_scaling_factor={attention_scaling_factor} attention_noise={attention_noise} diag={diag_penalty}')
            
            return (
                attention_output,
                (attention_probabilities if output_attentions else None),
                (past_key_value if use_cache else None),
            )
        
        # Bind the forward function to the layer
        layer.forward = types.MethodType(forward, layer)
        return layer

    @staticmethod
    def js_divergence(p, q, epsilon=1e-6, threshold=1e-6):
        mask = (p > threshold) | (q > threshold)
    
        if mask.sum() == 0:
            raise ValueError('No valid probabilities found in input distributions')
        
        p = p[mask]
        q = q[mask]

        # Add small epsilon to avoid log(0) issues
        p = p + epsilon
        q = q + epsilon
        # Renormalise
        p = p / p.sum(-1, keepdim=True)
        q = q / q.sum(-1, keepdim=True)
        
        m = 0.5 * (p + q)

        def kl(a, b):
            return (a * (torch.log(a) - torch.log(b))).sum(-1)

        return 0.5 * kl(p, m) + 0.5 * kl(q, m)

    def set_teacher_blend(self, pert_logits, base_logits):
        # Debug: check if logits are identical
        are_identical = torch.allclose(pert_logits, base_logits, atol=1e-6)
        if are_identical:
            print("WARNING: Perturbed and base logits are identical!")

        # Distributions
        p_base = torch.softmax(base_logits, dim=-1)
        p_pert = torch.softmax(pert_logits, dim=-1)
        
        # JS divergence (batch-mean scalar)
        divergence = self.js_divergence(p_pert, p_base).mean().item()
        if np.isnan(divergence):
            print(f'nan divergence: {divergence}')
            divergence = 0.0
            
        # EMA
        # Gain (high drift → backwards step, low drift → forwards step)
        divergence = divergence * 100
        js_tolerance = self.js_tolerance * 100

        teacher_blend_percentage = ((js_tolerance - divergence) / js_tolerance)
        teacher_blend_percentage = min(max(0.0, teacher_blend_percentage), 1.0)
        self.teacher_blend = self.teacher_blend_max - (self.teacher_blend_max - self.teacher_blend_min) * teacher_blend_percentage

        if self.debug:
            print(f'Divergence: {divergence}, Teacher blend percentage: {teacher_blend_percentage}')


    @torch.no_grad()
    def generate_with_leash(self,
                            prompt: str,
                            max_new_tokens: int = 128,
                            temperature: float = 0.7,
                            top_p: float = 0.95,
                            repetition_penalty: float = 1.15,
                            only_new_tokens: bool = False):

        input_ids = self.tokenizer(prompt, return_tensors='pt').to(self.device)['input_ids']
        generated = input_ids.clone()
        
        base_past_key_values = None
        perturbed_past_key_values = None

        self.running_divergence = 0.0

        # We run two forwards per step: base (no perturb), then perturbed (with current state)
        for step in range(max_new_tokens):
            # Determine input_ids for this step (only new tokens if using cache)
            if base_past_key_values is None:
                # First step: use full prompt
                current_input_ids = input_ids
            else:
                # Subsequent steps: only use the last generated token
                current_input_ids = generated[:, -1:]

            # ---- BASE pass (perturbation off)
            self.reset_to_base_model()
            base_output = self.model(
                input_ids=current_input_ids, 
                use_cache=True, 
                past_key_values=base_past_key_values, 
                output_hidden_states=True)
            
            base_logits = base_output.logits[:, -1, :]

            # ---- PERTURBED pass (with perturbations)
            self.apply_perturbation()
            perturbed_output = self.model(
                input_ids=current_input_ids, 
                use_cache=True, 
                past_key_values=perturbed_past_key_values,
                output_hidden_states=True)

            perturbed_logits = perturbed_output.logits[:, -1, :]

            self.set_teacher_blend(perturbed_logits, base_logits)

            base_past_key_values = base_output.past_key_values
            perturbed_past_key_values = perturbed_output.past_key_values

            # ---- Final logits for sampling: teacher blend
            if self.debug:
                print(f'teacher blend: {self.teacher_blend}')
            final_logits = (1 - self.teacher_blend) * perturbed_logits + self.teacher_blend * base_logits

            # ---- Sample
            final_logits = final_logits / temperature

            if repetition_penalty > 1.0:
                unique_recent_tokens = set(generated[0, :].tolist())# A set of unique recent tokens for efficient lookup
                
                for token_id in unique_recent_tokens:
                    #if token_id not in self.stop_words:
                    # Apply the repetition penalty to only the tokens that have appeared recently
                    if final_logits[0, token_id] < 0:
                        final_logits[0, token_id] = final_logits[0, token_id] * repetition_penalty
                    else:
                        final_logits[0, token_id] = final_logits[0, token_id] / repetition_penalty


            probs = torch.softmax(final_logits, dim=-1)
            if top_p < 1.0:
                sorted_probs, sorted_idx = torch.sort(probs, descending=True)
                cum = torch.cumsum(sorted_probs, dim=-1)
                mask = cum > top_p
                mask[..., 1:] = mask[..., :-1].clone()
                mask[..., 0] = False
                sorted_probs[mask] = 0
                sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True).clamp_min(1e-9)
                next_idx = torch.multinomial(sorted_probs, num_samples=1)
                next_token = torch.gather(sorted_idx, -1, next_idx)
            else:
                next_token = torch.multinomial(probs, num_samples=1)

            if self.debug:
                print(self.tokenizer.decode(next_token[0]))

            generated = torch.cat([generated, next_token], dim=-1)

            if next_token.item() == self.tokenizer.eos_token_id:
                break

        if only_new_tokens:
            new_tokens = generated[:, input_ids.shape[-1]:]
            llm_response = self.tokenizer.decode(new_tokens[0], skip_special_tokens=True)
        else:
            llm_response = self.tokenizer.decode(generated[0], skip_special_tokens=True)

        return llm_response

    #dtype = torch.bfloat16
    #force_eager_attention: bool = True force either way


In [16]:
#|test
model = LSDPerturbedLLM(
    layer_start=5,                              # Start layer to perturb
    layer_end=25,#30                               # Final layer to perturb  
    attention_scaling_factor=1.7,
    attention_noise=0.3,
    attention_diagonal_penalty=-0.9,
    attention_probability_smoothing_factor=0.5,
    js_tolerance=0.3,                          # Tolerated JS divergence
    teacher_blend_min=0.1,                     # Min teacher blend weight
    teacher_blend_max=0.3,                     # Max teacher blend weight
)

# Generate text with the leash mechanism
#prompt = 'You feel a slipping feeling pulling you in'#As a friend describe what you see looking out at the sea
prompt = 'As a friend describe what you see looking out at the sea'
#prompt = 'Why did we choose to go to the Moon? '
#prompt = 'What do you see standing on the surface of the moon?'
#prompt = 'What is the capital of France?'
#prompt = 'What bear is best?'

#prompt = 'You feel a slipping feeling pulling you in'
#prompt = 'What is love?'
#prompt = 'Why do all cultures share a fear of the dark?'

print(f'Prompt: {prompt}')
response = model.generate_with_leash(
    prompt=prompt,
    max_new_tokens=128,
    temperature=0.7,
    top_p=0.9,
    repetition_penalty=1.9,#100
    only_new_tokens=False
)

print(f'Response: {response}')



#I am a person who is standing on an beach and I can describe in detail about how it feels to be able of being described as if Describing
# 1. A flag, a manned by what or who is it that stands in front-what does one can I find out there and where are You may not be able to determine any information about something else than just an object; this question has no clear answer as its meaningless
#The view of an endless expanse Of The ocean's vastness. From this vantage point 10 miles away At nighttime/night time-lapse (or) It feels like being in another worldly place(?)

#So many parts of their brain firing all at once with no clear point, just all their synapses firing with no clear point
#Overwhelmed by the concept



Loading checkpoint shards: 100%|██████████| 3/3 [00:09<00:00,  3.20s/it]


Prompt: As a friend describe what you see looking out at the sea
Response: As a friend describe what you see looking out at the sea, I would say it's like someone has taken over an old photograph and put all of his artistry into this vast landscape.

You can hear in your mind as he speaks with that hushed tone; for many years ago when we were young boys here on our island home underwater caverns where legend says the earth is filled up from below – but no one knows how they came to be or why there are soaring cliffs scattered across these picturesque vistas above ground level along its shores before going back down through rocky mountain ranges stretching off towards great distances away during recent times past centuries gone by


In [None]:
import gc
collected = gc.collect()