# Layer exploration (continued) 
We're trying to explore the layers so we're comfortable modifying things by hand. 

In [110]:
### Load libraries
# import flash_attn
from dotenv import main
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import jinja2
import os
import sys
import re
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig # for quantization
import plotly
from transformers import pipeline, set_seed
from tqdm import tqdm

# auth for gated repos (like llama) - gen token here: https://huggingface.co/settings/tokens
from huggingface_hub import notebook_login
notebook_login(os.getenv('HF_TOKEN'))

# model ids
model_id = ["microsoft/Phi-3-mini-4k-instruct"]

# Set seed for reproducibility 
torch.random.manual_seed(0)

# Increase max width of pd df columns 
pd.set_option('max_colwidth', 300)

# Instantiate jinja environment - used later for icl prompting 
environment = jinja2.Environment()

device = 'cuda'

# requirements.txt
# !pip3 freeze > requirements.txt

User is already logged in.


In [3]:
# Define utility functions 
# mem. monitoring! 
def check_memory():
    print("Allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
    print("Reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
    print("Total: %fGB"%(torch.cuda.get_device_properties(0).total_memory/1024/1024/1024))

# notification/text-to-speech
def text_to_speech(text):
    if sys.platform == 'darwin':
        os.system(f'say "{text}"')
    elif sys.platform.startswith('linux'):
        os.system(f'espeak "{text}"')
    else:
        print("Text-to-speech is not supported on this platform.")

# parse + template phi inputs
def parse_phi(messages: list[dict], append_response_start = True) -> str:
    """
    Converts a multi-turn conversation into a Llama-3-tokenizable input.

    Output format:
    # <s><|system|>
    # You are a helpful AI assistant.<|end|>
    # <|user|>
    # Guess my dog's name!<|end|>
    # <|assistant|>
    """
    format = '<s>'
    
    format += '\n'.join([f"<|{m['role']}|>\n{m['content']}<|end|>" for m in messages])

    if append_response_start:
        format += "\n<|assistant|>"
    
    return format

# print(parse_phi([
#     {'role': 'system', 'content': 'Hello'}, {'role': 'user', 'content': '1+1?'}, {'role': 'assistant', 'content': '2'}
# ], False))

# model eval
def eval_model(model, tokenizer, prompt):
    tokens = tokenizer(prompt, return_tensors = 'pt').to(device)
    model.eval()
    with torch.no_grad():
        res = model.generate(
            **tokens,
            max_new_tokens = 1,
            do_sample = False,
            temperature = 0.6,
            top_p = 0.9,
            eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(tokenizer.eos_token)]
        )
    return tokenizer.batch_decode(res)[0]

# assess model perf
def get_model_performance(eval_df, base_model, tokenizer, verbose = False): 

    val = []
    for idx, row in tqdm(eval_df.iterrows()): 
        response = eval_model(model = base_model, tokenizer = tokenizer, prompt = row['llm_input'])

        # error handling for malformed outputs 
        response_json = re.findall(r'(?=.*"rationale")(?=.*"answer"){.*?}', response)[-1] # extract response + json

        # initialize keep_going + check if response_json is empty list 
        try:
            response_dict = json.loads(response_json)
            
            # validate model preds against correct answer 
            if response_dict['answer'] == row['solution']:
                # print('✅ Good answer - 😎👍')
                is_correct_pred = 1
            elif response_dict['answer'] != row['solution']: 
                # print('❌ Wrong answer!!') 
                is_correct_pred = 0
                
            # validation dictionary 
            val_dict = {'question': row['question'], 'response': response_json,
                        'difficulty': row['difficulty'],
                        'answer': response_dict['answer'],
                        'rationale': response_dict['rationale'],
                        'correct_solution': row['solution'],
                        'is_correct_pred': is_correct_pred} 
            # print(val_dict['question'], '\n\n')
            val.append(val_dict)
            keep_going = False
    
        except Exception as e:
            print("Exception occurred:", e)

    val_df = pd.DataFrame(val)

    # metrics 
    n_responses = len(val_df)
    accuracy = sum(val_df['is_correct_pred'])/n_responses

    if verbose == True: 
        perf_dict = {'responses': n_responses, 'accuracy': accuracy, 'val_dict': val}
    else: 
        perf_dict = {'responses': n_responses, 'accuracy': accuracy}
        
    return(perf_dict)

In [5]:
# Utility functions (cont.) - instantiate base_model; load eval_dict
def reload_base_model(model_id = "microsoft/Phi-3-mini-4k-instruct", add_tokenizer = True): 
    # Load bnb config, base model, and tokenizer
    bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_use_double_quant = True,
    bnb_4bit_quant_type = 'nf4',
    bnb_4bit_compute_dtype = torch.bfloat16
    )

    base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map = 'auto', # not sure what's up with device_map, but this is what causes errors
    quantization_config = bnb_config,
    trust_remote_code = True
    )

    if add_tokenizer == True: 
        # Load tokenizer - remove bos token since my function already pre-pends
        tokenizer = AutoTokenizer.from_pretrained(model_id,
                                                 add_eos_token = False,
                                                 add_bos_token = False,
                                                 padding_side = 'left')

    return(base_model)

