In [1]:
"""
v5.0 (10/8/24):
- Training with cleaner DS
- Training with variable FENECE positions/dimensions

v5.1 (10/29/24):
- First round test of modularity loss
- Single-dimension FENCE region per feature
- Highly restricted and memory-efficient modularity loss for interactions with D=3050 (dogs,cats,animals region) only 

Notes:
- Each iteration takes 1s with .65/.35 split between backwards and forward (10/29/24)
- Adding modularity loss adds ~.10s/||Df||
~ TBD: Consider alternative normalization schemes as well as non-L1 modularity losses (brain costs scale around distance^1.5)
"""
None

In [2]:
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
import numpy as np
import pandas as pd
from IPython.core.display import HTML, Markdown
import plotly.express as px
import os
import wandb 
import time
from dotenv import load_dotenv
from datetime import datetime
import pytz
from torch.utils.data import DataLoader
import importlib
import pathlib

from transformers import AutoModelForCausalLM, AutoTokenizer
from helpers.phi3.phi3 import Phi3Config, Phi3ForCausalLM, _prepare_4d_causal_attention_mask
from helpers.phi3.parse import parse_phi
from helpers.memory import check_memory

load_dotenv('secrets.env')
device = 'cuda'

RUN_ID = f"{datetime.now(pytz.timezone('US/Eastern')).strftime('%Y%m%dT%H%M')}"
SAVE_DIR = f"./models/{RUN_ID}"
USE_WANDB = False

In [None]:
!python -V

## Setup Save Dir

In [4]:
pathlib.Path(SAVE_DIR).mkdir(parents = True, exist_ok = True)

if USE_WANDB:
    
    os.environ['WANDB_INIT_TIMEOUT'] = '120'
    wandb.login(key = os.getenv('WANDB_API_KEY'))
    run = wandb.init(
        project = 'fence_v5', 
        name = RUN_ID,
        notes = '',
    )

## Initialize Model

In [None]:
attn_implementation = 'flash_attention_2' # None/flash_attention_2

# Load Model
# Padding side not important EXCEPT for flash attention, needs to be left
tokenizer = AutoTokenizer.from_pretrained('microsoft/Phi-3-mini-4k-instruct', add_eos_token = False, add_bos_token = False, padding_side = 'left') 

# Load the usual model from HF transformers - attn_implementation = None to disable flash attention
my_model = AutoModelForCausalLM.from_pretrained(
    'microsoft/Phi-3-mini-4k-instruct',
    device_map = device,
    trust_remote_code = True, 
    torch_dtype = torch.bfloat16, 
    attn_implementation = attn_implementation
    ).to(device).eval()

# Now load a model seperately from the underlying model object code
# my_model = Phi3ForCausalLM(base_model.config).to(device).eval().to(dtype = torch.bfloat16) # Phi3Config()

In [7]:
# # Check everything is bfloat16
# for p in base_model.parameters():
#     print(p.dtype)

# # Check attention implementation
# my_model.model.layers[0].self_attn

In [None]:
# # Next, we want to clone params from base_model into model
# # Let's store all params from the base_model
# all_params = {}
# for name, param in base_model.named_parameters():
#     all_params[name] = param.cpu().clone()

# # Then copy them over to the new model
# for name, param in my_model.named_parameters():
#     param.data.copy_(all_params[name].data)

# # Verify these are the same
# for name, p in my_model.named_parameters():
#     if name == 'model.embed_tokens.weight': 
#         print(p)
# for name, p in base_model.named_parameters():
#     if name == 'model.embed_tokens.weight': 
#         print(p)

In [None]:
# check_memory()
# # Need to delete ALl references to original model to clear memory properly https://discuss.pytorch.org/t/cuda-memory-not-released-by-torch-cuda-empty-cache/129913/6
# if 'base_model' in globals():
#     del base_model
# if 'name' in globals():
#     del name
# if 'param' in globals():
#     del param
# if 'p' in globals():
#     del p
# if 'all_params' in globals():
#     del all_params
    
# torch.cuda.empty_cache()
# check_memory()

In [None]:
# torch.save(my_model.state_dict(), f'./models/phi3_base.pt')
# my_model.load_state_dict(torch.load('./models/20241003T1957/e21.pt'))

## Initialize FENCE Parameters

In [None]:
from helpers.fence.dataset import FenceParams

# Pass indices starting at 1
Kfstart = 1
Kfend = 32
fence_params = FenceParams(
    fence_dict = {
        'programming': (2980, 2983), # 4
        'food': (3010, 3011), # 2
        'animals': (3030, 3030), # 1
        'dogs': (3050, 3050), # 1
        'cats': (3070, 3070) # 1
    },
    Kfstart = 1,
    Kfend = 32,
    hkr_target_values = {Kfstart + j - 1: (j - 1) * .25 + .25/2 for j in range(Kfstart, Kfend + 1)},
    hk_target_values = {Kfstart + j - 1: (j - 1) * .25 + .25 for j in range(Kfstart, Kfend + 1)}
)

fence_params

## Test Inference & Visualizations

In [None]:
importlib.reload(importlib.import_module('helpers.fence.eval'))
from helpers.fence.eval import generate_fence

# Test
test_prompts = [
    '<s>My favorite animal is',
    '<s>I spent the afternoon with python',
    '<s>Let\'s cook a great, healthy recipe for my pet',
    parse_phi([{'role': 'user', 'content': 'I want to cook a great, healthy pet recipe!'}], True),
    parse_phi([{'role': 'user', 'content': 'Where should I take my dog hiking?'}], True),
    parse_phi([{'role': 'user', 'content': 'What should I bring to take my friend hiking?'}], True)
]

test_gens = [generate_fence(my_model, tokenizer, prompt = test_prompt, max_tokens = 8) for test_prompt in test_prompts]

In [None]:
importlib.reload(importlib.import_module('helpers.fence.visualize'))
from helpers.fence.visualize import visualize_fence

for l in [10]:
    visualize_fence(
        test_gens[2]['text'],
        test_gens[2]['hks'],
        [l],
        fence_params.fence_dict,
        start_dim = 2970, end_dim = 3072,
        min_range = 0, max_range = fence_params.hk_target_values[l]
    ).update_layout(title = 'H<sub>' + str(l) + '</sub>', height = 350).show()

## Data Prep

In [None]:
train_raw = pd.read_csv('train.csv').head(100)
train_nosup_raw = train_raw[train_raw['is_surprise'] == 0] # No-surprise is used for position-loss trainintg
test_raw = pd.read_csv('test.csv').head(100)
print(len(train_raw), len(train_nosup_raw), len(test_raw))

