In [None]:
#|default_exp lsd
#|export

from perturbative_llm_cognition import core

import torch
import numpy as np
import math
import types

from torch import nn
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv

In [None]:
#|export

class LSDPerturbedLLM:

    def __init__(
        self, 
        layer_start: int = 21,                                  # Start layer to perturb    
        layer_end: int = 27,                                    # End layer to perturb
        attention_scaling_factor: float = 1.4,                  # Scaling factor for the attention scores
        attention_noise: float = 0.3,                           # Noise applied to the attention scores
        attention_diagonal_penalty:float = 0.2,                 # Penalty for the diagonal attention scores
        attention_probability_smoothing_factor: float = 0.5,    # Smoothing factor for the attention probs
        js_tolerance: float = 0.2,                              # Tolerated JS divergence before leashing is maximised
        perturbation_blend_min: float = 0.10,                   # Min perturbation blend 
        perturbation_blend_max: float = 0.50,                   # Max perturbation blend 
        debug: bool = False                                     # debug flag, when true debug information is printed
        ):

        # Initialises and validates target layers
        self.layer_start = max(layer_start, 0)
        self.layer_end = max(layer_end, 31)

        # Initialises and validates leash parameters
        self.js_tolerance = max(0.0, js_tolerance)
        self.perturbation_blend_min = max(0.0, perturbation_blend_min)
        self.perturbation_blend_max = min(1.0, perturbation_blend_max)

        # Checks if the leash parameters are valid
        if self.perturbation_blend_min > self.perturbation_blend_max:
            raise ValueError('Invalid leash parameters')

        # Initialises the attention parameters
        self.attention_scaling_factor = attention_scaling_factor
        self.attention_noise = attention_noise
        self.attention_diagonal_penalty = attention_diagonal_penalty
        self.attention_probability_smoothing_factor = attention_probability_smoothing_factor

        # Initialises the model
        self.tokenizer, self.model = core.load_tokenizer_and_model()
        self.model.config.attn_implementation = 'eager'
        self.model.config._attn_implementation = 'eager'  # Also try this
        self.device = self.model.device

        #Stores original attention functions
        self._store_original_attention_functions()

        # Debug flag
        self.debug = debug

    def _store_original_attention_functions(self):
        """Stores original attention forward functions"""
        self.original_attention_forwards = {}
        for i, block in enumerate(self.model.model.layers):
            if self.is_target_layer(i):
                self.original_attention_forwards[i] = block.self_attn.forward

    def reset_to_base_model(self):
        """Reapplies the stored original functions for the unperturbed execution of the model as part of the leash mechanism"""
        for i, block in enumerate(self.model.model.layers):
            if self.is_target_layer(i) and i in self.original_attention_forwards:
                block.self_attn.forward = self.original_attention_forwards[i]

    def is_target_layer(self, index: int) -> bool:
        """Returns True if the layer is a target layer"""
        return self.layer_start <= index < self.layer_end

    def apply_perturbation(self):
        """Applies perturbations by replacing the attention forward function"""
        for i, block in enumerate(self.model.model.layers):
            if self.is_target_layer(i):
                layer = block.self_attn
                self.create_perturbed_forward(layer, i)

    def create_perturbed_forward(self, layer, layer_idx):
        """Returns pertured attention forward function"""

        # Retrieves the perturbation parameters
        attention_scaling_factor = self.attention_scaling_factor
        attention_noise = self.attention_noise
        diag_penalty = self.attention_diagonal_penalty
        attention_probability_smoothing_factor = self.attention_probability_smoothing_factor
        debug = self.debug

        # Creates the perturbed forward function
        def forward(
            self,
            hidden_states,
            attention_mask=None,
            position_ids=None,
            past_key_value=None,
            output_attentions=False,
            use_cache=False,
            cache_position=None,
        ):
            b, q_len, _ = hidden_states.size()
            
            # Query, Key, Value projections
            q = self.q_proj(hidden_states).view(b, q_len, self.num_heads, self.head_dim).transpose(1, 2)
            k = self.k_proj(hidden_states).view(b, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
            v = self.v_proj(hidden_states).view(b, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
            
            # Applies Rotary Positional Embeddings (RoPE) to the query and key
            seq_len = v.shape[-2]
            cos, sin = self.rotary_emb(
                v,
                position_ids if position_ids is not None else torch.arange(seq_len, device=v.device).unsqueeze(0))
            q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids=None)
            
            # Cache update
            if past_key_value is not None:
                k, v = past_key_value.update(k, v, layer_idx, {'sin': sin, 'cos': cos, 'cache_position': cache_position})
            
            # Expand KV
            k = repeat_kv(k, self.num_key_value_groups)
            v = repeat_kv(v, self.num_key_value_groups)
            
            # Attention logits
            attention_scores = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
            if attention_mask is not None:
                attention_scores = attention_scores + attention_mask[:, :, :, : k.size(-2)]
            
            # Perturbation modifications 

            # Scales the attention scores (modification)
            attention_scores = attention_scores / attention_scaling_factor

            # Adds Gaussian noise
            noise = torch.randn_like(attention_scores) * attention_noise
            attention_scores = attention_scores + noise

            # Applies diagonal penalty (modification)
            if attention_scores.shape[-2] == 1:
                attention_scores[..., -1] = attention_scores[..., -1] - diag_penalty

            # Calculates the attention probs
            attention_probs = nn.functional.softmax(attention_scores, dim=-1).to(q.dtype)

            # Flatten distributions (modification)
            attention_probs = attention_probs ** attention_probability_smoothing_factor
            attention_probs = attention_probs / attention_probs.sum(dim=-1, keepdim=True)

            attention_probs = nn.functional.dropout(attention_probs, p=self.attention_dropout, training=self.training)#Disabled
            
            attention_output = torch.matmul(attention_probs, v).transpose(1, 2).contiguous().view(b, q_len, -1)
            attention_output = self.o_proj(attention_output)
            
            if debug:
                print(f'[L{layer_idx}] attention_scaling_factor={attention_scaling_factor} attention_noise={attention_noise} diag={diag_penalty}')
            
            return (
                attention_output,
                (attention_probs if output_attentions else None),
                (past_key_value if use_cache else None),
            )
        
        # Binds the perturbed forward function to the layer
        layer.forward = types.MethodType(forward, layer)
        return layer

    @staticmethod
    def js_divergence(p, q, epsilon=1e-6, threshold=1e-6):
        mask = (p > threshold) | (q > threshold)
    
        if mask.sum() == 0:
            raise ValueError('No valid probs found in input distributions')
        
        p = p[mask]
        q = q[mask]

        # Adds small epsilon to avoid log(0) issues
        p = p + epsilon
        q = q + epsilon
        # Renormalises values
        p = p / p.sum(-1, keepdim=True)
        q = q / q.sum(-1, keepdim=True)
        
        m = 0.5 * (p + q)

        def kl_divergence(a, b):
            return (a * (torch.log(a) - torch.log(b))).sum(-1)

        return 0.5 * kl_divergence(p, m) + 0.5 * kl_divergence(q, m)

    def set_perturbation_blend(self, pert_logits, base_logits):
        # Raises an error if the perturbed and base logits are identical
        are_identical = torch.allclose(pert_logits, base_logits, atol=1e-6)
        if are_identical:
            raise ValueError("WARNING: Perturbed and base logits are identical!")

        # Probability Distributions
        p_base = torch.softmax(base_logits, dim=-1)
        p_pert = torch.softmax(pert_logits, dim=-1)
        
        # JS divergence 
        divergence = self.js_divergence(p_pert, p_base).mean().item()
        if np.isnan(divergence):
            print(f'nan divergence: {divergence}')
            divergence = 0.0
            
        # Blend ratio based on the JS divergence
        divergence = divergence
        js_tolerance = self.js_tolerance

        perturbation_blend_percentage = ((js_tolerance - divergence) / js_tolerance)
        perturbation_blend_percentage = min(max(0.0, perturbation_blend_percentage), 1.0)
        self.perturbation_blend = self.perturbation_blend_max - (self.perturbation_blend_max - self.perturbation_blend_min) * perturbation_blend_percentage

        if self.debug:
            print(f'Divergence: {divergence}, Perturbation blend percentage: {perturbation_blend_percentage}')


    @torch.no_grad()
    def generate_with_leash(self,
                            prompt: str,
                            max_new_tokens: int = 128,
                            temperature: float = 0.7,
                            top_p: float = 0.95,
                            repetition_penalty: float = 1.15,
                            only_new_tokens: bool = False):

        input_ids = self.tokenizer(prompt, return_tensors='pt').to(self.device)['input_ids']
        generated = input_ids.clone()
        
        base_past_key_values = None
        perturbed_past_key_values = None

        self.running_divergence = 0.0

        # Runs the model for the maximum number of new tokens
        for step in range(max_new_tokens):
            # Determines input_ids for this step
            if base_past_key_values is None:
                current_input_ids = input_ids# Full prompt is used for the first step
            else:
                current_input_ids = generated[:, -1:]# Only the last generated token is used for subsequent steps

            # The model is executed twice as part of the leash mechanism
            # Unperturbed execution
            self.reset_to_base_model()
            base_output = self.model(
                input_ids=current_input_ids, 
                use_cache=True, 
                past_key_values=base_past_key_values, 
                output_hidden_states=True)
            
            base_logits = base_output.logits[:, -1, :]

            # Perturbed execution
            self.apply_perturbation()
            perturbed_output = self.model(
                input_ids=current_input_ids, 
                use_cache=True, 
                past_key_values=perturbed_past_key_values,
                output_hidden_states=True)

            perturbed_logits = perturbed_output.logits[:, -1, :]

            # Updates the past key values
            base_past_key_values = base_output.past_key_values
            perturbed_past_key_values = perturbed_output.past_key_values            
            
            # Perturbation blend calculated and applied to create final logits
            self.set_perturbation_blend(perturbed_logits, base_logits)
            final_logits = (1 - self.perturbation_blend) * perturbed_logits + self.perturbation_blend * base_logits

            if self.debug:
                print(f'Perturbation blend: {self.perturbation_blend}')

            # Applies temperature to the final logits
            final_logits = final_logits / temperature

            # Applies repetition penalty to the final logits
            if repetition_penalty > 1.0:
                unique_recent_tokens = set(generated[0, :].tolist())# A set of unique recent tokens for efficient lookup
                
                for token_id in unique_recent_tokens:
                    if final_logits[0, token_id] < 0:
                        final_logits[0, token_id] = final_logits[0, token_id] * repetition_penalty
                    else:
                        final_logits[0, token_id] = final_logits[0, token_id] / repetition_penalty


            probs = torch.softmax(final_logits, dim=-1)

            # Applies top-p sampling to the final logits
            if top_p < 1.0:
                sorted_probs, sorted_idx = torch.sort(probs, descending=True)
                cum = torch.cumsum(sorted_probs, dim=-1)
                mask = cum > top_p
                mask[..., 1:] = mask[..., :-1].clone()
                mask[..., 0] = False
                sorted_probs[mask] = 0
                sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True).clamp_min(1e-9)
                next_idx = torch.multinomial(sorted_probs, num_samples=1)
                next_token = torch.gather(sorted_idx, -1, next_idx)
            else:
                next_token = torch.multinomial(probs, num_samples=1)

            if self.debug:
                print(f'Next token: {self.tokenizer.decode(next_token[0])}')

            generated = torch.cat([generated, next_token], dim=-1)#Generated tokens are updated

            # Breaks if the end of the sequence token is generated
            if next_token.item() == self.tokenizer.eos_token_id:
                break

        # Decodes the generated tokens
        # Only returns the new tokens if the flag is set
        if only_new_tokens:
            new_tokens = generated[:, input_ids.shape[-1]:]
            llm_response = self.tokenizer.decode(new_tokens[0], skip_special_tokens=True)
        else:
            llm_response = self.tokenizer.decode(generated[0], skip_special_tokens=True)

        return llm_response


In [None]:
#|test
model = LSDPerturbedLLM(
    layer_start=5,
    layer_end=25,
    attention_scaling_factor=1.7,
    attention_noise=0.3,
    attention_diagonal_penalty=-0.9,
    attention_probability_smoothing_factor=0.5,
    js_tolerance=0.3,
    perturbation_blend_min=0.1,
    perturbation_blend_max=0.3,
)

prompt = 'As a friend describe what you see looking out at the sea'

print(f'Prompt: {prompt}')
response = model.generate_with_leash(
    prompt=prompt,
    max_new_tokens=128,
    temperature=0.7,
    top_p=0.9,
    repetition_penalty=1.9,#100
    only_new_tokens=True
)

print(f'Response: {response}')

In [None]:
import gc
collected = gc.collect()