# 

In [1]:
from functools import partial

import joblib
import rootutils
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
path_to_root = rootutils.find_root(indicator=".project-root")

In [3]:
def print_number_of_trainable_model_parameters(model: AutoModelForCausalLM) -> str:
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()

    return (
        f"\ntrainable model parameters: {trainable_model_params}"
        f"\nall model parameters: {all_model_params}"
        f"\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"
    )

## Experiments with tokenizer

In [4]:
tokenizer_gemma = AutoTokenizer.from_pretrained("google/gemma-3-270m")
tokenizer_modified = AutoTokenizer.from_pretrained("transhumanist-already-exists/tereshchenkoblue-tokenizer")

In [5]:
# token -> id
vocab_gemma = tokenizer_gemma.get_vocab()
vocab_modified = tokenizer_modified.get_vocab()

print("Sample tokens from Gemma tokenizer:", list(vocab_gemma.items())[:5])
print(f"Vocab sizes: Gemma = {len(vocab_gemma)}, Modiified = {len(vocab_modified)}")

Sample tokens from Gemma tokenizer: [('▁pede', 139534), ('▁disproportion', 58670), ('▁mindset', 39927), ('xt', 903), ('isque', 105598)]
Vocab sizes: Gemma = 262145, Modiified = 262145


In [6]:
# Inverse maps id -> token
inv_gemma = {v: k for k, v in vocab_gemma.items()}
inv_modified = {v: k for k, v in vocab_modified.items()}

tokens_gemma = set(vocab_gemma.keys())
tokens_modified = set(vocab_modified.keys())

In [7]:
inv_gemma[0], inv_gemma[262144], inv_modified[0], inv_modified[262144]

('<pad>', '<image_soft_token>', '<pad>', '<image_soft_token>')

In [8]:
# Overlap analysis
shared_tokens = tokens_gemma & tokens_modified
only_in_gemma = tokens_gemma - tokens_modified
only_in_modified = tokens_modified - tokens_gemma

In [9]:
same_token_same_id = 0
same_token_diff_id = 0
for t in shared_tokens:
    id_gemma = vocab_gemma[t]
    id_modified = vocab_modified[t]
    if id_gemma == id_modified:
        same_token_same_id += 1
    else:
        same_token_diff_id += 1

# same numeric id that maps to different tokens
shared_ids = set(inv_gemma.keys()) & set(inv_modified.keys())
same_id_diff_token = sum(1 for i in shared_ids if inv_gemma[i] != inv_modified[i])
same_id_same_token = sum(1 for i in shared_ids if inv_gemma[i] == inv_modified[i])

# Summary
summary = {
    "vocab_size_1": len(vocab_gemma),
    "vocab_size_2": len(vocab_modified),
    "shared_token_count": len(shared_tokens),
    "same_token_same_id": same_token_same_id,
    "same_token_diff_id": same_token_diff_id,
    "same_id_count": len(shared_ids),
    "same_id_same_token": same_id_same_token,
    "same_id_diff_token": same_id_diff_token,
    "only_in_1": len(only_in_gemma),
    "only_in_2": len(only_in_modified),
}

print("SUMMARY:")
for k, v in summary.items():
    print(f"  {k}: {v}")

SUMMARY:
  vocab_size_1: 262145
  vocab_size_2: 262145
  shared_token_count: 187598
  same_token_same_id: 180656
  same_token_diff_id: 6942
  same_id_count: 262145
  same_id_same_token: 180656
  same_id_diff_token: 81489
  only_in_1: 74547
  only_in_2: 74547


In [10]:
# tokens_same_id_and_same_token = [inv_gemma[i] for i in shared_ids if inv_gemma[i] == inv_modified[i]]
ids_same_id_and_same_token = [i for i in shared_ids if inv_gemma[i] == inv_modified[i]]
print(
    f"Same id and same token: {ids_same_id_and_same_token[:3] + ids_same_id_and_same_token[-3:]}"
    f" (total {len(ids_same_id_and_same_token)})"
)

Same id and same token: [0, 1, 2, 256002, 256003, 262144] (total 180656)


## Model Experiments

In [11]:
model_gemma = AutoModelForCausalLM.from_pretrained("google/gemma-3-270m")
print(print_number_of_trainable_model_parameters(model_gemma))


trainable model parameters: 268098176
all model parameters: 268098176
percentage of trainable model parameters: 100.00%


In [12]:
model_gemma