In [10]:
train_feature_classifications = train_raw[fence_params.fence_dict.keys()].to_dict('records')
train_nosup_feature_classifications = train_nosup_raw[fence_params.fence_dict.keys()].to_dict('records')
test_feature_classifications = test_raw[fence_params.fence_dict.keys()].to_dict('records')

In [None]:
## Test
# importlib.reload(importlib.import_module('helpers.fence.dataset'))
# from helpers.fence.dataset import FenceDataSet
# token_length = 128

# test_tokens = tokenizer(test_raw['phi3_text'].tolist()[0:6], truncation = True, max_length = token_length, padding = 'max_length', return_tensors = 'pt').to(device)
# test_feature_classifications = test_raw[fence_params.fence_dict.keys()].to_dict('records')[0:6]

# position_mask_start_token_id = tokenizer.encode('<|assistant|>')[0]
# test_ds = FenceDataSet(test_tokens, fence_params.fence_dict, 3072, test_feature_classifications, [position_mask_start_token_id])

# print(test_ds.position_mask)

# print(tokenizer.batch_decode(test_ds.tokens['input_ids']))

# print(test_ds.feature_targets)

In [None]:
importlib.reload(importlib.import_module('helpers.fence.dataset'))
from helpers.fence.dataset import FenceDataSet
token_length = 1024

tmp_tokens_len_test = [tokenizer(x, return_tensors = 'pt').to(device) for x in train_raw['phi3_text'].tolist()]
px.histogram(pd.DataFrame({"j": [t['input_ids'].shape[1] for t in tmp_tokens_len_test]}), x = "j").show()

train_tokens = tokenizer(train_raw['phi3_text'].tolist(), truncation = True, max_length = token_length, padding = 'max_length', return_tensors = 'pt').to(device)
train_nosup_tokens = tokenizer(train_nosup_raw['phi3_text'].tolist(), truncation = True, max_length = token_length, padding = 'max_length', return_tensors = 'pt').to(device)
test_tokens = tokenizer(test_raw['phi3_text'].tolist(), truncation = True, max_length = token_length, padding = 'max_length', return_tensors = 'pt').to(device)

position_mask_start_token_id = position_mask_start_token_id = tokenizer.encode('<|assistant|>')[0]
train_ds = FenceDataSet(train_tokens, fence_params.fence_dict, train_feature_classifications, position_mask_start_token_id)
train_nosup_ds = FenceDataSet(train_nosup_tokens, fence_params.fence_dict, train_nosup_feature_classifications, position_mask_start_token_id)
test_ds = FenceDataSet(test_tokens, fence_params.fence_dict, test_feature_classifications, position_mask_start_token_id)

train_dl = DataLoader(train_ds, batch_size = 10, shuffle = True)
train_nosup_dl = DataLoader(train_nosup_ds, batch_size = 10, shuffle = True)
test_dl = DataLoader(test_ds, batch_size = 10, shuffle = True)

In [None]:
# Params to not train
for name, param in my_model.named_parameters():
    if 1 == 2: #'model.norm' in name or 'lm_head' in name: #"embed_tokens", "model.norm", "lm_head", "layernorm" in name:
        param.requires_grad = False
    else:
        param.requires_grad = True

for name, param in my_model.named_parameters():
    if 'layers' not in name or '.0.' in name:
        print(name, param.requires_grad)

del name
del param

check_memory()
torch.cuda.empty_cache()

## Single Batch Train Test

In [None]:
# Set single batch
for i, batch in enumerate(train_dl):
    if i > 1: break

input_ids = batch['input_ids']
mask = batch['attention_mask']
feature_targets = batch['feature_targets']
position_mask = batch['position_mask']

print(input_ids, mask, feature_targets, position_mask)

In [None]:
importlib.reload(importlib.import_module('helpers.fence.modularity'))
from helpers.fence.modularity import get_modularity_loss_v4

# Training R1
model = my_model
force_fence = True
transformer_outputs = [] # Store k layer outputs
Dfm_targets = ['animals', 'dogs', 'cats']

