In [None]:
import torch
import torch.nn.functional as F
import math
from tqdm import tqdm
import numpy as np
import pandas as pd
from IPython.core.display import HTML, Markdown
from py_helpers.memory import check_memory
from py_helpers.data import parse_phi
import plotly.express as px
import os
import wandb 
from dotenv import load_dotenv
from datetime import datetime

load_dotenv('secrets.env')
device = 'cuda'

## Initialize Model

In [None]:
os.environ['WANDB_INIT_TIMEOUT'] = '120'
wandb.login(key = os.getenv('WANDB_API_KEY'))
run = wandb.init(project = 'fence', name = f"FENCE-{datetime.now().strftime('%Y-%m-%d-%H-%M')}")

In [None]:
# Load Model
from transformers import AutoModelForCausalLM, AutoTokenizer
from py_helpers.phi3 import Phi3Config, Phi3ForCausalLM 

tokenizer = AutoTokenizer.from_pretrained('microsoft/Phi-3-mini-4k-instruct', add_eos_token = False, add_bos_token = False, padding_side = 'right') # Padding side not important

# Load the usual model from HF transformers
base_model = AutoModelForCausalLM.from_pretrained('microsoft/Phi-3-mini-4k-instruct', device_map = 'auto', trust_remote_code = True).to(device).eval()

# Now load a model seperately from the underlying model object code
my_model = Phi3ForCausalLM(base_model.config).to(device).eval() # Phi3Config()

In [None]:
# Next, we want to clone params from base_model into model
# Let's store all params from the base_model
all_params = {}
for name, param in base_model.named_parameters():
    all_params[name] = param.cpu().clone()

# Then copy them over to the new model
for name, param in my_model.named_parameters():
    param.data.copy_(all_params[name].data)

# Verify these are the same
for name, p in my_model.named_parameters():
    if name == 'model.embed_tokens.weight': 
        print(p)
for name, p in base_model.named_parameters():
    if name == 'model.embed_tokens.weight': 
        print(p)

In [None]:
check_memory()
# Need to delete ALl references to original model to clear memory properly https://discuss.pytorch.org/t/cuda-memory-not-released-by-torch-cuda-empty-cache/129913/6
if 'base_model' in globals():
    del base_model
if 'name' in globals():
    del name
if 'param' in globals():
    del param
if 'p' in globals():
    del p
if 'all_params' in globals():
    del all_params
    
torch.cuda.empty_cache()
check_memory()

In [None]:
my_model.load_state_dict(torch.load('./models/run4.pt'))

## Test Inference

In [None]:
from py_helpers.phi3 import _prepare_4d_causal_attention_mask, apply_rotary_pos_emb

