# Layer exploration (continued) 
We're trying to explore the layers so we're comfortable modifying things by hand. 

In [None]:
# Run on 1 x RTX A6000
!pip install -q wandb -U
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q -U datasets scipy ipywidgets matplotlib
!pip install plotly.express
!pip install scikit-learn
!pip install -U flash-attn --no-build-isolation
!pip install pyyaml
!pip install pyarrow
!pip install termcolor
!pip install pandas
!pip install tqdm
!pip install python-dotenv
# If distutils error, https://stackoverflow.com/a/78050586

In [1]:
### Load libraries
# import flash_attn
# from dotenv import main
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import jinja2
import os
import sys
import re
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig # for quantization
import plotly
from transformers import pipeline, set_seed
from tqdm import tqdm

# auth for gated repos (like llama) - gen token here: https://huggingface.co/settings/tokens
from huggingface_hub import notebook_login
notebook_login(os.getenv('HF_TOKEN'))

# model ids
model_id = ["microsoft/Phi-3-mini-4k-instruct"]

# Set seed for reproducibility 
torch.random.manual_seed(0)

# Increase max width of pd df columns 
pd.set_option('max_colwidth', 300)

# Instantiate jinja environment - used later for icl prompting 
environment = jinja2.Environment()

device = 'cuda'

# requirements.txt
# !pip3 freeze > requirements.txt

User is already logged in.


In [2]:
# Define utility functions 
# mem. monitoring! 
def check_memory():
    print("Allocated: %fGB"%(torch.cuda.memory_allocated(0)/1024/1024/1024))
    print("Reserved: %fGB"%(torch.cuda.memory_reserved(0)/1024/1024/1024))
    print("Total: %fGB"%(torch.cuda.get_device_properties(0).total_memory/1024/1024/1024))

# notification/text-to-speech
def text_to_speech(text):
    if sys.platform == 'darwin':
        os.system(f'say "{text}"')
    elif sys.platform.startswith('linux'):
        os.system(f'espeak "{text}"')
    else:
        print("Text-to-speech is not supported on this platform.")

# parse + template phi inputs
def parse_phi(messages: list[dict], append_response_start = True) -> str:
    """
    Converts a multi-turn conversation into a Llama-3-tokenizable input.

    Output format:
    # <s><|system|>
    # You are a helpful AI assistant.<|end|>
    # <|user|>
    # Guess my dog's name!<|end|>
    # <|assistant|>
    """
    format = '<s>'
    
    format += '\n'.join([f"<|{m['role']}|>\n{m['content']}<|end|>" for m in messages])

    if append_response_start:
        format += "\n<|assistant|>"
    
    return format

# print(parse_phi([
#     {'role': 'system', 'content': 'Hello'}, {'role': 'user', 'content': '1+1?'}, {'role': 'assistant', 'content': '2'}
# ], False))

# model eval
def eval_model(model, tokenizer, prompt):
    tokens = tokenizer(prompt, return_tensors = 'pt').to(device)
    model.eval()
    with torch.no_grad():
        res = model.generate(
            **tokens,
            max_new_tokens = 1,
            do_sample = False,
            temperature = 0.6,
            top_p = 0.9,
            eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(tokenizer.eos_token)]
        )
    return tokenizer.batch_decode(res)[0]

# assess model perf
def get_model_performance(eval_df, base_model, tokenizer, verbose = False): 

    val = []
    for idx, row in tqdm(eval_df.iterrows()): 
        response = eval_model(model = base_model, tokenizer = tokenizer, prompt = row['llm_input'])

        # error handling for malformed outputs 
        response_json = re.findall(r'(?=.*"rationale")(?=.*"answer"){.*?}', response)[-1] # extract response + json

        # initialize keep_going + check if response_json is empty list 
        try:
            response_dict = json.loads(response_json)
            
            # validate model preds against correct answer 
            if response_dict['answer'] == row['solution']:
                # print('✅ Good answer - 😎👍')
                is_correct_pred = 1
            elif response_dict['answer'] != row['solution']: 
                # print('❌ Wrong answer!!') 
                is_correct_pred = 0
                
            # validation dictionary 
            val_dict = {'question': row['question'], 'response': response_json,
                        'difficulty': row['difficulty'],
                        'answer': response_dict['answer'],
                        'rationale': response_dict['rationale'],
                        'correct_solution': row['solution'],
                        'is_correct_pred': is_correct_pred} 
            # print(val_dict['question'], '\n\n')
            val.append(val_dict)
            keep_going = False
    
        except Exception as e:
            print("Exception occurred:", e)

    val_df = pd.DataFrame(val)

    # metrics 
    n_responses = len(val_df)
    accuracy = sum(val_df['is_correct_pred'])/n_responses

    if verbose == True: 
        perf_dict = {'responses': n_responses, 'accuracy': accuracy, 'val_dict': val}
    else: 
        perf_dict = {'responses': n_responses, 'accuracy': accuracy}
        
    return(perf_dict)

