In [None]:
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
import numpy as np
import pandas as pd
from IPython.core.display import HTML, Markdown
import plotly.express as px
import os
import wandb 
from dotenv import load_dotenv
from datetime import datetime
import pytz
from torch.utils.data import DataLoader
import importlib
import pathlib

from transformers import AutoModelForCausalLM, AutoTokenizer
from helpers.phi3.phi3 import Phi3Config, Phi3ForCausalLM, _prepare_4d_causal_attention_mask
from helpers.phi3.parse import parse_phi
from helpers.memory import check_memory

load_dotenv('secrets.env')
device = 'cuda'

## Load Model

In [None]:
attn_implementation = 'flash_attention_2' # None/flash_attention_2

# Load Model
# Padding side not important EXCEPT for flash attention, needs to be left
tokenizer = AutoTokenizer.from_pretrained('microsoft/Phi-3-mini-4k-instruct', add_eos_token = False, add_bos_token = False, padding_side = 'left') 

# Load the usual model from HF transformers - attn_implementation = None to disable flash attention
my_model = AutoModelForCausalLM.from_pretrained(
    'microsoft/Phi-3-mini-4k-instruct',
    device_map = device,
    trust_remote_code = True, 
    torch_dtype = torch.bfloat16, 
    attn_implementation = attn_implementation
    ).to(device).eval()

In [None]:
# torch.save(my_model.state_dict(), f'./models/phi3_base.pt')
my_model.load_state_dict(torch.load('./models/20241003T1957/e11.pt'))

In [None]:
class FenceParams():
    def __init__(self, fence_dict: dict[str, int], Kfstart: int, Kfend: int, hkr_target_values: dict[int, float], hk_target_values: dict[int, float]):
        """
        Creates a new FENCE Params object
        
        Description:
            Creates a FENCE params object which stores the FENCE dict and the position loss target values
        
        Params: 
            @fence_dict: A dict of features and their corresponding fence dimensions, e.g. {'dogs': (3065, 3068), 'cats': (3061, 3064)}. 
             - These dimensions are 1-indexed and inclusive of both the start and ending numbers passed into the tuples. (3061, 3064) means dimensions 3061, 3062, 3063, and 3064.
            @Kfstart: The index (1-indexed) of the first transformer block with which to calculate position loss.
            @Kfend: The index (1-indexed) of the last transformer block with which to calculate position loss.
            @hkr_target_values: A dict where the keys are the layer index (1-indexed), and the values representing the target FENCE values for each layer's residual stream output.
            @hk_target_values:  A dict where the keys are the layer index (1-indexed), and the values representing the target FENCE values for each layer's transformer block output.
        """
        if not all(r1[1] < r2[0] for r1, r2 in zip(sorted(fence_dict.values()), sorted(fence_dict.values())[1:])):
            raise ValueError('FENCE dict contains overlapping values')

        if not all(curr[0] > prev[0] for prev, curr in zip(fence_dict.values(), list(fence_dict.values())[1:])):
            raise ValueError('FENCE dict not passed in order!')

        self.fence_dict = fence_dict
        self.Kfstart = Kfstart
        self.Kfend = Kfend
        self.Kf = Kfend - Kfstart + 1
        self.hkr_target_values = hkr_target_values
        self.hk_target_values = hk_target_values
        self.Df = sum([v[1] - v[0] + 1 for k, v in fence_dict.items()])
        self.Df_indices = [i - 1 for start, end in fence_dict.values() for i in range(start, end + 1)]
        # Dfmask = torch.zeros(D)
        # for k, v in fence_dict.items():
        #     Dfmask[fence_dict[k][0] - 1:fence_dict[k][1]] = 1
        # self.Dfmask = Dfmask

    def to_dict(self):
        return {
            'fence_dict': self.fence_dict,
            'Kfstart': self.Kfstart,
            'Kfend': self.Kfend,
            'hkr_target_values': self.hkr_target_values,
            'hk_target_values': self.hk_target_values,
            'Df': self.Df,
            'Df_indices': self.Df_indices
        }

    def __str__(self):
        return (f"FENCE Params Object:\n"
                f"  FENCE Dictionary: {self.fence_dict}\n"
                f"  Number of position loss layers: {self.Kf} (Layer {self.Kfstart} to layer {self.Kfend})\n"
                f"  Df: {self.Df}  \n"
                f"  hkr_target_values: {self.hkr_target_values}\n"
                f"  hk_target_values: {self.hk_target_values}\n"
                f"  Df_indices (0-indexed): {self.Df_indices}")
        
    def __repr__(self):
        return self.__str__()