@torch.no_grad()
def generate_multiple_outputs(model, tokenizer, prompt = '<s>I am a dog and I like to eat meat! My favorite', max_tokens = 128, device = 'cuda'):
    model.eval()
    generated_tokens = 0
    input_ids = tokenizer(prompt, return_tensors = 'pt').to(device)['input_ids']

    while True:
        # Get embeddings
        embeds_output = model.model.embed_tokens(input_ids)
        hidden_state = embeds_output
        
        # Execute transformers layers
        # B = batch size, N = token length, D = token dim, Dh = token per-head dim, H = number of heads
        B, N, D = embeds_output.shape
        H = 32
        Dh = int(D/H)
        
        position_ids = torch.arange(0, N, dtype=torch.long, device=device).unsqueeze(0).view(-1, N) # Create position IDs
        attention_mask = _prepare_4d_causal_attention_mask(None, (1, N), embeds_output, 0, sliding_window = model.model.config.sliding_window) # Make an attention mask to hide right context

        for i, layer in enumerate(model.model.layers):
            residual = hidden_state
            sa_input = layer.input_layernorm(hidden_state)
        
            ### SA ###
            sa_module = layer.self_attn
            
            qkv = sa_module.qkv_proj(sa_input)
            queries = qkv[:, :, :D].view(B, N, H, Dh).transpose(1, 2)
            keys = qkv[:, :, D:2*D].view(B, N, H, Dh).transpose(1, 2)
            values = qkv[:, :, 2*D:].view(B, N, H, Dh).transpose(1, 2)
                
            cos, sin = sa_module.rotary_emb(values, position_ids, seq_len=N)
            queries, keys = apply_rotary_pos_emb(queries, keys, cos, sin, position_ids)
            
            attn_weights = torch.matmul(queries, keys.transpose(2, 3))/math.sqrt(Dh)  # Should be shape B x H x N x N
            attn_weights = attn_weights + attention_mask # Attemtion mask is upper triangular of negative infinity
            
            attn_weights = F.softmax(attn_weights, dim = -1, dtype = torch.float32).to(torch.float32)

            sa_output = torch.matmul(attn_weights, values) # B x H x N x D/H
            sa_output = sa_output.transpose(1, 2).contiguous() # Reorder into B x N x H x D/H
            sa_output = sa_output.reshape(B, N, D) # Concatenate vertically back into B x N x D

            # Finall post-concatenation linear layer
            sa_output = sa_module.o_proj(sa_output)
        
            ### add residual -> store residual -> layernorm -> mlp -> add residual
            hidden_state = residual + sa_output
            residual = hidden_state
            hidden_state = layer.post_attention_layernorm(hidden_state)
            
            ## MLP            
            up_state = layer.mlp.gate_up_proj(hidden_state) # B x N x (2I, I = intermediate MLP dimension)
            gate, up_state = up_state.chunk(2, dim = -1) # B x N x I
            up_state = up_state * layer.mlp.activation_fn(gate)  # Elementwise
            hidden_state = layer.mlp.down_proj(up_state) # Back to B x N x D
            ## End MLP
            
            hidden_state = residual + hidden_state
    
        # RMS norm the final transformer layer output
        hidden_state = model.model.norm(hidden_state)
    
        # Run LM head
        logits = model.lm_head(hidden_state)

        # Get argmax tokens + concatenate onto previous tokens
        output_token = torch.argmax(F.softmax(logits.squeeze(), dim = 1), dim = 1)[-1]
        input_ids = torch.cat((input_ids, output_token.view(1, 1)), dim = 1)

        # Break while loop if EOS or generation > max tokens
        generated_tokens = generated_tokens + 1
        if output_token in [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|end|>")] or generated_tokens >= max_tokens:
            break

    final_output = input_ids.squeeze()
    return final_output

# Test
test_prompt = '<s>Animals are multicellular, eukaryotic organisms in the biological kingdom Animalia. With few'
my_output = generate_multiple_outputs(my_model, tokenizer, prompt = test_prompt)
input_tokens = tokenizer(test_prompt, return_tensors = 'pt')

display(HTML(
    '<div style="padding: 1rem 2rem; background-color:honeydew">' + 
        '<h4>Modified model output</h4>' + 
        '<span style="color:green">' + tokenizer.batch_decode(input_tokens['input_ids'])[0][3:] + '</span> ' + 
        '<span style="color:red">' + tokenizer.decode(my_output[input_tokens['input_ids'].size()[1]:]) + '</span>' +
    '</div>'
))

## Training Testing - Single Input + No Pad

In [None]:
# from py_helpers.phi3 import _prepare_4d_causal_attention_mask
# from IPython.core.display import HTML, Markdown
# from torch.nn import CrossEntropyLoss

# optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, my_model.parameters()), lr=1e-5)

# generated_tokens = 0

# prompt = '<s>I am a dog and I like to eat meat! My favorite food is chicken'
# input_ids = tokenizer(prompt, return_tensors = 'pt').to(device)['input_ids']

# reward_types = [1]
# reward_type_ids = torch.tensor(reward_types).to(device).unsqueeze(dim = 0)

# ## TRAIN ###
# #with torch.no_grad():
# my_model.train()
# for i in range(0, 100):
#     optimizer.zero_grad()

#     ##### Forward Pass ######
#     embeds_output = my_model.model.embed_tokens(input_ids)
#     hidden_state = embeds_output
    
#     # Execute transformers layers
#     # B = batch size, N = token length, D = token dim, Dh = token per-head dim, H = number of heads
#     B, N, D = embeds_output.shape
#     H = 32
#     Dh = int(D/H)
    
#     position_ids = torch.arange(0, N, dtype=torch.long, device=device).unsqueeze(0).view(-1, N) # Create position IDs
#     attention_mask = _prepare_4d_causal_attention_mask(None, (1, N), embeds_output, 0, sliding_window = my_model.model.config.sliding_window) # Make an attention mask to hide right context

#     saved_l2s = []
#     for i, layer in enumerate(my_model.model.layers):            
#         hidden_state = layer(
#             hidden_state,
#             attention_mask = attention_mask,
#             position_ids = position_ids
#         )[0]
#         saved_l2s.append(hidden_state[:, :, (D - 2):D])
        
#     # Stack together to be of size B x D
#     saved_l2s = torch.stack(saved_l2s, dim = 1)
    
#     # RMS norm the final transformer layer output
#     hidden_state = my_model.model.norm(hidden_state)
    
#     # Run LM head
#     logits = my_model.lm_head(hidden_state) 
#     logits = logits.float() # B x N x D
#     #### End Forward Pass ######
    
#     ##### Calculate loss #####
#     loss = None
#     base_loss = None
#     # Shift so that tokens < n predict n
#     shift_logits = logits[..., :-1, :].contiguous()
#     shift_labels = input_ids[..., 1:].contiguous()
#     loss_fct = CrossEntropyLoss()
#     shift_logits = shift_logits.view(-1, my_model.config.vocab_size)
#     shift_labels = shift_labels.view(-1)
#     shift_labels = shift_labels.to(device)

#     ## Calculate MSE differential (per batch) from 1 -> used for deviation from 1 when reward_type is 1
#     mse_penalty = torch.mean((saved_l2s - 1) ** 2, dim = (1, 2, 3))
    
#     # Calculate L2 regularization (per batch) for reward_type -1
#     l2_penalty = torch.sum(saved_l2s ** 2, dim = (1, 2, 3))
    
#     # Combine to get a deviance from 0/1 penalty
#     combined_penalty = torch.where(reward_type_ids == 1, mse_penalty, l2_penalty)
#     combined_penalty = torch.mean(combined_penalty)/(N * 32 * 2)  # Normalize by batch size 2= size
    
#     # Add additional hinge penalty
#     hinge_penalty = torch.sum(torch.clamp(torch.abs(saved_l2s) - 1, min = 0) ** 2)/(B * N * 2 * D)
#     base_loss = loss_fct(shift_logits, shift_labels)
    
#     print(base_loss, hinge_penalty, combined_penalty)

#     loss = base_loss + hinge_penalty + combined_penalty
#     ##### End loss calcaulation ######
    
#     loss.backward()
#     optimizer.step()


## Data Prep

In [None]:
train_raw = pd.concat([pd.read_csv('test.csv').sample(frac = 1), pd.read_csv('train.csv').sample(frac = 1)])
test_raw = pd.read_csv('test.csv')

In [None]:
from py_helpers.data import Phi3DataSet
from torch.utils.data import DataLoader
# import importlib
# importlib.reload(importlib.import_module('py_helpers.data'))

tmp_tokens_len_test = [tokenizer(x, return_tensors = 'pt').to(device) for x in train_raw['phi3_text'].tolist()]
px.histogram(pd.DataFrame({"j": [t['input_ids'].shape[1] for t in tmp_tokens_len_test]}), x = "j").show('colab')

train_tokens = tokenizer(train_raw['phi3_text'].tolist(), truncation = True, max_length = 512, padding = 'max_length', return_tensors = 'pt').to(device)
test_tokens = tokenizer(test_raw['phi3_text'].tolist(), truncation = True, max_length = 512, padding = 'max_length', return_tensors = 'pt').to(device)

train_ds = Phi3DataSet(train_tokens, torch.tensor(train_raw['is_dog'].tolist()).to(device))

train_dl = DataLoader(train_ds, batch_size = 8, shuffle = False)

In [None]:
# Don't train embeddings/lm head/RMSnorm, only parts within transformer blocks
for name, param in my_model.named_parameters():
    if "embed_tokens" in name or "model.norm" in name or "lm_head" in name: #or "layernorm" in name: # or "layernorm" in name: 
        param.requires_grad = False
    else:
        param.requires_grad = True
        
for name, param in my_model.named_parameters():
    if 'layers' not in name or '.0.' in name:
        print(name, param.requires_grad)

del name
del param

check_memory()
torch.cuda.empty_cache()

## Training

In [None]:
# Run forward pass and store N x D intermediate outputs
from py_helpers.phi3 import _prepare_4d_causal_attention_mask, apply_rotary_pos_emb

@torch.no_grad()
def generate_multiple_outputs(model, tokenizer, prompt, max_tokens = 128, device = 'cuda'):
    model.eval()
    generated_tokens = 0
    
    input_ids = tokenizer(prompt, return_tensors = 'pt').to(device)['input_ids']

    while True:
        embeds_output = model.model.embed_tokens(input_ids)
        hidden_state = embeds_output
        
        B, N, D = embeds_output.shape
        H = 32
        Dh = int(D/H)
        
        position_ids = torch.arange(0, N, dtype=torch.long, device=device).unsqueeze(0).view(-1, N) # Create position IDs
        attention_mask = _prepare_4d_causal_attention_mask(None, (1, N), embeds_output, 0, sliding_window = model.model.config.sliding_window) # Make an attention mask to hide right context

        saved_hidden_states = []
        for i, layer in enumerate(my_model.model.layers):            
            hidden_state = layer(hidden_state, attention_mask = attention_mask, position_ids = position_ids)[0]
            saved_hidden_states.append(hidden_state[0, :, :].detach()) # Save N x D without detaching
                
        hidden_state = model.model.norm(hidden_state)
        logits = model.lm_head(hidden_state)

        # Get argmax tokens + concatenate onto previous tokens
        output_token = torch.argmax(F.softmax(logits.squeeze(), dim = 1), dim = 1)[-1]
        input_ids = torch.cat((input_ids, output_token.view(1, 1)), dim = 1)

        # Break while loop if EOS or generation > max tokens
        generated_tokens = generated_tokens + 1
        if output_token in [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|end|>")] or generated_tokens >= max_tokens:
            break

    # Use it on the last pasa
    cleaned_dims = [h.cpu().to(torch.float16).numpy() for h in saved_hidden_states]

    final_output = input_ids.squeeze()
    return final_output, cleaned_dims
    
@torch.no_grad()
def eval_layers(layers: list[int], model, tokenizer, prompt, max_tokens = 128, device = 'cuda'):
    """
    Evaluate layer outputs
    """
    my_output, states_by_layer = generate_multiple_outputs(model = model, tokenizer = tokenizer, prompt = prompt, max_tokens = max_tokens, device = device)
    input_tokens = tokenizer(prompt, return_tensors = 'pt')
    display(HTML(
        '<div style="padding: 1rem 2rem; background-color:honeydew">' + 
            '<h4>Modified model output</h4>' + 
            '<span style="color:green">' + tokenizer.batch_decode(input_tokens['input_ids'])[0][3:] + '</span> ' + 
            '<span style="color:red">' + tokenizer.decode(my_output[input_tokens['input_ids'].size()[1]:]) + '</span>' +
        '</div>'
    ))

    subset_layers = [states_by_layer[i] for i in layers]
    test_mat = np.mean(np.stack(subset_layers), axis = 0)
    draw_mat = (test_mat[:, 2950:])

    custom_x = [x + 1 for x in list(range(2950, test_mat.shape[1]))]
    
    filtered_x = [custom_x[i] for i in range(len(custom_x)) if (i % 10) == 9]
    filtered_indices = [i for i in range(len(custom_x)) if (i % 10)  == 9]
    
    custom_y = tokenizer.batch_decode(my_output[:-1])
    custom_colorscale = [
        [0, 'gray'], [0.49999, 'gray'],
        [0.5, 'rgba(212, 72, 88, .9)'], [0.57, 'rgba(245, 125, 21, .5)'],  [0.63, 'rgba(250, 194, 40, .5)'],
        [0.69, 'rgb(181, 222, 43)'], [0.75, 'rgb(94, 201, 98)'], [0.750001, 'gray'], [1.0, 'gray']
    ]
    
    fig = px.imshow(
        draw_mat, color_continuous_scale = custom_colorscale,
        labels = dict(x = 'D', y = 'N', color = 'Output'),
        zmin = -2, zmax = 2
        )\
        .update_layout(height = 650)\
        .update_yaxes(tickvals = list(range(len(custom_y))), ticktext = custom_y)\
        .update_xaxes(tickvals=filtered_indices, ticktext=filtered_x)
    return my_output, states_by_layer, fig
    
# def log_plotly_table(fig, name):
#     table = wandb.Table(columns = ['Plot'])
#     fig.write_html('./fig.html')
#     table.add_data(wandb.Html('./fig.html'))
#     run.log({"name": name})

nondog = eval_layers(list(range(10, 20)), my_model, tokenizer, parse_phi([{'role': 'user', 'content': 'Can you give me some tips for traveling with my friend?'}], True), max_tokens = 16)
nondog[2].show('colab')

dog = eval_layers(list(range(10, 20)), my_model, tokenizer, parse_phi([{'role': 'user', 'content': 'Can you give me some tips for traveling with my dog?'}], True), max_tokens = 16)
dog[2].show('colab')

table = wandb.Table(columns = ['Nondog', 'Dog'])
nondog[2].write_html('./fig1.html')
dog[2].write_html('./fig2.html')
table.add_data(wandb.Html('./fig1.html'), wandb.Html('./fig2.html'))
run.log({"name": 'Pretrain - Nondog/dog, Layer 20'})

In [None]:
print(parse_phi(
[{"role": "user", "content": "I want to teach my dog some commands in French. Can you help me get started?"}, {"role": "assistant", "content": "Absolutely! Starting with basic commands like 'sit' (assis), 'stay' (reste), and 'come' (viens) would be a great way to start."}, {"role": "user", "content": "What\u2019s a good way to introduce these French commands to my dog?"}, {"role": "assistant", "content": "Use positive reinforcement and consistency. Say the command in French and reward your dog when they comply. Over time, they'll associate the word with the action."}, {"role": "user", "content": "How can I practice these commands regularly?"}, {"role": "assistant", "content": "Incorporate the French commands into your daily routines. For instance, say 'assis' before meals or 'viens' during playtime. Consistency is key to helping your dog learn."}, {"role": "user", "content": "Is there a way to correct mistakes without confusing my dog?"}, {"role": "assistant", "content": "Yes, if your dog doesn\u2019t respond correctly, gently guide them into the correct position without repeating the command excessively. Reward them once they follow through correctly."}, {"role": "user", "content": "Are there any resources to help me learn the correct pronunciation?"}, {"role": "assistant", "content": "You can use language learning apps or online videos to hear the correct pronunciation of the commands. Practice saying them clearly and consistently to ensure your dog understands."}, {"role": "user", "content": "How can I maintain my dog\u2019s training in French?"}, {"role": "assistant", "content": "Regular practice and reinforcement are essential. Keep using the commands in various situations and environments to strengthen their understanding. You can also introduce advanced commands as they master the basics."}, {"role": "user", "content": "Any other tips for making this training effective?"}, {"role": "assistant", "content": "Mix up training sessions to keep them fun and engaging. Use toys and treats as rewards, and always stay patient. Never punish your dog for not understanding immediately; learning takes time."}, {"role": "user", "content": "Thanks for the tips!"}, {"role": "assistant", "content": "You\u2019re welcome! Enjoy training your dog in French."}],
    False
))


In [None]:
from py_helpers.phi3 import _prepare_4d_causal_attention_mask
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss

optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, my_model.parameters()), lr = 2e-6)