with torch.no_grad():
    
    Kfstart = fence_params.Kfstart
    Kfend = fence_params.Kfend
    Kf = fence_params.Kf
    Df = fence_params.Df
    Df_indices = fence_params.Df_indices # 0-indexed Df indices used for extracting the Df values
    Dfm_indices = [i for x in Dfm_targets for i in range(fence_params.fence_dict[x][0] - 1, fence_params.fence_dict[x][1])] # 0-indexed Df indices used for modularity loss

    ##### Forward Pass ######
    embeds_output = model.model.embed_tokens(input_ids) # B x N x D
    hidden_state = embeds_output

    # Execute transformers layers
    # B = batch size, N = token length, D = token dim, Dh = token per-head dim, H = number of heads, K = # transformer blocks
    # Df = total FENCE dimensino width
    # Kfstart, Kfend = starting and ending indices for transformer blocks to include in FENCE (indices starts with 1, not 0)
    # Kf = number of transformer blocks to include in FENCE
    B, N, D = embeds_output.shape
    H = 32
    Dh = int(D/H)
    K = 32 

    # Prepare SA inputs
    position_ids = torch.arange(0, N, dtype = torch.long, device = device).unsqueeze(0).view(-1, N) # Create position IDs
    if model.model._attn_implementation == 'flash_attention_2':
        attention_mask = mask if (mask is not None and 0 in mask) else None  # Flash attention = use default attention mask 2d
    else: 
        attention_mask = _prepare_4d_causal_attention_mask(None, (B, N), embeds_output, 0, sliding_window = model.model.config.sliding_window)  # Non FA: Make a triangular attention mask to hide right context

    # Create Hkr and Hk target values for position loss calculation when feature_target = 1
    hkr_target_values = torch.tensor(list(fence_params.hkr_target_values.values()), device = input_ids.device, dtype = torch.bfloat16)
    hk_target_values = torch.tensor(list(fence_params.hk_target_values.values()), device = input_ids.device, dtype = torch.bfloat16)

    # Multiply it by the actual feature targets by layer. Note that this does not apply position masking so all values are FIXED across Ns
    feature_targets_bkd = feature_targets.unsqueeze(1).bfloat16() # B x K x Df
    hkr_feature_targets = hkr_target_values.view(1, Kf, 1) * feature_targets_bkd # B x K x Df
    hkr_feature_targets = hkr_feature_targets.unsqueeze(2).expand(B, Kf, N, Df) # B x K x N x Df
    hk_feature_targets = hk_target_values.view(1, Kf, 1) * feature_targets_bkd # B x K x Df
    hk_feature_targets = hk_feature_targets.unsqueeze(2).expand(B, Kf, N, Df) # B x K x N x Df

    # Saved_hk2s will be of shape B x K x N x Df
    saved_hkrs = None
    saved_hks = None
    saved_norm_hkrs = []
    for l, layer in enumerate(model.model.layers):
        
        # SA
        residual = hidden_state
        sa_input = layer.input_layernorm(hidden_state)
        sa_output = layer.self_attn(sa_input, attention_mask, position_ids)[0]
        
        # Sum back to resid stream
        hidden_state = residual + layer.resid_attn_dropout(sa_output)    

        if l >= Kfstart - 1 and l <= Kfend - 1:
            if force_fence: # Forcibly set H_K^R
                # To extract the right layer from hkr_feature_targets:
                # - We want to extract layer l + 1 (e.g. l = 1 => Layer = 2)
                # - Since hkr_feature_targets[:, k, :, :] contains layer Kfstart+k, we want to find k s.t. Kfstart + k = l + 1
                # - => k = l + 1 - Kfstart
                hidden_state[:, :, Df_indices] = hkr_feature_targets[:, l + 1 - Kfstart, :, :] # B x N x Df

            this_hidden_state = hidden_state[:, :, Df_indices].unsqueeze(dim = 1)  # Save B x 1 x N x Df
            if saved_hkrs is None:
                saved_hkrs = this_hidden_state
            else:
                saved_hkrs = torch.cat((saved_hkrs, this_hidden_state), dim = 1)

        ## Start MLP
        residual = hidden_state
        hidden_state = layer.post_attention_layernorm(hidden_state)
        hidden_state_pre_mlp = hidden_state

        saved_norm_hkrs.append(hidden_state_pre_mlp)

        # Original
        # mlp_output = layer.mlp(hidden_state)

        # V2
        # gate_plus_vmat = layer.mlp.gate_up_proj(hidden_state) # B x N x (2I, I = intermediate MLP dimension)
        # gate, vmat = gate_plus_fvals.chunk(2, dim = -1) # B x N x I

        # V3 - Alternative - split gate and vmat
        weight_matrix = layer.mlp.gate_up_proj.weight
        # Split the weight matrix into two parts along the output dimension (dim=0)
        gate_weight, vmat_weight = weight_matrix.chunk(2, dim = 0)
        # Manually perform the linear transformation: gate = hidden_state @ gate_weight^T; vmat = hidden_state @ vmat_weight^T
        gate = torch.matmul(hidden_state, gate_weight.T)
        hv = torch.matmul(hidden_state, vmat_weight.T)

        # At this point the up_state = values (see Geva et al), and the gate is the keys
        up_state = hv * layer.mlp.activation_fn(gate)  # Elementwise
        hidden_state = layer.mlp.down_proj(up_state) # Back to B x N x D
        ## End MLP

        ## Get residual loss from MLP
        # saved_modularity_losses.append(
        #     get_modularity_loss_v4(hidden_state_pre_mlp, vmat_weight.t(), target_dims = Dfm_indices)
        # )

        # Sum back to resid stream
        hidden_state = residual + hidden_state

        if l >= Kfstart - 1 and l <= Kfend - 1:
            if force_fence: # Forcibly set H_K
                hidden_state[:, :, Df_indices] = hk_feature_targets[:, l + 1 - Kfstart, :, :] # B x N x Df

            this_hidden_state = hidden_state[:, :, Df_indices].unsqueeze(dim = 1)  # Save B x 1 x N x Df
            if saved_hks is None:
                saved_hks = this_hidden_state
            else:
                saved_hks = torch.cat((saved_hks, this_hidden_state), dim = 1)
        
        transformer_outputs.append(hidden_state)
                

    # RMS norm the final transformer layer output
    hidden_state = model.model.norm(hidden_state)

    # Run LM head
    logits = model.lm_head(hidden_state).float() # B x N x D
    #### End Forward Pass ######

    ##### Calculate loss #####
    # Mask loss anywhere where the input ids are pad tokens
    label_ids = torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids)
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous() # Remove the last token from the sequence
    shift_labels = label_ids[..., 1:].contiguous() # Remove the first token from the sequence
    # Flatten tokens
    loss_fct = CrossEntropyLoss(ignore_index = -100)
    shift_logits = shift_logits.view(-1, model.config.vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    # shift_labels = shift_labels.to(device)
    base_loss = loss_fct(shift_logits, shift_labels)

    # Calculate the position loss
    # Apply the position target mask
    position_mask_bknd = position_mask.unsqueeze(1).unsqueeze(-1).expand(B, Kf, N, Df) # B x K X N x Df (recycles across K and Df)

    position_loss_hkrs = torch.where(position_mask_bknd == 1, torch.abs((saved_hkrs - hkr_feature_targets)/(hkr_target_values.unsqueeze(-1).unsqueeze(-1).unsqueeze(0))), torch.tensor(0.0))
    position_loss_hks = torch.where(position_mask_bknd == 1, torch.abs((saved_hks - hk_feature_targets)/(hk_target_values.unsqueeze(-1).unsqueeze(-1).unsqueeze(0))), torch.tensor(0.0))
    
    position_loss_hkrs_by_k = position_loss_hkrs.sum(dim = (0, 2, 3))/position_mask_bknd.sum(dim = (0, 2, 3))
    position_loss_hks_by_k = position_loss_hks.sum(dim = (0, 2, 3))/position_mask_bknd.sum(dim = (0, 2, 3))
    
    position_loss_hkrs_by_dim = position_loss_hkrs.sum(dim = (0, 1, 2))/position_mask_bknd.sum(dim = (0, 1, 2))
    position_loss_hks_by_dim = position_loss_hks.sum(dim = (0, 1, 2))/position_mask_bknd.sum(dim = (0, 1, 2))

    position_loss_hkrs = position_loss_hkrs.sum(dim = (0, 1, 2, 3))/position_mask_bknd.sum(dim = (0, 1, 2, 3))
    position_loss_hks = position_loss_hks.sum(dim = (0, 1, 2, 3))/position_mask_bknd.sum(dim = (0, 1, 2, 3))
    
    # # Add additional hinge penalty to disproportionately penalize values with abs value outside desired range
    # hinge_loss = (torch.clamp(saved_l2s, max = 0) - 0) ** 2 + (torch.clamp(saved_l2s, min = 1) - 1) ** 2 
    # hinge_loss = torch.where(l2_mask == 1, hinge_loss, torch.tensor(0.0)).sum()/l2_mask.sum() 

    loss = base_loss + position_loss_hks + position_loss_hkrs
    ##### End loss calcaulation ######

torch.cuda.empty_cache()
del model
check_memory()

In [None]:
importlib.reload(importlib.import_module('helpers.fence.modularity'))
from helpers.fence.modularity import get_modularity_loss_v4

with torch.no_grad():
    mloss = get_modularity_loss_v4(
        torch.stack(saved_norm_hkrs, dim = 1),
        vmat_weight.t(),
        target_dims = Dfm_indices
    )
    
mloss 

In [None]:
torch.rand(1, 1, 1, 3).repeat(1, 2, 1, 3)

In [None]:
importlib.reload(importlib.import_module('helpers.fence.modularity'))
from helpers.fence.modularity import get_modularity_loss_v4, get_modularity_loss_v3

# Helper function to print results clearly
test_cases = [
    {
        'desc': 'Verify that output is constant across Ks',
        'H': torch.rand(1, 1, 1, 3).repeat(1, 2, 1, 1), # B = 1, K = 2, N = 1, D = 3
        'V': torch.rand(3, 2),  # D = 3, I = 2
        'target_dims': [1]
    },
    # {
    #     'desc': '??',
    #     'H': torch.tensor([1, 2, 5, 2, 1]).unsqueeze(0).unsqueeze(0).unsqueeze(0), # B = 1, K = 1, N = 1, D = 5
    #     'V': torch.tensor([1, 20, 3, 20, 1]).unsqueeze(1).repeat(1, 10),  # D = 5, I = 1
    #     'target_dims': [1, 3]
    # },
    {
        'desc': 'Symmetric H and V repeated col-wise (see hand-written workout from 10/29/24)',
        'H': torch.tensor([1, 2, 1]).unsqueeze(0).unsqueeze(0).unsqueeze(0), # B = 1, K = 1, N = 1, D = 3
        'V': torch.tensor([1, 2, 3]).unsqueeze(1).repeat(1, 2),  # D = 3, I = 2
        'target_dims': [0, 2]
    }#

    # {
    #     'desc': '??',
    #     'H': torch.tensor([1, 2, 5, 2, 1]).unsqueeze(0).unsqueeze(0).unsqueeze(0), # B = 1, K = 1, N = 1, D = 5
    #     'V': torch.tensor([1, 2, 3, 4, 5]).unsqueeze(1).repeat(1, 2),  # D = 5, I = 2
    #     'target_dims': [1, 3]
    # }
]

for t in test_cases:
    print(t['desc'])
    result = get_modularity_loss_v4(t['H'], t['V'], t['target_dims']) 
    print(result)


In [None]:
20/(4*2 + 3 * 2)

In [None]:
torch.tensor([1, 2, 5, 2, 1]).unsqueeze(1).repeat(1, 10)

In [None]:
torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]]).shape

