In [1]:
import torch
from transformers import AutoTokenizer
from models.modeling_xgemma import XGemmaForCausalLM, XGemmaConfig

  import torch.distributed._shard.checkpoint as dist_cp


In [28]:
config = XGemmaConfig(
    vocab_size=256000,  # Gemma-2 uses 256k vocabulary
    hidden_size=2048,   # Hidden size for 1B model
    intermediate_size=16384,  # FFN intermediate size
    num_hidden_layers=18,  # Number of transformer layers
    num_attention_heads=8,  # Number of attention heads
    num_key_value_heads=1,  # GQA: number of key-value heads
    head_dim=256,  # Dimension per attention head
    max_position_embeddings=8192,  # Context length
    rms_norm_eps=1e-6,
    rope_theta=10000.0,
    attention_bias=False,
    attention_dropout=0.0,
    # Custom XGemma parameters
    projector_type='mlp2x_gelu',
    retriever_hidden_size=4096,
)
    

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [29]:
model = XGemmaForCausalLM(config)
_ = model.to(device)

In [30]:
model

XGemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): GELUTanh()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
      )
    )
    (norm): GemmaRMSNorm((2048,), eps=1

In [31]:
model.set_xrag_token_id(256000)

In [20]:
batch = {
    'input_ids': torch.tensor([[
        733, 16289, 28793, 24316, 28747, 32001, 28725, 690, 835, 2825,
        28747, 28792, 28748, 16289, 28793, 415, 2990, 302, 9143, 403,
        16783, 23799, 356, 4117, 28705, 28740, 28787, 28725, 28705, 28740,
        28787, 28750, 28750, 28725, 304, 403, 10806, 1987, 9143, 8897
    ]]).to(device),
    'attention_mask': torch.ones(1, 40).to(device),  # Simplified for example
    'retrieval_embeds': torch.randn(1, 128).to(device),  # Random retrieval embedding
    'labels': torch.tensor([[
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, 415, 2990, 302, 9143, 403,
        16783, 23799, 356, 4117, 28705, 28740, 28787, 28725, 28705, 28740,
        28787, 28750, 28750, 28725, 304, 403, 10806, 1987, 9143, 8897
    ]]).to(device),
}

In [None]:
outputs = model(
        input_ids=batch['input_ids'],
        attention_mask=batch['attention_mask'],
        retrieval_embeds=batch['retrieval_embeds'],
        labels=batch['labels'],
    )

In [6]:
from configs.datasets import dataset as DatasetConfig
from configs.training import train_config as TrainConfig
from configs.distillation import distillation_config as DistillationConfig
from configs.fsdp import fsdp_config as FsdpConfig

In [7]:
dataset_config = DatasetConfig()
train_config = TrainConfig()
distill_config = DistillationConfig()
fsdp_config = DistillationConfig()
train_config.model_name = "google/gemma-3-1b-it"
distill_config.model_name = "Hannibal046/xrag-7b"
train_config.batch_size_training = 1
train_config.num_workers_dataloader = 1
dataset_config.file = "data/loaders/squad-v2-sampled.py"

In [8]:
from transformers import AutoTokenizer
xgemma_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")

In [9]:
xgemma_tokenizer.add_special_tokens({"additional_special_tokens": ["<xRAG>"]})

1

In [32]:
model.resize_token_embeddings(len(xgemma_tokenizer))

Embedding(256001, 2048, padding_idx=0)

In [11]:
from data.data_utils import get_dataloader
student_train_dataloader, student_eval_dataloader = get_dataloader(dataset_config, train_config, xgemma_tokenizer, 0, distill_config)

--> Training Set Length = 1000
--> Validation Set Length = 250


In [12]:
iter1 = iter(student_train_dataloader)

In [25]:
item = next(iter1)
# item = batch

In [33]:
item

{'input_ids': tensor([[   106,   1645,    108,   7266, 235292, 235248, 256000, 235269,    948,
           1170,   3454, 235292,    107, 235274, 235265,  45029,    840,  98615,
            109, 235284, 235265,  11976,  98615,    109, 235304, 235265,  76759,
           1706,    109, 235310, 235265,  76759,  11827,    109, 235308, 235265,
          76759,   6181,    109, 235318, 235265,  76759,  12776,    109, 235324,
         235265,  76759,  46002,    109, 235321, 235265,  76759,  51625,    109,
         235315, 235265,  76759,  19967,    109, 235274, 235276, 235265,  76759,
         113061,    109, 235274, 235274, 235265,  76759,   5368,    109, 235274,
         235284, 235265,  76759, 127403,    109, 235274, 235304, 235265,  76759,
           5239,    109, 235274, 235310, 235265,  76759,  22241,    109, 235274,
         235308, 235265,  76759,  46002,    604,  38823,    109, 235274, 235318,
         235265,  76759,  46002,    604,  61348,    109, 235274, 235324, 235265,
          7675

In [16]:
xgemma_tokenizer.decode(item['input_ids'][0])

'<start_of_turn>user\nBackground: <xRAG>, which also means:<end_of_turn>Internet service provider'

In [17]:
xgemma_tokenizer.convert_tokens_to_ids(["<xRAG>"])

[256000]

In [None]:
item['retrieval_embeds'].to(device)

In [36]:
input_ids=item['input_ids'].to(device)
attention_mask=item['attention_mask'].to(device)
retrieval_embeds=item['retrieval_embeds'].to(device)
labels=item['labels'].to(device)

In [None]:
input_ids

In [37]:
retrieval_embeds.shape

torch.Size([1, 4096])

In [38]:
outputs = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        retrieval_embeds=retrieval_embeds,
        labels=labels,
    )

In [39]:
outputs.loss

tensor(12.8053, device='cuda:0', grad_fn=<NllLossBackward0>)

In [41]:
outputs.logits.shape

torch.Size([1, 149, 256001])