In [1]:
import torch
from torch.nn import functional as F

In [2]:
from model import GptLanguageModel
from NanoGPTLangugageModel import NanoGPTLanguageModel
from common import encode, decode, GptConfig, lp_hyperparameters
from einops import rearrange, einsum

In [3]:
hyperparameters = GptConfig()

In [4]:
model = GptLanguageModel(hyperparameters)
model.load_state_dict(torch.load('model_weights.pth'))
m = model.to(model.device)

In [5]:
start_str = "\n"
idx = torch.tensor(encode(start_str), dtype=torch.long, device=model.device).unsqueeze(0)
print(decode(m.generate(idx = idx, max_new_tokens=model.block_size)[0].tolist()))


Sirrah, I would touch that murdered?

RICHARD:
Petir, and thy good sovereign.

RATCLIFF:
That friend from to-the inform worthy brother.

EDWARD:
Nay! Forguring and Edward cousin!

LADY GREY:
So you Hold Gaunt Somerset, grief; from they should
That muny fri


In [4]:
def convert_state_dict_to_nano(nano_model, gpt_state_dict):
  """
  Converts a state_dict from an Einops model to a format compatible with the nano_model.

  Args:
    nano_model: The nano model you want to convert the state_dict to.
    gpt_state_dict: The state_dict from the Einops model.

  Returns:
    A dictionary containing the converted state_dict for the nano model.
  """
  nano_state_dict = nano_model.state_dict()
  n_head = hyperparameters.n_head

  # Copy over the parameters that don't need to be transformed
  for param in ['token_embedding_table.weight', 'position_embedding_table.weight', 'lm_head.weight']:
    field = param.split('.')[0]
    if param == 'lm_head.weight':
      nano_state_dict[param] = gpt_state_dict[field].T  # Convert from Einops format
    else:
      nano_state_dict[param] = gpt_state_dict[field]

  # Transform parameters for the blocks
  for i in range(hyperparameters.n_layer):
    block_prefix = f'blocks.{i}.'
    
    # Attention weights
    att_kvq = gpt_state_dict['attention_kvq'][i]
    key_weight, value_weight, query_weight = rearrange(att_kvq, "s c h d -> s c (h d)", h=n_head)

    nano_state_dict[block_prefix + 'sa_heads.key.weight'] = key_weight.T
    nano_state_dict[block_prefix + 'sa_heads.value.weight'] = value_weight.T
    nano_state_dict[block_prefix + 'sa_heads.query.weight'] = query_weight.T
    nano_state_dict[block_prefix + 'sa_heads.proj.weight'] = gpt_state_dict['out_proj'][i] 

    # Feedforward network weights and biases
    nano_state_dict[block_prefix + 'ffwd.net.0.weight'] = gpt_state_dict['w_in'][i].T
    nano_state_dict[block_prefix + 'ffwd.net.2.weight'] = gpt_state_dict['w_out'][i].T

  return nano_state_dict

In [5]:
model = NanoGPTLanguageModel(hyperparameters)
model.load_state_dict(convert_state_dict_to_nano(model, torch.load('model_weights.pth')))
m = model.to(model.device)

In [6]:
start_str = "\n"
idx = torch.tensor(encode(start_str), dtype=torch.long, device=model.device).unsqueeze(0)
print(decode(m.generate(idx = idx, max_new_tokens=model.block_size)[0].tolist()))


Secernatory with it it, and my candition
While we doubt it discontent it touche feeble.

Third Citizens:
We are time the confingel of proof this: and
to give stood write in this man belly,
That you do move the chance of the encounter's
Marriage.

FRIAR LAU


In [14]:
# matrix mult using einsum w/ B H T D tensor
torch.manual_seed(1337)
B, H, T, D = 1, 3, 4, 3
out = torch.randint(-10, 10, (B, H, T, D)) # (B, H, T, D)
out_proj = torch.randint(-10, 10, (H, D, H*D)) # (H, D, C)

In [15]:
K_cache = out
K_cache

tensor([[[[  5,   7,   2],
          [-10,   5,   3],
          [  5,   0,   4],
          [  0,   2, -10]],

         [[  7,  -4,   0],
          [  8,   1,  -6],
          [ -1,   5,   3],
          [ -4,  -8,   0]],

         [[ -8,  -9,   6],
          [  5,  -1,  -6],
          [  5,   9,  -4],
          [ -1,   9,   9]]]])

