## Preliminary Selection

In [1]:
with open("results/new_prompt_lca_pca.txt", "r") as f:
    data = f.read()

In [2]:
import sys, os
sys.path.append(os.path.abspath(".."))
from utils import load_config
config = load_config("../config/models.yaml")

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
model = config["models"]["perplexed"]
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)
ppl_tokenizer = AutoTokenizer.from_pretrained(model)
ppl_model = AutoModelForCausalLM.from_pretrained(
    model,
    quantization_config=quantization_config,
    device_map="auto"
)
ppl_model.eval()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRM

In [41]:
import numpy as np

def preliminary_sgd(prompt, token_to_optimize_idx, w1=30.0, b_candidates=50):
    if len(prompt.split()[token_to_optimize_idx]) < 3:
        return
    ppl_model.eval()
    ppl_model.zero_grad()
    inputs = ppl_tokenizer(prompt, return_tensors="pt").to("cuda")
    input_ids = inputs.input_ids
    inputs_embeds = ppl_model.get_input_embeddings()(input_ids)
    inputs_embeds.requires_grad_(True)
    inputs_embeds.retain_grad()
    outputs = ppl_model(inputs_embeds=inputs_embeds, labels=input_ids)
    loss = outputs.loss
    logits = outputs.logits
    loss.backward()
    token_grad = inputs_embeds.grad[0, token_to_optimize_idx]
    word_embedding_matrix = ppl_model.get_input_embeddings().weight
    r_obj = token_grad @ word_embedding_matrix.T
    r_int = logits[0, token_to_optimize_idx - 1]
    combined_score = (w1 * r_obj) + r_int
    combined_score = combined_score.detach().cpu().numpy()
    top_b_idx = np.argsort(combined_score)[-b_candidates:]
    cand_tokens = ppl_tokenizer.convert_ids_to_tokens(top_b_idx)
    return cand_tokens
    
    
    
    

In [42]:
tokens_data = data.split()

In [43]:
token_to_optimize = tokens_data[3]

In [44]:
token_to_optimize

'report'

In [45]:
cands = preliminary_sgd(prompt=data, token_to_optimize_idx = 3)

In [None]:
cands

### Fine Selection

In [4]:
from simulated_annealing import calc_perplexity
def fine_selection(prompt, cands, token_to_optimize_idx, maximize_ppl=False):
    best_token = None
    if maximize_ppl: best_cost = -float('inf')
    else: best_cost = float('inf')
    tokens = prompt.split()
    original_token = tokens[token_to_optimize_idx]
    for c in cands:
        clean_token = c.replace('Ä ', '')
        tokens[token_to_optimize_idx] = clean_token
        new_prompt = " ".join(tokens)
        curr_cost = calc_perplexity(new_prompt)
        if maximize_ppl:
            if curr_cost > best_cost:
                best_cost = curr_cost
                best_token = c
        else:
            if curr_cost < best_cost:
                best_cost = curr_cost
                best_token = c

        tokens[token_to_optimize_idx] = original_token
    return best_token, best_cost
        
    

ModuleNotFoundError: No module named 'simulated_annealing'

In [8]:
prompt_tokens = data.tokens()
for i, token in enumarate(prompt_tokens):
    cands = preliminary_sgd(data, i)
    if not cands: 
        continue
    fine_selection(data, cands, i, True)
