In [1]:
import math
import functools

import jax
import jax.numpy as jnp 
from jax import random
from jax.tree_util import tree_map

import flax
import flax.linen as nn

import numpy as np

from general.utils import ModelConfig
from jax_impl.model import *

In [2]:
key = random.PRNGKey(42)

model_config = ModelConfig(
    vocab_size=1024, 
    context_window=64, 
    model_type='gpt-nano')
model_config

ModelConfig(vocab_size=1024, context_window=64, n_embd=48, n_head=3, n_layer=3, model_type='gpt-nano', attn_pdrop=0.1, recid_pdrop=0.1, embd_pdrop=0.1)

In [3]:
rx, rp, rd, key = random.split(key, 4)
init_rngs = {'params': rp, 'dropout': rd}

x = random.normal(rx, (10, 4, 48))

attn = Attention(model_config)
params = attn.init(init_rngs, x)

out = attn.apply(params, x, rngs={'dropout': rd})
out.shape

(10, 4, 48)

In [4]:
blk = Block(model_config)
params = blk.init(init_rngs, x)

output = blk.apply(params, x, rngs={'dropout': rd})
output.shape

(10, 4, 48)

In [5]:
idx = random.randint(rx, (10, 4), 0, maxval=model_config.vocab_size, dtype=jnp.integer)

emb = Embedding(model_config)
params = emb.init(init_rngs, idx)

output = emb.apply(params, idx, rngs={'dropout': rd})
output.shape

(10, 4, 48)

In [6]:
gpt = GPT(model_config)
params = gpt.init(init_rngs, idx)

output = gpt.apply(params, idx, rngs={'dropout': rd})
output.shape

(10, 4, 1024)

In [7]:
print(gpt.tabulate(init_rngs, idx))




