In [None]:
"""
V5 experiments:
- Add in MLP value decay based off distance matrix
- Training with cleaner DS
"""
None

In [1]:
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
import numpy as np
import pandas as pd
from IPython.core.display import HTML, Markdown
import plotly.express as px
import os
import wandb 
from dotenv import load_dotenv
from datetime import datetime
import pytz
from torch.utils.data import DataLoader
import importlib
import pathlib

from transformers import AutoModelForCausalLM, AutoTokenizer
from helpers.phi3.phi3 import Phi3Config, Phi3ForCausalLM, _prepare_4d_causal_attention_mask
from helpers.phi3.parse import parse_phi
from helpers.memory import check_memory

load_dotenv('secrets.env')
device = 'cuda'

RUN_ID = f"{datetime.now(pytz.timezone('US/Eastern')).strftime('%Y%m%dT%H%M')}"
SAVE_DIR = f"./models/{RUN_ID}"
USE_WANDB = False

## Setup Save Dir

In [2]:
pathlib.Path(SAVE_DIR).mkdir(parents = True, exist_ok = True)

if USE_WANDB:
    
    os.environ['WANDB_INIT_TIMEOUT'] = '120'
    wandb.login(key = os.getenv('WANDB_API_KEY'))
    run = wandb.init(
        project = 'fence_v5', 
        name = RUN_ID,
        notes = '',
    )

## Initialize Model

In [3]:
attn_implementation = 'flash_attention_2' # None/flash_attention_2

# Load Model
# Padding side not important EXCEPT for flash attention, needs to be left
tokenizer = AutoTokenizer.from_pretrained('microsoft/Phi-3-mini-4k-instruct', add_eos_token = False, add_bos_token = False, padding_side = 'left') 

# Load the usual model from HF transformers - attn_implementation = None to disable flash attention
my_model = AutoModelForCausalLM.from_pretrained(
    'microsoft/Phi-3-mini-4k-instruct',
    device_map = device,
    trust_remote_code = True, 
    torch_dtype = torch.bfloat16, 
    attn_implementation = attn_implementation
    ).to(device).eval()

# Now load a model seperately from the underlying model object code
# my_model = Phi3ForCausalLM(base_model.config).to(device).eval().to(dtype = torch.bfloat16) # Phi3Config()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
# # Check everything is bfloat16
# for p in base_model.parameters():
#     print(p.dtype)

# # Check attention implementation
# my_model.model.layers[0].self_attn

In [None]:
# # Next, we want to clone params from base_model into model
# # Let's store all params from the base_model
# all_params = {}
# for name, param in base_model.named_parameters():
#     all_params[name] = param.cpu().clone()

# # Then copy them over to the new model
# for name, param in my_model.named_parameters():
#     param.data.copy_(all_params[name].data)

# # Verify these are the same
# for name, p in my_model.named_parameters():
#     if name == 'model.embed_tokens.weight': 
#         print(p)
# for name, p in base_model.named_parameters():
#     if name == 'model.embed_tokens.weight': 
#         print(p)

In [None]:
# check_memory()
# # Need to delete ALl references to original model to clear memory properly https://discuss.pytorch.org/t/cuda-memory-not-released-by-torch-cuda-empty-cache/129913/6
# if 'base_model' in globals():
#     del base_model
# if 'name' in globals():
#     del name
# if 'param' in globals():
#     del param
# if 'p' in globals():
#     del p
# if 'all_params' in globals():
#     del all_params
    
# torch.cuda.empty_cache()
# check_memory()

In [None]:
# torch.save(my_model.state_dict(), f'./models/phi3_base.pt')
# base_model.load_state_dict(torch.load('./models/FENCEV2-20240619T0813/e2.pt'))

## Initialize FENCE Parameters

In [11]:
class FenceParams():
    def __init__(self, fence_dict: dict[str, int], D: int, Kfstart: int, Kfend: int, hkr_target_values: dict[int, float], hk_target_values: dict[int, float]):
        """
        Creates a new FENCE Params object
        
        Description:
            Creates a FENCE params object which stores the FENCE dict and the position loss target values
        
        Params: 
            @fence_dict: A dict of features and their corresponding fence dimensions, e.g. {'dogs': (3065, 3068), 'cats': (3061, 3064)}. 
             - These dimensions are 1-indexed and inclusive of both the start and ending numbers passed into the tuples. (3061, 3064) means dimensions 3061, 3062, 3063, and 3064.
            @D: The dimension of the hidden state.
            @Kfstart: The index (1-indexed) of the first transformer block with which to calculate position loss.
            @Kfend: The index (1-indexed) of the last transformer block with which to calculate position loss.
            @hkr_target_values: A dict where the keys are the layer index (1-indexed), and the values representing the target FENCE values for each layer's residual stream output.
            @hk_target_values:  A dict where the keys are the layer index (1-indexed), and the values representing the target FENCE values for each layer's transformer block output.
        """
        if not all(r1[1] < r2[0] for r1, r2 in zip(sorted(fence_dict.values()), sorted(fence_dict.values())[1:])):
            raise ValueError('FENCE dict contains overlapping values')

        if not all(curr[0] > prev[0] for prev, curr in zip(fence_dict.values(), list(fence_dict.values())[1:])):
            raise ValueError('FENCE dict not passed in order!')

        self.fence_dict = fence_dict
        self.Kfstart = Kfstart
        self.Kfend = Kfend
        self.Kf = Kfend - Kfstart + 1
        self.hkr_target_values = hkr_target_values
        self.hk_target_values = hk_target_values
        self.D = D
        self.Df = sum([v[1] - v[0] + 1 for k, v in fence_dict.items()])
        self.Df_zero_indices = [i - 1 for start, end in fence_dict.values() for i in range(start, end + 1)]
        # Dfmask = torch.zeros(D)
        # for k, v in fence_dict.items():
        #     Dfmask[fence_dict[k][0] - 1:fence_dict[k][1]] = 1
        # self.Dfmask = Dfmask

    def __str__(self):
        return (f"FENCE Params Object:\n"
                f"  FENCE Dictionary: {self.fence_dict}\n"
                f"  Number of position loss layers: {self.Kf} (Layer {self.Kfstart} to layer {self.Kfend})\n"
                f"  Df: {self.Df}  \n"
                f"  hkr_target_values: {self.hkr_target_values}\n"
                f"  hk_target_values: {self.hk_target_values}\n"
                f"  Df_zero_indices: {self.Df_zero_indices}")
        
    def __repr__(self):
        return self.__str__()

