In [None]:
import pandas as pd

In [None]:
char_dataset = pd.read_pickle("/data2/brain2text/b2t_24/brain2text24_with_fa_char")

In [None]:
char_dataset['train'][0].keys()

In [None]:
char_dataset['train'][10]["transcriptions"][30]
characters = char_dataset['train'][10]["textC"][30]

In [None]:
CHAR_VOCAB = [
    "<sp>",          # space token
    "!", ",", ".", "?", "'",   # punctuation (incl. apostrophe)
] + [chr(i) for i in range(ord('a'), ord('z') + 1)]  # 'a'..'z'

# Build mappings
_CHAR_TO_ID = {c: i for i, c in enumerate(CHAR_VOCAB)}
_ID_TO_CHAR = {i: c for c, i in _CHAR_TO_ID.items()}

# Convenience indices
SPACE_ID = _CHAR_TO_ID["<sp>"]

print(CHAR_VOCAB)


character_units = ["-"]
for cv in CHAR_VOCAB:
    
    if cv == "<sp>":
        character_units.append("|")
    else:
        character_units.append(cv)
        
string = ""
for c in characters:
    if c > 1:
        string += character_units[c]
    if c == 1:
        string += " "
        

In [None]:
import pickle
file_path = "/data2/brain2text/b2t_24/saved_val_results/transformer_short_training_fixed.pkl"

with open(file_path, "rb") as f:
    data = pickle.load(f)



In [None]:
data['transformer_short_training_fixed']

In [None]:
import torch
from brainaudio.models.transformer_chunking import TransformerModel
import yaml

config_path = "tm_transformer_b2t24_log_dynchunk.yaml"
config_file = f"../src/brainaudio/training/utils/custom_configs/{config_path}"

with open(config_file, 'r') as f:
    config = yaml.safe_load(f)

model_type = config['modelType']
model_args = config['model'][model_type]

model = TransformerModel(features_list=model_args['features_list'], samples_per_patch=model_args['samples_per_patch'], dim=model_args['d_model'], depth=model_args['depth'], heads=model_args['n_heads'], mlp_dim_ratio=model_args['mlp_dim_ratio'],  dim_head=model_args['dim_head'], 
                     dropout=config['dropout'], input_dropout=config['input_dropout'], nClasses=config['nClasses'], 
                     max_mask_pct=config['max_mask_pct'], num_masks=config['num_masks'], num_participants=len(model_args['features_list']), return_final_layer=False, 
                     chunked_attention=model_args["chunked_attention"])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# --- 1. Initialize the Optimizer FIRST ---
# This ensures it's part of the memory measurement
optimizer = torch.optim.AdamW(model.parameters(), lr=config['learning_rate'], 
                               weight_decay=config['l2_decay'], 
                               eps=config['eps'], 
                               betas=(config['beta1'], config['beta2']), 
                               fused=True)

# --- 2. Check PEAK memory during a forward/backward/step pass ---
torch.cuda.reset_peak_memory_stats()

# --- Create plausible dummy data for your Transformer ---
batch_size = 64
seq_len = 1500 
d_model = model_args['d_model'] 

dummy_input = torch.randn(batch_size, seq_len, 256, device=device) 
dummy_X_len = torch.full((batch_size,), seq_len, dtype=torch.long, device=device)
dummy_participant_idx = 0
dummy_day_idx = 1

# --- Simulate a FULL training step ---
optimizer.zero_grad() # Clear old gradients
output = model(dummy_input, dummy_X_len, dummy_participant_idx, dummy_day_idx)
loss = output.sum() # Dummy loss
loss.backward() # Calculate gradients
optimizer.step() # Update weights AND allocate optimizer state
# ---------------------------------

# Get the peak memory
peak_memory_bytes = torch.cuda.max_memory_allocated()
peak_memory_mb = peak_memory_bytes / (1024 * 1024)

print(f"Peak VRAM during a full training step: {peak_memory_mb:.2f} MB")

Peak VRAM during a full training step: 5250.92 MB
