In [3]:
from model_tbyt_3 import GPT, GPTConfig
import torch
import os

def remap_state_dict_keys(state_dict):
    """Remap keys from notebook model naming to model_tbyt_3.py naming."""
    new_state_dict = {}
    for k, v in state_dict.items():
        new_key = k.replace('.attn.', '.c_attn.').replace('.mlp.', '.c_fc.')
        new_state_dict[new_key] = v
    return new_state_dict

itr_num = 60000
block_size = 16
vocab_size = 256
device = 'cpu'
config = GPTConfig(block_size=block_size, vocab_size=vocab_size)
model = GPT(config)
model_state_dict = torch.load(os.path.join(os.getcwd(), f'saved_models/Final_N256_K16_L2_H1_E32_r8over1_npos1_mlp1_dup0_testK16_iters60000.pt'), map_location=device)['model']
model_state_dict = remap_state_dict_keys(model_state_dict)

# Handle wpe size mismatch manually (checkpoint has 33, model has 65)
wpe_key = 'transformer.wpe.weight'
if wpe_key in model_state_dict:
    checkpoint_wpe = model_state_dict.pop(wpe_key)  # Remove from state_dict
    with torch.no_grad():
        model.transformer.wpe.weight[:checkpoint_wpe.size(0), :] = checkpoint_wpe

model.load_state_dict(model_state_dict, strict=False)
model.to(device=device)

GPT(
  (transformer): ModuleDict(
    (wte): Embedding(257, 32)
    (wpe): Embedding(65, 32)
    (h): ModuleList(
      (0-1): 2 x Block(
        (c_attn): CasualSelfAttention(
          (c_attn): Linear(in_features=32, out_features=96, bias=True)
          (c_proj): Linear(in_features=32, out_features=32, bias=True)
        )
        (c_fc): MLP(
          (fc_1): Linear(in_features=32, out_features=96, bias=True)
          (gelu): GELU(approximate='tanh')
          (fc_2): Linear(in_features=96, out_features=32, bias=True)
        )
        (ln_1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        (ln_2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      )
    )
    (ln_f): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=32, out_features=256, bias=False)
)

In [9]:
batch_size = 1
#test_1_seq = torch.cat((torch.arange(110, 30, -5) , torch.arange(127, 111, -1)), dim=0) 
#print(test_1_seq)
def get_batch(changing_num=-1, changing_index=-1, initial_sequence=None, batch_size=batch_size):
   def cat_sorted_tensor(x):
      if initial_sequence is not None:
         x = initial_sequence
      else:
         x = x
         #x, _ = torch.sort(x, descending=True)
      if changing_num != -1:
         if changing_index == -1:
            x[0] = changing_num
         else:
            x[changing_index] = changing_num
      #x = torch.cat((torch.tensor([100]).repeat(16), torch.tensor([1]).repeat(16)))
      #x = torch.tensor([100,100,100,100,1,1,1,1])
      vals, _ = torch.sort(x)
      #vals2, _ = torch.sort(x, descending=True)
      #print('vals are ', vals)
      return torch.cat((x, torch.tensor([vocab_size]), vals), dim=0)
   #x = torch.stack([cat_sorted_tensor(torch.randperm(vocab_size)[:block_size]) for _ in range(batch_size)])
   x = torch.stack([cat_sorted_tensor(torch.randperm(vocab_size)[:block_size]) for _ in range(batch_size)])
   return x

In [None]:
idx = get_batch()
print('idx dim is ', idx.shape)
logits, loss = model(idx)
print('loss is ', loss.item())
print(f'idx is: {idx}')
print('model output is ', torch.argmax(logits, dim=-1))

idx dim is  torch.Size([1, 33])
layer_n is  0
layer_n is  1
loss is  5.624738693237305
idx is: tensor([[132,  95, 187, 143, 204,  64, 173, 172,  83,  87, 190,  26, 149,  32,
          67,  19, 256,  19,  26,  32,  64,  67,  83,  87,  95, 132, 143, 149,
         172, 173, 187, 190, 204]])