In [None]:
torch.rand(1, 1, 1, 3).repeat(1, 2, 1, 3).shape

In [None]:
H = torch.randn(10, 1024, 3072, dtype = torch.bfloat16, device = 'cuda')
V = torch.randn(3072, 8192, dtype = torch.bfloat16, device = 'cuda')

H = H.unsqueeze(1).repeat(1, 32, 1, 1)
get_modularity_loss_v4(H, V, target_dims = [3040, 3052])

In [None]:
torch.manual_seed(0)
H = torch.randn(10, 1024, 3072, dtype = torch.bfloat16, device = 'cuda').unsqueeze(1).repeat(1, 32, 1, 1)
V = torch.randn(3072, 8192, dtype = torch.bfloat16, device = 'cuda')
start = time.time()
loss = get_modularity_loss_v4(H, V, target_dims = [3040, 3052])
end = time.time()
print(end - start)
print(loss)


In [None]:
fence_params.fence_dict

In [52]:
# Get (0-indexed)
modularity_dimensions = [i for x in ['animals', 'dogs', 'cats'] for i in range(fence_params.fence_dict[x][0] - 1, fence_params.fence_dict[x][1])]

In [None]:
# Test logit lens of intermediate hidden states
importlib.reload(importlib.import_module('helpers.fence.eval'))
from helpers.fence.eval import get_logit_lens

for i, p in enumerate(tokenizer.batch_decode(input_ids)):
    display(HTML(
        '<div style="padding: 1rem 2rem; background-color:honeydew;font-size:xs">' + 
            '<h4>Input #' + str(i) + '</h4>' + 
            '<span style="color:green">' + p.replace('<s>', '').replace('<|endoftext|>', '<>') + '</span> ' + 
        '</div>'
    ))

input_ix = 5 # Input index to test with

results = []
for k in range(len(transformer_outputs)):
    results.append(
        get_logit_lens(my_model, tokenizer, hidden_state = transformer_outputs[k][input_ix:(input_ix + 1), :, :], top_k = 5)\
            .assign(k = k)
    )

results = pd.concat(results)

def plot_logit_lens(results):
    """ 
    Plot logit lens top k as a heatmap 
    """
    tokens_data = results.pivot(index='k', columns = 'token_rank', values = 'token')
    probabilities_data = results.pivot(index='k', columns = 'token_rank', values = 'probability')

    custom_colorscale = [[0, 'rgb(255, 204, 204)'], [0.5, 'rgb(255, 255, 204)'], [1, 'rgb(204, 255, 204)']]

    fig = px.imshow(probabilities_data, labels = {'x': 'Token Rank', 'y': 'k', 'color': 'Probability'}, aspect = 'auto', color_continuous_scale = custom_colorscale) 

    # Add text annotations (tokens) to the heatmap
    for i in range(len(tokens_data.index)):
        for j in range(len(tokens_data.columns)):
            fig.add_annotation(
                text = str(tokens_data.iloc[i, j]), x = j + 1, y = i, xref = 'x', yref = 'y', 
                font = {'size': 12}, showarrow = False
            )

    fig.update_layout(
        xaxis_title = 'Token Rank', yaxis_title = 'Layer',
        xaxis_nticks = len(tokens_data.columns), yaxis_nticks = len(tokens_data.index),
        width = 800, height = 600,
        xaxis = {'showgrid': True, 'tickmode': 'linear'}, yaxis = {'showgrid': True, 'tickmode': 'linear'}
    )

    return fig

plot_logit_lens(results).show()

## Training

In [None]:
# Pre-train Visualizations
test_prompts = [
    parse_phi([{'role': 'user', 'content': 'Can you give me some tips for traveling with my dog?'}], True),
    parse_phi([{'role': 'user', 'content': 'Can you give me some tips for traveling with my friend?'}], True),
    parse_phi([{'role': 'user', 'content': 'Hey idiot, what\'s wrong with my dog?'}], True),
]
test_gens = [generate_fence(my_model, tokenizer, prompt = t, max_tokens = 16) for t in test_prompts]