Gemma3ForCausalLM(
  (model): Gemma3TextModel(
    (embed_tokens): Gemma3TextScaledWordEmbedding(262144, 640, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x Gemma3DecoderLayer(
        (self_attn): Gemma3Attention(
          (q_proj): Linear(in_features=640, out_features=1024, bias=False)
          (k_proj): Linear(in_features=640, out_features=256, bias=False)
          (v_proj): Linear(in_features=640, out_features=256, bias=False)
          (o_proj): Linear(in_features=1024, out_features=640, bias=False)
          (q_norm): Gemma3RMSNorm((256,), eps=1e-06)
          (k_norm): Gemma3RMSNorm((256,), eps=1e-06)
        )
        (mlp): Gemma3MLP(
          (gate_proj): Linear(in_features=640, out_features=2048, bias=False)
          (up_proj): Linear(in_features=640, out_features=2048, bias=False)
          (down_proj): Linear(in_features=2048, out_features=640, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma3RMSNorm((640,), eps

In [13]:
model_gemma.get_input_embeddings().weight.shape

torch.Size([262144, 640])

In [15]:
t = torch.tensor(sorted(ids_same_id_and_same_token), dtype=torch.long)

valid_mask = t < model_gemma.get_input_embeddings().weight.shape[0]
invalid_ids = t[~valid_mask]
ids_to_freeze = t[valid_mask]

joblib.dump(ids_to_freeze, path_to_root / "data" / "freeze_ids_gemma_tereshchenko.joblib")

print(f"Number of valid IDs: {len(ids_to_freeze)}")
print(f"Number of invalid IDs: {len(invalid_ids)}")

Number of valid IDs: 180655
Number of invalid IDs: 1


In [16]:
def _hook(grad, ids):
    if grad is None:
        return None

    # Make a copy of the gradient tensor to avoid in-place modification issues
    # But probably this is not necessary
    grad = grad.clone()
    grad[ids] = 0
    return grad


hook_handle = model_gemma.get_input_embeddings().weight.register_hook(partial(_hook, ids=ids_to_freeze))

In [17]:
model_gemma.get_input_embeddings().weight[0][0].requires_grad

True

In [18]:
model_gemma.get_input_embeddings(), model_gemma.get_input_embeddings().weight.shape

(Gemma3TextScaledWordEmbedding(262144, 640, padding_idx=0),
 torch.Size([262144, 640]))

## Quick smoke test

+ run a forward+backward and confirm grads were zeroed for those ids

In [19]:
model_gemma.train()

input_text = "Hello world, мене звати Роман, я хочу дещо протестувати."
enc = tokenizer_gemma(input_text, return_tensors="pt")
outputs = model_gemma(**enc, labels=enc["input_ids"])

loss = outputs.loss
loss.backward()

emb = model_gemma.get_input_embeddings()

if emb.weight.grad is None or ids_to_freeze.numel() == 0:
    total_grad_l1 = 0.0
else:
    row_grad_l1 = emb.weight.grad.abs().sum(dim=1).detach().cpu()
    ids_cpu = ids_to_freeze.cpu().to(torch.long)
    freezed_grad_l1 = float(row_grad_l1[ids_cpu].sum().item())
    total_grad_l1 = float(row_grad_l1.sum().item())

print(f"Total grad L1: {total_grad_l1}, freezed grad L1: {freezed_grad_l1}")

It is strongly recommended to train Gemma3 models with the `eager` attention implementation instead of `sdpa`. Use `eager` with `AutoModelForCausalLM.from_pretrained('<path-to-checkpoint>', attn_implementation='eager')`.


Total grad L1: 2780.26708984375, freezed grad L1: 0.0


In [20]:
print_number_of_trainable_model_parameters(model_gemma)

'\ntrainable model parameters: 268098176\nall model parameters: 268098176\npercentage of trainable model parameters: 100.00%'

## We can't just set `requires_grad=False` for specific rows

In [21]:
emb = model_gemma.get_input_embeddings()
print("param is leaf:", emb.weight.is_leaf)
print("param.requires_grad:", emb.weight.requires_grad)

# indexing returns a view (non-leaf) that inherits requires_grad
x = emb.weight[0]
print("view is leaf:", x.is_leaf)
print("view.requires_grad:", x.requires_grad)

param is leaf: True
param.requires_grad: True
view is leaf: False
view.requires_grad: True


In [22]:
try:
    view = emb.weight[ids_to_freeze[0]]
    view.requires_grad_(False)
except Exception as e:
    print("Attempt to set requires_grad_ on view failed as expected:\n-", repr(e))

Attempt to set requires_grad_ on view failed as expected:
- RuntimeError("you can only change requires_grad flags of leaf variables. If you want to use a computed variable in a subgraph that doesn't require differentiation use var_no_grad = var.detach().")