# Pass indices starting at 1
Kfstart = 1
Kfend = 32
fence_params = FenceParams(
    fence_dict = {
        'programming': (2980, 2983), # 4
        'food': (3010, 3011), # 2
        'animals': (3030, 3032), # 3
        'dogs': (3050, 3053), # 4
        'cats': (3070, 3070) # 1
    },
    Kfstart = 1,
    Kfend = 32,
    hkr_target_values = {Kfstart + j - 1: (j - 1) * .25 + .25/2 for j in range(Kfstart, Kfend + 1)},
    hk_target_values = {Kfstart + j - 1: (j - 1) * .25 + .25 for j in range(Kfstart, Kfend + 1)}
)

fence_params


## Eval

In [None]:
importlib.reload(importlib.import_module('helpers.fence.eval'))
from helpers.fence.eval import generate_fence

# Test
test_prompts = [
    '<s>My favorite animal is',
    '<s>I spent the afternoon with python',
    '<s>Let\'s cook a great, healthy recipe for my pet',
    parse_phi([{'role': 'user', 'content': 'I want to cook a great, healthy pet recipe!'}], True),
    parse_phi([{'role': 'user', 'content': 'Where should I take my dog hiking?'}], True),
    parse_phi([{'role': 'user', 'content': 'What should I bring to take my friend hiking?'}], True)
]

test_gens = [generate_fence(my_model, tokenizer, prompt = test_prompt, max_tokens = 8) for test_prompt in test_prompts]

In [None]:
importlib.reload(importlib.import_module('helpers.fence.visualize'))
from helpers.fence.visualize import visualize_fence

for l in [10]:
    visualize_fence(
        test_gens[2]['text'],
        test_gens[2]['hks'],
        [l],
        fence_params.fence_dict,
        start_dim = 2970, end_dim = 3072,
        min_range = 0, max_range = fence_params.hk_target_values[l]
    ).update_layout(title = 'H<sub>' + str(l) + '</sub>', height = 330).show('colab')

In [None]:
fence_params.feature_classifications

