In [3]:
import os
import pickle
from contextlib import nullcontext
import torch
import tiktoken
from model import ModelArgs, Transformer
from tokenizer import Tokenizer

from collections import defaultdict

use_mps = True

# -----------------------------------------------------------------------------
out_dir = 'out' # ignored if init_from is not 'resume'
start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 1 # number of samples to draw
max_new_tokens = 100 # number of tokens generated in each sample
temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 300 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() and use_mps else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
#dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
dtype = "float32"
compile = False # use PyTorch 2.0 to compile the model to be faster
#exec(open('configurator.py').read()) # overrides from command line or config file
# -----------------------------------------------------------------------------


class Hook:
    def __init__(self):
        self.clear()
        
    def __call__(self, layer_name, layer_id, activation):
        self.activations[f"{layer_name}.{layer_id}"] = activation

    def clear(self):
        self.activations = {}



hook = Hook()


torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'mps' if 'mps' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type != 'cuda' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# init from a model saved in a specific directory
ckpt_path = os.path.join(out_dir, 'stories42M.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = ModelArgs(**checkpoint['model_args'])
model = Transformer(gptconf, hook)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    # if k == "tok_embeddings.weight":
    #     print(k, v.shape)
    #     state_dict[k] = v.transpose(1, 0)
    #     print(state_dict[k].shape)
model.load_state_dict(state_dict, strict=False)

# model.eval()
model.to(device)

Transformer(
  (tok_embeddings): Linear(in_features=32000, out_features=512, bias=False)
  (dropout): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0-7): 8 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=512, out_features=512, bias=False)
        (wk): Linear(in_features=512, out_features=512, bias=False)
        (wv): Linear(in_features=512, out_features=512, bias=False)
        (wo): Linear(in_features=512, out_features=512, bias=False)
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_dropout): Dropout(p=0.0, inplace=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=512, out_features=1376, bias=False)
        (w2): Linear(in_features=1376, out_features=512, bias=False)
        (w3): Linear(in_features=512, out_features=1376, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
 

In [6]:
enc = Tokenizer()

In [7]:
model.hook = lambda x, y, z: None
start_ids = enc.encode("", bos=True, eos=False)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
# run generation
with torch.no_grad():
    model.eval()
    with ctx:
        for k in range(num_samples):
            y = model.generate(x, max_new_tokens, temperature=1.0, top_k=top_k)
            print(enc.decode(y[0].tolist()))
            print('---------------')
model.hook = hook

Once upon a time, there was a big, red bus. The bus was very good at its job. The bus liked to help people get to places where they needed to go.
One day, the bus met a little boy named Tim. Tim wanted to go to the park to feed his dog. The bus said, "Okay, I will take you to the park to feed your dog." Tim was very happy because the bus was a dependable friend.
After Tim fed his
---------------


In [None]:
start = "! ! ! ! ! ! ! ! ! ! ! !"

# encode the beginning of the prompt
if start.startswith('FILE:'):
    with open(start[5:], 'r', encoding='utf-8') as f:
        start = f.read()
start_ids = enc.encode(start, bos=True, eos=False)
print(start_ids)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
print(x)



In [None]:
enc.encode("computer", bos=False, eos=False)

In [None]:
len(start_ids)

In [None]:
num_free_tokens = len(start_ids) - 1 #30
# num_free_tokens = 9

In [None]:
# start_ids = start_ids + [1738]*num_free_tokens

In [None]:
import random

In [None]:
def loss(x_idx, x, logits, activations, print_loss=True):
    act = activations["ffn.5"][:, -1, 32]
    L = act

    # logits = logits[:, :-1].clone()
    # x_idx = x_idx[:, 1:].clone()
    
    # log_p = -torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), x_idx.view(-1))
    # alpha = 0.0
    # beta = 1.0
    # return alpha * L + beta * log_p
    # return L
    # 

    logits = logits[:, -num_free_tokens-1:-1]
    
    probs = torch.log_softmax(logits, dim=-1)
    x = x[:, -num_free_tokens:]

    # print(x.shape)
    
    p = torch.sum(probs * x, dim=-1).sum(-1)
    # print(p.shape)

    # if print_loss:
        # print(f"L = {L.item()}")
        # print(f"p = {p.item()}")
    
    return 0.02 * p + L, p, L
    # return L
    # return p

In [None]:
import matplotlib.pyplot as plt
from IPython.display import display, HTML, clear_output

In [None]:
losses, p_losses, L_losses = [], [], []

In [None]:
top_k = 10
# num_free_tokens = 30
batch_size = 50

assert batch_size <= num_free_tokens * top_k


In [None]:
model.train()

for i in range(1000):
    model.train()
    
    # print(f"iter {i}")
    # print(start_ids)
    
    x_idx = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
    x = torch.nn.functional.one_hot(x_idx, num_classes=model.params.vocab_size).float()
    
    x.requires_grad = True

    hook.clear()
    logits = model(x, x_idx)
    activations = hook.activations
    hook.clear()
    
    l, p_loss, L_loss = loss(x_idx, x, logits, activations)

    losses.append(l.item())
    p_losses.append(p_loss.item())
    L_losses.append(L_loss.item())
    
    fig, axs = plt.subplots(2, 2)
    fig.set_size_inches(15, 10)
    axs[0, 0].plot(losses)
    axs[0, 0].set_title("Objective")

    axs[0, 1].plot(p_losses)
    axs[0, 1].set_title("Likelihood")

    axs[1, 0].plot(L_losses)
    axs[1, 0].set_title("Activation")
    plt.show()

    print(enc.decode(start_ids))
    print(f"objective = {l.item()}")


    clear_output(wait=True)
    
    
    # log_probs = torch.log_softmax(logits, axis=-1)

    # logits_cont = logits[:, -(len(continuation_enc)+1):-1]
    # print(logits_cont.shape)
    # l = -torch.nn.functional.cross_entropy(logits_cont.view(-1, logits_cont.size(-1)), continuation_tensor.view(-1))
    
    l.backward()

    replacement_tokens = []
    for t in range(1, num_free_tokens + 1):
        _, top_k_indices = torch.topk(x.grad[0, -t], top_k)
        # print(top_k_indices)
        # print(logits[0, -t].argmax())
        for ik in top_k_indices:
            replacement_tokens.append((t, ik.item()))

    random.shuffle(replacement_tokens)

    batch = []
    for t, new_token in replacement_tokens[:batch_size]:
        start_ids_repl = start_ids.copy()
        start_ids_repl[-t] = new_token
        batch.append(start_ids_repl)

    
    x_repl_idx = torch.tensor(batch, dtype=torch.long, device=device)
    x_repl = torch.nn.functional.one_hot(x_repl_idx, num_classes=model.params.vocab_size).float()

    # max_obj = -1e9
    max_obj = l[0].item()
    best_replacement = None
    with torch.no_grad():
        model.eval()

        hook.clear()
        logits = model(x_repl, x_repl_idx)
        activations = hook.activations
        hook.clear()
        
        l1, _, _ = loss(x_repl_idx, x_repl, logits, activations, print_loss=False)
        
        # logits_cont = logits[:, -(len(continuation_enc)+1):-1]
        # print(logits_cont.shape)
        # print(logits_cont.argmax())
        best_loss, best_index = torch.max(l1, dim=0)
        # if best_loss.item() > l[0].item():
        best_replacement = replacement_tokens[best_index.item()]
        # for i in range(batch_size):
        #     loss_i = l1[i].item()
        #     if loss_i > max_obj:
        #         best_replacement = replacement_tokens[i]
        #         max_obj = loss_i

    # print(f"max_obj = {max_obj}")

    if best_replacement is not None:
        # print("replacing", best_replacement, best_loss)
        best_t, best_token = best_replacement
        start_ids[-best_t] = best_token
    # else:
        # print("replacing none")
        
    model.zero_grad()

    
    
    # print(max_token)
    # print(enc.decode([max_token]))

In [None]:
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
#run generation
with torch.no_grad():
    model.eval()
    with ctx:
        for k in range(num_samples):
            y = model.generate(x, max_new_tokens, temperature=0.0, top_k=top_k)
            print(enc.decode(y[0].tolist()))
            print('---------------')

In [None]:
x

In [None]:
with torch.no_grad():
    model.eval()
    x_onehot = torch.nn.functional.one_hot(x, num_classes=model.params.vocab_size).float()
    logits = model(x_onehot, x)

In [849]:
logits.shape

torch.Size([1, 21, 32000])

In [850]:
torch.topk(torch.softmax(logits[:, -1], axis=-1), axis=-1, k=5)

torch.return_types.topk(
values=tensor([[0.5858, 0.0666, 0.0427, 0.0418, 0.0324]], device='mps:0'),
indices=tensor([[1260,  372, 1183,  540,  310]], device='mps:0'))

In [804]:
torch.softmax(logits[:, -1], axis=-1)[0, 1260]

tensor(0.0001, device='mps:0')

In [121]:
probs[0, -1].argmax()

tensor(368)

In [122]:
enc.decode([368])

'ly'

In [114]:
l = torch.log(probs[0, -1, 7870])

In [115]:
probs.shape

torch.Size([1, 4, 32000])

In [116]:
l.backward()

In [117]:
(x.grad[0, -1]).max()

tensor(9.7334)

In [118]:
torch.topk(x.grad[0, -1], 5)

torch.return_types.topk(
values=tensor([9.7334, 8.6131, 8.1965, 7.7321, 7.2178]),
indices=tensor([  611,   509, 20285,  1557,  2233]))

In [10]:
enc.encode("Tim", bos=False,eos=False)

[7870]

In [119]:
enc.decode([611])

'ma'