test_plots = [
    visualize_fence(gen['text'], gen['hks'], [20], fence_params.fence_dict, 2950, 3072, 0,  fence_params.hk_target_values[20]).update_layout(title = 'H<sub>20</sub>', height = 300)
    for gen in test_gens
]

for p in test_plots:
    p.show()

In [None]:
# Evaluation function
@torch.no_grad()
def eval_fence(
    model, 
    tokenizer,
    test_ds: FenceDataSet, 
    fence_params: FenceParams,
    force_fence: bool = True, 
    batch_size: int = 10, 
    num_batches: int = 20,
    device = 'cuda'
    ):
    """
    Evaluation function to get test losses    

    Params:
        @model: The model to use
        @tokenizer: The tokenizer object
        @test_ds: The test dataset object, with a FENCE feature dict object
        @fence_params: A FenceParams object - needed for Kfstart, Kfend, Kf, Df, and Df_indices
        @force_fence: Whether to force FENCE position indices
        @batch_size: The batch size to use for eval
        @num_batches: The number of batches to use fo reval
        @device: The torch device
    """
    Kfstart = fence_params.Kfstart
    Kfend = fence_params.Kfend
    Kf = fence_params.Kf
    Df = fence_params.Df
    Df_indices = fence_params.Df_indices # 0-indexed Df indices used for extracting the Df values
    
    model.eval()
    batch_results = []
    batches_to_eval = num_batches
    input_count = 0
    
    for ix, batch in enumerate(DataLoader(test_ds, batch_size = batch_size, shuffle = True)):
        
        if ix >= batches_to_eval:
            break
            
        input_ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        feature_targets = batch['feature_targets'].to(device)
        position_mask = batch['position_mask'].to(device)
                
        ##### Forward Pass ######
        embeds_output = model.model.embed_tokens(input_ids) # B x N x D
        hidden_state = embeds_output

        B, N, D = embeds_output.shape
        H = 32

        # Prepare SA inputs
        position_ids = torch.arange(0, N, dtype = torch.long, device = device).unsqueeze(0).view(-1, N) # Create position IDs
        if model.model._attn_implementation == 'flash_attention_2':
            attention_mask = mask if (mask is not None and 0 in mask) else None  # Flash attention = use default attention mask 2d
        else: 
            attention_mask = _prepare_4d_causal_attention_mask(None, (B, N), embeds_output, 0, sliding_window = model.model.config.sliding_window)  # Non FA: Make a triangular attention mask to hide right context

        hkr_target_values = torch.tensor(list(fence_params.hkr_target_values.values()), device = input_ids.device, dtype = torch.bfloat16)
        hk_target_values = torch.tensor(list(fence_params.hk_target_values.values()), device = input_ids.device, dtype = torch.bfloat16)

        # Multiply it by the actual feature targets by layer
        feature_targets_bkd = feature_targets.unsqueeze(1).bfloat16() # B x K x Df
        hkr_feature_targets = hkr_target_values.view(1, Kf, 1) * feature_targets_bkd # B x K x Df
        hkr_feature_targets = hkr_feature_targets.unsqueeze(2).expand(B, Kf, N, Df) # B x K x N x Df
        hk_feature_targets = hk_target_values.view(1, Kf, 1) * feature_targets_bkd # B x K x Df
        hk_feature_targets = hk_feature_targets.unsqueeze(2).expand(B, Kf, N, Df) # B x K x N x Df

        # Saved_hk2s will be of shape B x K x N x Df
        saved_hkrs = None
        saved_hks = None
        for l, layer in enumerate(model.model.layers):
            
            # SA
            residual = hidden_state
            sa_input = layer.input_layernorm(hidden_state)
            sa_output = layer.self_attn(sa_input, attention_mask, position_ids)[0]
            
            # Sum back to resid stream
            hidden_state = residual + layer.resid_attn_dropout(sa_output)    

            if l >= Kfstart - 1 and l <= Kfend - 1:
                if force_fence: # Forcibly set H_K^R
                    # To extract the right layer from hkr_feature_targets:
                    # - We want to extract layer l + 1 (e.g. l = 1 => Layer = 2)
                    # - Since hkr_feature_targets[:, k, :, :] contains layer Kfstart+k, we want to find k s.t. Kfstart + k = l + 1
                    # - => k = l + 1 - Kfstart
                    hidden_state[:, :, Df_indices ] = hkr_feature_targets[:, l + 1 - Kfstart, :, :] # B x N x Df

                this_hidden_state = hidden_state[:, :, Df_indices].unsqueeze(dim = 1)  # Save B x 1 x N x Df
                if saved_hkrs is None:
                    saved_hkrs = this_hidden_state
                else:
                    saved_hkrs = torch.cat((saved_hkrs, this_hidden_state), dim = 1)

            # MLP
            residual = hidden_state
            hidden_state = layer.post_attention_layernorm(hidden_state)
            mlp_output = layer.mlp(hidden_state)

            # Sum back to resid stream
            hidden_state = residual + layer.resid_mlp_dropout(mlp_output)

            if l >= Kfstart - 1 and l <= Kfend - 1:
                if force_fence: # Forcibly set H_K
                    hidden_state[:, :, Df_indices] = hk_feature_targets[:, l + 1 - Kfstart, :, :] # B x N x Df

                this_hidden_state = hidden_state[:, :, Df_indices].unsqueeze(dim = 1)  # Save B x 1 x N x Df
                if saved_hks is None:
                    saved_hks = this_hidden_state
                else:
                    saved_hks = torch.cat((saved_hks, this_hidden_state), dim = 1)
                    

        # RMS norm the final transformer layer output
        hidden_state = model.model.norm(hidden_state)

        # Run LM head
        logits = model.lm_head(hidden_state).float() # B x N x D
        #### End Forward Pass ######

        ##### Calculate loss #####
        # Mask loss anywhere where the input ids are pad tokens
        label_ids = torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids)
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous() # Remove the last token from the sequence
        shift_labels = label_ids[..., 1:].contiguous() # Remove the first token from the sequence
        # Flatten tokens
        loss_fct = CrossEntropyLoss(ignore_index = -100)
        shift_logits = shift_logits.view(-1, model.config.vocab_size)
        shift_labels = shift_labels.view(-1)
        # Enable model parallelism
        # shift_labels = shift_labels.to(device)
        base_loss = loss_fct(shift_logits, shift_labels)

        # Calculate the position loss
        # Apply the position target mask
        position_mask_bknd = position_mask.unsqueeze(1).unsqueeze(-1).expand(B, Kf, N, Df) # B x K X N x Df (recycles across K and Df)

        # Creates a B x K x N x Df tensor with targets differing by K and Df, and values masked (-100) or not (0, .25, .5, etc.) varying by N
        # MAPE Loss
        position_loss_hkrs = torch.where(position_mask_bknd == 1, torch.abs((saved_hkrs - hkr_feature_targets)/(hkr_target_values.unsqueeze(-1).unsqueeze(-1).unsqueeze(0))), torch.tensor(0.0))
        position_loss_hks = torch.where(position_mask_bknd == 1, torch.abs((saved_hks - hk_feature_targets)/(hk_target_values.unsqueeze(-1).unsqueeze(-1).unsqueeze(0))), torch.tensor(0.0))
        
        position_loss_hkrs_by_k = position_loss_hkrs.sum(dim = (0, 2, 3))/position_mask_bknd.sum(dim = (0, 2, 3))
        position_loss_hks_by_k = position_loss_hks.sum(dim = (0, 2, 3))/position_mask_bknd.sum(dim = (0, 2, 3))

        position_loss_hkrs_by_dim = position_loss_hkrs.sum(dim = (0, 1, 2))/position_mask_bknd.sum(dim = (0, 1, 2))
        position_loss_hks_by_dim = position_loss_hks.sum(dim = (0, 1, 2))/position_mask_bknd.sum(dim = (0, 1, 2))

        position_loss_hkrs = position_loss_hkrs.sum(dim = (0, 1, 2, 3))/position_mask_bknd.sum(dim = (0, 1, 2, 3))
        position_loss_hks = position_loss_hks.sum(dim = (0, 1, 2, 3))/position_mask_bknd.sum(dim = (0, 1, 2, 3))
                
        loss = base_loss + position_loss_hks + position_loss_hkrs
        ##### End loss calcaulation ######
        
        dim_ix = 0
        position_loss_hks_by_feature = {}
        for fname, fdim in test_ds.fence_dict.items():
            flen = (fdim[1] - fdim[0] + 1)
            position_loss_hks_by_feature[fname] = position_loss_hks_by_dim[dim_ix : dim_ix + flen].detach().cpu()
            dim_ix = dim_ix + flen

        batch_results.append({
            'base_loss': base_loss.detach().cpu().item(),
            'position_loss_hkrs': position_loss_hkrs.detach().cpu().item(),
            'position_loss_hks': position_loss_hks.detach().cpu().item(),
            'position_loss_hkrs_by_k': dict(zip([i + Kfstart for i in list(range(Kfstart - 1, Kfend))], position_loss_hkrs_by_k.detach().cpu().tolist())),
            'position_loss_hks_by_k': dict(zip([i + Kfstart for i in list(range(Kfstart - 1, Kfend))], position_loss_hks_by_k.detach().cpu().tolist())),
            'position_loss_hks_by_feature': position_loss_hks_by_feature
        })
        input_count = input_count + B

    return {
        'input_count': input_count,
        'base_loss': np.mean([b['base_loss'] for b in batch_results]),
        'position_loss_hkrs': np.mean([b['position_loss_hkrs'] for b in batch_results]),
        'position_loss_hks': np.mean([b['position_loss_hks'] for b in batch_results]),
        'position_loss_hkrs_by_k': {k: np.mean([b['position_loss_hkrs_by_k'][k] for b in batch_results]) for k in batch_results[0]['position_loss_hkrs_by_k'].keys()},
        'position_loss_hks_by_k': {k: np.mean([b['position_loss_hks_by_k'][k] for b in batch_results]) for k in batch_results[0]['position_loss_hkrs_by_k'].keys()},
        'position_loss_hks_by_feature': {
            fname: torch.stack([b['position_loss_hks_by_feature'][fname] for b in batch_results], dim = 0).mean().item()
            for fname in test_ds.fence_dict.keys()
        }
    }

