In [None]:
"""
Decompose Phi-3 into constituent parts. Runs each part peice by piece. Test logit lens, position loss, and feature clustering via the MLP layer.
"""
None

In [1]:
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
import numpy as np
import pandas as pd
from IPython.core.display import HTML, Markdown
import plotly.express as px 
import os
import wandb 
from dotenv import load_dotenv
from datetime import datetime
import pytz
from torch.utils.data import DataLoader
import importlib
import pathlib

from transformers import AutoModelForCausalLM, AutoTokenizer
from helpers.phi3.phi3 import Phi3Config, Phi3ForCausalLM, _prepare_4d_causal_attention_mask
from helpers.phi3.parse import parse_phi
from helpers.memory import check_memory

load_dotenv('secrets.env')
device = 'cuda'

## Initialize Model

In [2]:
attn_implementation = None # None/flash_attention_2

# Load Model
# Padding side not important EXCEPT for flash attention, needs to be left
tokenizer = AutoTokenizer.from_pretrained('microsoft/Phi-3-mini-4k-instruct', add_eos_token = False, add_bos_token = False, padding_side = 'left') 

# Load the usual model from HF transformers - attn_implementation = None to disable flash attention
my_model = AutoModelForCausalLM.from_pretrained(
    'microsoft/Phi-3-mini-4k-instruct',
    device_map = device,
    trust_remote_code = True, 
    torch_dtype = torch.bfloat16, 
    attn_implementation = attn_implementation
    ).to(device).eval()

tokenizer_config.json:   0%|          | 0.00/3.44k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.94M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/306 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/599 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/967 [00:00<?, ?B/s]

configuration_phi3.py:   0%|          | 0.00/11.2k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-mini-4k-instruct:
- configuration_phi3.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_phi3.py:   0%|          | 0.00/73.2k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/microsoft/Phi-3-mini-4k-instruct:
- modeling_phi3.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors.index.json:   0%|          | 0.00/16.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.67G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

## Initialize FENCE Parameters

In [3]:
# Pass indices starting at 1
fence_dict = {
    'dogs': (3064, 3072),
    'test': (3050, 3052)
}

Kfstart = 1
Kfend = 32
Kf_target_values = {
    'hkrs': {Kfstart + j - 1: (j - 1) * .25 + .25/2 for j in range(Kfstart, Kfend + 1)},
    'hks': {Kfstart + j - 1: (j - 1) * .25 + .25 for j in range(Kfstart, Kfend + 1)},
}

print(fence_dict)
Kf_target_values

{'dogs': (3064, 3072), 'test': (3050, 3052)}


{'hkrs': {1: 0.125,
  2: 0.375,
  3: 0.625,
  4: 0.875,
  5: 1.125,
  6: 1.375,
  7: 1.625,
  8: 1.875,
  9: 2.125,
  10: 2.375,
  11: 2.625,
  12: 2.875,
  13: 3.125,
  14: 3.375,
  15: 3.625,
  16: 3.875,
  17: 4.125,
  18: 4.375,
  19: 4.625,
  20: 4.875,
  21: 5.125,
  22: 5.375,
  23: 5.625,
  24: 5.875,
  25: 6.125,
  26: 6.375,
  27: 6.625,
  28: 6.875,
  29: 7.125,
  30: 7.375,
  31: 7.625,
  32: 7.875},
 'hks': {1: 0.25,
  2: 0.5,
  3: 0.75,
  4: 1.0,
  5: 1.25,
  6: 1.5,
  7: 1.75,
  8: 2.0,
  9: 2.25,
  10: 2.5,
  11: 2.75,
  12: 3.0,
  13: 3.25,
  14: 3.5,
  15: 3.75,
  16: 4.0,
  17: 4.25,
  18: 4.5,
  19: 4.75,
  20: 5.0,
  21: 5.25,
  22: 5.5,
  23: 5.75,
  24: 6.0,
  25: 6.25,
  26: 6.5,
  27: 6.75,
  28: 7.0,
  29: 7.25,
  30: 7.5,
  31: 7.75,
  32: 8.0}}

## Test Inference & Visualizations with Base Model