In [3]:
# Utility functions (cont.) - instantiate base_model; load eval_dict
def reload_base_model(model_id = "microsoft/Phi-3-mini-4k-instruct", add_tokenizer = True): 
    # Load bnb config, base model, and tokenizer
    bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_use_double_quant = True,
    bnb_4bit_quant_type = 'nf4',
    bnb_4bit_compute_dtype = torch.bfloat16
    )

    base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map = 'auto', # not sure what's up with device_map, but this is what causes errors
    # quantization_config = bnb_config,
    trust_remote_code = True
    )

    if add_tokenizer == True: 
        # Load tokenizer - remove bos token since my function already pre-pends
        tokenizer = AutoTokenizer.from_pretrained(model_id,
                                                 add_eos_token = False,
                                                 add_bos_token = False,
                                                 padding_side = 'left')

    return(base_model)

def load_eval_df(file_path = os.getcwd() + '/data/question.json', includes_math = False): # turn off math for now due to high failure rate
    # load base prompt 
    bp_file_path = os.getcwd() + '/data/base_prompt.json'
    bp_json = json.load(open(bp_file_path))

    # load eval questions 
    q_json = json.load(open(file_path))

    if includes_math == True: 
        eval_df = pd.DataFrame(q_json).assign(
         full_question = lambda df: df.apply(lambda row: row['question'] + '\n' + '\n'.join([o['code'] + '. ' + o['text'] for o in row['options']]),  axis = 1),
         llm_input = lambda df: df.apply(lambda row: parse_phi(bp_json + [{'role': 'assistant', 'content': row['full_question']}]), axis = 1)
        )
    else: 
        eval_df = pd.DataFrame(q_json).assign(
         full_question = lambda df: df.apply(lambda row: row['question'] + '\n' + '\n'.join([o['code'] + '. ' + o['text'] for o in row['options']]),  axis = 1),
         llm_input = lambda df: df.apply(lambda row: parse_phi(bp_json + [{'role': 'assistant', 'content': row['full_question']}]), axis = 1)
        )

        eval_df = eval_df[eval_df['type'] != 'math']

    return(eval_df)

In [4]:
# # Load bnb config, base model, and tokenizer
# bnb_config = BitsAndBytesConfig(
#     load_in_4bit = True,
#     bnb_4bit_use_double_quant = True,
#     bnb_4bit_quant_type = 'nf4',
#     bnb_4bit_compute_dtype = torch.bfloat16
# )

# base_model = AutoModelForCausalLM.from_pretrained(
#     model_id[0],
#     device_map = 'auto', # not sure what's up with device_map, but this is what causes errors
#     quantization_config = bnb_config,
#     trust_remote_code = True
# )

# # Load tokenizer - remove bos token since my function already pre-pends
# tokenizer = AutoTokenizer.from_pretrained(model_id[0],
#                                          add_eos_token = False,
#                                          add_bos_token = False,
#                                          padding_side = 'left')

# Breaking apart phi-3 (+ checking if outputs flow through analogously) 
Recreating phi-3 layer by layer (took out self_attn repro code for now, but can recover via git history) + trying to break it down to most granular level possible in order to track + modify outputs :). Checking to ensure everything is analogous by doing a forward pass 
with the phi-3 model (not broken apart) as a baseline + tracking outputs w/ hooks. 

In [6]:
# Re-instantiate model 
base_model = reload_base_model()

# Load eval dict 
eval_df = load_eval_df()

# Load tokenizer - remove bos token since my function already pre-pends
tokenizer = AutoTokenizer.from_pretrained(model_id[0],
                                         add_eos_token = False,
                                         add_bos_token = False,
                                         padding_side = 'left')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