eval_fence(my_model, tokenizer, test_ds, fence_params, force_fence = False, batch_size = 10, num_batches = 10, device = device)

In [None]:
import pickle

with open(f'{SAVE_DIR}/fence_params.pkl', 'wb') as output_file:
    pickle.dump(fence_params, output_file)

wandb.config.update(fence_params.to_dict())

In [None]:
# Training
# Investigate lower LR
# Investigate partial position-loss targeting
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, my_model.parameters()), lr = 3e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 10000, gamma = 0.8)
train_with_force_fence = True

my_model.train()

step = 0
max_grad_norm = 5.0  # Set the value for gradient clipping

# STEPS
# 1-100: Nothing, logging purposes only
# 100-5k: Always force forcing FENCE [no position loss] - This is just to get the model to adjust to force FENCE (control only!)
# 5k-10k: Force FENCE 90% of the time, remaining 10% trains position loss w/weight 2
# 10k-15k: Force FENCE 80% of the time, remaining 20% trains position loss w/weight 5
# 15k-20k: Force FENCE 70% of the time, remaining 30% trains position loss w/weight 10
# 20k-30k: Force FENCE 50% of the time, remaining 50% trains position loss w/weight 20
# 30-50k: Force FENCE 50% of the time, remaining 50% trains position loss w/weight 30
# 50k+: Force FENCE 50% of the time, remaining 50% trains position loss w/weight 10
Kfstart = fence_params.Kfstart
Kfend = fence_params.Kfend
Kf = fence_params.Kf
Df = fence_params.Df
Df_indices = fence_params.Df_indices # 0-indexed Df indices used for extracting the Df values

