In [1]:
from train_model_single_electrode_new_lin import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
log(f'Using device: {device}')
dtype = torch.float32

subject_id, trial_id = (3, 1) # changed from 3, 0
window_size = 2048
subject = BrainTreebankSubject(subject_id, cache=True)

log(f'Subject: {subject.subject_identifier}, Trial: {trial_id}, loading data...')
electrode_subset = subject.electrode_labels
eval_electrode_index = electrode_subset.index('T1cIe11')
subject.set_electrode_subset(electrode_subset)
dataset = SubjectTrialDataset_SingleElectrode(subject, trial_id, window_size=window_size, dtype=dtype, unsqueeze_electrode_dimension=False, electrodes_subset=electrode_subset)
log("Data shape: " + str(dataset[0]['data'].shape))

log("Loading the eval trial...")
subject.load_neural_data(0)
log("Done.")

[11:39:15 gpu 0.0G ram 0.5G] (0) Using device: cuda
[11:39:15 gpu 0.0G ram 0.5G] (0) Subject: btbank3, Trial: 1, loading data...
[11:40:26 gpu 0.0G ram 10.5G] (0) Data shape: torch.Size([2048])
[11:40:26 gpu 0.0G ram 10.5G] (0) Loading the eval trial...
[11:41:13 gpu 0.0G ram 17.0G] (0) Done.


In [3]:
d_embed = 128
resolution = 10
n_steps = 1000
batch_size = 128
log_every_step = min(100, n_steps//10)

save_dir = "eval_results/juno/"
os.makedirs(save_dir, exist_ok=True)

filename = f'{subject.subject_identifier}_{trial_id}_embed{d_embed}_resolution{resolution}.json'

log(f'Creating models...')
import itertools
dataloader = iter(torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True))
dataloader = iter(itertools.cycle(dataloader))

embed = EmbedderDiscretized(d_model=d_embed, resolution=resolution, range=(-3, 3))
unembed = EmbedderDiscretized(d_model=d_embed, resolution=resolution, range=(-3, 3))

model = ContrastiveModel(d_input=n_samples_per_bin, embed=embed, unembed=unembed).to(device, dtype=dtype)
masker = NoneMasker()

# Create samples from 10 random indices of the dataset
samples = torch.cat([dataset[random.randint(0, len(dataset)-1)]['data'].flatten() for _ in range(n_samples_inverter)])
inverter = DistributionInverter(samples=samples).to(device, dtype=dtype)

evaluation = ModelEvaluation_BTBench(model, inverter, [(subject, 0)], ["speech", "gpt2_surprisal"], feature_aggregation_method='concat', 
                                        mean_collapse_factor=mean_collapse_factor, eval_electrode_index=eval_electrode_index)

log(f'Training model...')
initial_lr = 0.003
use_muon = True
optimizers = []
schedulers = []
if use_muon:
    from muon import Muon
    all_params = list(model.parameters())
    matrix_params = [p for p in all_params if p.ndim >= 2]
    other_params = [p for p in all_params if p.ndim < 2]
    optimizers.append(Muon(matrix_params, lr=initial_lr, momentum=0.95, nesterov=True, backend='newtonschulz5', backend_steps=5))
    if len(other_params) > 0:
        optimizers.append(torch.optim.AdamW(other_params, lr=initial_lr, betas=(0.9, 0.95)))
    #schedulers.append(None)  # Muon doesn't support schedulers
    #schedulers.append(torch.optim.lr_scheduler.LinearLR(optimizers[1], start_factor=1.0, end_factor=0.0, total_iters=n_steps))
else:
    optimizers = [torch.optim.AdamW(model.parameters(), lr=initial_lr, betas=(0.9, 0.95))]
    #schedulers = [torch.optim.lr_scheduler.LinearLR(optimizers[0], start_factor=1.0, end_factor=0.0, total_iters=n_steps)]

log("Evaluating the model before training...")
evaluation_results = evaluation.evaluate()
log(evaluation_results, indent=2)
evaluation_results['step'] = 0
evaluation_results['train_loss'] = -1
training_logs = [evaluation_results]

step = 1
for batch in dataloader:
    for optimizer in optimizers:
        optimizer.zero_grad()

    batch_data = batch['data'].to(device, dtype=dtype).reshape(batch_size, window_size//n_samples_per_bin, n_samples_per_bin) # shape (batch_size, seq_len, 1)
    batch_data = inverter(batch_data)
    masked_x, mask = masker.forward(batch_data)

    loss = model.calculate_loss(masked_x.unsqueeze(-2), mask=mask)
    loss.backward()
    for optimizer in optimizers:
        optimizer.step()
    
    # Step the schedulers
    for scheduler in schedulers:
        if scheduler is not None:
            scheduler.step()
    
    # Log metrics
    log_dict = {
        'train_loss': loss.item(),
        'step': step,
    }
    
    if step % log_every_step == 0:
        current_lr = optimizers[-1].param_groups[0]['lr']
        log(f"Step {step}, Loss: {loss.item():.4f}, LR: {current_lr:.6f}")
        
        # Add evaluation results
        evaluation_results = evaluation.evaluate()
        log_dict.update(evaluation_results)
        log(log_dict, indent=2)
        
    training_logs.append(log_dict)
    
    # Save training results to file
    if step % log_every_step == 0 or step == n_steps:
        # Convert training logs to a format that can be saved as JSON
        import json
        save_path = os.path.join(save_dir, filename)
        with open(save_path, 'w') as f:
            json.dump(training_logs, f, indent=2)
        log(f"Saved training logs to {save_path}")
        
    if step == n_steps:
        break # Only process one batch per step
    step += 1

[11:49:37 gpu 10.3G ram 18.8G] (0) Creating models...
[11:49:40 gpu 10.3G ram 18.8G] (0) Training model...
[11:49:40 gpu 10.3G ram 18.8G] (0) Evaluating the model before training...
[11:51:08 gpu 10.3G ram 18.8G] (0)         {'eval_auroc/btbank3_0_speech': 0.7899697076353438, 'eval_acc/btbank3_0_speech': 0.7166997847325716, 'eval_auroc/average_speech': 0.7899697076353438, 'eval_acc/average_speech': 0.7166997847325716, 'eval_auroc/btbank3_0_gpt2_surprisal': 0.6275235639695034, 'eval_acc/btbank3_0_gpt2_surprisal': 0.5971399671479208, 'eval_auroc/average_gpt2_surprisal': 0.6275235639695034, 'eval_acc/average_gpt2_surprisal': 0.5971399671479208}
[11:52:07 gpu 15.1G ram 18.8G] (0) Step 100, Loss: 3.4815, LR: 0.003000
[11:53:31 gpu 15.1G ram 18.8G] (0)         {'train_loss': 3.4815049171447754, 'step': 100, 'eval_auroc/btbank3_0_speech': 0.7930816221395186, 'eval_acc/btbank3_0_speech': 0.7189699644617678, 'eval_auroc/average_speech': 0.7930816221395186, 'eval_acc/average_speech': 0.718969964