def load_eval_df(file_path = os.getcwd() + '/data/question.json', includes_math = False): # turn off math for now due to high failure rate
    # load base prompt 
    bp_file_path = os.getcwd() + '/data/base_prompt.json'
    bp_json = json.load(open(bp_file_path))

    # load eval questions 
    q_json = json.load(open(file_path))

    if includes_math == True: 
        eval_df = pd.DataFrame(q_json).assign(
         full_question = lambda df: df.apply(lambda row: row['question'] + '\n' + '\n'.join([o['code'] + '. ' + o['text'] for o in row['options']]),  axis = 1),
         llm_input = lambda df: df.apply(lambda row: parse_phi(bp_json + [{'role': 'assistant', 'content': row['full_question']}]), axis = 1)
        )
    else: 
        eval_df = pd.DataFrame(q_json).assign(
         full_question = lambda df: df.apply(lambda row: row['question'] + '\n' + '\n'.join([o['code'] + '. ' + o['text'] for o in row['options']]),  axis = 1),
         llm_input = lambda df: df.apply(lambda row: parse_phi(bp_json + [{'role': 'assistant', 'content': row['full_question']}]), axis = 1)
        )

        eval_df = eval_df[eval_df['type'] != 'math']

    return(eval_df)

In [4]:
# Load bnb config, base model, and tokenizer
bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_use_double_quant = True,
    bnb_4bit_quant_type = 'nf4',
    bnb_4bit_compute_dtype = torch.bfloat16
)

base_model = AutoModelForCausalLM.from_pretrained(
    model_id[0],
    device_map = 'auto', # not sure what's up with device_map, but this is what causes errors
    quantization_config = bnb_config,
    trust_remote_code = True
)

# Load tokenizer - remove bos token since my function already pre-pends
tokenizer = AutoTokenizer.from_pretrained(model_id[0],
                                         add_eos_token = False,
                                         add_bos_token = False,
                                         padding_side = 'left')

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# Load self-attention layer
Goal is to load self-attn, know where corresponds to on diagram, and be able to identify inputs + outputs (along w/ dims of each).

**Self-note:** remember to add with torch no grad so you don't accumulate grads...

In [7]:
# Re-instantiate model 
base_model = reload_base_model()

# Load eval dict 
eval_df = load_eval_df()

`flash-attention` package not found, consider installing for better performance: No module named 'flash_attn'.
Current `flash-attention` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [28]:
sen = "<s>My dog is a good boy who likes to"

# tokenize sentence 
dog_tok = tokenizer(sen, return_tensors = 'pt').to(device)
print(f"Token dims: {dog_tok['input_ids'].squeeze().shape}")

# gen. embeddings / hidden states (ref. of "hidden states" changes over time) 
dog_embed = base_model.model.embed_tokens(dog_tok['input_ids'])
print(f"Embedding dims: {dog_embed.squeeze().shape}")

#################### NOW ENTERING TRANSFORMERS ###########################

Token dims: torch.Size([11])
Embedding dims: torch.Size([11, 3072])


In [15]:
# get position id's again (o.w. will silently fail since model looks for dims)
# this comes from line ~1064 in https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py#L243
seq_length = dog_tok['input_ids'].shape[1]

position_ids = torch.arange(0, seq_length + 0, dtype=torch.long, device = device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)

# basically, it's just tracking seq. length of inputs

In [115]:
# this is a single transformers block :) - we're going to go inside of it
one_block = base_model.model.layers[0]
# print(one_block)