# Pass indices starting at 1
Kfstart = 1
Kfend = 32
fence_params = FenceParams(
    fence_dict = {
        'programming': (2980, 2983),
        'food': (3000, 3003),
        'animals': (3020, 3023),
        'dogs': (3040, 3043),
        'cats': (3060, 3063)
    },
    D = 3072,
    Kfstart = 1,
    Kfend = 32,
    hkr_target_values = {Kfstart + j - 1: (j - 1) * .25 + .25/2 for j in range(Kfstart, Kfend + 1)},
    hk_target_values = {Kfstart + j - 1: (j - 1) * .25 + .25 for j in range(Kfstart, Kfend + 1)}
)

fence_params

FENCE Params Object:
  FENCE Dictionary: {'programming': (2980, 2983), 'food': (3000, 3003), 'animals': (3020, 3023), 'dogs': (3040, 3043), 'cats': (3060, 3063)}
  Number of position loss layers: 32 (Layer 1 to layer 32)
  Df: 20  
  hkr_target_values: {1: 0.125, 2: 0.375, 3: 0.625, 4: 0.875, 5: 1.125, 6: 1.375, 7: 1.625, 8: 1.875, 9: 2.125, 10: 2.375, 11: 2.625, 12: 2.875, 13: 3.125, 14: 3.375, 15: 3.625, 16: 3.875, 17: 4.125, 18: 4.375, 19: 4.625, 20: 4.875, 21: 5.125, 22: 5.375, 23: 5.625, 24: 5.875, 25: 6.125, 26: 6.375, 27: 6.625, 28: 6.875, 29: 7.125, 30: 7.375, 31: 7.625, 32: 7.875}
  hk_target_values: {1: 0.25, 2: 0.5, 3: 0.75, 4: 1.0, 5: 1.25, 6: 1.5, 7: 1.75, 8: 2.0, 9: 2.25, 10: 2.5, 11: 2.75, 12: 3.0, 13: 3.25, 14: 3.5, 15: 3.75, 16: 4.0, 17: 4.25, 18: 4.5, 19: 4.75, 20: 5.0, 21: 5.25, 22: 5.5, 23: 5.75, 24: 6.0, 25: 6.25, 26: 6.5, 27: 6.75, 28: 7.0, 29: 7.25, 30: 7.5, 31: 7.75, 32: 8.0}
  Df_zero_indices: [2979, 2980, 2981, 2982, 2999, 3000, 3001, 3002, 3019, 3020, 3021, 3

## Test Inference & Visualizations

In [12]:
importlib.reload(importlib.import_module('helpers.fence.eval'))
from helpers.fence.eval import generate_fence

# Test
test_prompts = [
    '<s>Animals are multicellular, eukaryotic organisms in the biological kingdom Animalia. With few',
    parse_phi([{'role': 'user', 'content': 'What did you do today?'}], True),
    parse_phi([{'role': 'user', 'content': 'What should I bring to take my dog hiking?'}], True),
    parse_phi([{'role': 'user', 'content': 'What should I bring to take my friend hiking?'}], True)
]

test_gens = [generate_fence(my_model, tokenizer, prompt = test_prompt, max_tokens = 16) for test_prompt in test_prompts]

In [None]:
importlib.reload(importlib.import_module('helpers.fence.visualize'))
from helpers.fence.visualize import visualize_fence

for l in [1, 10, 20, 30]:
    visualize_fence(
        test_gens[0]['text'],
        test_gens[0]['hks'],
        [l],
        fence_params.fence_dict,
        start_dim = 2900, end_dim = 3072,
        min_range = 0, max_range = fence_params.hk_target_values[l]
    ).update_layout(title = 'H<sub>' + str(l) + '</sub>', height = 300).show('colab')

## Data Prep

In [None]:
train_raw = pd.read_csv('train.csv')
train_nosup_raw = train_raw[train_raw['is_surprise'] == 0] # No-surprise is used for position-loss trainintg
test_raw = pd.read_csv('test.csv')
print(len(train_raw), len(train_nosup_raw), len(test_raw))

In [None]:
train_feature_classifications = train_raw[fence_params.fence_dict.keys()].to_dict('records')
train_nosup_feature_classifications = train_nosup_raw[fence_params.fence_dict.keys()].to_dict('records')
test_feature_classifications = test_raw[fence_params.fence_dict.keys()].to_dict('records')

In [None]:
## Test
# importlib.reload(importlib.import_module('helpers.fence.dataset'))
# from helpers.fence.dataset import FenceDataSet
# token_length = 128

# test_tokens = tokenizer(test_raw['phi3_text'].tolist()[0:6], truncation = True, max_length = token_length, padding = 'max_length', return_tensors = 'pt').to(device)
# test_feature_classifications = test_raw[fence_params.fence_dict.keys()].to_dict('records')[0:6]

# position_mask_start_token_id = tokenizer.encode('<|assistant|>')[0]
# test_ds = FenceDataSet(test_tokens, fence_params.fence_dict, 3072, test_feature_classifications, [position_mask_start_token_id])

# print(test_ds.position_mask)

# print(tokenizer.batch_decode(test_ds.tokens['input_ids']))

In [None]:
importlib.reload(importlib.import_module('helpers.fence.dataset'))
from helpers.fence.dataset import FenceDataSet
token_length = 1024

tmp_tokens_len_test = [tokenizer(x, return_tensors = 'pt').to(device) for x in train_raw['phi3_text'].tolist()]
px.histogram(pd.DataFrame({"j": [t['input_ids'].shape[1] for t in tmp_tokens_len_test]}), x = "j").show('colab')

train_tokens = tokenizer(train_raw['phi3_text'].tolist(), truncation = True, max_length = token_length, padding = 'max_length', return_tensors = 'pt').to(device)
train_nosup_tokens = tokenizer(train_nosup_raw['phi3_text'].tolist(), truncation = True, max_length = token_length, padding = 'max_length', return_tensors = 'pt').to(device)
test_tokens = tokenizer(test_raw['phi3_text'].tolist(), truncation = True, max_length = token_length, padding = 'max_length', return_tensors = 'pt').to(device)

position_mask_start_token_id = position_mask_start_token_id = tokenizer.encode('<|assistant|>')[0]
train_ds = FenceDataSet(train_tokens, fence_params.fence_dict, 3072, train_feature_classifications, position_mask_start_token_id)
train_nosup_ds = FenceDataSet(train_nosup_tokens, fence_params.fence_dict, 3072, train_nosup_feature_classifications, position_mask_start_token_id)
test_ds = FenceDataSet(test_tokens, fence_params.fence_dict, 3072, test_feature_classifications, position_mask_start_token_id)

train_dl = DataLoader(train_ds, batch_size = 10, shuffle = True)
train_nosup_dl = DataLoader(train_nosup_ds, batch_size = 10, shuffle = True)
test_dl = DataLoader(test_ds, batch_size = 10, shuffle = True)

In [None]:
# Don't train embeddings/lm head/RMSnorm, only parts within transformer blocks
for name, param in my_model.named_parameters():
    if 1 == 2: #'model.norm' in name or 'lm_head' in name: #"embed_tokens", "model.norm", "lm_head", "layernorm" in name:
        param.requires_grad = False
    else:
        param.requires_grad = True

for name, param in my_model.named_parameters():
    if 'layers' not in name or '.0.' in name:
        print(name, param.requires_grad)

del name
del param

check_memory()
torch.cuda.empty_cache()

## Single Batch Train Test

In [None]:
# Set single batch
for i, batch in enumerate(train_dl):
    if i > 1: break

input_ids = batch['input_ids']
mask = batch['attention_mask']
feature_targets = batch['feature_targets']
position_mask = batch['position_mask']

print(input_ids, mask, feature_targets, position_mask)

In [None]:
# Training R1: Force FENCE (i.e., no Position Loss)
model = my_model
force_fence = False

with torch.no_grad():

    ##### Forward Pass ######
    embeds_output = model.model.embed_tokens(input_ids) # B x N x D
    hidden_state = embeds_output

    # Execute transformers layers
    # B = batch size, N = token length, D = token dim, Dh = token per-head dim, H = number of heads, K = # transformer blocks
    # Df = total FENCE dimensino width
    # Kfstart, Kfend = starting and ending indices for transformer blocks to include in FENCE (indices starts with 1, not 0)
    # Kf = number of transformer blocks to include in FENCE
    B, N, D = embeds_output.shape
    H = 32
    Dh = int(D/H)
    K = 32 

    Kfstart = fence_params.Kfstart
    Kfend = fence_params.Kfend
    Kf = fence_params.Kf
    Df = fence_params.Df
    Dfmask = fence_params.Dfmask

    # Prepare SA inputs
    position_ids = torch.arange(0, N, dtype = torch.long, device = device).unsqueeze(0).view(-1, N) # Create position IDs
    if model.model._attn_implementation == 'flash_attention_2':
        attention_mask = mask if (mask is not None and 0 in mask) else None  # Flash attention = use default attention mask 2d
    else: 
        attention_mask = _prepare_4d_causal_attention_mask(None, (B, N), embeds_output, 0, sliding_window = model.model.config.sliding_window)  # Non FA: Make a triangular attention mask to hide right context

    # Create Hkr and Hk target values for position loss calculation when feature_target = 1
    hkr_target_values = torch.tensor(list(fence_params.hkr_target_values.values()), device = input_ids.device, dtype = torch.bfloat16)
    hk_target_values = torch.tensor(list(fence_params.hk_target_values.values()), device = input_ids.device, dtype = torch.bfloat16)

    # Multiply it by the actual feature targets by layer. Note that this does not apply position masking so all values are FIXED across Ns
    feature_targets_bkd = feature_targets.unsqueeze(1) # B x K x Df
    hkr_feature_targets = hkr_target_values.view(1, Kf, 1) * feature_targets_bkd # B x K x D
    hkr_feature_targets = hkr_feature_targets.unsqueeze(2).expand(B, Kf, N, D) # B x K x N x D
    hk_feature_targets = hk_target_values.view(1, Kf, 1) * feature_targets_bkd # B x K x D
    hk_feature_targets = hk_feature_targets.unsqueeze(2).expand(B, Kf, N, D) # B x K x N x D

    # Saved_hk2s will be of shape B x K x N x Df
    saved_hkrs = None
    saved_hks = None
    for l, layer in enumerate(model.model.layers):
        
        # SA
        residual = hidden_state
        sa_input = layer.input_layernorm(hidden_state)
        sa_output = layer.self_attn(sa_input, attention_mask, position_ids)[0]
        
        # Sum back to resid stream
        hidden_state = residual + layer.resid_attn_dropout(sa_output)    

        if l >= Kfstart - 1 and l <= Kfend - 1:
            if force_fence: # Forcibly set H_K^R
                # To extract the right layer from hkr_feature_targets:
                # - We want to extract layer l + 1 (e.g. l = 1 => Layer = 2)
                # - Since hkr_feature_targets[:, k, :, :] contains layer Kfstart+k, we want to find k s.t. Kfstart + k = l + 1
                # - => k = l + 1 - Kfstart
                hidden_state[:, :, -D: ] = hkr_feature_targets[:, l + 1 - Kfstart, :, :] # B x N x D

            this_hidden_state = hidden_state[:, :, -D:].unsqueeze(dim = 1)  # Save B x 1 x N x D
            if saved_hkrs is None:
                saved_hkrs = this_hidden_state
            else:
                saved_hkrs = torch.cat((saved_hkrs, this_hidden_state), dim = 1)

        # MLP
        residual = hidden_state
        hidden_state = layer.post_attention_layernorm(hidden_state)
        mlp_output = layer.mlp(hidden_state)

        # Sum back to resid stream
        hidden_state = residual + layer.resid_mlp_dropout(mlp_output)

        if l >= Kfstart - 1 and l <= Kfend - 1:
            if force_fence: # Forcibly set H_K
                hidden_state[:, :, -D: ] = hk_feature_targets[:, l + 1 - Kfstart, :, :] # B x N x D

            this_hidden_state = hidden_state[:, :, -Df:].unsqueeze(dim = 1)  # Save B x 1 x N x D
            if saved_hks is None:
                saved_hks = this_hidden_state
            else:
                saved_hks = torch.cat((saved_hks, this_hidden_state), dim = 1)
                

    # RMS norm the final transformer layer output
    hidden_state = model.model.norm(hidden_state)

    # Run LM head
    logits = model.lm_head(hidden_state).float() # B x N x D
    #### End Forward Pass ######

    ##### Calculate loss #####
    # Mask loss anywhere where the input ids are pad tokens
    label_ids = torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids)
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous() # Remove the last token from the sequence
    shift_labels = label_ids[..., 1:].contiguous() # Remove the first token from the sequence
    # Flatten tokens
    loss_fct = CrossEntropyLoss(ignore_index = -100)
    shift_logits = shift_logits.view(-1, model.config.vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    # shift_labels = shift_labels.to(device)
    base_loss = loss_fct(shift_logits, shift_labels)

    # Calculate the position loss
    # Apply the position target mask
    position_mask_bknd = position_mask.unsqueeze(1).unsqueeze(-1).expand(B, Kf, N, Df) # B x K X N x Df (recycles across K and Df)

    position_loss_hkrs = torch.where(position_mask_bknd == 1, torch.abs((saved_hkrs - hkr_feature_targets)/(hkr_target_values.unsqueeze(-1).unsqueeze(-1).unsqueeze(0))), torch.tensor(0.0))
    position_loss_hks = torch.where(position_mask_bknd == 1, torch.abs((saved_hks - hk_feature_targets)/(hk_target_values.unsqueeze(-1).unsqueeze(-1).unsqueeze(0))), torch.tensor(0.0))
    
    position_loss_hkrs_by_k = position_loss_hkrs.sum(dim = (0, 2, 3))/position_mask_bknd.sum(dim = (0, 2, 3))
    position_loss_hks_by_k = position_loss_hks.sum(dim = (0, 2, 3))/position_mask_bknd.sum(dim = (0, 2, 3))
    
    position_loss_hkrs_by_dim = position_loss_hkrs.sum(dim = (0, 1, 2))/position_mask_bknd.sum(dim = (0, 1, 2))
    position_loss_hks_by_dim = position_loss_hks.sum(dim = (0, 1, 2))/position_mask_bknd.sum(dim = (0, 1, 2))

    position_loss_hkrs = position_loss_hkrs.sum(dim = (0, 1, 2, 3))/position_mask_bknd.sum(dim = (0, 1, 2, 3))
    position_loss_hks = position_loss_hks.sum(dim = (0, 1, 2, 3))/position_mask_bknd.sum(dim = (0, 1, 2, 3))
    
    # # Add additional hinge penalty to disproportionately penalize values with abs value outside desired range
    # hinge_loss = (torch.clamp(saved_l2s, max = 0) - 0) ** 2 + (torch.clamp(saved_l2s, min = 1) - 1) ** 2 
    # hinge_loss = torch.where(l2_mask == 1, hinge_loss, torch.tensor(0.0)).sum()/l2_mask.sum() 

    loss = base_loss + position_loss_hks + position_loss_hkrs
    ##### End loss calcaulation ######

torch.cuda.empty_cache()
del model
check_memory()

In [None]:
hk_feature_targets[2, 4, 10, -50:]

In [None]:
hk_feature_targets.shape

In [None]:
hkr_feature_targets.shape

In [None]:
feature_targets[4,-50:]

## Training

In [None]:
# Pre-train Visualizations
test_prompts = [
    parse_phi([{'role': 'user', 'content': 'Can you give me some tips for traveling with my dog?'}], True),
    parse_phi([{'role': 'user', 'content': 'Can you give me some tips for traveling with my friend?'}], True),
    parse_phi([{'role': 'user', 'content': 'Hey idiot, what\'s wrong with my dog?'}], True),
]
test_gens = [generate_fence(my_model, tokenizer, prompt = t, max_tokens = 16) for t in test_prompts]

test_plots = [visualize_fence(gen['text'], gen['hks'], [30], feature_dict, 2950, 0, 5).update_layout(title = 'H<sub>10</sub>', height = 300) for gen in test_gens]

for p in test_plots:
    p.show('colab')

In [None]:
# Evaluation function
@torch.no_grad()
def eval_fence(
    model, 
    tokenizer,
    test_ds: FenceDataSet, 
    Kf_target_values: dict,
    Kfstart: int = 1, 
    Kfend: int = 32, 
    force_fence: bool = True, 
    batch_size: int = 10, 
    num_batches: int = 20,
    device = 'cuda'
    ):
    """
    Evaluation function to get test losses    

    Params:
        @model: The model to use
        @tokenizer: The tokenizer object
        @test_ds: The test dataset object, with a FENCE feature dict object
        @Kf_target_values: The targeted values by layer, e.g. {'hk': [0, 1, 2], 'hkrs': [1, 2, 3]}
        @Kfstart: The starting index layer index to track position losses (or force a FENCE) - starts at 1, not 0
        @Kfend: The ending index layer index to track position losses 
        @force_fence: Whether to force FENCE position indices
        @batch_size: The batch size to use for eval
        @num_batches: The number of batches to use fo reval
        @device: The torch device
    """
    model.eval()
    batch_results = []
    batches_to_eval = num_batches
    input_count = 0
    
    for ix, batch in enumerate(DataLoader(test_ds, batch_size = batch_size, shuffle = True)):
        
        if ix >= batches_to_eval:
            break
            
        input_ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        feature_targets = batch['feature_targets'].to(device)
        position_mask = batch['position_mask'].to(device)
                
        ##### Forward Pass ######
        embeds_output = model.model.embed_tokens(input_ids) # B x N x D
        hidden_state = embeds_output

        B, N, D = embeds_output.shape
        H = 32
        Df = feature_targets.shape[1] # Total FENCE width
        Kf = Kfend - Kfstart + 1

        # Prepare SA inputs
        position_ids = torch.arange(0, N, dtype = torch.long, device = device).unsqueeze(0).view(-1, N) # Create position IDs
        if model.model._attn_implementation == 'flash_attention_2':
            attention_mask = mask if (mask is not None and 0 in mask) else None  # Flash attention = use default attention mask 2d
        else: 
            attention_mask = _prepare_4d_causal_attention_mask(None, (B, N), embeds_output, 0, sliding_window = model.model.config.sliding_window)  # Non FA: Make a triangular attention mask to hide right context

        hkr_target_values = torch.tensor(list(Kf_target_values['hkrs'].values()), device = input_ids.device, dtype = torch.bfloat16)
        hk_target_values = torch.tensor(list(Kf_target_values['hks'].values()), device = input_ids.device, dtype = torch.bfloat16)

        # Multiply it by the actual feature targets by layer
        feature_targets_bkd = feature_targets.unsqueeze(1) # B x K x Df
        hkr_feature_targets = hkr_target_values.view(1, Kf, 1) * feature_targets_bkd # B x K x Df
        hkr_feature_targets = hkr_feature_targets.unsqueeze(2).expand(B, Kf, N, Df) # B x K x N x Df
        hk_feature_targets = hk_target_values.view(1, Kf, 1) * feature_targets_bkd # B x K x Df
        hk_feature_targets = hk_feature_targets.unsqueeze(2).expand(B, Kf, N, Df) # B x K x N x Df

        # Saved_hk2s will be of shape B x K x N x Df
        saved_hkrs = None
        saved_hks = None
        for l, layer in enumerate(model.model.layers):
            
            # SA
            residual = hidden_state
            sa_input = layer.input_layernorm(hidden_state)
            sa_output = layer.self_attn(sa_input, attention_mask, position_ids)[0]
            
            # Sum back to resid stream
            hidden_state = residual + layer.resid_attn_dropout(sa_output)    

            if l >= Kfstart - 1 and l <= Kfend - 1:
                if force_fence: # Forcibly set H_K^R
                    # To extract the right layer from hkr_feature_targets:
                    # - We want to extract layer l + 1 (e.g. l = 1 => Layer = 2)
                    # - Since hkr_feature_targets[:, k, :, :] contains layer Kfstart+k, we want to find k s.t. Kfstart + k = l + 1
                    # - => k = l + 1 - Kfstart
                    hidden_state[:, :, -Df: ] = hkr_feature_targets[:, l + 1 - Kfstart, :, :] # B x N x Df

                this_hidden_state = hidden_state[:, :, -Df:].unsqueeze(dim = 1)  # Save B x 1 x N x Df
                if saved_hkrs is None:
                    saved_hkrs = this_hidden_state
                else:
                    saved_hkrs = torch.cat((saved_hkrs, this_hidden_state), dim = 1)

            # MLP
            residual = hidden_state
            hidden_state = layer.post_attention_layernorm(hidden_state)
            mlp_output = layer.mlp(hidden_state)

            # Sum back to resid stream
            hidden_state = residual + layer.resid_mlp_dropout(mlp_output)

            if l >= Kfstart - 1 and l <= Kfend - 1:
                if force_fence: # Forcibly set H_K
                    hidden_state[:, :, -Df: ] = hk_feature_targets[:, l + 1 - Kfstart, :, :] # B x N x Df

                this_hidden_state = hidden_state[:, :, -Df:].unsqueeze(dim = 1)  # Save B x 1 x N x Df
                if saved_hks is None:
                    saved_hks = this_hidden_state
                else:
                    saved_hks = torch.cat((saved_hks, this_hidden_state), dim = 1)
                    

        # RMS norm the final transformer layer output
        hidden_state = model.model.norm(hidden_state)

        # Run LM head
        logits = model.lm_head(hidden_state).float() # B x N x D
        #### End Forward Pass ######

        ##### Calculate loss #####
        # Mask loss anywhere where the input ids are pad tokens
        label_ids = torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids)
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous() # Remove the last token from the sequence
        shift_labels = label_ids[..., 1:].contiguous() # Remove the first token from the sequence
        # Flatten tokens
        loss_fct = CrossEntropyLoss(ignore_index = -100)
        shift_logits = shift_logits.view(-1, model.config.vocab_size)
        shift_labels = shift_labels.view(-1)
        # Enable model parallelism
        # shift_labels = shift_labels.to(device)
        base_loss = loss_fct(shift_logits, shift_labels)

        # Calculate the position loss
        # Apply the position target mask
        position_mask_bknd = position_mask.unsqueeze(1).unsqueeze(-1).expand(B, Kf, N, Df) # B x K X N x Df (recycles across K and Df)

        # Creates a B x K x N x Df tensor with targets differing by K and Df, and values masked (-100) or not (0, .25, .5, etc.) varying by N
        # MAPE Loss
        position_loss_hkrs = torch.where(position_mask_bknd == 1, torch.abs((saved_hkrs - hkr_feature_targets)/(hkr_target_values.unsqueeze(-1).unsqueeze(-1).unsqueeze(0))), torch.tensor(0.0))
        position_loss_hks = torch.where(position_mask_bknd == 1, torch.abs((saved_hks - hk_feature_targets)/(hk_target_values.unsqueeze(-1).unsqueeze(-1).unsqueeze(0))), torch.tensor(0.0))
        
        position_loss_hkrs_by_k = position_loss_hkrs.sum(dim = (0, 2, 3))/position_mask_bknd.sum(dim = (0, 2, 3))
        position_loss_hks_by_k = position_loss_hks.sum(dim = (0, 2, 3))/position_mask_bknd.sum(dim = (0, 2, 3))

        position_loss_hkrs_by_dim = position_loss_hkrs.sum(dim = (0, 1, 2))/position_mask_bknd.sum(dim = (0, 1, 2))
        position_loss_hks_by_dim = position_loss_hks.sum(dim = (0, 1, 2))/position_mask_bknd.sum(dim = (0, 1, 2))

        position_loss_hkrs = position_loss_hkrs.sum(dim = (0, 1, 2, 3))/position_mask_bknd.sum(dim = (0, 1, 2, 3))
        position_loss_hks = position_loss_hks.sum(dim = (0, 1, 2, 3))/position_mask_bknd.sum(dim = (0, 1, 2, 3))
        
        loss = base_loss + position_loss_hks + position_loss_hkrs
        ##### End loss calcaulation ######
        
        dim_ix = 0
        position_loss_hks_by_feature = {}
        for fname, fdim in test_ds.feature_dict.items():
            position_loss_hks_by_feature[fname] = position_loss_hks_by_dim[dim_ix : dim_ix + fdim].mean().detach().cpu()
            dim_ix = dim_ix + fdim

        batch_results.append({
            'base_loss': base_loss.detach().cpu().item(),
            'position_loss_hkrs': position_loss_hkrs.detach().cpu().item(),
            'position_loss_hks': position_loss_hks.detach().cpu().item(),
            'position_loss_hkrs_by_k': dict(zip([i + Kfstart for i in list(range(Kfstart - 1, Kfend))], position_loss_hkrs_by_k.detach().cpu().tolist())),
            'position_loss_hks_by_k': dict(zip([i + Kfstart for i in list(range(Kfstart - 1, Kfend))], position_loss_hks_by_k.detach().cpu().tolist())),
            'position_loss_hks_by_feature': position_loss_hks_by_feature
        })
        input_count = input_count + B

    return {
        'input_count': input_count,
        'base_loss': np.mean([b['base_loss'] for b in batch_results]),
        'position_loss_hkrs': np.mean([b['position_loss_hkrs'] for b in batch_results]),
        'position_loss_hks': np.mean([b['position_loss_hks'] for b in batch_results]),
        'position_loss_hkrs_by_k': {k: np.mean([b['position_loss_hkrs_by_k'][k] for b in batch_results]) for k in batch_results[0]['position_loss_hkrs_by_k'].keys()},
        'position_loss_hks_by_k': {k: np.mean([b['position_loss_hks_by_k'][k] for b in batch_results]) for k in batch_results[0]['position_loss_hkrs_by_k'].keys()},
        'position_loss_hks_by_feature': {
            fname: torch.stack([b['position_loss_hks_by_feature'][fname] for b in batch_results], dim = 0).mean().item()
            for fname in test_ds.feature_dict.keys()
        }
    }

