In [1]:
from spline_model import SplineGPT, GPTConfig 

config = GPTConfig(block_size=256, vocab_size=92, n_layer=4, n_head=4, n_embd=768, 
                   dropout=0.0, bias=True)

model = SplineGPT(config)

number of parameters: 30.79M


In [2]:
from spline_model import SplineGPT, SplineGPTConfig
import torch 

# Setup example input
batch_size = 4
seq_len = 16
vocab_size = 100
device = "mps"

# Create random input tokens
input_tokens = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)

# Create model config
config = SplineGPTConfig(
    block_size=256,        # maximum sequence length
    vocab_size=vocab_size, # vocabulary size
    n_layer=4,            # number of transformer layers
    n_head=4,             # number of attention heads
    n_embd=128,          # embedding dimension
    dropout=0.1,
    bias=True,
    spline_control_layers=None
)

# Initialize model
model = SplineGPT(config).to(device)

# Forward pass
logits, loss = model(input_tokens, targets=input_tokens)

print(f"Input shape: {input_tokens.shape}")
print(f"Output logits shape: {logits.shape}")
print(f"Loss value: {loss.item()}")

Total number of parameters: 0.87M
Control predictor number of parameters: 0.07M
Input shape: torch.Size([4, 16])
Output logits shape: torch.Size([4, 16, 100])
Loss value: 4.209731101989746


In [3]:
model.generate(input_tokens, 10)

tensor([[51, 39, 51, 67, 91, 15, 90, 49, 89, 62, 78,  6, 13, 25, 39, 50, 68, 27,
         51, 50, 52, 13, 68, 72, 90, 61],
        [35, 81, 52, 82,  7, 74, 60,  5, 87, 20, 87, 44, 27, 68, 29, 76, 84, 32,
         38, 71, 48, 51, 26, 92,  0, 80],
        [49,  6, 90, 27, 48, 40, 50, 30, 46, 75, 68, 73, 11,  6, 72, 66, 76, 26,
         58, 19, 27, 33, 16, 58, 18,  9],
        [17, 65,  7, 14, 38, 23, 45, 80, 17, 18, 69, 29, 52, 68, 54, 59, 81, 88,
         25, 11, 44, 23, 95, 62, 13, 23]], device='mps:0')

In [3]:
# save and load checkpoint 
# 1. save checkpoint 
torch.save({
    'model_state_dict': model.state_dict(),
    'config': config,
}, 'spline_model_checkpoint.pt')

In [4]:
checkpoint = torch.load('spline_model_checkpoint.pt')
loaded_config = checkpoint['config']
loaded_model = SplineGPT(loaded_config).to(device)
loaded_model.load_state_dict(checkpoint['model_state_dict'])

Total number of parameters: 0.87M
Control predictor number of parameters: 0.07M


  checkpoint = torch.load('spline_model_checkpoint.pt')


<All keys matched successfully>