with torch.no_grad():
    # layer_norm on hidden states 
    hidden_states = one_block.input_layernorm(dog_embed)
    
    # enter self_attn layer 
    # this is all of the self_attn stuff at once 
    # self_attn = one_block.self_attn(hidden_states, position_ids = position_ids)
    # print(self_attn[0].shape)

    # hidden_states_two = hidden_states + self_attn[0]

    # # enter MLP 
    # print(one_block.mlp(hidden_states_two).shape)
    # print(one_block.self_attn.head_dim, one_block.self_attn.hidden_size)

    # o_proj is a linear layer that seems to prep. for future transforms; also injects more weights that can 
    # be trained / can hold meaning 
    # o_proj_output = one_block.self_attn.o_proj(dog_embed)
    # print(o_proj_output.shape) # 11 x 3072 
    
    # qkv proj - these are now stacked; like a mega-tensor 
    qkv = one_block.self_attn.qkv_proj(hidden_states)
    print(qkv.shape) 

    # call forward on the attn module 
    # self_attn = one_block.self_attn(hidden_states, position_ids = position_ids)
    bsz, q_len, _ = hidden_states.size()
    print(bsz, q_len)

    query_pos = one_block.self_attn.num_heads * one_block.self_attn.head_dim
    print(query_pos)

    query_states = qkv[..., :query_pos] # should be ~1/3
    key_states = qkv[..., query_pos : query_pos + one_block.self_attn.num_key_value_heads * one_block.self_attn.head_dim]
    value_states = qkv[..., query_pos + one_block.self_attn.num_key_value_heads * one_block.self_attn.head_dim :]
    print(query_states.shape, key_states.shape, value_states.shape)

    # re-shape each (head_dim is D/H)
    query_states = query_states.view(bsz, q_len, one_block.self_attn.num_heads, one_block.self_attn.head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, one_block.self_attn.num_heads, one_block.self_attn.head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, one_block.self_attn.num_heads, one_block.self_attn.head_dim).transpose(1, 2)
    # print(query_states.shape)

    kv_seq_len = key_states.shape[-2]
    print(kv_seq_len)

    # now, apply rotary embeddings (return to figure out what is going on here) 
    cos, sin = one_block.self_attn.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
    query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
    print(query_states.shape)

    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(one_block.self_attn.head_dim)
    print(attn_weights.shape) # it is 11 x 11 - and there's 32 since there's 32 blocks
    # mlp portion - gate up proj 
    # gate_up_proj_output = one_block.mlp.gate_up_proj(qkv_proj_output)
    # print(gate_up_proj_output.shape)

    # attention mask piece helps ensure that things only pay attention to what occurs before; ow everything "pays attention" to everything 
    # this is a way to force boundaries 

torch.Size([1, 11, 9216])
1 11
3072
torch.Size([1, 11, 3072]) torch.Size([1, 11, 3072]) torch.Size([1, 11, 3072])
11
torch.Size([1, 32, 11, 96])
torch.Size([1, 32, 11, 11])


In [105]:
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


In [101]:
sin

tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 8.4131e-01,  7.3486e-01,  6.2988e-01,  ...,  1.7786e-04,
           1.4675e-04,  1.2118e-04],
         [ 9.0918e-01,  9.9658e-01,  9.7852e-01,  ...,  3.5572e-04,
           2.9349e-04,  2.4235e-04],
         ...,
         [ 9.8926e-01,  3.1470e-01, -7.3975e-01,  ...,  1.4229e-03,
           1.1740e-03,  9.6941e-04],
         [ 4.1211e-01,  9.1113e-01, -1.5100e-01,  ...,  1.6003e-03,
           1.3208e-03,  1.0900e-03],
         [-5.4395e-01,  9.2090e-01,  5.0537e-01,  ...,  1.7786e-03,
           1.4677e-03,  1.2112e-03]]], device='cuda:0', dtype=torch.float16)

In [19]:
# track layer defs 
layer_names = []
for idx, (name, param) in enumerate(base_model.named_parameters()): 

    # store layer names (for testing) 
    layer_names.append({'idx': idx, 'name': name, 'dims': param.shape})

# view layers 
pd.DataFrame(layer_names)

Unnamed: 0,idx,name,dims
0,0,model.embed_tokens.weight,"(32064, 3072)"
1,1,model.layers.0.self_attn.o_proj.weight,"(4718592, 1)"
2,2,model.layers.0.self_attn.qkv_proj.weight,"(14155776, 1)"
3,3,model.layers.0.mlp.gate_up_proj.weight,"(25165824, 1)"
4,4,model.layers.0.mlp.down_proj.weight,"(12582912, 1)"
...,...,...,...
190,190,model.layers.31.mlp.down_proj.weight,"(12582912, 1)"
191,191,model.layers.31.input_layernorm.weight,"(3072,)"
192,192,model.layers.31.post_attention_layernorm.weight,"(3072,)"
193,193,model.norm.weight,"(3072,)"