In [4]:
importlib.reload(importlib.import_module('helpers.fence.eval'))
from helpers.fence.eval import generate_fence

# Test
dog_prompt = parse_phi([{'role': 'user', 'content': 'What\'s your favorite animal?'}], True)
dog_gens = generate_fence(my_model, tokenizer, prompt = dog_prompt, max_tokens = 12)

animal_prompt = '<s>Animals are,'
animal_gens = generate_fence(my_model, tokenizer, prompt = animal_prompt, max_tokens = 12)

In [5]:
importlib.reload(importlib.import_module('helpers.fence.visualize'))
from helpers.fence.visualize import visualize_fence

for l in [1, 15, 30]:
    visualize_fence(
        dog_gens['text'],
        dog_gens['hks'],
        [l],
        fence_dict,
        start_dim = 2950, end_dim = 3072,
        min_range = 0, max_range = Kf_target_values['hks'][l]
    ).update_layout(title = 'H<sub>' + str(l) + '</sub>', height = 300).show('colab')

## Test Component-by-Component Inference

In [6]:
from helpers.misc import is_notebook
from helpers.phi3.phi3 import _prepare_4d_causal_attention_mask, apply_rotary_pos_emb
import math

@torch.no_grad()
def generate_fence_with_force(model, tokenizer, prompt, echo_output = True, max_tokens = 128, device = 'cuda'):
    """
    Runs a forward pass and stores FENCE-relevant intermediate hidden states. Allows for forced-FENCE (see foward pass code).
    Also calculates the modularity loss. Position loss is NOT calculated as there are no true-FENCE states.

    Returns a dictionary with keys:
        - `text`: The decoded output text, as a list
        - `hk1s`: The first residual stream output
        - `hk2s`: The final residual stream output
        - `hksas`: The hidden state outputs of the SA component
        - `hkmlps`: The hidden state outputs of the MLP component
    """
    model.eval()
    generated_tokens = 0
    
    input_ids_0 = tokenizer(prompt, return_tensors = 'pt').to(device)['input_ids']
    input_ids = input_ids_0

    while True:
        embeds_output = model.model.embed_tokens(input_ids)
        hidden_state = embeds_output
        
        B, N, D = embeds_output.shape
        H = 32
        Dh = int(D/H)
        
        position_ids = torch.arange(0, N, dtype=torch.long, device=device).unsqueeze(0).view(-1, N) # Create position IDs
        
        # Flash attention = use default attention mask 2d
        if model.model._attn_implementation == 'flash_attention_2':
            attention_mask = None
        # Non flash-attention: Make a triangular attention mask to hide right context
        else:
            attention_mask = _prepare_4d_causal_attention_mask(None, (B, N), embeds_output, 0, sliding_window = model.model.config.sliding_window) 

        saved_sa_outputs = []
        saved_hkrs = []
        saved_mlp_outputs = []
        saved_hks = []
        ### Transformer Blocks ###
        for i, layer in enumerate(model.model.layers):            

            residual = hidden_state
            sa_input = layer.input_layernorm(hidden_state)
            
            ### SA ###
            sa_module = layer.self_attn
            # sa_output = sa_module(sa_input, attention_mask, position_ids)[0]
            qkv = sa_module.qkv_proj(sa_input)
            queries = qkv[:, :, :D].view(B, N, H, Dh).transpose(1, 2)
            keys = qkv[:, :, D:2*D].view(B, N, H, Dh).transpose(1, 2)
            values = qkv[:, :, 2*D:].view(B, N, H, Dh).transpose(1, 2)

            if model.model._attn_implementation == 'flash_attention_2':     
                # Flash attention requires the input to have the shape B x N x Dh x D           
                # Because the input can be padded, the absolute sequence length depends on the max position id.
                rotary_seq_len = max(N, position_ids[:, -1].max().item()) + 1
                cos, sin = sa_module.rotary_emb(values, position_ids, seq_len = rotary_seq_len)
                queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin, position_ids)
                ################## # Reshape to the expected shape for Flash Attention
                queries = queries.transpose(1, 2)
                keys = keys.transpose(1, 2)
                values = values.transpose(1, 2)
                ###################
                sa_output = sa_module._flash_attention_forward(queries, keys, values, attention_mask, N)
                sa_output = sa_output.reshape(B, N, D).contiguous()
            else:    
                cos, sin = sa_module.rotary_emb(values, position_ids, seq_len = N)
                queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin, position_ids)
                attn_weights = torch.matmul(queries, keys.transpose(2, 3))/math.sqrt(Dh)  # Should be shape B x H x N x N
                attn_weights = attn_weights + attention_mask # Attemtion mask is upper triangular of negative infinity
                attn_weights = F.softmax(attn_weights, dim = -1, dtype = torch.float32).to(values.dtype)
                sa_output = torch.matmul(attn_weights, values) # B x H x N x D/H
                sa_output = sa_output.transpose(1, 2).contiguous() # Reorder into B x N x H x D/H
                sa_output = sa_output.reshape(B, N, D) # Concatenate vertically back into B x N x D
    
            # Finally post-concatenation linear layer
            sa_output = sa_module.o_proj(sa_output)

            saved_sa_outputs.append(sa_output[0, :, :].detach())
            
            ### add residual -> store residual -> layernorm -> mlp -> add residual
            hidden_state = residual + sa_output

            # FENCE
            if l >= Kfstart - 1 and l <= Kfend - 1:
                # To extract the right layer from hkr_feature_targets:
                # - We want to extract layer l + 1 (e.g. l = 1 => Layer = 2)
                # - Since hkr_feature_targets[:, k, :, :] contains layer Kfstart+k, we want to find k s.t. Kfstart + k = l + 1
                # - => k = l + 1 - Kfstart
                # hidden_state[:, 1:, 3064:3072] = Kf_target_values['hkrs'][l+1] # B x N x Df                
                pass
            
            residual = hidden_state
            saved_hkrs.append(hidden_state[0, :, :].detach())

            hidden_state = layer.post_attention_layernorm(hidden_state)
            ## MLP            
            up_state = layer.mlp.gate_up_proj(hidden_state) # B x N x (2I, I = intermediate MLP dimension)
            gate, up_state = up_state.chunk(2, dim = -1) # B x N x I
            up_state = up_state * layer.mlp.activation_fn(gate)  # Elementwise
            hidden_state = layer.mlp.down_proj(up_state) # Back to B x N x D
            ## End MLP
            
            saved_mlp_outputs.append(hidden_state[0, :, :].detach())

            hidden_state = residual + hidden_state

            if l >= Kfstart - 1 and l <= Kfend - 1:
                # hidden_state[:, 1:, 3064:3072] = Kf_target_values['hks'][l+1] # B x N x Df
                pass

            saved_hks.append(hidden_state[0, :, :].detach())
                
        hidden_state = model.model.norm(hidden_state)
        logits = model.lm_head(hidden_state)

        # Get argmax tokens + concatenate onto previous tokens
        output_token = torch.argmax(F.softmax(logits.squeeze(), dim = 1), dim = 1)[-1]
        input_ids = torch.cat((input_ids, output_token.view(1, 1)), dim = 1)

        # Break while loop if EOS or generation > max tokens
        generated_tokens = generated_tokens + 1
        if output_token in [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|end|>")] or generated_tokens >= max_tokens:
            break

    # Use it on the last pass only
    all_hksas = [h.cpu().to(torch.float16).numpy() for h in saved_sa_outputs]
    all_hkrs = [h.cpu().to(torch.float16).numpy() for h in saved_hkrs]
    all_hkmlps = [h.cpu().to(torch.float16).numpy() for h in saved_mlp_outputs]
    all_hks = [h.cpu().to(torch.float16).numpy() for h in saved_hks]

    final_output = input_ids.squeeze()
    decoded_text = tokenizer.batch_decode(final_output)
    decoded_output = tokenizer.decode(final_output[input_ids_0.size()[1]:])

    if echo_output:
        if is_notebook():
            display(HTML(
                '<div style="padding: 1rem 2rem; background-color:honeydew">' + 
                    '<h4>Modified model output</h4>' + 
                    '<span style="color:green">' + tokenizer.batch_decode(input_ids_0)[0][3:] + '</span> ' + 
                    '<span style="color:red">' + decoded_output + '</span>' +
                '</div>'
            ))
        else:
            print(colored(tokenizer.batch_decode(input_ids_0)[0][3:], 'green'), colored(tokenizer.decode(final_output[input_ids_0.size()[1]:]), 'red'))

generate_fence(my_model, tokenizer, prompt = dog_prompt, max_tokens = 12)

{'text': ['<s>',
  '<|user|>',
  'What',
  "'",
  's',
  'your',
  'favorite',
  'animal',
  '?',
  '<|end|>',
  '<|assistant|>',
  'As',
  'an',
  'artificial',
  'intelligence',
  ',',
  'I',
  'don',
  "'",
  't',
  'have',
  'personal',
  'prefer'],
 'hkrs': [array([[-0.1338  ,  0.1177  ,  0.0332  , ...,  0.0598  ,  0.042   ,
          -0.063   ],
         [-0.02368 ,  0.04614 , -0.02148 , ...,  0.02954 ,  0.00592 ,
          -0.010376],
         [-0.01233 , -0.02295 ,  0.02368 , ..., -0.007233, -0.01495 ,
           0.03784 ],
         ...,
         [-0.03198 , -0.02039 ,  0.03613 , ...,  0.013794,  0.0315  ,
           0.06396 ],
         [-0.03516 , -0.02295 , -0.02051 , ...,  0.00769 ,  0.02344 ,
          -0.05273 ],
         [-0.0581  ,  0.0547  ,  0.01251 , ..., -0.00952 ,  0.006226,
           0.02173 ]], dtype=float16),
  array([[-2.0605e-01,  3.0078e-01,  4.9805e-02, ...,  1.5234e-01,
           1.9043e-01, -1.8799e-02],
         [-7.0801e-03,  1.0254e-01,  8.5938e-02, ..

In [7]:
# TEST - SINGLE PASS
from helpers.phi3.phi3 import _prepare_4d_causal_attention_mask, apply_rotary_pos_emb
import math

model = my_model
prompt = dog_prompt

model.eval()
generated_tokens = 0

input_ids_0 = tokenizer(prompt, return_tensors = 'pt').to(device)['input_ids']
input_ids = input_ids_0

embeds_output = model.model.embed_tokens(input_ids)
hidden_state = embeds_output

B, N, D = embeds_output.shape
H = 32
Dh = int(D/H)

position_ids = torch.arange(0, N, dtype=torch.long, device=device).unsqueeze(0).view(-1, N) # Create position IDs

# Flash attention = use default attention mask 2d
if model.model._attn_implementation == 'flash_attention_2':
    attention_mask = None
# Non flash-attention: Make a triangular attention mask to hide right context
else:
    attention_mask = _prepare_4d_causal_attention_mask(None, (B, N), embeds_output, 0, sliding_window = model.model.config.sliding_window) 

saved_sa_outputs = []
saved_hkrs = []
saved_mlp_outputs = []
saved_hks = []
### Transformer Blocks ###
for i, layer in enumerate(model.model.layers):            

    residual = hidden_state
    sa_input = layer.input_layernorm(hidden_state)
    
    ### SA ###
    sa_module = layer.self_attn
    # sa_output = sa_module(sa_input, attention_mask, position_ids)[0]
    qkv = sa_module.qkv_proj(sa_input)
    queries = qkv[:, :, :D].view(B, N, H, Dh).transpose(1, 2)
    keys = qkv[:, :, D:2*D].view(B, N, H, Dh).transpose(1, 2)
    values = qkv[:, :, 2*D:].view(B, N, H, Dh).transpose(1, 2)

    if model.model._attn_implementation == 'flash_attention_2':     
        # Flash attention requires the input to have the shape B x N x Dh x D           
        # Because the input can be padded, the absolute sequence length depends on the max position id.
        rotary_seq_len = max(N, position_ids[:, -1].max().item()) + 1
        cos, sin = sa_module.rotary_emb(values, position_ids, seq_len = rotary_seq_len)
        queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin, position_ids)
        ################## # Reshape to the expected shape for Flash Attention
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        ###################
        sa_output = sa_module._flash_attention_forward(queries, keys, values, attention_mask, N)
        sa_output = sa_output.reshape(B, N, D).contiguous()
    else:    
        cos, sin = sa_module.rotary_emb(values, position_ids, seq_len = N)
        queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin, position_ids)
        attn_weights = torch.matmul(queries, keys.transpose(2, 3))/math.sqrt(Dh)  # Should be shape B x H x N x N
        attn_weights = attn_weights + attention_mask # Attemtion mask is upper triangular of negative infinity
        attn_weights = F.softmax(attn_weights, dim = -1, dtype = torch.float32).to(values.dtype)
        sa_output = torch.matmul(attn_weights, values) # B x H x N x D/H
        sa_output = sa_output.transpose(1, 2).contiguous() # Reorder into B x N x H x D/H
        sa_output = sa_output.reshape(B, N, D) # Concatenate vertically back into B x N x D

    # Finally post-concatenation linear layer
    sa_output = sa_module.o_proj(sa_output)

    saved_sa_outputs.append(sa_output[0, :, :].detach())
    
    ### add residual -> store residual -> layernorm -> mlp -> add residual
    hidden_state = residual + sa_output

    # FENCE
    if l >= Kfstart - 1 and l <= Kfend - 1:
        # To extract the right layer from hkr_feature_targets:
        # - We want to extract layer l + 1 (e.g. l = 1 => Layer = 2)
        # - Since hkr_feature_targets[:, k, :, :] contains layer Kfstart+k, we want to find k s.t. Kfstart + k = l + 1
        # - => k = l + 1 - Kfstart
        # hidden_state[:, 1:, 3064:3072] = Kf_target_values['hkrs'][l+1] # B x N x Df                
        pass
    
    residual = hidden_state
    saved_hkrs.append(hidden_state[0, :, :].detach())

    hidden_state = layer.post_attention_layernorm(hidden_state)
    hidden_state_pre_mlp = hidden_state
    ## MLP       
    # hidden_state is of size B x N x D
    gate_plus_fvals = layer.mlp.gate_up_proj(hidden_state) # B x N x (2I, I = intermediate MLP dimension)
    gate, fvals = gate_plus_fvals.chunk(2, dim = -1) # B x N x I
    # At this point the up_state = values (see Geva et al), and the gate is the keys
    up_state = fvals * layer.mlp.activation_fn(gate)  # Elementwise
    hidden_state = layer.mlp.down_proj(up_state) # Back to B x N x D
    ## End MLP
    
    saved_mlp_outputs.append(hidden_state[0, :, :].detach())

    hidden_state = residual + hidden_state

    if l >= Kfstart - 1 and l <= Kfend - 1:
        # hidden_state[:, 1:, 3064:3072] = Kf_target_values['hks'][l+1] # B x N x Df
        pass

    saved_hks.append(hidden_state[0, :, :].detach())
        
hidden_state = model.model.norm(hidden_state)
logits = model.lm_head(hidden_state)

# Get argmax tokens + concatenate onto previous tokens
output_token = torch.argmax(F.softmax(logits.squeeze(), dim = 1), dim = 1)[-1]
input_ids = torch.cat((input_ids, output_token.view(1, 1)), dim = 1)


tokenizer.batch_decode([output_token])

['As']

In [12]:
fvals

tensor([[[ 0.1992, -0.1187, -0.5352,  ..., -0.2754, -1.1641,  0.6445],
         [-0.1816,  1.8359, -0.9102,  ...,  0.3574, -1.8594,  1.2656],
         [-0.1953,  2.0156,  0.2422,  ...,  1.1172, -0.3496,  0.7734],
         ...,
         [ 0.1133,  1.1875, -0.6523,  ...,  0.1035, -1.4922,  1.1797],
         [-1.1797,  1.8906, -2.5156,  ...,  0.8359, -1.4609, -1.4844],
         [-1.1250,  2.5312, -2.1094,  ...,  0.2217, -1.1172, -0.2129]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<SplitBackward0>)

tensor(0.3145, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)

In [157]:
hidden_state_pre_mlp.shape

torch.Size([1, 11, 3072])