In [16]:
# keep the last block_size elements of the sequence in T
block_size = 3
K_cache = K_cache[:, :, -block_size:, :]
K_cache

tensor([[[[  5,   7,   2],
          [-10,   5,   3],
          [  5,   0,   4],
          [  0,   2, -10]],

         [[  7,  -4,   0],
          [  8,   1,  -6],
          [ -1,   5,   3],
          [ -4,  -8,   0]],

         [[ -8,  -9,   6],
          [  5,  -1,  -6],
          [  5,   9,  -4],
          [ -1,   9,   9]]]])

In [28]:
out_proj_T = rearrange(out_proj, 'h d c -> c h d') # this isn't transpose, how to transpose
einsum(out, out_proj_T, 'b h t d, c h d -> b t c') # (B, H, T, C)

tensor([[[-154,  224,  175,   75,  -45,   15,  224,  -18, -100],
         [-118,   61,   -8,  180,  -28,   -3,  -33,  -33,  -73],
         [ 137, -124,  -38,  -95,  -55,   49, -115,  113,   70],
         [  97,    7, -171,   -8,   91,   64,  123, -133, -206]]])

In [29]:
new_out = rearrange(out, 'b h t d -> b t (h d)') # (B, T, C)
new_out_proj = rearrange(out_proj, 'h d c -> (h d) c') # (C, C)
new_out_proj_T = new_out_proj.T

expected = F.linear(new_out, new_out_proj, bias=None) # (B, T, C)
actual = einsum(new_out, new_out_proj_T, 'b t c1, c1 c2  -> b t c2') # (B, T, C)
actual, expected == actual

(tensor([[[  49,  -97,   53,   34, -107, -146,  158,  -85,  -36],
          [ -72,  200,  -70,  175, -162, -143, -112, -187,    6],
          [  91,   55,  -38,   38,   30,   39,  -87,   61,   -4],
          [  50,   18,   86, -165,   72,  210,  146,  -25, -115]]]),
 tensor([[[True, True, True, True, True, True, True, True, True],
          [True, True, True, True, True, True, True, True, True],
          [True, True, True, True, True, True, True, True, True],
          [True, True, True, True, True, True, True, True, True]]]))

In [30]:
C = H * D
w_in = torch.randint(-10, 10, (C, 4*C)) 
w_out = torch.randint(-10, 10, (4*C, C))
w_in, w_out

(tensor([[  7,   3,  -2,   0,  -9,  -8,  -3,   5,   0,  -4,  -1,  -8,   4,   9,
           -9,   8, -10,  -7,   1,  -3,  -1,   6,  -9,   5,  -6,   6,   1,   1,
            5,  -9,   4,  -1,   2,   5, -10,   3],
         [  2,   3,   6,   1,  -5,   4,   9,   5,   0,   4,   5,   4,   9,   5,
           -7,  -5,  -4,   1,   4,   4,   3,   6,  -1,  -2,   5,   1,   8,  -6,
          -10,  -1,   4,   7,  -4,   1,  -5,   0],
         [-10,   4,  -4,  -7,   3,   3,   8,   6,   2,  -5,  -1,   0,   7,  -3,
           -9,  -8,  -1,   8,   4,  -7,  -7,  -9,   8,  -5,   1,   5,  -6,  -3,
            5,   1,  -2,   4,   2,   0,   4,  -5],
         [  2,   8,  -2,   4,  -6,   8,   4,   5,  -8,  -4,  -5, -10,  -5,   4,
           -2,   0, -10,   2,   2,  -1,   5,   4,   3,   1,   4,   0,  -1,  -9,
            1,   9,   3,  -5,   8,   1,   2,  -3],
         [  2,   8,   7,   9,  -4,  -1,  -4,   8,  -4,   3,  -3,  -9,   1,  -2,
          -10,   6,   0,   7,  -8,   6,  -6,  -9,  -9,  -9,  -1,  -9,   9,  

In [31]:
expected_mlp = F.linear(actual, w_in.T, bias=None) # (B, T, C)
actual_mlp = einsum(actual, w_in, 'b t c1, c1 c2  -> b t c2') # (B, T, C)