eval_fence(my_model, tokenizer, test_ds, Kf_target_values, Kfstart = Kfstart, Kfend = Kfend, force_fence = False, batch_size = 10, num_batches = 5, device = device)

In [None]:
train_nosup_dl

In [None]:
# Training
# Investigate lower LR
# Investigate partial position-loss targeting
optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, my_model.parameters()), lr = 3e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size = 10000, gamma = 0.8)
train_with_force_fence = True

my_model.train()

step = 0
max_grad_norm = 5.0  # Set the value for gradient clipping

# STEPS
# 1-100: Nothing, logging purposes only
# 100-5k: Always force forcing FENCE [no position loss] - This is just to get the model to adjust to force FENCE (control only!)
# 5k-10k: Force FENCE 90% of the time, remaining 10% trains position loss w/weight 2
# 10k-15k: Force FENCE 80% of the time, remaining 20% trains position loss w/weight 5
# 15k-20k: Force FENCE 70% of the time, remaining 30% trains position loss w/weight 10
# 20k-30k: Force FENCE 50% of the time, remaining 50% trains position loss w/weight 20
# 30k+: Force FENCE 50% of the time, remaining 50% trains position loss w/weight 25
for epoch_ix in range(0, 100):
    
    for batch_ix, batch in enumerate(train_dl):

        # If force FENCE (default), then there is no position loss
        force_fence = (
            train_with_force_fence and (
                (step < 5000) or
                (step >= 5000 and step < 10000 and step % 10 >= 1) or
                (step >= 10000 and step < 15000 and step % 10 >= 2) or
                (step >= 15000 and step < 20000 and step % 10 >= 3) or
                (step >= 20000 and step % 10 >= 5)
            )
        )

        optimizer.zero_grad()
        if step < 4:
            check_memory()

        # If not force FENCE (i.e., there exists some position loss, then train with no-surprise data)
        if force_fence == False:
            batch = next(iter(train_nosup_dl))

        input_ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        feature_targets = batch['feature_targets'].to(device)
        position_mask = batch['position_mask'].to(device)

        ##### Forward Pass ######
        embeds_output = my_model.model.embed_tokens(input_ids) # B x N x D
        hidden_state = embeds_output

        # Execute transformers layers
        # B = batch size, N = token length, D = token dim, Dh = token per-head dim, H = number of heads, K = # transformer blocks
        # Df = total FENCE dimensino width
        # Kfstart, Kfend = starting and ending indices for transformer blocks to include in FENCE (indices starts with 1, not 0)
        # Kf = number of transformer blocks to include in FENCE
        B, N, D = embeds_output.shape
        H = 32
        Dh = int(D/H)
        K = 32 

        Df = feature_targets.shape[1] # Total FENCE width
        Kf = Kfend - Kfstart + 1

        # Prepare SA inputs
        position_ids = torch.arange(0, N, dtype = torch.long, device = device).unsqueeze(0).view(-1, N) # Create position IDs
        if my_model.model._attn_implementation == 'flash_attention_2':
            attention_mask = mask if (mask is not None and 0 in mask) else None  # Flash attention = use default attention mask 2d
        else: 
            attention_mask = _prepare_4d_causal_attention_mask(None, (B, N), embeds_output, 0, sliding_window = my_model.model.config.sliding_window)  # Non FA: Make a triangular attention mask to hide right context

        hkr_target_values = torch.tensor(list(Kf_target_values['hkrs'].values()), device = input_ids.device, dtype = torch.bfloat16)
        hk_target_values = torch.tensor(list(Kf_target_values['hks'].values()), device = input_ids.device, dtype = torch.bfloat16)

        # Multiply it by the actual feature targets by layer
        feature_targets_bkd = feature_targets.unsqueeze(1) # B x K x Df
        hkr_feature_targets = hkr_target_values.view(1, Kf, 1) * feature_targets_bkd # B x K x Df
        hkr_feature_targets = hkr_feature_targets.unsqueeze(2).expand(B, Kf, N, Df) # B x K x N x Df
        hk_feature_targets = hk_target_values.view(1, Kf, 1) * feature_targets_bkd # B x K x Df
        hk_feature_targets = hk_feature_targets.unsqueeze(2).expand(B, Kf, N, Df) # B x K x N x Df

        # Saved_hk2s will be of shape B x K x N x Df
        saved_hkrs = None
        saved_hks = None
        for l, layer in enumerate(my_model.model.layers):
            
            # SA
            residual = hidden_state
            hidden_state = layer.input_layernorm(hidden_state)
            hidden_state = layer.self_attn(hidden_state, attention_mask, position_ids)[0]
            
            # Sum back to resid stream
            hidden_state = residual + layer.resid_attn_dropout(hidden_state)    

            if l >= Kfstart - 1 and l <= Kfend - 1:
                if force_fence and step >= 100: # Forcibly set H_K^R
                    # To extract the right layer from hkr_feature_targets:
                    # - We want to extract layer l + 1 (e.g. l = 1 => Layer = 2)
                    # - Since hkr_feature_targets[:, k, :, :] contains layer Kfstart+k, we want to find k s.t. Kfstart + k = l + 1
                    # - => k = l + 1 - Kfstart
                    hidden_state[:, :, -Df: ] = hkr_feature_targets[:, l + 1 - Kfstart, :, :] # B x N x Df
                else:
                    pass

                this_hidden_state = hidden_state[:, :, -Df:].unsqueeze(dim = 1)  # Save B x 1 x N x Df
                if saved_hkrs is None:
                    saved_hkrs = this_hidden_state
                else:
                    saved_hkrs = torch.cat((saved_hkrs, this_hidden_state), dim = 1)

            # MLP
            residual = hidden_state
            hidden_state = layer.post_attention_layernorm(hidden_state)
            hidden_state = layer.mlp(hidden_state)

            # Sum back to resid stream
            hidden_state = residual + layer.resid_mlp_dropout(hidden_state)

            if l >= Kfstart - 1 and l <= Kfend - 1:
                if force_fence and step >= 100: # Forcibly set H_K
                    hidden_state[:, :, -Df: ] = hk_feature_targets[:, l + 1 - Kfstart, :, :] # B x N x Df                
                else:
                    pass

                this_hidden_state = hidden_state[:, :, -Df:].unsqueeze(dim = 1)  # Save B x 1 x N x Df
                if saved_hks is None:
                    saved_hks = this_hidden_state
                else:
                    saved_hks = torch.cat((saved_hks, this_hidden_state), dim = 1)

        # RMS norm the final transformer layer output
        hidden_state = my_model.model.norm(hidden_state)

        # Run LM head
        logits = my_model.lm_head(hidden_state).float() # B x N x D
        #### End Forward Pass ######

        ##### Calculate loss #####
        # Mask loss anywhere where the input ids are pad tokens
        label_ids = torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids)
        # Shift so that tokens < n predict n
        shift_logits = logits[..., :-1, :].contiguous() # Remove the last token from the sequence
        shift_labels = label_ids[..., 1:].contiguous() # Remove the first token from the sequence
        # Flatten tokens
        loss_fct = CrossEntropyLoss(ignore_index = -100)
        shift_logits = shift_logits.view(-1, my_model.config.vocab_size)
        shift_labels = shift_labels.view(-1)
        # Enable model parallelism
        # shift_labels = shift_labels.to(device)
        base_loss = loss_fct(shift_logits, shift_labels)

        # Calculate the position loss
        # Apply the position target mask
        position_mask_bknd = position_mask.unsqueeze(1).unsqueeze(-1).expand(B, Kf, N, Df) # B x K X N x Df (recycles across K and Df)

        position_loss_hkrs = torch.where(position_mask_bknd == 1, torch.abs((saved_hkrs - hkr_feature_targets)/(hkr_target_values.unsqueeze(-1).unsqueeze(-1).unsqueeze(0))), torch.tensor(0.0))
        position_loss_hks = torch.where(position_mask_bknd == 1, torch.abs((saved_hks - hk_feature_targets)/(hk_target_values.unsqueeze(-1).unsqueeze(-1).unsqueeze(0))), torch.tensor(0.0))
        
        position_loss_hkrs_by_k = position_loss_hkrs.sum(dim = (0, 2, 3))/position_mask_bknd.sum(dim = (0, 2, 3))
        position_loss_hks_by_k = position_loss_hks.sum(dim = (0, 2, 3))/position_mask_bknd.sum(dim = (0, 2, 3))

        # Maybe consider using MAPE
        position_loss_hkrs = position_loss_hkrs.sum(dim = (0, 1, 2, 3))/position_mask_bknd.sum(dim = (0, 1, 2, 3))
        position_loss_hks = position_loss_hks.sum(dim = (0, 1, 2, 3))/position_mask_bknd.sum(dim = (0, 1, 2, 3))
        
        # # Add additional hinge penalty to disproportionately penalize values with abs value outside desired range
        # hinge_loss = (torch.clamp(saved_l2s, max = 0) - 0) ** 2 + (torch.clamp(saved_l2s, min = 1) - 1) ** 2 
        # hinge_loss = torch.where(l2_mask == 1, hinge_loss, torch.tensor(0.0)).sum()/l2_mask.sum() 
        if step < 5000:
            loss = base_loss
        elif step >= 5000 and step < 10000:
            loss = base_loss + 2 * position_loss_hks + 2 * position_loss_hkrs
        elif step >= 10000 and step < 15000:
            loss = base_loss + 5 * position_loss_hks + 5 * position_loss_hkrs
        elif step >= 15000 and step < 20000:
            loss = base_loss + 10 * position_loss_hks + 10 * position_loss_hkrs
        elif step >= 20000 and step < 30000:
            loss = base_loss + 20 * position_loss_hks + 20 * position_loss_hkrs
        else:
            loss = base_loss + 25 * position_loss_hks + 25 * position_loss_hkrs
        ##### End loss calcaulation ######

        ##### Logging #####
        if step % 50 == 0:
            print(np.round(base_loss.detach().cpu().item(), 2), np.round(position_loss_hks.detach().cpu().item(), 2))
            
        logging_dict = {
            'epoch': epoch_ix,
            'step': step,
            'lr': optimizer.param_groups[0]['lr'],
            'train': {
                'base_loss': base_loss.detach().cpu().item(),
                'position_loss_hkrs': position_loss_hkrs.detach().cpu().item(),
                'position_loss_hks': position_loss_hks.detach().cpu().item(),
                'position_loss_hkrs_by_k': dict(zip([i + Kfstart for i in list(range(Kfstart - 1, Kfend))], position_loss_hkrs_by_k.detach().cpu().tolist())),
                'position_loss_hks_by_k': dict(zip([i + Kfstart for i in list(range(Kfstart - 1, Kfend))], position_loss_hks_by_k.detach().cpu().tolist())),
                'total_loss': loss.detach().cpu().item()
            }
        }
        
        if step % 100 == 0:
            logging_dict = {
                **logging_dict, 
                **{'test_forced': eval_fence(my_model, tokenizer, test_ds, Kf_target_values = Kf_target_values, Kfstart = Kfstart, Kfend = Kfend, force_fence = True, batch_size = 10, num_batches = 20, device = device)},
                **{'test_unforced': eval_fence(my_model, tokenizer, test_ds, Kf_target_values = Kf_target_values, Kfstart = Kfstart, Kfend = Kfend, force_fence = False, batch_size = 10, num_batches = 20, device = device)}
            }
            my_model.train()
            
        # Log losses
        if USE_WANDB:
            wandb.log(logging_dict)
        ##### End Logging #####

        loss.backward()
        
        # Apply gradient clipping
        torch.nn.utils.clip_grad_norm_(my_model.parameters(), max_grad_norm)

        optimizer.step()
        scheduler.step()  # Step the scheduler to decay the learning rate
        step = step + 1

        del input_ids, mask, feature_targets, position_mask, embeds_output, hidden_state, this_hidden_state, logits, residual
        del label_ids, shift_logits, shift_labels, base_loss, position_mask_bknd, position_loss_hkrs, position_loss_hks, position_loss_hkrs_by_k, position_loss_hks_by_k
        del hkr_target_values, hk_target_values, feature_targets_bkd, hkr_feature_targets, hk_feature_targets
        del loss, position_ids, attention_mask
        torch.cuda.empty_cache()

    if epoch_ix % 2 == 0:
        torch.save(my_model.state_dict(), f"{SAVE_DIR}/e{str(epoch_ix + 1)}.pt")
    
    # nondog = visualize_fence(list(range(10, 20)), my_model, tokenizer, parse_phi([{'role': 'user', 'content': 'Can you give me some tips for traveling with my friend?'}], True), train_ds.feature_dict, max_tokens = 16)    
    # dog = visualize_fence(list(range(10, 20)), my_model, tokenizer, parse_phi([{'role': 'user', 'content': 'Can you give me some tips for traveling with my dog?'}], True), train_ds.feature_dict, max_tokens = 16)
    # angrydog = visualize_fence(list(range(10, 20)), my_model, tokenizer, parse_phi([{'role': 'user', 'content': 'My dogs and cats make me so MAD!'}], True), train_ds.feature_dict, max_tokens = 16)

    # nondog[2].write_html(f"{SAVE_DIR}/{str(epoch_ix + 1)}_nondog.html")
    # dog[2].write_html(f"{SAVE_DIR}/{str(epoch_ix + 1)}_yesdog.html")

    # angrydog[2].write_html(f"{SAVE_DIR}/{str(epoch_ix + 1)}_angrydog.html")

