In [1]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
model = AutoModelForCausalLM.from_pretrained("google/gemma-3-1b-it")

  from .autonotebook import tqdm as notebook_tqdm


In [29]:
model = model.eval()
model

Gemma3ForCausalLM(
  (model): Gemma3TextModel(
    (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 1152, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma3DecoderLayer(
        (self_attn): Gemma3Attention(
          (q_proj): Linear(in_features=1152, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1152, out_features=256, bias=False)
          (v_proj): Linear(in_features=1152, out_features=256, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1152, bias=False)
          (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
          (k_norm): Gemma3RMSNorm((256,), eps=1e-06)
        )
        (mlp): Gemma3MLP(
          (gate_proj): Linear(in_features=1152, out_features=6912, bias=False)
          (up_proj): Linear(in_features=1152, out_features=6912, bias=False)
          (down_proj): Linear(in_features=6912, out_features=1152, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma3RMSNorm((11

In [25]:
inputs = tokenizer.encode("Fransanın başkenti Paris, Fransa'da yer alır ve yaklaşık 170 kilometre (112 mil)", return_tensors="pt")
inputs

tensor([[     2,   4255,    743,  58434,  15274, 236767,  12135,   9079, 236764,
         204411, 236789,   1926,  23115, 224506,   1381, 137104, 236743, 236770,
         236832, 236771, 151409,    568, 236770, 236770, 236778,   4743, 236768]])

In [26]:
outputs = model.generate(inputs)


In [27]:
outputs

tensor([[     2,   4255,    743,  58434,  15274, 236767,  12135,   9079, 236764,
         204411, 236789,   1926,  23115, 224506,   1381, 137104, 236743, 236770,
         236832, 236771, 151409,    568, 236770, 236770, 236778,   4743, 236768,
         119556,   3482,  35618,  81332, 128693, 113012, 236761,   9079, 236764,
         120411, 236959,   2921, 227623, 236764,  23276,    980,  97230, 236752,
         173787,   1381]])

In [30]:
tokenizer.convert_ids_to_tokens(outputs[0])

['<bos>',
 'Fr',
 'ans',
 'anın',
 '▁baş',
 'k',
 'enti',
 '▁Paris',
 ',',
 '▁Fransa',
 "'",
 'da',
 '▁yer',
 '▁alır',
 '▁ve',
 '▁yaklaşık',
 '▁',
 '1',
 '7',
 '0',
 '▁kilometre',
 '▁(',
 '1',
 '1',
 '2',
 '▁mil',
 ')',
 'ব্যাপী',
 '▁bir',
 '▁alan',
 '▁üzerine',
 '▁kurul',
 'udur',
 '.',
 '▁Paris',
 ',',
 '▁kült',
 'ü',
 'rel',
 '▁miras',
 ',',
 '▁mim',
 'ari',
 '▁çeşit',
 'l',
 'iliği',
 '▁ve']

In [32]:
tokenizer.encode(" geniş")

[2, 124766]

In [28]:
tokenizer.decode(outputs[0])

"<bos>Fransanın başkenti Paris, Fransa'da yer alır ve yaklaşık 170 kilometre (112 mil)ব্যাপী bir alan üzerine kuruludur. Paris, kültürel miras, mimari çeşitliliği ve"

## Burada ilginç bir durum oldu, model anlam olarak yakın başka dildeki bir tokeni next token olarak seçti. Bu da bizim daha önce tartıştığımız anlamsal olarak tokenlerin diğer dillere çevrilebilirliği ile ilgili."<bos>Fransanın başkenti Paris, Fransa'da yer alır ve yaklaşık 170 kilometre (112 mil)ব্যাপী bir alan üzerine kuruludur. Paris, kültürel miras, mimari çeşitliliği ve" ব্যাপী kelimesi Bangali dilinde geniş anlamına geliyor.

In [12]:
from config import GemmaConfig, Architecture, AttentionType, get_config_for_1b


test_config = GemmaConfig(
  architecture=Architecture.GEMMA_3,
  num_hidden_layers=1,
  num_attention_heads=1,
  num_key_value_heads=1,
  hidden_size=3,
  intermediate_size=3,
  use_pre_ffw_norm=True,
  use_post_ffw_norm=True,
  head_dim=3,

  sliding_window_size=4,
  rope_wave_length={
    AttentionType.LOCAL_SLIDING: 10,
    AttentionType.GLOBAL: 100,
  },
  vocab_size=tokenizer.vocab_size,
  max_position_embeddings=12,
  tokenizer=tokenizer,
  use_qk_norm=True,
  vision_config=None
)

print(test_config.vocab_size)

test_config = get_config_for_1b(dtype='float32')
print(test_config.vocab_size)

test_config.tokenizer = tokenizer
test_config.vocab_size = tokenizer.vocab_size
print(test_config.vocab_size)


33
262144
33


In [13]:
import torch

device = 'cpu'

if torch.cuda.is_available():
    device = 'cuda'
""" elif torch.backends.mps.is_available():
    device = 'mps' """

print(device)

cpu


In [14]:
from model import GemmaForCausalLM
model = GemmaForCausalLM(test_config)
model.to(device)
model

GemmaForCausalLM(
  (embedder): Embedding()
  (model): GemmaModel(
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): GemmaAttention(
          (qkv_proj): Linear()
          (o_proj): Linear()
          (query_norm): RMSNorm()
          (key_norm): RMSNorm()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear()
          (up_proj): Linear()
          (down_proj): Linear()
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
        (pre_feedforward_layernorm): RMSNorm()
        (post_feedforward_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (sampler): Sampler()
)

In [15]:
model.generate("a", device=device, output_len=1)

x:  torch.Size([1, 1, 4, 256])
freqs_cis:  torch.Size([1, 128])
x:  torch.Size([1, 1, 1, 256])
freqs_cis:  torch.Size([1, 128])
x:  torch.Size([1, 1, 4, 256])
freqs_cis:  torch.Size([1, 128])
x:  torch.Size([1, 1, 1, 256])
freqs_cis:  torch.Size([1, 128])
x:  torch.Size([1, 1, 4, 256])
freqs_cis:  torch.Size([1, 128])
x:  torch.Size([1, 1, 1, 256])
freqs_cis:  torch.Size([1, 128])
x:  torch.Size([1, 1, 4, 256])
freqs_cis:  torch.Size([1, 128])
x:  torch.Size([1, 1, 1, 256])
freqs_cis:  torch.Size([1, 128])
x:  torch.Size([1, 1, 4, 256])
freqs_cis:  torch.Size([1, 128])
x:  torch.Size([1, 1, 1, 256])
freqs_cis:  torch.Size([1, 128])
x:  torch.Size([1, 1, 4, 256])
freqs_cis:  torch.Size([1, 128])
x:  torch.Size([1, 1, 1, 256])
freqs_cis:  torch.Size([1, 128])
x:  torch.Size([1, 1, 4, 256])
freqs_cis:  torch.Size([1, 128])
x:  torch.Size([1, 1, 1, 256])
freqs_cis:  torch.Size([1, 128])
x:  torch.Size([1, 1, 4, 256])
freqs_cis:  torch.Size([1, 128])
x:  torch.Size([1, 1, 1, 256])
freqs_cis

'j'