In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Load model and tokenizer
access_token = ""
model_name = "google/gemma-3-270m"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=access_token)
llm = AutoModelForCausalLM.from_pretrained(model_name, token=access_token)

In [None]:
from torchaudio.models.decoder import CTCDecoderLM, CTCDecoderLMState


class CustomLM(CTCDecoderLM):
    """Create a Python wrapper around `language_model` to feed to the decoder."""

    def __init__(self, language_model: torch.nn.Module):
        CTCDecoderLM.__init__(self)
        self.language_model = language_model
        self.device = language_model.device
        
        self.sil = torch.tensor([[-1]])  # index for silent token in the language model

        # self.states will store {CTCDecoderLMState: (log_probs_vector, kv_cache)}
        self.states = {}
        language_model.eval()

    def start(self, start_with_nothing: bool = False):
        state = CTCDecoderLMState()
        with torch.no_grad():
            score = self.language_model(self.sil)

        self.states[state] = score
        return state

    def score(self, state: CTCDecoderLMState, token_index: int):
        outstate = state.child(token_index)
        if outstate not in self.states:
            score = self.language_model(token_index)
            self.states[outstate] = score
        score = self.states[outstate]

        return outstate, score

    def finish(self, state: CTCDecoderLMState):
        return self.score(state, self.sil)

In [None]:
custom_llm = CustomLM(llm)

custom_llm.start(start_with_nothing=True)

IndexError: index out of range in self

In [11]:
import torch
import torch.nn.functional as F
from torchaudio.models.decoder import CTCDecoderLM, CTCDecoderLMState
from transformers import AutoTokenizer, AutoModelForCausalLM

class LMWrapper(CTCDecoderLM):
    def __init__(self, language_model, tokenizer):
        CTCDecoderLM.__init__(self)
        self.language_model = language_model
        self.tokenizer = tokenizer
        self.device = language_model.device
        
        # Use the actual BOS token ID from the tokenizer
        self.bos_id = tokenizer.bos_token_id
        
        # self.states will store {CTCDecoderLMState: (log_probs_vector, kv_cache)}
        self.states = {}
        language_model.eval()

    def start(self, start_with_nothing: bool = False):
        state = CTCDecoderLMState()
        
        # 1. Create the initial input tensor (BOS token)
        # Shape: (batch_size, sequence_length) -> (1, 1)
        input_ids = torch.tensor([[self.bos_id]], device=self.device)
        
        with torch.no_grad():
            # 2. Run the model
            outputs = self.language_model(input_ids)
            
            # 3. Get logits for the *next* token
            # outputs.logits shape is (1, 1, vocab_size)
            # We want the last token's logits: shape (vocab_size,)
            next_token_logits = outputs.logits[0, -1, :]
            
            # 4. Convert to log-probabilities
            log_probs = F.log_softmax(next_token_logits, dim=-1)
            
            # 5. Get the KV cache for stateful generation
            kv_cache = outputs.past_key_values

        # Store the full distribution and the KV cache
        self.states[state] = (log_probs, kv_cache)
        return state

    def score(self, state: CTCDecoderLMState, token_index: int):
        # 1. Get the parent state's predictions and KV cache
        # If state (parent) is not in self.states, it's an error (should be handled)
        if state not in self.states:
             # This might happen if start() wasn't called, handle gracefully
             raise ValueError("Parent state not found. Make sure 'start()' is called.")
             
        prev_log_probs, prev_kv_cache = self.states[state]

        # 2. The "score" of this token is its probability from the *previous* step
        # This is what the decoder uses for fusion
        score = prev_log_probs[token_index]

        # 3. Create the new state for this token
        outstate = state.child(token_index)

        # 4. Now, *predict the next token* to populate the new state
        if outstate not in self.states:
            # Create the input for this step: just the new token
            input_ids = torch.tensor([[token_index]], device=self.device)
            
            with torch.no_grad():
                # Run the model using the *previous* KV cache
                outputs = self.language_model(
                    input_ids,
                    past_key_values=prev_kv_cache
                )
                
                # Get the new logits and new KV cache
                next_token_logits = outputs.logits[0, -1, :]
                log_probs = F.log_softmax(next_token_logits, dim=-1)
                kv_cache = outputs.past_key_values
            
            self.states[outstate] = (log_probs, kv_cache)

        return outstate, score

    def finish(self, state: CTCDecoderLMState):
        # On finish, score the End-of-Sequence token
        eos_id = self.tokenizer.eos_token_id
        
        # Get the log-probs from the final state
        if state not in self.states:
             # Handle case where state is empty or invalid
             return state, 0.0 # Return a neutral score
             
        log_probs, _ = self.states[state]
        score = log_probs[eos_id]
        
        return state, score