In [None]:
torch.tensor(Kf_target_values['hkrs'].values())

In [None]:
to_delete = ['input_ids', 'mask', 'feature_targets', 'position_mask', 'embeds_output', 'hidden_state', 'this_hidden_state', 'logits', 'residual'
             'label_ids', 'shift_logits', 'shift_labels', 'base_loss', 'position_mask_bknd', 'position_loss_hkrs', 'position_loss_hks',  'position_loss_hkrs_by_k', 'position_loss_hks_by_k',
             'hkr_target_values', 'hk_target_values', 'feature_targets_bkd', 'hkr_feature_targets', 'hk_feature_targets',
             'loss', 'position_ids', 'attention_mask',
             'optimizer', 'scheduler']

for var in to_delete:
    if var in locals():
        del locals()[var]

torch.cuda.empty_cache()
check_memory()

In [None]:
nondog = visualize_fence(list(range(20, 21)), my_model, tokenizer, parse_phi([{'role': 'user', 'content': 'Can you give me some tips for traveling?'}], True), train_ds.feature_dict, max_tokens = 32)
nondog[2].show('colab')

dog = visualize_fence(list(range(20, 21)), my_model, tokenizer, parse_phi([{'role': 'user', 'content': 'Can you give me some tips for traveling with my dog?'}], True), train_ds.feature_dict, max_tokens = 32)
dog[2].show('colab')

table = wandb.Table(columns = ['Nondog', 'Dog'])
nondog[2].write_html('./fig1.html')
dog[2].write_html('./fig2.html')
table.add_data(wandb.Html('./fig1.html'), wandb.Html('./fig2.html'))
run.log({"name": 'Trained - Nondog/dog, Layer 20'})