In [1]:
# Standard imports for loading model
import torch
import os
from checkpoint_utils import load_checkpoint

In [2]:
# Normal PyTorch loading
from model_tbyt_3 import GPT, GPTConfig

# Config inferred from checkpoint filename:
# dec28_tbyt_without-pos-embedding_n_embd:64_1head_layers:2_vocab_size:128
checkpoint_path = os.path.join(os.getcwd(), 'saved_models/dec28_tbyt_without-pos-embedding_n_embd:64_1head_layers:2_vocab_size:128_itr:60000_checkpoint.pt')
device = 'cpu'

# Create config
config = GPTConfig(block_size=32, vocab_size=128, without_pos=True)
config.n_embd = 64
config.n_heads = 1
config.n_layers = 2

# Instantiate model
model = GPT(config)

# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
model.load_state_dict(checkpoint['model'])
model.eval()

print(f"Model loaded successfully! Config: block_size={config.block_size}, vocab_size={config.vocab_size}")

vocab_n = config.vocab_size
block_size = config.block_size

batch_size = 1
def get_batch(changing_num=-1, changing_index=-1, initial_sequence=None, batch_size=batch_size):
   def cat_sorted_tensor(x):
      if initial_sequence is not None:
         x = initial_sequence
      else:
         x = x
      if changing_num != -1:
         if changing_index == -1:
            x[0] = changing_num
         else:
            x[changing_index] = changing_num
      vals, _ = torch.sort(x)
      # Use vocab_n as separator
      return torch.cat((x, torch.tensor([vocab_n]), vals), dim=0)
   x = torch.stack([cat_sorted_tensor(torch.randperm(vocab_n)[:block_size]) for _ in range(batch_size)])
   return x

Model loaded successfully! Config: block_size=32, vocab_size=128


In [3]:
# Alternative: Load using load_checkpoint utility
checkpoint_path_alt = os.path.join(os.getcwd(), 'Grid_training_without_duplicates/Final_N256_K16_L2_H1_E32_r8over1_npos1_mlp1_dup0_testK16_iters60000.pt')
model, config = load_checkpoint(checkpoint_path_alt, device='cpu')
vocab_n = config.vocab_size - 1
block_size = config.block_size

batch_size = 1
def get_batch(changing_num=-1, changing_index=-1, initial_sequence=None, batch_size=batch_size):
   def cat_sorted_tensor(x):
      if initial_sequence is not None:
         x = initial_sequence
      else:
         x = x
      if changing_num != -1:
         if changing_index == -1:
            x[0] = changing_num
         else:
            x[changing_index] = changing_num
      vals, _ = torch.sort(x)
      # Use vocab_n as separator
      return torch.cat((x, torch.tensor([vocab_n]), vals), dim=0)
   x = torch.stack([cat_sorted_tensor(torch.randperm(vocab_n)[:block_size]) for _ in range(batch_size)])
   return x

In [4]:
# UPDATED: Inference with new model
idx = get_batch()
print('idx dim is ', idx.shape)
logits, loss = model(idx)
if loss is not None:
    print('loss is ', loss.item())
print(f'idx is: {idx}')
print('model output is ', torch.argmax(logits, dim=-1))

idx dim is  torch.Size([1, 33])
loss is  0.0005935349618084729
idx is: tensor([[124,  68, 159, 220, 183, 239, 133, 166,  78,  33,  49, 123,   3,   9,
         130, 195, 256,   3,   9,  33,  49,  68,  78, 123, 124, 130, 133, 159,
         166, 183, 195, 220, 239]])
model output is  tensor([[124,  81, 159, 220, 183, 239, 133, 166, 124,  33,  49, 124,   3,   9,
         131, 195,   3,   9,  33,  49,  68,  78, 123, 124, 130, 133, 159, 166,
         183, 195, 220, 239, 239]])