for epoch_ix in range(0, 100):
    
    for batch_ix, batch in enumerate(train_dl):

        iteration_start_time = time.time()

        # If force FENCE (default), then there is no position loss
        force_fence = (
            train_with_force_fence and (
                (step < 5000) or
                (step >= 5000 and step < 10000 and step % 10 >= 1) or
                (step >= 10000 and step < 15000 and step % 10 >= 2) or
                (step >= 15000 and step < 20000 and step % 10 >= 3) or
                (step >= 20000 and step < 30000 and step % 10 >= 5) or
                (step >= 30000 and step % 10 >= 5)
            )
        )

        optimizer.zero_grad()
        if step < 4:
            check_memory()

        # If not force FENCE (i.e., there exists some position loss, then train with no-surprise data)
        if force_fence == False:
            batch = next(iter(train_nosup_dl))

        input_ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        feature_targets = batch['feature_targets'].to(device)
        position_mask = batch['position_mask'].to(device)

        ##### Forward Pass ######
        embeds_output = my_model.model.embed_tokens(input_ids) # B x N x D
        hidden_state = embeds_output

        # Execute transformers layers
        # B = batch size, N = token length, D = token dim, Dh = token per-head dim, H = number of heads, K = # transformer blocks
        # Df = total FENCE dimensino width
        # Kfstart, Kfend = starting and ending indices for transformer blocks to include in FENCE (indices starts with 1, not 0)
        # Kf = number of transformer blocks to include in FENCE
        B, N, D = embeds_output.shape
        H = 32
        Dh = int(D/H)
        K = 32 

        # Prepare SA inputs
        position_ids = torch.arange(0, N, dtype = torch.long, device = device).unsqueeze(0).view(-1, N) # Create position IDs
        if my_model.model._attn_implementation == 'flash_attention_2':
            attention_mask = mask if (mask is not None and 0 in mask) else None  # Flash attention = use default attention mask 2d
        else: 
            attention_mask = _prepare_4d_causal_attention_mask(None, (B, N), embeds_output, 0, sliding_window = my_model.model.config.sliding_window)  # Non FA: Make a triangular attention mask to hide right context

        hkr_target_values = torch.tensor(list(fence_params.hkr_target_values.values()), device = input_ids.device, dtype = torch.bfloat16)
        hk_target_values = torch.tensor(list(fence_params.hk_target_values.values()), device = input_ids.device, dtype = torch.bfloat16)

        # Multiply it by the actual feature targets by layer
        feature_targets_bkd = feature_targets.unsqueeze(1).bfloat16() # B x K x Df
        hkr_feature_targets = hkr_target_values.view(1, Kf, 1) * feature_targets_bkd # B x K x Df
        hkr_feature_targets = hkr_feature_targets.unsqueeze(2).expand(B, Kf, N, Df) # B x K x N x Df
        hk_feature_targets = hk_target_values.view(1, Kf, 1) * feature_targets_bkd # B x K x Df
        hk_feature_targets = hk_feature_targets.unsqueeze(2).expand(B, Kf, N, Df) # B x K x N x Df

        # Saved_hk2s will be of shape B x K x N x Df
        saved_hkrs = None
        saved_hks = None
        for l, layer in enumerate(my_model.model.layers):
            
            # SA
            residual = hidden_state
            hidden_state = layer.input_layernorm(hidden_state)
            hidden_state = layer.self_attn(hidden_state, attention_mask, position_ids)[0]
            
            # Sum back to resid stream
            hidden_state = residual + layer.resid_attn_dropout(hidden_state)    

            if l >= Kfstart - 1 and l <= Kfend - 1:
                if force_fence and step >= 100: # Forcibly set H_K^R
                    # To extract the right layer from hkr_feature_targets:
                    # - We want to extract layer l + 1 (e.g. l = 1 => Layer = 2)
                    # - Since hkr_feature_targets[:, k, :, :] contains layer Kfstart+k, we want to find k s.t. Kfstart + k = l + 1
                    # - => k = l + 1 - Kfstart
                    hidden_state[:, :, Df_indices ] = hkr_feature_targets[:, l + 1 - Kfstart, :, :] # B x N x Df
                else:
                    pass

                this_hidden_state = hidden_state[:, :, Df_indices].unsqueeze(dim = 1)  # Save B x 1 x N x Df
                if saved_hkrs is None:
                    saved_hkrs = this_hidden_state
                else:
                    saved_hkrs = torch.cat((saved_hkrs, this_hidden_state), dim = 1)

            # MLP
            residual = hidden_state
            hidden_state = layer.post_attention_layernorm(hidden_state)
            hidden_state = layer.mlp(hidden_state)

            # Sum back to resid stream
            hidden_state = residual + layer.resid_mlp_dropout(hidden_state)

            if l >= Kfstart - 1 and l <= Kfend - 1:
                if force_fence and step >= 100: # Forcibly set H_K
                    hidden_state[:, :, Df_indices ] = hk_feature_targets[:, l + 1 - Kfstart, :, :] # B x N x Df                
                else:
                    pass

                this_hidden_state = hidden_state[:, :, Df_indices].unsqueeze(dim = 1)  # Save B x 1 x N x Df
                if saved_hks is None:
                    saved_hks = this_hidden_state
                else:
                    saved_hks = torch.cat((saved_hks, this_hidden_state), dim = 1)

        # RMS norm the final transformer layer output
        hidden_state = my_model.model.norm(hidden_state)

        # Run LM head
        logits = my_model.lm_head(hidden_state).float() # B x N x D
        #### End Forward Pass ######

        ##### Calculate loss #####
        # Mask loss anywhere where the input ids are pad tokens
        label_ids = torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids)
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous() # Remove the last token from the sequence
        shift_labels = label_ids[..., 1:].contiguous() # Remove the first token from the sequence
        # Flatten tokens
        loss_fct = CrossEntropyLoss(ignore_index = -100)
        shift_logits = shift_logits.view(-1, my_model.config.vocab_size)
        shift_labels = shift_labels.view(-1)
        # Enable model parallelism
        # shift_labels = shift_labels.to(device)
        base_loss = loss_fct(shift_logits, shift_labels)

        # Calculate the position loss
        # Apply the position target mask
        position_mask_bknd = position_mask.unsqueeze(1).unsqueeze(-1).expand(B, Kf, N, Df) # B x K X N x Df (recycles across K and Df)

        position_loss_hkrs = torch.where(position_mask_bknd == 1, torch.abs((saved_hkrs - hkr_feature_targets)/(hkr_target_values.unsqueeze(-1).unsqueeze(-1).unsqueeze(0))), torch.tensor(0.0))
        position_loss_hks = torch.where(position_mask_bknd == 1, torch.abs((saved_hks - hk_feature_targets)/(hk_target_values.unsqueeze(-1).unsqueeze(-1).unsqueeze(0))), torch.tensor(0.0))
        
        position_loss_hkrs_by_k = position_loss_hkrs.sum(dim = (0, 2, 3))/position_mask_bknd.sum(dim = (0, 2, 3))
        position_loss_hks_by_k = position_loss_hks.sum(dim = (0, 2, 3))/position_mask_bknd.sum(dim = (0, 2, 3))

        # Maybe consider using MAPE
        position_loss_hkrs = position_loss_hkrs.sum(dim = (0, 1, 2, 3))/position_mask_bknd.sum(dim = (0, 1, 2, 3))
        position_loss_hks = position_loss_hks.sum(dim = (0, 1, 2, 3))/position_mask_bknd.sum(dim = (0, 1, 2, 3))
        
        # # Add additional hinge penalty to disproportionately penalize values with abs value outside desired range
        # hinge_loss = (torch.clamp(saved_l2s, max = 0) - 0) ** 2 + (torch.clamp(saved_l2s, min = 1) - 1) ** 2 
        # hinge_loss = torch.where(l2_mask == 1, hinge_loss, torch.tensor(0.0)).sum()/l2_mask.sum() 
        if step < 5000:
            loss = base_loss
        elif step >= 5000 and step < 10000:
            loss = base_loss + 2 * position_loss_hks + 2 * position_loss_hkrs
        elif step >= 10000 and step < 15000:
            loss = base_loss + 5 * position_loss_hks + 5 * position_loss_hkrs
        elif step >= 15000 and step < 20000:
            loss = base_loss + 10 * position_loss_hks + 10 * position_loss_hkrs
        elif step >= 20000 and step < 30000:
            loss = base_loss + 20 * position_loss_hks + 20 * position_loss_hkrs
        elif step >= 30000 and step < 50000:
            loss = base_loss + 30 * position_loss_hks + 30 * position_loss_hkrs
        else:
            loss = base_loss + 10 * position_loss_hks + 10 * position_loss_hkrs
            
        ##### End loss calcaulation ######

        ##### Logging #####
        if step % 50 == 0:
            print(np.round(base_loss.detach().cpu().item(), 2), np.round(position_loss_hks.detach().cpu().item(), 2))

        forward_pass_time = time.time() - iteration_start_time
        logging_dict = {
            'epoch': epoch_ix,
            'step': step,
            'lr': optimizer.param_groups[0]['lr'],
            'forward_pass_time': forward_pass_time,
            'train': {
                'base_loss': base_loss.detach().cpu().item(),
                'position_loss_hkrs': position_loss_hkrs.detach().cpu().item(),
                'position_loss_hks': position_loss_hks.detach().cpu().item(),
                'position_loss_hkrs_by_k': dict(zip([i + Kfstart for i in list(range(Kfstart - 1, Kfend))], position_loss_hkrs_by_k.detach().cpu().tolist())),
                'position_loss_hks_by_k': dict(zip([i + Kfstart for i in list(range(Kfstart - 1, Kfend))], position_loss_hks_by_k.detach().cpu().tolist())),
                'total_loss': loss.detach().cpu().item()
            }
        }

        if step % 100 == 0:
            logging_dict = {
                **logging_dict, 
                **{'test_forced': eval_fence(my_model, tokenizer, test_ds, fence_params, force_fence = True, batch_size = 10, num_batches = 25, device = device)},
                **{'test_unforced': eval_fence(my_model, tokenizer, test_ds, fence_params, force_fence = False, batch_size = 10, num_batches = 25, device = device)}
            }
            my_model.train()
            
        # Log losses
        if USE_WANDB:
            wandb.log(logging_dict)
        ##### End Logging #####

        backward_start_time = time.time()
        loss.backward()
        backward_pass_time = time.time() - backward_start_time
        logging_dict['backward_pass_time'] = backward_pass_time

        # # Print runtime for every 100 steps
        # if step % 100 == 0:
        #     print(f"Iteration {step} forward pass: {forward_pass_time:.4f} s")
        #     print(f"Iteration {step} backward pass: {backward_pass_time:.4f} s")

        # Apply gradient clipping
        torch.nn.utils.clip_grad_norm_(my_model.parameters(), max_grad_norm)

        optimizer.step()
        scheduler.step()  # Step the scheduler to decay the learning rate
        step = step + 1

        del input_ids, mask, feature_targets, position_mask, embeds_output, hidden_state, this_hidden_state, logits, residual
        del label_ids, shift_logits, shift_labels, base_loss, position_mask_bknd, position_loss_hkrs, position_loss_hks, position_loss_hkrs_by_k, position_loss_hks_by_k
        del hkr_target_values, hk_target_values, feature_targets_bkd, hkr_feature_targets, hk_feature_targets
        del loss, position_ids, attention_mask
        torch.cuda.empty_cache()

    if epoch_ix % 2 == 0:
        torch.save(my_model.state_dict(), f"{SAVE_DIR}/e{str(epoch_ix + 1)}.pt")
    
    test_prompts = {
        'nondog_text': 'The history of the railroad is',
        'dog_text': 'A great recipe for homemade dog food is',
        'nondog_inst': parse_phi([{'role': 'user', 'content': 'Can you give me some tips for traveling with my friend?'}], True),
        'dog_inst': parse_phi([{'role': 'user', 'content': 'Can you give me some tips for traveling with my dog?'}], True),
        'catdog_inst': parse_phi([{'role': 'user', 'content': 'My dogs and cats make me so mad!'}], True),
    }
    test_gens = {title: generate_fence(my_model, tokenizer, prompt = t, max_tokens = 16, echo_output = False) for title, t in test_prompts.items()}
    
    test_plots = {
        title: visualize_fence(gen['text'], gen['hks'], [20], fence_params.fence_dict, 2950, 3072, 0,  fence_params.hk_target_values[20]).update_layout(title = 'H<sub>20</sub>', height = 350)
        for title, gen in test_gens.items()
    }
    
    for title, p in test_plots.items():
       p.write_html(f"{SAVE_DIR}/{str(epoch_ix + 1)}_{title}.html")


