In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load model and tokenizer
access_token = ""
model_name = "google/gemma-3-270m"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=access_token)
llm = AutoModelForCausalLM.from_pretrained(model_name, token=access_token)

In [None]:
# --- Standard Way (using input_ids) ---
input_text = "Hello, world!"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
# The model will look up the embeddings for these IDs
outputs_standard = llm(input_ids=input_ids)


In [None]:
embedding_layer = llm.get_input_embeddings()
text_embeddings = embedding_layer(input_ids)
outputs_advanced = llm(inputs_embeds=text_embeddings)

In [14]:
text_embeddings.shape

torch.Size([1, 5, 640])

In [None]:
embedding_layer = llm.get_input_embeddings()
print(embedding_layer.embedding_dim)

640


In [None]:
from brainaudio.models.e2e import E2EModel
from brainaudio.inference.inference_utils import load_model


encoder = load_model("/data2/brain2text/b2t_24/outputs/neurips_gru_nonoverlapping_4_4_768_seed_0", "/home3/lionehlhu/brainaudio/src/brainaudio/training/utils/custom_configs/neurips_gru_nonoverlapping_4_4_768_seed_0.yaml", "cuda:1")
model = E2EModel(encoder, 512, llm, tokenizer, 'cuda:1')

Loading custom YAML args from: /home3/lionehlhu/brainaudio/src/brainaudio/training/utils/custom_configs/neurips_gru_nonoverlapping_4_4_768_seed_0.yaml


In [None]:
from brainaudio.datasets.loading_data import getDatasetLoaders

trainLoaders, valLoaders, loadedData = getDatasetLoaders(
        ["/data2/brain2text/b2t_24/brain2text24_with_fa"],
        1, 
        return_alignments=True
    )

[tensor([[[-0.0088, -0.4234, -0.7253,  ...,  0.6295, -0.9338,  0.0572],
         [-0.0088, -0.4234,  1.0446,  ..., -1.2442,  0.2305,  1.3231],
         [-1.0800, -0.4234, -0.7253,  ..., -0.4961, -0.5319,  0.2987],
         ...,
         [-1.0800, -0.4234, -0.7253,  ..., -1.3404, -0.5700,  0.2358],
         [-1.0800,  1.8947, -0.7253,  ..., -1.5471, -0.8085, -0.2412],
         [-1.0800, -0.4234, -0.7253,  ...,  0.2723, -0.9532,  0.3059]]]), tensor([[36, 17,  8, 40, 17, 38, 40,  3, 21,  5,  9, 40, 31, 34, 40,  7, 18, 40,
         28, 17, 35, 25, 20, 31, 40,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  

In [8]:
batch = next(iter(trainLoaders[0]))
print(batch)

[tensor([[[-0.7449, -0.0380, -0.6573,  ..., -0.3152, -1.0477, -0.5842],
         [-0.7449, -0.0380, -0.6573,  ..., -0.8353, -1.0524,  0.5722],
         [-0.7449, -0.0380, -0.6573,  ..., -0.8269, -1.3402,  0.3839],
         ...,
         [-0.7449, -1.0911, -0.6573,  ..., -1.9963, -1.6023,  2.6769],
         [-0.7449, -1.0911, -0.6573,  ..., -0.8191, -1.0094,  1.1178],
         [-0.7449, -1.0911,  1.0947,  ..., -0.6907, -1.7220, -0.9082]]]), tensor([[ 6, 40, 20,  2, 23, 40, 23,  1, 31, 40, 28, 17, 22, 11, 22,  7, 12, 40,
         16, 17, 38, 40, 23, 13, 22, 40, 23,  5, 40,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  