In [13]:
custom_llm = LMWrapper(llm, tokenizer)

custom_llm.start(start_with_nothing=True)

<torchaudio.models.decoder._ctc_decoder.CTCDecoderLMState at 0x7fca0e83ee10>

In [None]:
# --- Standard Way (using input_ids) ---
input_text = "Hello, world!"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids
# The model will look up the embeddings for these IDs
outputs_standard = llm(input_ids=input_ids)


In [None]:
embedding_layer = llm.get_input_embeddings()
text_embeddings = embedding_layer(input_ids)
outputs_advanced = llm(inputs_embeds=text_embeddings)

In [14]:
text_embeddings.shape

torch.Size([1, 5, 640])

In [None]:
embedding_layer = llm.get_input_embeddings()
print(embedding_layer.embedding_dim)

640


In [None]:
from brainaudio.models.e2e import E2EModel
from brainaudio.inference.inference_utils import load_model


encoder = load_model("/data2/brain2text/b2t_24/outputs/neurips_gru_nonoverlapping_4_4_768_seed_0", "/home3/lionehlhu/brainaudio/src/brainaudio/training/utils/custom_configs/neurips_gru_nonoverlapping_4_4_768_seed_0.yaml", "cuda:1")
model = E2EModel(encoder, 512, llm, tokenizer, 'cuda:1')

Loading custom YAML args from: /home3/lionehlhu/brainaudio/src/brainaudio/training/utils/custom_configs/neurips_gru_nonoverlapping_4_4_768_seed_0.yaml


In [None]:
from brainaudio.datasets.loading_data import getDatasetLoaders

trainLoaders, valLoaders, loadedData = getDatasetLoaders(
        ["/data2/brain2text/b2t_24/brain2text24_with_fa"],
        1, 
        return_alignments=True
    )

[tensor([[[-0.0088, -0.4234, -0.7253,  ...,  0.6295, -0.9338,  0.0572],
         [-0.0088, -0.4234,  1.0446,  ..., -1.2442,  0.2305,  1.3231],
         [-1.0800, -0.4234, -0.7253,  ..., -0.4961, -0.5319,  0.2987],
         ...,
         [-1.0800, -0.4234, -0.7253,  ..., -1.3404, -0.5700,  0.2358],
         [-1.0800,  1.8947, -0.7253,  ..., -1.5471, -0.8085, -0.2412],
         [-1.0800, -0.4234, -0.7253,  ...,  0.2723, -0.9532,  0.3059]]]), tensor([[36, 17,  8, 40, 17, 38, 40,  3, 21,  5,  9, 40, 31, 34, 40,  7, 18, 40,
         28, 17, 35, 25, 20, 31, 40,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  

In [8]:
batch = next(iter(trainLoaders[0]))
print(batch)

[tensor([[[-0.7449, -0.0380, -0.6573,  ..., -0.3152, -1.0477, -0.5842],
         [-0.7449, -0.0380, -0.6573,  ..., -0.8353, -1.0524,  0.5722],
         [-0.7449, -0.0380, -0.6573,  ..., -0.8269, -1.3402,  0.3839],
         ...,
         [-0.7449, -1.0911, -0.6573,  ..., -1.9963, -1.6023,  2.6769],
         [-0.7449, -1.0911, -0.6573,  ..., -0.8191, -1.0094,  1.1178],
         [-0.7449, -1.0911,  1.0947,  ..., -0.6907, -1.7220, -0.9082]]]), tensor([[ 6, 40, 20,  2, 23, 40, 23,  1, 31, 40, 28, 17, 22, 11, 22,  7, 12, 40,
         16, 17, 38, 40, 23, 13, 22, 40, 23,  5, 40,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  