In [220]:
import os
import pickle
from contextlib import nullcontext
import torch
import tiktoken
from model import ModelArgs, Transformer
from tokenizer import Tokenizer

use_mps = True

# -----------------------------------------------------------------------------
out_dir = 'out' # ignored if init_from is not 'resume'
start = "" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt"
num_samples = 1 # number of samples to draw
max_new_tokens = 100 # number of tokens generated in each sample
temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 300 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() and use_mps else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
#dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
dtype = "float32"
compile = False # use PyTorch 2.0 to compile the model to be faster
#exec(open('configurator.py').read()) # overrides from command line or config file
# -----------------------------------------------------------------------------



torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'mps' if 'mps' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type != 'cuda' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# init from a model saved in a specific directory
ckpt_path = os.path.join(out_dir, 'stories42M.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = ModelArgs(**checkpoint['model_args'])
model = Transformer(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    # if k == "tok_embeddings.weight":
    #     print(k, v.shape)
    #     state_dict[k] = v.transpose(1, 0)
    #     print(state_dict[k].shape)
model.load_state_dict(state_dict, strict=False)

# model.eval()
model.to(device)

Transformer(
  (tok_embeddings): Linear(in_features=32000, out_features=512, bias=False)
  (dropout): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0-7): 8 x TransformerBlock(
      (attention): Attention(
        (wq): Linear(in_features=512, out_features=512, bias=False)
        (wk): Linear(in_features=512, out_features=512, bias=False)
        (wv): Linear(in_features=512, out_features=512, bias=False)
        (wo): Linear(in_features=512, out_features=512, bias=False)
        (attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_dropout): Dropout(p=0.0, inplace=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=512, out_features=1376, bias=False)
        (w2): Linear(in_features=1376, out_features=512, bias=False)
        (w3): Linear(in_features=512, out_features=1376, bias=False)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
 

In [453]:
# load the tokenizer
enc = Tokenizer()

start = "! ! ! ! ! ! ! ! ! ! ! ! ! ! !"

# encode the beginning of the prompt
if start.startswith('FILE:'):
    with open(start[5:], 'r', encoding='utf-8') as f:
        start = f.read()
start_ids = enc.encode(start, bos=True, eos=False)
print(start_ids)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
print(x)



[1, 1738, 1738, 1738, 1738, 1738, 1738, 1738, 1738, 1738, 1738, 1738, 1738, 1738, 1738, 1738]
tensor([[   1, 1738, 1738, 1738, 1738, 1738, 1738, 1738, 1738, 1738, 1738, 1738,
         1738, 1738, 1738, 1738]], device='mps:0')


In [454]:
# x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
# # run generation
# with torch.no_grad():
#     model.eval()
#     with ctx:
#         for k in range(num_samples):
#             y = model.generate(x, max_new_tokens, temperature=0.0, top_k=top_k)
#             print(enc.decode(y[0].tolist()))
#             print('---------------')

In [455]:
start_ids

[1,
 1738,
 1738,
 1738,
 1738,
 1738,
 1738,
 1738,
 1738,
 1738,
 1738,
 1738,
 1738,
 1738,
 1738,
 1738]

In [456]:
# num_free_tokens = len(start_ids) - 1 #30
num_free_tokens = 15

In [457]:
# start_ids = start_ids + [1738]*num_free_tokens

In [458]:
import random

In [459]:
def loss(x_idx, x, logits, activations, print_loss=True):
    act = activations[4][:, -1]
    L = torch.sum(act * act, dim=-1)

    # logits = logits[:, :-1].clone()
    # x_idx = x_idx[:, 1:].clone()
    
    # log_p = -torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), x_idx.view(-1))
    # alpha = 0.0
    # beta = 1.0
    # return alpha * L + beta * log_p
    # return L
    # 

    logits = logits[:, -num_free_tokens-1:-1]
    
    probs = torch.log_softmax(logits, dim=-1)
    x = x[:, -num_free_tokens:]

    # print(x.shape)
    
    p = torch.sum(probs * x, dim=-1).sum(-1)
    # print(p.shape)

    # if print_loss:
    #     print(f"L = {L.item()}")
    #     print(f"p = {p.item()}")
    
    # return 0.5 * p + 0.5 * L
    # return L
    return p

In [460]:
model.train()

top_k = 30
# num_free_tokens = 30
batch_size = 60

assert batch_size <= num_free_tokens * top_k

for i in range(1000):
    model.train()
    # print(f"iter {i}")
    # print(start_ids)
    print(enc.decode(start_ids))
    x_idx = torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]
    x = torch.nn.functional.one_hot(x_idx, num_classes=model.params.vocab_size).float()
    
    x.requires_grad = True

    logits, activations = model(x, x_idx, return_activations=True)
    l = loss(x_idx, x, logits, activations)
    # log_probs = torch.log_softmax(logits, axis=-1)

    # logits_cont = logits[:, -(len(continuation_enc)+1):-1]
    # print(logits_cont.shape)
    # l = -torch.nn.functional.cross_entropy(logits_cont.view(-1, logits_cont.size(-1)), continuation_tensor.view(-1))
    
    print(f"objective = {l.item()}")

    l.backward()

    replacement_tokens = []
    for t in range(1, num_free_tokens + 1):
        _, top_k_indices = torch.topk(x.grad[0, -t], top_k)
        # print(top_k_indices)
        # print(logits[0, -t].argmax())
        for ik in top_k_indices:
            replacement_tokens.append((t, ik.item()))

    random.shuffle(replacement_tokens)

    batch = []
    for t, new_token in replacement_tokens[:batch_size]:
        start_ids_repl = start_ids.copy()
        start_ids_repl[-t] = new_token
        batch.append(start_ids_repl)

    
    x_repl_idx = torch.tensor(batch, dtype=torch.long, device=device)
    x_repl = torch.nn.functional.one_hot(x_repl_idx, num_classes=model.params.vocab_size).float()

    max_obj = -1e9
    # max_obj = l[0].item()
    best_replacement = None
    with torch.no_grad():
        model.eval()
        logits, activations = model(x_repl, x_repl_idx, return_activations=True)
        l1 = loss(x_repl_idx, x, logits, activations, print_loss=False)
        
        # logits_cont = logits[:, -(len(continuation_enc)+1):-1]
        # print(logits_cont.shape)
        # print(logits_cont.argmax())
        for i in range(batch_size):
            loss_i = l1[i].item()
            if loss_i > max_obj:
                best_replacement = replacement_tokens[i]
                max_obj = loss_i

    # print(f"max_obj = {max_obj}")

    if best_replacement is not None:
        best_t, best_token = best_replacement
        start_ids[-best_t] = best_token
        
    model.zero_grad()
    # print(max_token)
    # print(enc.decode([max_token]))

! ! ! ! ! ! ! ! ! ! ! ! ! ! !
objective = -156.90789794921875
! ! ! !... ! ! ! ! ! ! ! ! ! !
objective = -153.88275146484375
! ! ! bla... ! ! ! ! ! ! ! ! ! !
objective = -148.08917236328125
! ! ! bla... ! ! ! !osh ! ! ! ! !
objective = -145.60289001464844
... ! ! bla... ! ! ! !osh ! ! ! ! !
objective = -141.79090881347656
... ! ! boot... ! ! ! !osh ! ! ! ! !
objective = -140.93704223632812
... !ton boot... ! ! ! !osh ! ! ! ! !
objective = -138.7862548828125
... !ton boot... !osh ! !osh ! ! ! ! !
objective = -133.75059509277344
... !ton boot... !osh ! !osh ! ! !iting !
objective = -132.04266357421875
... !ton boot... !osh ! !osh !O !iting !
objective = -127.2388916015625
...Eton boot... !osh ! !osh !O !iting !
objective = -120.1210708618164
...Eton boot... !osh ! !W !O !iting !
objective = -121.9192123413086
...Eton boot... !osh ! !W !O !iting down
objective = -118.85062408447266
...Eton boot... !osh ! !W !ck !iting down
objective = -121.3896484375
...Eton boot... !osh ! !ack !ck !iting

KeyboardInterrupt: 

In [None]:
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
#run generation
with torch.no_grad():
    model.eval()
    with ctx:
        for k in range(num_samples):
            y = model.generate(x, max_new_tokens, temperature=0.0, top_k=top_k)
            print(enc.decode(y[0].tolist()))
            print('---------------')

In [819]:
x

tensor([[    1, 14084,  5271,  5381, 10395, 17845,   413,  1424,   961,  1145,
          7155,  4446,   907,  4983,  7933,  2038,  9644,  2020,   968,  1670,
          1310]], device='mps:0')

In [848]:
with torch.no_grad():
    model.eval()
    x_onehot = torch.nn.functional.one_hot(x, num_classes=model.params.vocab_size).float()
    logits = model(x_onehot, x)

In [849]:
logits.shape

torch.Size([1, 21, 32000])

In [850]:
torch.topk(torch.softmax(logits[:, -1], axis=-1), axis=-1, k=5)

torch.return_types.topk(
values=tensor([[0.5858, 0.0666, 0.0427, 0.0418, 0.0324]], device='mps:0'),
indices=tensor([[1260,  372, 1183,  540,  310]], device='mps:0'))

In [804]:
torch.softmax(logits[:, -1], axis=-1)[0, 1260]

tensor(0.0001, device='mps:0')

In [121]:
probs[0, -1].argmax()

tensor(368)

In [122]:
enc.decode([368])

'ly'

In [114]:
l = torch.log(probs[0, -1, 7870])

In [115]:
probs.shape

torch.Size([1, 4, 32000])

In [116]:
l.backward()

In [117]:
(x.grad[0, -1]).max()

tensor(9.7334)

In [118]:
torch.topk(x.grad[0, -1], 5)

torch.return_types.topk(
values=tensor([9.7334, 8.6131, 8.1965, 7.7321, 7.2178]),
indices=tensor([  611,   509, 20285,  1557,  2233]))

In [10]:
enc.encode("Tim", bos=False,eos=False)

[7870]

In [119]:
enc.decode([611])

'ma'