In [None]:
to_delete = ['input_ids', 'mask', 'feature_targets', 'position_mask', 'embeds_output', 'hidden_state', 'this_hidden_state', 'logits', 'residual'
             'label_ids', 'shift_logits', 'shift_labels', 'base_loss', 'position_mask_bknd', 'position_loss_hkrs', 'position_loss_hks',  'position_loss_hkrs_by_k', 'position_loss_hks_by_k',
             'hkr_target_values', 'hk_target_values', 'feature_targets_bkd', 'hkr_feature_targets', 'hk_feature_targets',
             'loss', 'position_ids', 'attention_mask',
             'optimizer', 'scheduler']

for var in to_delete:
    if var in locals():
        del locals()[var]

torch.cuda.empty_cache()
check_memory()

In [None]:
# nondog = visualize_fence(list(range(20, 21)), my_model, tokenizer, parse_phi([{'role': 'user', 'content': 'Can you give me some tips for traveling?'}], True), train_ds.feature_dict, max_tokens = 32)
# nondog[2].show('colab')

# dog = visualize_fence(list(range(20, 21)), my_model, tokenizer, parse_phi([{'role': 'user', 'content': 'Can you give me some tips for traveling with my dog?'}], True), train_ds.feature_dict, max_tokens = 32)
# dog[2].show('colab')

# table = wandb.Table(columns = ['Nondog', 'Dog'])
# nondog[2].write_html('./fig1.html')
# dog[2].write_html('./fig2.html')
# table.add_data(wandb.Html('./fig1.html'), wandb.Html('./fig2.html'))
# run.log({"name": 'Trained - Nondog/dog, Layer 20'})