#### train 3 encoding models: 
#### 1) poisson GLM trained from scratch
#### 2) frozen pretrained CNN with poisson GLM readout
#### 3) unfrozen pretrained CNN with poisson GLM readout
#### 4) CNN trained from scratch with poisson GLM readout

In [None]:
def train_model(full_model, config):
    # corr avgs
    corr_avgs = []
    
    for ses_idx, session_id in enumerate(session_ids):
        # set sess_corr_avg
        sess_corr_avg = -1
        sess_corrs = []

        # set session index 
        config["session_id"] = session_id

        # setup logging
        if config["wandb"]:
            wandb.init(
                project=f'{config["modality"]}-baseline',
                config=config,
            )
            wandb.define_metric("corr_to_avg", summary="max")
            wandb.define_metric("test_loss", summary="min")

        # load datasets and loaders 
        train_dataset, test_dataset, train_loader, test_loader = get_datasets_and_loaders(input_dir, session_id, config["modality"], config["exp_var_thresholds"][ses_idx], stim_dur_ms, config["stim_size"], config["win_size"], stimulus_dir, config["batch_size"], config["first_frame_only"], blur_sigma = config["blur_sigma"], pos = config['pos'])

        # set which parameters to use regularization with and which not to
        params_with_l2 = []
        params_without_l2 = []
        for name, param in full_model.named_parameters():
            if 'mu' in name or 'sigma' in name:
                params_without_l2.append(param)
            else:
                params_with_l2.append(param)

        # setup Adam optimizer
        optimizer = torch.optim.Adam([
            {'params': params_with_l2, 'weight_decay': config['l2_weight']},  # Apply L2 regularization (weight decay)
            {'params': params_without_l2, 'weight_decay': 0.0}  # No L2 regularization
        ], lr=config["lr"], weight_decay=config['l2_weight'])
    
        # using poisson loss 
        loss_func = PoissonNLLLoss(log_input=False, full=True)
            
        for epochs in range(config["num_epochs"]):
            epoch_loss = 0
            for i, (stimulus, targets) in enumerate(train_loader): 
                stimulus = stimulus.to(device)
                targets = targets.to(device)
                
                optimizer.zero_grad()
                preds = full_model(stimulus)
                loss = loss_func(preds, targets)
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.item()
    
            # printing corr to avg and loss metrics 
            with torch.no_grad():
                corr_avg = corr_to_avg(full_model, test_loader, modality=config["modality"], device=device)
                test_loss = 0
                for i, (stimulus, targets) in enumerate(test_loader):
                    stimulus = stimulus.to(device)
                    targets = targets.to(device)
                    preds = full_model(stimulus) 
                    loss = loss_func(preds, targets)
                    test_loss += loss.item()
                    
            if config["wandb"]:
                wandb.log({"corr_to_avg": np.nanmean(corr_avg), "train_loss": epoch_loss / len(train_loader), "test_loss": test_loss / len(test_loader)})
            
            if np.nanmean(corr_avg) > sess_corr_avg:
                sess_corr_avg = np.nanmean(corr_avg)
                sess_corrs = corr_avg
                
            print('  epoch {} loss: {} corr: {}'.format(epochs + 1, epoch_loss / len(train_dataset), np.nanmean(corr_avg)))
            print(f' num. neurons : {len(corr_avg)}')
            
        if config["save"]:
            torch.save(full_model.state_dict(), f"{model_output_path}_{session_id}.pickle")
            
        corr_avgs.append(sess_corrs)
        
        if config["wandb"]:
            wandb.finish()
    
    if config["wandb"]:
        wandb.init(
            project=f'{config["modality"]}-basline',
            config=config,
        )
        for sess_corr in corr_avgs:
            for corr in sess_corr:
                wandb.log({"corr": corr})
        wandb.finish()