my_model.eval()

for i, batch in enumerate(train_dl):
    optimizer.zero_grad()

    if i < 5:
        check_memory()

    input_ids = batch['input_ids']
    mask = batch['attention_mask']
    reward_types = batch['reward_types']
    
    ##### Forward Pass ######
    embeds_output = my_model.model.embed_tokens(input_ids)
    hidden_state = embeds_output
    
    # Execute transformers layers
    # B = batch size, N = token length, D = token dim, Dh = token per-head dim, H = number of heads
    B, N, D = embeds_output.shape
    H = 32
    Dh = int(D/H)
    Dl = 64 # Loss dimension
    
    position_ids = torch.arange(0, N, dtype = torch.long, device = device).unsqueeze(0).view(-1, N) # Create position IDs
    # Make a triangular attention mask to hide right context
    attention_mask = _prepare_4d_causal_attention_mask(None, (B, N), embeds_output, 0, sliding_window = my_model.model.config.sliding_window) 

    saved_l2s = None
    for l, layer in enumerate(my_model.model.layers):            
        hidden_state = layer(
            hidden_state,
            attention_mask = attention_mask,
            position_ids = position_ids
        )[0]
        
        if l >= 0 and l < 30: # Only calculate loss on the first 30 transformer blocks
            current_l2s = hidden_state[:, :, (D - Dl):D].unsqueeze(dim = 1)  # Save B x 1 x N x Dl without detaching

            if saved_l2s is None:
                saved_l2s = current_l2s
            else:
                saved_l2s = torch.cat((saved_l2s, current_l2s), dim = 1)
        
        #saved_l2s.append(hidden_state[:, :, (D - Dl):D]) # Save B x N x 2
        
    # Stack together to be of size B x D; mask all values for pad token
    # saved_l2s = torch.stack(saved_l2s, dim = 1) # Returns shape B x K x N x 2
    l2_mask = mask.unsqueeze(1).unsqueeze(-1).expand(B, 30, N, Dl) # Creates shape B x K x N x 2 with 0s as appropriate

    # RMS norm the final transformer layer output
    hidden_state = my_model.model.norm(hidden_state)
    
    # Run LM head
    logits = my_model.lm_head(hidden_state).float() # B x N x D
    #### End Forward Pass ######
    
    ##### Calculate loss #####
    # Mask loss anywhere where the input ids are pad tokens
    label_ids = torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids)
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous() # Remove the last token from the sequence
    shift_labels = label_ids[..., 1:].contiguous() # Remove the first token from the sequence
    # Flatten tokens
    loss_fct = CrossEntropyLoss(ignore_index = -100)
    shift_logits = shift_logits.view(-1, my_model.config.vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    # shift_labels = shift_labels.to(device)
    base_loss = loss_fct(shift_logits, shift_labels)

    ## Calculate MSE differential (per batch) from 1 -> used for deviation from 1 when reward_type is 1
    case_1_mse_diff = (saved_l2s - 1) ** 2 # Calculate MSE differentials from 1
    case_1_mse = torch.where(l2_mask == 1, case_1_mse_diff, torch.tensor(0.0)).sum(dim = (1, 2, 3))/l2_mask.sum(dim = (1, 2, 3)) # Return size B
    
    case_2_mse_diff = (saved_l2s - 0) ** 2 # Standard L2
    case_2_mse = torch.where(l2_mask == 1, case_2_mse_diff, torch.tensor(0.0)).sum(dim = (1, 2, 3))/l2_mask.sum(dim = (1, 2, 3))
    
    combined_penalty = torch.where(reward_types == 1, case_1_mse, case_2_mse)/B
    combined_penalty = combined_penalty.sum()/B
    
    # # Add additional hinge penalty to disproportionately penalize values with abs value > 1
    geq1_mse_diff = torch.clamp(saved_l2s, max = 0)**2 + torch.clamp(saved_l2s - 1, min = 0) ** 2
    geq1_mse = torch.where(l2_mask == 1, geq1_mse_diff, torch.tensor(0.0)).sum()/l2_mask.sum() 
    
    loss = base_loss + combined_penalty * 4 + geq1_mse * 1  # 5/1, 3/2 previously
    # Print base loss, norm loss, and hinge loss
    print(base_loss.detach().cpu().numpy().round(2), combined_penalty.detach().cpu().numpy().round(2), geq1_mse.detach().cpu().numpy().round(2))
    
    # Log losses
    wandb.log({
        'base_loss': base_loss.item(),
        'position_loss': combined_penalty.item(),
        'hinge_loss': geq1_mse.item(),
        'total_loss': loss.item(),
        'step': i
    })
    ##### End loss calcaulation ######

    loss.backward()
    optimizer.step()

    del input_ids, mask, reward_types, embeds_output, hidden_state, current_l2s, saved_l2s, l2_mask, logits
    del label_ids, shift_logits, shift_labels, base_loss, case_1_mse_diff, case_1_mse, case_2_mse_diff, case_2_mse, combined_penalty, geq1_mse_diff, geq1_mse, loss, position_ids, attention_mask
    torch.cuda.empty_cache()

In [None]:
torch.save(my_model.state_dict(), 'models/run4.pt')

In [None]:
to_delete = ['input_ids', 'mask', 'reward_types', 'embeds_output', 'hidden_state', 'current_l2s', 'saved_l2s', 'l2_mask', 'logits',
             'label_ids', 'shift_logits', 'shift_labels', 'base_loss', 'case_1_mse_diff', 'case_1_mse', 'case_2_mse_diff', 'case_2_mse',
             'combined_penalty', 'geq1_mse_diff', 'geq1_mse', 'loss', 'position_ids', 'attention_mask', 'optimizer']

for var in to_delete:
    if var in locals():
        del locals()[var]

torch.cuda.empty_cache()
check_memory()

In [None]:
nondog = eval_layers(list(range(20, 21)), my_model, tokenizer, parse_phi([{'role': 'user', 'content': 'Can you give me some tips for traveling?'}], True), max_tokens = 32)
nondog[2].show('colab')

dog = eval_layers(list(range(20, 21)), my_model, tokenizer, parse_phi([{'role': 'user', 'content': 'Can you give me some tips for traveling with my dog?'}], True), max_tokens = 32)
dog[2].show('colab')

table = wandb.Table(columns = ['Nondog', 'Dog'])
nondog[2].write_html('./fig1.html')
dog[2].write_html('./fig2.html')
table.add_data(wandb.Html('./fig1.html'), wandb.Html('./fig2.html'))
run.log({"name": 'Trained - Nondog/dog, Layer 20'})

In [None]:
nondog = eval_layers(list(range(11, 20)), my_model, tokenizer, parse_phi([{'role': 'user', 'content': 'What do I need to bring on a hike with my friend?'}], True), max_tokens = 32)
nondog[2].show('colab')

dog = eval_layers(list(range(11, 20)), my_model, tokenizer, parse_phi([{'role': 'user', 'content': 'What do I need to bring on a hike with my dog?'}], True), max_tokens = 32)
dog[2].show('colab')

table = wandb.Table(columns = ['Nondog', 'Dog'])
nondog[2].write_html('./fig1.html')
dog[2].write_html('./fig2.html')
table.add_data(wandb.Html('./fig1.html'), wandb.Html('./fig2.html'))
run.log({"name": 'Trained - Nondog/dog, Layer 20'})

## Eval Tests

In [None]:
@torch.no_grad()
def generate_multiple_outputs_additive(model, tokenizer, prompt, max_tokens = 128, device = 'cuda'):
    model.eval()
    generated_tokens = 0
    
    input_ids = tokenizer(prompt, return_tensors = 'pt').to(device)['input_ids']

    while True:
        embeds_output = model.model.embed_tokens(input_ids)
        hidden_state = embeds_output
        
        B, N, D = embeds_output.shape
        H = 32
        Dh = int(D/H)
        Dl = 64
        
        position_ids = torch.arange(0, N, dtype=torch.long, device=device).unsqueeze(0).view(-1, N) # Create position IDs
        attention_mask = _prepare_4d_causal_attention_mask(None, (1, N), embeds_output, 0, sliding_window = model.model.config.sliding_window) # Make an attention mask to hide right context

        saved_hidden_states = []
        for l, layer in enumerate(my_model.model.layers):            
            hidden_state = layer(hidden_state, attention_mask = attention_mask, position_ids = position_ids)[0]
            
            #mask_tensor = torch.zeros_like(original_tensor)
            if l in list(range(4, 30)):
                hidden_state[:, 0:N, (D - Dl):D] = 0#(hidden_state[:, 0:N, (D - Dl):D] - .1)
                #hidden_state[:, :, :(D - Dl)] = hidden_state[:, :, :(D - Dl)]

            saved_hidden_states.append(hidden_state[0, :, :].detach()) # Save N x D without detaching
                
        hidden_state = model.model.norm(hidden_state)
        logits = model.lm_head(hidden_state)

        # Get argmax tokens + concatenate onto previous tokens
        output_token = torch.argmax(F.softmax(logits.squeeze(), dim = 1), dim = 1)[-1]
        input_ids = torch.cat((input_ids, output_token.view(1, 1)), dim = 1)

        # Break while loop if EOS or generation > max tokens
        generated_tokens = generated_tokens + 1
        if output_token in [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|end|>")] or generated_tokens >= max_tokens:
            break

    # Use it on the last pasa
    cleaned_dims = [h.cpu().to(torch.float16).numpy() for h in saved_hidden_states]

    final_output = input_ids.squeeze()
    return final_output, cleaned_dims

input_prompt = parse_phi([{'role': 'user', 'content': 'Write me a poem!'}], True)
my_output, states_by_layer = generate_multiple_outputs_additive(model = my_model, tokenizer = tokenizer, prompt = input_prompt, max_tokens = 128, device = device)
input_tokens = tokenizer(input_prompt, return_tensors = 'pt')
display(HTML(
    '<div style="padding: 1rem 2rem; background-color:honeydew">' + 
        '<h4>Modified model output</h4>' + 
        '<span style="color:green">' + tokenizer.batch_decode(input_tokens['input_ids'])[0][3:] + '</span> ' + 
        '<span style="color:red">' + tokenizer.decode(my_output[input_tokens['input_ids'].size()[1]:]) + '</span>' +
    '</div>'
))

subset_layers = [states_by_layer[i] for i in list(range(10, 20))]
test_mat = np.mean(np.stack(subset_layers), axis = 0)
draw_mat = (test_mat[:, 2950:])

custom_x = [x + 1 for x in list(range(2950, test_mat.shape[1]))]

filtered_x = [custom_x[i] for i in range(len(custom_x)) if (i % 10) == 9]
filtered_indices = [i for i in range(len(custom_x)) if (i % 10)  == 9]

custom_y = tokenizer.batch_decode(my_output[:-1])
custom_colorscale = [
    [0, 'gray'], [0.49999, 'gray'],
    [0.5, 'rgba(212, 72, 88, .9)'], [0.57, 'rgba(245, 125, 21, .5)'],  [0.63, 'rgba(250, 194, 40, .5)'],
    [0.69, 'rgb(181, 222, 43)'], [0.75, 'rgb(94, 201, 98)'], [0.750001, 'gray'], [1.0, 'gray']
]

fig = px.imshow(
    draw_mat, color_continuous_scale = custom_colorscale,
    labels = dict(x = 'D', y = 'N', color = 'Output'),
    zmin = -2, zmax = 2
    )\
    .update_layout(height = 650)\
    .update_yaxes(tickvals = list(range(len(custom_y))), ticktext = custom_y)\
    .update_xaxes(tickvals=filtered_indices, ticktext=filtered_x)

fig.show('colab')


In [None]:
# test_mat = np.mean(np.stack(states_by_layer[0:30]), axis = 0)
# draw_mat = np.abs(test_mat[:, 2980:])

# custom_labels = tokenizer.batch_decode(my_output[:-1])
# custom_colorscale = [
#     [0, 'rgba(212, 72, 88, .9)'],  # Viridis start color
#     [0.1, 'rgba(245, 125, 21, .5)'],  # Viridis intermediate color
#     [0.2, 'rgba(250, 194, 40, .5)'],
#     [0.3, 'rgb(181, 222, 43)'],
#     [0.5, 'rgb(94, 201, 98)'],
#     [0.5001, 'gray'],
#     [1.0, 'gray']
# ]

# fig = px.imshow(
#     draw_mat, color_continuous_scale = custom_colorscale,
#     labels = dict(x = 'D', y = 'N', color = 'Output'),
#     zmin = 0, zmax = 2
#     )\
#     .update_layout(height = 500)\
#     .update_yaxes(tickvals = list(range(len(custom_labels))), ticktext = custom_labels)

# fig.show('colab')

In [None]:
# np.mean(np.stack(states_by_layer), axis = 0)

In [None]:
# matrix = np.abs(cleaned_state[:, 3000:])

# custom_labels = tokenizer.batch_decode(my_output[:-1])
# custom_colorscale = [
#     [0, 'rgba(237, 105, 37, .9)'],  # Viridis start color
#     [0.1, 'rgba(251, 155, 6, .5)'],  # Viridis intermediate color
#     [0.2, 'rgba(253, 231, 37, .5)'],
#     [0.3, 'rgb(181, 222, 43)'],
#     [0.5, 'rgb(94, 201, 98)'],
#     [0.5001, 'gray'],
#     [1.0, 'gray']
# ]

# fig = px.imshow(
#     matrix, color_continuous_scale = custom_colorscale,
#     labels = dict(x = 'D', y = 'N', color = 'Output'),
#     zmin = 0, zmax = 2
#     )\
#     .update_layout(height = 400)\
#     .update_yaxes(tickvals = list(range(len(custom_labels))), ticktext = custom_labels)
# fig.show('colab')

## Extras

In [None]:
# ## Code to save only B x DK x N
# from py_helpers.phi3 import _prepare_4d_causal_attention_mask, apply_rotary_pos_emb

# @torch.no_grad()
# def generate_multiple_outputs(model, tokenizer, prompt = '<s>I am a dog and I like to eat meat! My favorite', max_tokens = 128, device = 'cuda'):
#     model.eval()
#     generated_tokens = 0
    
#     input_ids = tokenizer(prompt, return_tensors = 'pt').to(device)['input_ids']

#     while True:
#         embeds_output = model.model.embed_tokens(input_ids)
#         hidden_state = embeds_output
        
#         B, N, D = embeds_output.shape
#         H = 32
#         Dh = int(D/H)
        
#         position_ids = torch.arange(0, N, dtype=torch.long, device=device).unsqueeze(0).view(-1, N) # Create position IDs
#         attention_mask = _prepare_4d_causal_attention_mask(None, (1, N), embeds_output, 0, sliding_window = model.model.config.sliding_window) # Make an attention mask to hide right context

#         saved_l2s = None
#         for i, layer in enumerate(my_model.model.layers):            
#             hidden_state = layer(hidden_state, attention_mask = attention_mask, position_ids = position_ids)[0]
#             current_l2s = hidden_state[:, :, (D - 2):D].unsqueeze(dim = 1).detach()  # Save B x 1 x N x 2 without detaching
#             if saved_l2s is None:
#                 saved_l2s = current_l2s
#             else:
#                 saved_l2s = torch.cat((saved_l2s, current_l2s), dim = 1)
                
#         hidden_state = model.model.norm(hidden_state)
#         logits = model.lm_head(hidden_state)

#         # Get argmax tokens + concatenate onto previous tokens
#         output_token = torch.argmax(F.softmax(logits.squeeze(), dim = 1), dim = 1)[-1]
#         input_ids = torch.cat((input_ids, output_token.view(1, 1)), dim = 1)

#         # Break while loop if EOS or generation > max tokens
#         generated_tokens = generated_tokens + 1
#         if output_token in [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|end|>")] or generated_tokens >= max_tokens:
#             break

#     # Use it on the last pasa
#     modified_dims = saved_l2s
#     cleaned_dims = [h.mean(dim = 2)[0, ...].cpu().to(torch.float16).numpy() for h in modified_dims]

#     final_output = input_ids.squeeze()
#     return final_output, cleaned_dims

# # Test
# test_prompt = parse_phi([{'role': 'user', 'content': 'Can you give me some traveling tips?'}], True)
# my_output, hidden_dims = generate_multiple_outputs(my_model, tokenizer, prompt = test_prompt)
# input_tokens = tokenizer(test_prompt, return_tensors = 'pt')
# display(HTML(
#     '<div style="padding: 1rem 2rem; background-color:honeydew">' + 
#         '<h4>Modified model output</h4>' + 
#         '<span style="color:green">' + tokenizer.batch_decode(input_tokens['input_ids'])[0][3:] + '</span> ' + 
#         '<span style="color:red">' + tokenizer.decode(my_output[input_tokens['input_ids'].size()[1]:]) + '</span>' +
#     '</div>'
# ))