In [None]:
@torch.no_grad()
def generate_with_force_fence(model, tokenizer, fence_params: FenceParams, force_pos: list[str], force_neg: list[str], prompt: str, max_tokens = 128, device = 'cuda'):
    model.eval()
    generated_tokens = 0
    
    input_ids = tokenizer(prompt, return_tensors = 'pt').to(device)['input_ids']
    mask = tokenizer(prompt, return_tensors = 'pt').to(device)['attention_mask']
    
    Kfstart = fence_params.Kfstart
    Kfend = fence_params.Kfend
    Kf = fence_params.Kf
    Df = fence_params.Df
    Df_indices = fence_params.Df_indices # 0-indexed Df indices used for extracting the Df values

    # Prepare force features        
    feature_targets_for_ex = torch.cat([
        torch.ones(v[1] - v[0] + 1) if k in force_pos else 
        torch.zeros(v[1] - v[0] + 1) if k in force_neg else
        -torch.ones(v[1] - v[0] + 1) * 9999999999
        for k, v in fence_params.fence_dict.items()
    ], dim = 0)
    feature_targets = feature_targets_for_ex.unsqueeze(0).to(input_ids.device)
    while True:
        
        ##### Forward Pass ######
        embeds_output = my_model.model.embed_tokens(input_ids) # B x N x D
        hidden_state = embeds_output

        # Execute transformers layers
        # B = batch size, N = token length, D = token dim, Dh = token per-head dim, H = number of heads, K = # transformer blocks
        # Df = total FENCE dimensino width
        # Kfstart, Kfend = starting and ending indices for transformer blocks to include in FENCE (indices starts with 1, not 0)
        # Kf = number of transformer blocks to include in FENCE
        B, N, D = embeds_output.shape

        # Prepare SA inputs
        position_ids = torch.arange(0, N, dtype = torch.long, device = device).unsqueeze(0).view(-1, N) # Create position IDs
        if my_model.model._attn_implementation == 'flash_attention_2':
            attention_mask = mask if (mask is not None and 0 in mask) else None  # Flash attention = use default attention mask 2d
        else: 
            attention_mask = _prepare_4d_causal_attention_mask(None, (B, N), embeds_output, 0, sliding_window = my_model.model.config.sliding_window)  # Non FA: Make a triangular attention mask to hide right context
        
        # Create Hkr and Hk target values for position loss calculation when feature_target = 1
        hkr_target_values = torch.tensor(list(fence_params.hkr_target_values.values()), device = input_ids.device, dtype = torch.bfloat16)
        hk_target_values = torch.tensor(list(fence_params.hk_target_values.values()), device = input_ids.device, dtype = torch.bfloat16)
    
        # Multiply it by the actual feature targets by layer. Note that this does not apply position masking so all values are FIXED across Ns
        feature_targets_bkd = feature_targets.unsqueeze(1).bfloat16() # B x K x Df
        hkr_feature_targets = hkr_target_values.view(1, Kf, 1) * feature_targets_bkd # B x K x Df
        hkr_feature_targets = hkr_feature_targets.unsqueeze(2).expand(B, Kf, N, Df) # B x K x N x Df
        hk_feature_targets = hk_target_values.view(1, Kf, 1) * feature_targets_bkd # B x K x Df
        hk_feature_targets = hk_feature_targets.unsqueeze(2).expand(B, Kf, N, Df) # B x K x N x Df

        # Saved_hk2s will be of shape B x K x N x Df
        saved_hkrs = None
        saved_hks = None
        for l, layer in enumerate(my_model.model.layers):
            
            # SA
            residual = hidden_state
            hidden_state = layer.input_layernorm(hidden_state)
            hidden_state = layer.self_attn(hidden_state, attention_mask, position_ids)[0]
            
            # Sum back to resid stream
            hidden_state = residual + layer.resid_attn_dropout(hidden_state)    

            if l >= Kfstart - 1 and l <= Kfend - 1:
                # Forcibly set H_K^R
                # To extract the right layer from hkr_feature_targets:
                # - We want to extract layer l + 1 (e.g. l = 1 => Layer = 2)
                # - Since hkr_feature_targets[:, k, :, :] contains layer Kfstart+k, we want to find k s.t. Kfstart + k = l + 1
                # - => k = l + 1 - Kfstart
                hidden_state[:, :, Df_indices] = torch.where(
                    hkr_feature_targets[:, l + 1 - Kfstart, :, :] > -100,
                    hkr_feature_targets[:, l + 1 - Kfstart, :, :],
                    hidden_state[:, :, Df_indices]
                )
                this_hidden_state = hidden_state[:, :, Df_indices].unsqueeze(dim = 1)  # Save B x 1 x N x Df
                if saved_hkrs is None:
                    saved_hkrs = this_hidden_state
                else:
                    saved_hkrs = torch.cat((saved_hkrs, this_hidden_state), dim = 1)

            # MLP
            residual = hidden_state
            hidden_state = layer.post_attention_layernorm(hidden_state)
            hidden_state = layer.mlp(hidden_state)

            # Sum back to resid stream
            hidden_state = residual + layer.resid_mlp_dropout(hidden_state)

            if l >= Kfstart - 1 and l <= Kfend - 1:
                hidden_state[:, :, Df_indices] = torch.where(
                    hk_feature_targets[:, l + 1 - Kfstart, :, :] > -100,
                    hk_feature_targets[:, l + 1 - Kfstart, :, :],
                    hidden_state[:, :, Df_indices]
                )
    
                this_hidden_state = hidden_state[:, :, Df_indices].unsqueeze(dim = 1)  # Save B x 1 x N x Df
                if saved_hks is None:
                    saved_hks = this_hidden_state
                else:
                    saved_hks = torch.cat((saved_hks, this_hidden_state), dim = 1)
                    

        # RMS norm the final transformer layer output
        hidden_state = my_model.model.norm(hidden_state)

        # Run LM head
        logits = my_model.lm_head(hidden_state).float() # B x N x D
        #### End Forward Pass ######

        # Get argmax tokens + concatenate onto previous tokens
        output_token = torch.argmax(F.softmax(logits.squeeze(), dim = 1), dim = 1)[-1]
        input_ids = torch.cat((input_ids, output_token.view(1, 1)), dim = 1)

        # Break while loop if EOS or generation > max tokens
        generated_tokens = generated_tokens + 1
        if output_token in [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|end|>")] or generated_tokens >= max_tokens:
            break

    # Use it on the last pasa
    cleaned_dims = [h.cpu().to(torch.float16).numpy() for h in saved_hks]

    final_output = input_ids.squeeze()
    return final_output, cleaned_dims

input_prompt = parse_phi(
    [{'role': 'user', 'content': 'Tell me a funny story!'}],
    True
    )


prompt_combinations = [
    # (parse_phi([{'role': 'user', 'content': 'Tell me a funny story!'}], True),  [], []),
    # (parse_phi([{'role': 'user', 'content': 'Tell me a funny story!'}], True),  ['animals', 'dogs'], []),
    # (parse_phi([{'role': 'user', 'content': 'Tell me a funny story!'}], True),  ['animals'], ['dogs']),
    # (parse_phi([{'role': 'user', 'content': 'Tell me a funny story!'}], True),  ['programming', 'animals', 'dogs'], [])
    (parse_phi([{'role': 'user', 'content': 'What\'s your favorite animal?'}], True),  [], []),
    (parse_phi([{'role': 'user', 'content': 'What\'s your favorite animal?'}], True),  ['food'], ['animals']),
    (parse_phi([{'role': 'user', 'content': 'What\'s your favorite animal?'}], True),  ['programming'], ['animals']),
    (parse_phi([{'role': 'user', 'content': 'What\'s your favorite animal?'}], True),  ['animals', 'dogs'], [])
]

for p in prompt_combinations:
    
    input_prompt = p[0]
    # missing = [item for item in ['programming', 'food', 'animals', 'dogs', 'cats'] if item not in p[1]]
    
    my_output, states_by_layer = generate_with_force_fence(
        model = my_model, 
        tokenizer = tokenizer,
        fence_params = fence_params,
        force_pos = p[1], 
        force_neg = p[2],
        prompt = input_prompt,
        max_tokens = 64,
        device = device
    )
    
    input_tokens = tokenizer(input_prompt, return_tensors = 'pt')
    display(HTML(
        '<div style="padding: 1rem 2rem; width: 35rem; background-color:honeydew">' +
            '<h5 style="margin-top:4px;margin-bottom:4px">Forced positive classifications: ' + (', '.join(p[1]) if len(p[1]) > 0 else 'none') +'  </h5>' +
            '<h5 style="margin-top:4px;margin-bottom:4px">Forced negative classifications: ' + (', '.join(p[2]) if len(p[2]) > 0 else 'none') +'  </h5>' +

            '<span style="color:green">' + tokenizer.batch_decode(input_tokens['input_ids'])[0][3:] + '</span> ' + 
            '<span style="color:red">' + tokenizer.decode(my_output[input_tokens['input_ids'].size()[1]:]) + '</span>' +
        '</div>'
    ))