max_tokens = 128
from py_helpers.phi3 import _prepare_4d_causal_attention_mask
# Testing for transformers block
with torch.no_grad(): 
    prompt = '<s>I am a dog and I like to eat meat! My favorite'
    base_model.eval()
    generated_tokens = 0
    input_ids = tokenizer(prompt, return_tensors = 'pt').to(device)['input_ids']
    input_ids_1 = input_ids

    output_dict = {
        

    }
    while True: 
        N = input_ids.shape[1]
    
        # get embeddings
        embeds_output = base_model.model.embed_tokens(input_ids)
        hidden_state = embeds_output
    
        position_ids = torch.arange(0, N, dtype=torch.long, device=device).unsqueeze(0).view(-1, N) # Create position IDs
        attention_mask = _prepare_4d_causal_attention_mask(None, (1, N), embeds_output, 0, sliding_window = base_model.model.config.sliding_window) # Make an attention mask to hide right context
    
        ##### TRANSFORMER BLOCK #####
        
        for (idx, layer) in enumerate(base_model.model.layers[0:1]): 
            decoder_layer = base_model.model.layers[idx] 
            
            # store residuals 
            residual = hidden_state # line 851
            hidden_states = decoder_layer.input_layernorm(hidden_state) # layer norm on hidden states - line 853 (https://github.com/huggingface/transformers/blob/main/src/transformers/models/phi3/modeling_phi3.py#L810)
        
            # now, self attn - line 856
            attn_outputs, self_attn_weights, present_key_value = decoder_layer.self_attn(
                hidden_states = hidden_states,
                attention_mask = attention_mask,
                position_ids = position_ids,
                output_attentions = True # this is the one that helps pop. self_attn_weights and present_key_value :)) Those are related to caching!
                # past_key_value = ## don't have - optional, cached 
                # output_attentions = ## don't have - line 842; whether to return attention tensors of all attention layers 
                # use_cache = use_cache ### don't have - optional, related to caching 
            )
        
            # line 865 
            hidden_states = residual + decoder_layer.resid_attn_dropout(attn_outputs)
        
            residual = hidden_states # line 867
            hidden_states = decoder_layer.post_attention_layernorm(hidden_states) # line 868
        
            hidden_states = decoder_layer.mlp(hidden_states)
            hidden_states = residual + decoder_layer.resid_mlp_dropout(hidden_states)
        
            outputs = (hidden_states,) 
        
            # these map back to those booleans arguments defined within forward from earlier :) 
            # if output_attentions:
            #         outputs += (self_attn_weights,)
        
            # if use_cache:
            #         outputs += (present_key_value,)
        
            hidden_state = base_model.model.norm(hidden_states) # hm, this seems to be correct - it was just called outputs when charles defined it as layer outputs :) 
    
        # run LM head 
        logits = base_model.lm_head(hidden_state) # remember you need to use the version w/ causal LM 
    
        # get argmax tokens + concatenate onto previous tokens 
        output_token = torch.argmax(F.softmax(logits.squeeze(), dim = 1), dim = 1)[-1]
        input_ids = torch.cat((input_ids, output_token.view(1, 1)), dim = 1)
    
        # Break while loop if EOS or generation > max tokens 
        generated_tokens = generated_tokens + 1
        if output_token in [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|end|>")] or generated_tokens >= max_tokens:
           break
    
    final_output = tokenizer.decode(input_ids.squeeze())
    
    print(final_output)
   

In [7]:
# these are re-used across both of below chunks
prompt = '<s>I am a dog and I like to eat meat! My favorite'
input_ids = tokenizer(prompt, return_tensors = 'pt').to(device)['input_ids']

In [162]:
def getOutputs(name):
    # the hook signature
    def hook(model, input, output):
        layer_outputs[name] = output
    return hook

layer_outputs = {} 

# add hooks 
h1 = base_model.model.embed_tokens.register_forward_hook(getOutputs('embed')) # embed layer 
h2 = base_model.model.layers[0].register_forward_hook(getOutputs('trans_one')) # first transformers block
h3 = base_model.model.layers[0].input_layernorm.register_forward_hook(getOutputs('sa_layer_norm')) # this is the layernorm that happens to hidden states before sa
h4 = base_model.model.layers[0].self_attn.register_forward_hook(getOutputs('self_attn')) # note, this self attn. piece is a sub-component of the above
h5 = base_model.model.layers[0].resid_attn_dropout.register_forward_hook(getOutputs('resid_attn_dropout')) # this dropout happens after sa
h6 = base_model.model.layers[0].mlp.register_forward_hook(getOutputs('mlp')) # mlp 
h7 = base_model.model.layers[31].register_forward_hook(getOutputs('final_output')) # final output after all transformers blocks are run; comparing now that we've brought back loop

# forward pass
with torch.no_grad():   
    base_model(input_ids)

# remove hooks - should rewrite as loop later
hooks = [h1, h2, h3, h4, h5, h6, h7]
for hook in hooks: 
    hook.remove()

print(layer_outputs['trans_one'])

(tensor([[[-0.0636,  0.1626,  0.0082,  ...,  0.1023, -0.0550, -0.0476],
         [ 0.0550,  0.0246,  0.0246,  ...,  0.0624, -0.0468, -0.0152],
         [ 0.0213,  0.0441, -0.0356,  ...,  0.0537,  0.0057, -0.1038],
         ...,
         [-0.0206, -0.0517, -0.0347,  ...,  0.0051, -0.0113,  0.0405],
         [ 0.0935,  0.0172, -0.0123,  ...,  0.0196,  0.0150, -0.0505],
         [-0.0741, -0.0551, -0.0603,  ...,  0.0099,  0.0236, -0.0074]]],
       device='cuda:0'), DynamicCache())


In [166]:
from py_helpers.phi3 import _prepare_4d_causal_attention_mask
from py_helpers.phi3 import apply_rotary_pos_emb 

# Testing for transformers block
with torch.no_grad():
    
    embeds_output = base_model.model.embed_tokens(input_ids)

    hidden_state = embeds_output
    N = input_ids.shape[1]
    
    position_ids = torch.arange(0, N, dtype=torch.long, device=device).unsqueeze(0).view(-1, N) # Create position IDs
    attention_mask = _prepare_4d_causal_attention_mask(None, (1, N), embeds_output, 0, sliding_window = base_model.model.config.sliding_window) # Make an attention mask to hide right context

    # print(attention_mask, torch.where(attention_mask != 0, torch.tensor(1), attention_mask)) # this makes it easier to see the diagonal

    ##### TRANSFORMER BLOCK #####
    for i, transformer_block in enumerate(base_model.model.layers): 
                
        residual = hidden_state 
        hidden_states_one = transformer_block.input_layernorm(hidden_state)
        
        # self attn - now working on re-breaking this out 
        B, N, D = embeds_output.shape # line 337; where B is batch, N is tok. length, D is embedding dimensions 
        H = 32 # this is # of sa heads 
        Dh = int(D/H)
        
        sa =  transformer_block.self_attn # later, won't just use 1st layer's sa 
        qkv = sa.qkv_proj(hidden_states_one)
        
        # splitting qkv into query, key, value matrices 
        query_states = qkv[..., :D]
        key_states = qkv[..., D: 2*D]
        value_states = qkv[..., 2*D:]
        
        # check dims - should all be same :) 
        # print(query_states.shape, key_states.shape, value_states.shape)
        
        # re-shaping to distribute our guys across the 32 heads
        query_states = query_states.view(B, N, H, Dh).transpose(1, 2)
        key_states = key_states.view(B, N, H, Dh).transpose(1, 2)
        value_states = value_states.view(B, N, H, Dh).transpose(1, 2)
        
        # check dims - should all be same (yet again) :) 
        # print(query_states.shape, key_states.shape, value_states.shape) # great, it looks right - covers every token, but dimension is small! 
        
        # note: not going to re-create rotary embeddings 
        cos, sin = sa.rotary_emb(value_states, position_ids, seq_len = N) # prep. for rotation 
        
        # now, apply rotation 
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
        
        # check dims of query_states and key_states - should be same
        # print(query_states.shape, key_states.shape) # nice! 
        
        # calculate attention weights
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(Dh) # should be 32 14 x 14 - attn. weights is relation of each token by each token
        
        # add in attn. mask 
        attn_weights = attn_weights + attention_mask # negative infinities from the mask will convert to zeroes via softmax
        
        # softmax 
        attn_weights = F.softmax(attn_weights, dim = -1, dtype = torch.float32).to(value_states.dtype)
        
        # sa x v 
        attn_output = torch.matmul(attn_weights, value_states) # sa times the value matrix - like in the book
        
        ###### we want to reconcat. all of the different guys that were thrown across 32 heads ######
        # transpose 
        attn_output = attn_output.transpose(1, 2).contiguous()
        
        # reshape 
        attn_output = attn_output.reshape(B, N, D) # it should be the right size now - N X D 

        # now, apply the linear transform (o_proj) 
        attn_output = sa.o_proj(attn_output)
        ####################

        
        hidden_states_two = residual + attn_output
    
        residual = hidden_states_two # line 867
        hidden_states_three = transformer_block.post_attention_layernorm(hidden_states_two) # line 868
    
        mlp = transformer_block.mlp(hidden_states_three)
        hidden_states_four = residual + mlp # dropout doesn't do anything right now
    
        hidden_state = hidden_states_four

    hidden_state = base_model.model.norm(hidden_state)
    logits = base_model.lm_head(hidden_state)
    logits = logits.float()

print(logits)

tensor([[[ 1.6646,  1.0017, -0.4470,  ...,  0.0000,  0.0000,  0.0000],
         [ 4.4798,  9.5102,  6.6458,  ...,  0.0000,  0.0000,  0.0000],
         [ 6.7997, 11.5268, 12.2446,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 6.5071,  6.5425, 13.4693,  ...,  0.0000,  0.0000,  0.0000],
         [10.5837,  6.0312,  9.2012,  ...,  0.0000,  0.0000,  0.0000],
         [10.3935,  7.4350, 11.9247,  ...,  0.0000,  0.0000,  0.0000]]],
       device='cuda:0')


In [190]:
from py_helpers.phi3 import _prepare_4d_causal_attention_mask

@torch.no_grad()
def generate_multiple_outputs(model, tokenizer, prompt = '<s>I am a dog and I like to eat meat! My favorite', max_tokens = 128, device = 'cuda'):
    model.eval()
    generated_tokens = 0
    input_ids = tokenizer(prompt, return_tensors = 'pt').to(device)['input_ids']

    while True:
        N = input_ids.shape[1]

        # Get embeddings
        embeds_output = model.model.embed_tokens(input_ids)
        hidden_state = embeds_output
        
        # Get some parameters needed for transformers layers
        position_ids = torch.arange(0, N, dtype=torch.long, device=device).unsqueeze(0).view(-1, N) # Create position IDs
        attention_mask = _prepare_4d_causal_attention_mask(None, (1, N), embeds_output, 0, sliding_window = model.model.config.sliding_window) # Make an attention mask to hide right context
    
        # Execute transformers layers
        for i, transformer_block in enumerate(model.model.layers):
            residual = hidden_state 
            hidden_states_one = transformer_block.input_layernorm(hidden_state)
            
            # self attn - now working on re-breaking this out 
            B, N, D = embeds_output.shape # line 337; where B is batch, N is tok. length, D is embedding dimensions 
            H = 32 # this is # of sa heads 
            Dh = int(D/H)
            
            sa =  transformer_block.self_attn # later, won't just use 1st layer's sa 
            qkv = sa.qkv_proj(hidden_states_one)
            
            # splitting qkv into query, key, value matrices 
            query_states = qkv[..., :D]
            key_states = qkv[..., D: 2*D]
            value_states = qkv[..., 2*D:]
            
            # check dims - should all be same :) 
            # print(query_states.shape, key_states.shape, value_states.shape)
            
            # re-shaping to distribute our guys across the 32 heads
            query_states = query_states.view(B, N, H, Dh).transpose(1, 2)
            key_states = key_states.view(B, N, H, Dh).transpose(1, 2)
            value_states = value_states.view(B, N, H, Dh).transpose(1, 2)
            
            # check dims - should all be same (yet again) :) 
            # print(query_states.shape, key_states.shape, value_states.shape) # great, it looks right - covers every token, but dimension is small! 
            
            # note: not going to re-create rotary embeddings 
            cos, sin = sa.rotary_emb(value_states, position_ids, seq_len = N) # prep. for rotation 
            
            # now, apply rotation 
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
            
            # check dims of query_states and key_states - should be same
            # print(query_states.shape, key_states.shape) # nice! 
            
            # calculate attention weights
            attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(Dh) # should be 32 14 x 14 - attn. weights is relation of each token by each token
            
            # add in attn. mask 
            attn_weights = attn_weights + attention_mask # negative infinities from the mask will convert to zeroes via softmax
            if i < 30:
                # construct diagonal matrix - currently using this to downweight off-diagonals to make past context less relevant 
                mat = torch.full((B, H, N, N), 0.5).to(device) # this is the param. you can change to make your model more "forgetful" - downweighting the past 
                diag_indices = torch.arange(N)
                mat[:, :, diag_indices, diag_indices] = 1.0

                attn_weights = attn_weights * mat
                
            # softmax 
            attn_weights = F.softmax(attn_weights, dim = -1, dtype = torch.float32).to(value_states.dtype)
            
            # sa x v 
            attn_output = torch.matmul(attn_weights, value_states) # sa times the value matrix - like in the book
            
            ###### we want to reconcat. all of the different guys that were thrown across 32 heads ######
            # transpose 
            attn_output = attn_output.transpose(1, 2).contiguous()
            
            # reshape 
            attn_output = attn_output.reshape(B, N, D) # it should be the right size now - N X D 
    
            # now, apply the linear transform (o_proj) 
            attn_output = sa.o_proj(attn_output)
            ####################
    
            hidden_states_two = residual + attn_output
        
            residual = hidden_states_two # line 867
            hidden_states_three = transformer_block.post_attention_layernorm(hidden_states_two) # line 868
        
            mlp = transformer_block.mlp(hidden_states_three)
            hidden_states_four = residual + mlp # dropout doesn't do anything right now
        
            hidden_state = hidden_states_four

        # RMS norm the final transformer layer output
        hidden_state = model.model.norm(hidden_state)
    
        # Run LM head
        logits = model.lm_head(hidden_state)

        # Get argmax tokens + concatenate onto previous tokens
        output_token = torch.argmax(F.softmax(logits.squeeze(), dim = 1), dim = 1)[-1]
        input_ids = torch.cat((input_ids, output_token.view(1, 1)), dim = 1)

        # Break while loop if EOS or generation > max tokens
        generated_tokens = generated_tokens + 1
        if output_token in [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|end|>")] or generated_tokens >= max_tokens:
            break

    final_output = tokenizer.decode(input_ids.squeeze())
    return final_output

# Test
test_prompt = '<s>I am a dog and I like to eat meat! My favorite'
# Use function
print('my_model + manual generation', generate_multiple_outputs(base_model, tokenizer, prompt = prompt))

my_model + manual generation <s> I am a dog and I like to eat meat! My favorite food is is.

**response: The sentence is a simple sentence that is is

- [Response]:

- I am a dog. I am a dog. I eat meat.

- I am a dog. I eat meat.

- I am a canine. I eat meat.

- I am a canine. I eat meat.

- I am a canine. I eat meat.


- I am a canine. I eat meat.


- I am a dog. I eat meat.


- I am a dog


In [157]:
# self attn - now working on re-breaking this out 
B, N, D = embeds_output.shape # line 337; where B is batch, N is tok. length, D is embedding dimensions 
H = 32 # this is # of sa heads 
Dh = int(D/H)

sa =  base_model.model.layers[0].self_attn # later, won't just use 1st layer's sa 
qkv = sa.qkv_proj(layer_one_sa_input)

# splitting qkv into query, key, value matrices 
query_states = qkv[..., :D]
key_states = qkv[..., D: 2*D]
value_states = qkv[..., 2*D:]

# check dims - should all be same :) 
# print(query_states.shape, key_states.shape, value_states.shape)

# re-shaping to distribute our guys across the 32 heads
query_states = query_states.view(B, N, H, Dh).transpose(1, 2)
key_states = key_states.view(B, N, H, Dh).transpose(1, 2)
value_states = value_states.view(B, N, H, Dh).transpose(1, 2)

# check dims - should all be same (yet again) :) 
# print(query_states.shape, key_states.shape, value_states.shape) # great, it looks right - covers every token, but dimension is small! 

# note: not going to re-create rotary embeddings 
cos, sin = sa.rotary_emb(value_states, position_ids, seq_len = N) # prep. for rotation 

# now, apply rotation 
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

# check dims of query_states and key_states - should be same
# print(query_states.shape, key_states.shape) # nice! 

# calculate attention weights
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(Dh) # should be 32 14 x 14 - attn. weights is relation of each token by each token

# add in attn. mask 
attn_weights = attn_weights + attention_mask # negative infinities from the mask will convert to zeroes via softmax

# softmax 
attn_weights = F.softmax(attn_weights, dim = -1, dtype = torch.float32).to(value_states.dtype)

# sa x v 
attn_output = torch.matmul(attn_weights, value_states) # sa times the value matrix - like in the book

###### we want to reconcat. all of the different guys that were thrown across 32 heads ######
# transpose 
attn_output = attn_output.transpose(1, 2).contiguous()

# reshape 
attn_output = attn_output.reshape(B, N, D) # it should be the right size now - N X D 

# now, apply the linear transform (o_proj) 
attn_output = sa.o_proj(attn_output)
####################


attn_output

tensor([[[-0.0145,  0.0132, -0.0053,  ...,  0.0355,  0.0086, -0.0197],
         [ 0.0011, -0.0013,  0.0006,  ...,  0.0078,  0.0023, -0.0082],
         [ 0.0158, -0.0120,  0.0128,  ...,  0.0003, -0.0097,  0.0045],
         ...,
         [-0.0010, -0.0067,  0.0182,  ...,  0.0031,  0.0117, -0.0083],
         [-0.0128, -0.0039,  0.0122,  ..., -0.0006,  0.0170, -0.0039],
         [-0.0162, -0.0064,  0.0052,  ...,  0.0034,  0.0092, -0.0065]]],
       device='cuda:0', grad_fn=<UnsafeViewBackward0>)

In [153]:
layer_outputs['self_attn']

(tensor([[[-0.0145,  0.0132, -0.0053,  ...,  0.0355,  0.0086, -0.0197],
          [ 0.0011, -0.0013,  0.0006,  ...,  0.0078,  0.0023, -0.0082],
          [ 0.0158, -0.0120,  0.0128,  ...,  0.0003, -0.0097,  0.0045],
          ...,
          [-0.0010, -0.0067,  0.0182,  ...,  0.0031,  0.0117, -0.0083],
          [-0.0128, -0.0039,  0.0122,  ..., -0.0006,  0.0170, -0.0039],
          [-0.0162, -0.0064,  0.0052,  ...,  0.0034,  0.0092, -0.0065]]],
        device='cuda:0'),
 None,
 None)

In [142]:
layer_outputs['self_attn']

(tensor([[[-0.0145,  0.0132, -0.0053,  ...,  0.0355,  0.0086, -0.0197],
          [ 0.0011, -0.0013,  0.0006,  ...,  0.0078,  0.0023, -0.0082],
          [ 0.0158, -0.0120,  0.0128,  ...,  0.0003, -0.0097,  0.0045],
          ...,
          [-0.0010, -0.0067,  0.0182,  ...,  0.0031,  0.0117, -0.0083],
          [-0.0128, -0.0039,  0.0122,  ..., -0.0006,  0.0170, -0.0039],
          [-0.0162, -0.0064,  0.0052,  ...,  0.0034,  0.0092, -0.0065]]],
        device='cuda:0'),
 None,
 None)

In [87]:
# Checking if outputs are flowing appropriately through my repro. 
# torch.equal(hidden_states_one, layer_outputs['sa_layer_norm']) # nice - this is the first layernorm on hidden states 
# torch.equal(attn_outputs[0], layer_outputs['self_attn'][0]) # nice - this is on self attn 
# torch.equal(resid_attn_dropout, layer_outputs['resid_attn_dropout']) # nice - this is sorta analog. to line 865 in phi-3 docs; not sure if directly comp. otherwise
# torch.equal(mlp, layer_outputs['mlp']) # nice - this is the mlp piece 

# check if tracks w/ block 
# torch.equal(outputs, layer_outputs['trans_one'][0])

# check if final hidden state (after reintroducing loop) tracks w/ the non-broken down model's final state 
# torch.equal(hidden_state, layer_outputs['final_output'][0]) # nice! 

True