In [7]:
import torch
import os
import json
import numpy as np
from train_utils import get_default_configs, unconvert_dtypes, log, update_dir_name
from model_model import GranularModel, BinTransformer, CrossModel
from model_electrode_embedding import ElectrodeEmbedding_Learned, ElectrodeEmbedding_NoisyCoordinate, ElectrodeEmbedding_Learned_CoordinateInit

# Get default configs
training_config, model_config, cluster_config = get_default_configs(random_string="TEMP", wandb_project="")
cluster_config['num_workers_dataloaders'] = 2
cluster_config['num_workers_eval'] = 2
model_config['init_identity'] = False

dir_name = update_dir_name(model_config, training_config, cluster_config)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
log(f"Using device: {device}", priority=0)

# Specify the model directory and epoch to load
model_dir = f"models_data/{dir_name}"  # Replace with your actual model directory name
epoch_to_load = 20

# Load the saved model checkpoint
checkpoint_path = f"{model_dir}/model_epoch_{epoch_to_load}.pth"
log(f"Loading model from {checkpoint_path}", priority=0)

checkpoint = torch.load(checkpoint_path, map_location=device)

# Update configs from the saved checkpoint
saved_training_config = unconvert_dtypes(checkpoint['training_config'])
saved_model_config = unconvert_dtypes(checkpoint['model_config'])
saved_cluster_config = unconvert_dtypes(checkpoint['cluster_config'])

# Merge configs (use saved configs)
training_config.update(saved_training_config)
model_config.update(saved_model_config)
cluster_config.update(saved_cluster_config)

# Initialize model components
bin_transformer = BinTransformer(
    first_kernel=int(model_config['sample_timebin_size']*2048)//16, 
    d_model=192,
    n_layers=2,
    n_heads=4,
    overall_sampling_rate=2048,
    sample_timebin_size=model_config['sample_timebin_size'],
    n_downsample_factor=1,
    identity_init=model_config['init_identity']
).to(device, dtype=model_config['dtype'])

model = GranularModel(
    int(model_config['sample_timebin_size'] * 2048),
    model_config['transformer']['d_model'],  
    n_layers=model_config['transformer']['n_layers_time'],
    n_heads=model_config['transformer']['n_heads'],
    identity_init=model_config['init_identity']
).to(device, dtype=model_config['dtype'])

cross_model = CrossModel(
    int(model_config['sample_timebin_size'] * 2048),
    model_config['transformer']['d_model'],
    n_layers=model_config['transformer']['n_layers_electrode'],
    n_heads=model_config['transformer']['n_heads'],
).to(device, dtype=model_config['dtype'])

# Initialize electrode embeddings based on the type
if model_config['electrode_embedding']['type'] == 'learned' or model_config['electrode_embedding']['type'] == 'zero':
    electrode_embeddings = ElectrodeEmbedding_Learned(
        model_config['transformer']['d_model'], 
        embedding_dim=model_config['electrode_embedding']['embedding_dim'],
        embedding_requires_grad=model_config['electrode_embedding']['type'] != 'zero'
    )
elif model_config['electrode_embedding']['type'] == 'coordinate_init':
    electrode_embeddings = ElectrodeEmbedding_Learned_CoordinateInit(
        model_config['transformer']['d_model'], 
        embedding_dim=model_config['electrode_embedding']['embedding_dim']
    )
elif model_config['electrode_embedding']['type'] == 'noisy_coordinate':
    electrode_embeddings = ElectrodeEmbedding_NoisyCoordinate(
        model_config['transformer']['d_model'], 
        coordinate_noise_std=model_config['electrode_embedding']['coordinate_noise_std'],
        embedding_dim=model_config['electrode_embedding']['embedding_dim']
    )
else:
    raise ValueError(f"Invalid electrode embedding type: {model_config['electrode_embedding']['type']}")
electrode_embeddings = electrode_embeddings.to(device, dtype=model_config['dtype'])

# Load state dictionaries from checkpoint
model.load_state_dict(checkpoint['model_state_dict'])
bin_transformer.load_state_dict(checkpoint['bin_transformer_state_dict'])
cross_model.load_state_dict(checkpoint['cross_model_state_dict'])
electrode_embeddings.load_state_dict(checkpoint['electrode_embeddings_state_dict'])

# Set models to evaluation mode
model.eval()
bin_transformer.eval()
cross_model.eval()
electrode_embeddings.eval()

# Print model information
n_model_params = sum(p.numel() for p in model.parameters()) + sum(p.numel() for p in bin_transformer.parameters()) + sum(p.numel() for p in cross_model.parameters())
n_embed_params = sum(p.numel() for p in electrode_embeddings.parameters())
log(f"Model parameters: {n_model_params:,}", priority=0)
log(f"Embedding parameters: {n_embed_params:,}", priority=0)
log(f"Total parameters: {n_model_params + n_embed_params:,}", priority=0)

# Print evaluation results from checkpoint
log(f"Evaluation results at epoch {epoch_to_load}:", priority=0)
for key, value in checkpoint['eval_results'].items():
    if isinstance(value, (int, float)):
        log(f"  {key}: {value:.4f}", priority=0)
    else:
        log(f"  {key}: {value}", priority=0)

# Now the model is loaded and ready to use
log(f"Model successfully loaded from epoch {epoch_to_load}", priority=0)


[21:13:28 gpu 16.2G ram 2.6G] Using device: cuda
[21:13:28 gpu 16.2G ram 2.6G] Loading model from models_data/M_nst1_dm192_nh12_nl5_5_nes50_nf_nII_nSP_mxt8_eeL_fb1_cls_lr0.003_rTEMP_lrL_ws100/model_epoch_20.pth
[21:13:30 gpu 16.2G ram 2.6G] Model parameters: 5,576,962
[21:13:30 gpu 16.2G ram 2.6G] Embedding parameters: 19,200
[21:13:30 gpu 16.2G ram 2.6G] Total parameters: 5,596,162
[21:13:30 gpu 16.2G ram 2.6G] Evaluation results at epoch 20:
[21:13:30 gpu 16.2G ram 2.6G]   train_contrastive_x: 0.0084
[21:13:30 gpu 16.2G ram 2.6G]   train_accuracy_x: 0.0081
[21:13:30 gpu 16.2G ram 2.6G]   test_contrastive_x: 2.6661
[21:13:30 gpu 16.2G ram 2.6G]   test_accuracy_x: 0.4271
[21:13:30 gpu 16.2G ram 2.6G]   eval_auroc/average_gpt2_surprisal: 0.5774
[21:13:30 gpu 16.2G ram 2.6G]   eval_auroc/average_volume: 0.6154
[21:13:30 gpu 16.2G ram 2.6G]   eval_auroc/average_word_part_speech: 0.5197
[21:13:30 gpu 16.2G ram 2.6G]   eval_auroc/average_pitch: 0.5364
[21:13:30 gpu 16.2G ram 2.6G]   eval_au

In [None]:
from dataset import load_subjects
from evaluation_btbench import FrozenModelEvaluation_SS_SM
# Load subjects for evaluation
log(f"Loading subjects for evaluation...", priority=0)
all_subjects = load_subjects(training_config['train_subject_trials'], training_config['eval_subject_trials'], training_config['data_dtype'], 
                             cache=cluster_config['cache_subjects'], allow_corrupted=False)

# Set electrode subset for each subject (temporal lobe electrodes)
from btbench.btbench_config import BTBENCH_LITE_ELECTRODES
for subject_identifier, subject in all_subjects.items():
    consider_electrode_names = list(BTBENCH_LITE_ELECTRODES[subject_identifier])
    electrode_subset = [electrode_label for electrode_label in consider_electrode_names if electrode_label.startswith('T')]
    subject.set_electrode_subset(electrode_subset)
    log(f"Subject {subject_identifier} has {len(electrode_subset)} temporal lobe electrodes", priority=0)

# Add subjects to electrode embeddings
for subject in all_subjects.values():
    log(f"Adding subject {subject.subject_identifier} to electrode embeddings...", priority=0)
    this_subject_trials = [trial_id for (sub_id, trial_id) in training_config['train_subject_trials'] if sub_id == subject.subject_identifier]
    electrode_embeddings.add_subject(subject)
electrode_embeddings = electrode_embeddings.to(device, dtype=model_config['dtype'])

# Define evaluation electrode subset if needed
eval_electrode_subset = {
    #'btbank3': ['T1cIe11'],
}

# Create evaluation object
eval_subject_trials = [(all_subjects[subject_identifier], trial_id) for subject_identifier, trial_id in training_config['eval_subject_trials']]

[21:07:53 gpu 0.0G ram 0.7G] Loading subjects for evaluation...
[21:07:53 gpu 0.0G ram 0.7G]     loading subject btbank3...
[21:07:53 gpu 0.0G ram 0.7G] Subject btbank3 has 21 temporal lobe electrodes
[21:07:53 gpu 0.0G ram 0.7G] Adding subject btbank3 to electrode embeddings...
[21:08:03 gpu 0.0G ram 1.5G] Running evaluation on loaded model...
[21:08:03 gpu 0.0G ram 1.5G]     evaluating on all metrics




[21:10:46 gpu 16.2G ram 2.6G]     done evaluating on all metrics
{'eval_auroc/average_gpt2_surprisal': 0.6562908496732025, 'eval_auroc/average_volume': 0.6157898224224754, 'eval_auroc/average_word_part_speech': 0.5658141739308217, 'eval_auroc/average_pitch': 0.5683403247806058, 'eval_auroc/average_speech': 0.8377959183673469}


In [11]:
from torch.utils.data import DataLoader
import sklearn.metrics
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
import numpy as np
from btbench.btbench_train_test_splits import generate_splits_SS_SM
from btbench.btbench_config import BTBENCH_LITE_ELECTRODES
from train_utils import log
import torch

# Evaluation class for Same Subject Same Movie (SS-SM), on btbench evals
class FrozenModelEvaluation_SS_SM():
    def __init__(self, eval_names, subject_trials, dtype, batch_size, embeddings_map,
                 num_workers_eval=4, prefetch_factor=2,
                 feature_aggregation_method='concat', # 'mean', 'concat'
                 # regression parameters
                 regression_random_state=42,  regression_solver='lbfgs', 
                 regression_tol=1e-3,
                 regression_max_iter=10000,
                 lite=True, electrode_subset=None):
        """
        Args:
            eval_names (list): List of evaluation metric names to use (e.g. ["volume", "word_gap"])
            subject_trials (list): List of tuples where each tuple contains (subject, trial_id).
                                 subject is a BrainTreebankSubject object and trial_id is an integer.
            dtype (torch.dtype, optional): Data type for tensors.
        """
        self.eval_names = eval_names
        self.subject_trials = subject_trials
        self.all_subjects = set([subject for subject, trial_id in self.subject_trials])
        self.all_subject_identifiers = set([subject.subject_identifier for subject in self.all_subjects])
        self.dtype = dtype
        self.batch_size = batch_size
        self.lite = lite
        
        self.feature_aggregation_method = feature_aggregation_method

        self.regression_max_iter = regression_max_iter
        self.regression_random_state = regression_random_state
        self.regression_solver = regression_solver
        self.regression_tol = regression_tol
        self.num_workers_eval = num_workers_eval
        self.prefetch_factor = prefetch_factor

        self.evaluation_datasets = {}
        for eval_name in self.eval_names:
            for subject, trial_id in self.subject_trials:
                splits = generate_splits_SS_SM(subject, trial_id, eval_name, dtype=self.dtype, lite=self.lite, start_neural_data_before_word_onset=0, end_neural_data_after_word_onset=2048)
                self.evaluation_datasets[(eval_name, subject.subject_identifier, trial_id)] = splits
                
        self.all_subject_electrode_indices = {}
        for subject in self.all_subjects:
            self.all_subject_electrode_indices[subject.subject_identifier] = []
            for electrode_label in BTBENCH_LITE_ELECTRODES[subject.subject_identifier] if self.lite else subject.get_electrode_labels():
                key = (subject.subject_identifier, electrode_label)
                if key in embeddings_map: # If the electrodes were subset to exclude this one, ignore it
                    self.all_subject_electrode_indices[subject.subject_identifier].append(embeddings_map[key])
            self.all_subject_electrode_indices[subject.subject_identifier] = torch.tensor(self.all_subject_electrode_indices[subject.subject_identifier])
        
        # XXX Surely there is a better way to do this
        self.all_subject_electrode_subset_indices = None
        if electrode_subset is not None:
            self.all_subject_electrode_subset_indices = {}
            for subject in self.all_subjects:
                if subject.subject_identifier not in electrode_subset:
                    continue
                self.all_subject_electrode_subset_indices[subject.subject_identifier] = []
                for electrode_label in electrode_subset[subject.subject_identifier]:
                    key = (subject.subject_identifier, electrode_label)
                    if key in embeddings_map: # If the electrodes were subset to exclude this one, ignore it
                        self.all_subject_electrode_subset_indices[subject.subject_identifier].append(list(self.all_subject_electrode_indices[subject.subject_identifier]).index(embeddings_map[key]))
                self.all_subject_electrode_subset_indices[subject.subject_identifier] = torch.tensor(self.all_subject_electrode_subset_indices[subject.subject_identifier])

    def _evaluate_on_dataset(self, model, bin_transformer, electrode_embeddings, subject, train_dataset, test_dataset, log_priority=0):
        subject_identifier = subject.subject_identifier
        train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers_eval, 
                                      prefetch_factor=self.prefetch_factor, pin_memory=True)
        test_dataloader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers_eval, 
                                      prefetch_factor=self.prefetch_factor, pin_memory=True)
        device, dtype = model.device, model.dtype
        X_train, y_train = [], []
        log('generating frozen train features', priority=log_priority, indent=2)
        for i, (batch_input, batch_label) in enumerate(train_dataloader):
            log(f'generating frozen features for batch {i} of {len(train_dataloader)}', priority=log_priority, indent=3)
            batch_input = batch_input.to(device, dtype=dtype, non_blocking=True) # shape (batch_size, n_electrodes, n_samples)
            if self.all_subject_electrode_subset_indices is not None and subject_identifier in self.all_subject_electrode_subset_indices:
                batch_input = batch_input[:, self.all_subject_electrode_subset_indices[subject_identifier], :]

            # normalize the data
            batch_input = batch_input - torch.mean(batch_input, dim=[0, 2], keepdim=True)
            batch_input = batch_input / (torch.std(batch_input, dim=[0, 2], keepdim=True) + 1)
            #electrode_data = batch_input.reshape(bin_transformer(batch_input).shape) # 
            electrode_data = bin_transformer(batch_input) # shape (batch_size, n_electrodes, n_samples)

            electrode_indices = self.all_subject_electrode_indices[subject_identifier].to(device, dtype=torch.long, non_blocking=True)
            electrode_indices = electrode_indices.unsqueeze(0).expand(batch_input.shape[0], -1) # Add the batch dimension to the electrode indices
            if self.all_subject_electrode_subset_indices is not None and subject_identifier in self.all_subject_electrode_subset_indices:
                electrode_indices = electrode_indices[:, self.all_subject_electrode_subset_indices[subject_identifier]]

            embeddings = electrode_embeddings.forward(electrode_indices)
            
            features = model.generate_frozen_evaluation_features(electrode_data, embeddings, feature_aggregation_method=self.feature_aggregation_method)

            #features = electrode_data.reshape(batch_input.shape[0], -1)
            
            #features = batch_input.reshape(batch_input.shape[0], -1)
            #features = electrode_embedded_data.reshape(batch_input.shape[0], -1)
            log(f'done generating frozen features for batch {i} of {len(train_dataloader)}', priority=log_priority, indent=3)
            X_train.append(features.detach().cpu().float().numpy())
            y_train.append(batch_label.numpy())

        X_test, y_test = [], []
        log('generating frozen test features', priority=log_priority, indent=2)
        for i, (batch_input, batch_label) in enumerate(test_dataloader):
            log(f'generating frozen features for batch {i} of {len(test_dataloader)}', priority=log_priority, indent=3)
            batch_input = batch_input.to(device, dtype=dtype, non_blocking=True)
            if self.all_subject_electrode_subset_indices is not None and subject_identifier in self.all_subject_electrode_subset_indices:
                batch_input = batch_input[:, self.all_subject_electrode_subset_indices[subject_identifier], :]

            # normalize the data
            batch_input = batch_input - torch.mean(batch_input, dim=[0, 2], keepdim=True)
            batch_input = batch_input / (torch.std(batch_input, dim=[0, 2], keepdim=True) + 1)
            #electrode_data = batch_input.reshape(bin_transformer(batch_input).shape) # bin_transformer(batch_input) # shape (batch_size, n_electrodes, n_samples)
            electrode_data = bin_transformer(batch_input) # shape (batch_size, n_electrodes, n_samples)

            electrode_indices = self.all_subject_electrode_indices[subject_identifier].to(device, dtype=torch.long, non_blocking=True)
            electrode_indices = electrode_indices.unsqueeze(0).expand(batch_input.shape[0], -1) # Add the batch dimension to the electrode indices
            if self.all_subject_electrode_subset_indices is not None and subject_identifier in self.all_subject_electrode_subset_indices:
                electrode_indices = electrode_indices[:, self.all_subject_electrode_subset_indices[subject_identifier]]

            embeddings = electrode_embeddings.forward(electrode_indices)

            features = model.generate_frozen_evaluation_features(electrode_data, embeddings, feature_aggregation_method=self.feature_aggregation_method)

            #features = electrode_data.reshape(batch_input.shape[0], -1)

            #features = batch_input.reshape(batch_input.shape[0], -1)
            #features = electrode_embedded_data.reshape(batch_input.shape[0], -1)
            log(f'done generating frozen features for batch {i} of {len(test_dataloader)}', priority=log_priority, indent=3)
            X_test.append(features.detach().cpu().float().numpy())
            y_test.append(batch_label.numpy())
        log('done generating frozen features', priority=log_priority, indent=2)

        log("creating numpy arrays", priority=log_priority, indent=2)
        X_train = np.concatenate(X_train)
        y_train = np.concatenate(y_train)
        X_test = np.concatenate(X_test)
        y_test = np.concatenate(y_test)
        log("done creating numpy arrays", priority=log_priority, indent=2)

        regressor = LogisticRegression(
            random_state=self.regression_random_state, 
            max_iter=self.regression_max_iter, 
            n_jobs=self.num_workers_eval, 
            solver=self.regression_solver, 
            tol=self.regression_tol
        )

        # Standardize the features
        log('standardizing features', priority=log_priority, indent=2)
        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_test = scaler.transform(X_test)

        log('fitting regressor', priority=log_priority, indent=2)
        regressor.fit(X_train, y_train)
        log('done fitting regressor', priority=log_priority, indent=2)

        # Get predictions for multiclass classification
        train_probs = regressor.predict_proba(X_train)
        test_probs = regressor.predict_proba(X_test)

        # Filter test samples to only include classes that were in training
        valid_class_mask = np.isin(y_test, regressor.classes_)
        y_test_filtered = y_test[valid_class_mask]
        test_probs_filtered = test_probs[valid_class_mask]

        # Convert to one-hot encoding
        y_test_onehot = np.zeros((len(y_test_filtered), len(regressor.classes_)))
        for i, label in enumerate(y_test_filtered):
            class_idx = np.where(regressor.classes_ == label)[0][0]
            y_test_onehot[i, class_idx] = 1

        y_train_onehot = np.zeros((len(y_train), len(regressor.classes_)))
        for i, label in enumerate(y_train):
            class_idx = np.where(regressor.classes_ == label)[0][0]
            y_train_onehot[i, class_idx] = 1

        # Calculate ROC AUC based on number of classes
        n_classes = len(regressor.classes_)
        if n_classes > 2:
            auroc = sklearn.metrics.roc_auc_score(y_test_onehot, test_probs_filtered, multi_class='ovr', average='macro')
        else:
            auroc = sklearn.metrics.roc_auc_score(y_test_onehot, test_probs_filtered)

        accuracy = regressor.score(X_test, y_test)
        log('done evaluating', priority=log_priority, indent=2)
        return auroc, accuracy
    
    def _evaluate_on_metric_cv(self, model, bin_transformer, electrode_embeddings, subject, train_datasets, test_datasets, log_priority=0, quick_eval=False):
        auroc_list, accuracy_list = [], []
        for train_dataset, test_dataset in zip(train_datasets, test_datasets):
            auroc, accuracy = self._evaluate_on_dataset(model, bin_transformer, electrode_embeddings, subject, train_dataset, test_dataset, log_priority=log_priority)
            auroc_list.append(auroc)
            accuracy_list.append(accuracy)
            if quick_eval: break
        return np.mean(auroc_list), np.mean(accuracy_list)
    
    def evaluate_on_all_metrics(self, model, bin_transformer, electrode_embeddings, log_priority=0, quick_eval=False, only_keys_containing=None):
        log('evaluating on all metrics', priority=log_priority, indent=1)
        evaluation_results = {}
        for subject in self.all_subjects:
            for eval_name in self.eval_names:
                trial_ids = [trial_id for _subject, trial_id in self.subject_trials if _subject.subject_identifier == subject.subject_identifier]
                for trial_id in trial_ids:
                    splits = self.evaluation_datasets[(eval_name, subject.subject_identifier, trial_id)]
                    auroc, accuracy = self._evaluate_on_metric_cv(model, bin_transformer, electrode_embeddings, subject, splits[0], splits[1], log_priority=log_priority+1, quick_eval=quick_eval)
                    evaluation_results[(eval_name, subject.subject_identifier, trial_id)] = (auroc, accuracy)
        
        evaluation_results_strings = self._format_evaluation_results_strings(evaluation_results)
        log('done evaluating on all metrics', priority=log_priority, indent=1)

        if only_keys_containing is not None:
            evaluation_results_strings = {k: v for k, v in evaluation_results_strings.items() if only_keys_containing in k}
        return evaluation_results_strings

    def _format_evaluation_results_strings(self, evaluation_results):
        evaluation_results_strings = {}
        for eval_name in self.eval_names:
            auroc_values = []
            acc_values = []
            subject_aurocs = {}
            subject_accs = {}
            for (metric, subject_identifier, trial_id) in [key for key in evaluation_results.keys() if key[0] == eval_name]:
                if subject_identifier not in subject_aurocs:
                    subject_aurocs[subject_identifier] = []
                    subject_accs[subject_identifier] = []
                auroc, accuracy = evaluation_results[(eval_name, subject_identifier, trial_id)]
                auroc, accuracy = auroc.item(), accuracy.item()

                subject_aurocs[subject_identifier].append(auroc)
                subject_accs[subject_identifier].append(accuracy)
                evaluation_results_strings[f"eval_auroc/{subject_identifier}_{trial_id}_{eval_name}"] = auroc
                evaluation_results_strings[f"eval_acc/{subject_identifier}_{trial_id}_{eval_name}"] = accuracy
            for subject_identifier in subject_aurocs:
                auroc_values.append(np.mean(subject_aurocs[subject_identifier]).item())
                acc_values.append(np.mean(subject_accs[subject_identifier]).item())
            if len(auroc_values) > 0:
                evaluation_results_strings[f"eval_auroc/average_{eval_name}"] = np.mean(auroc_values).item()
                evaluation_results_strings[f"eval_acc/average_{eval_name}"] = np.mean(acc_values).item()
        return evaluation_results_strings

In [12]:
eval_tasks = ['gpt2_surprisal', 'volume', 'speech']
evaluation = FrozenModelEvaluation_SS_SM(
    eval_tasks, eval_subject_trials, 
    training_config['data_dtype'], training_config['batch_size'],
    electrode_embeddings.embeddings_map,
    num_workers_eval=cluster_config['num_workers_eval'],
    prefetch_factor=cluster_config['prefetch_factor'],
    feature_aggregation_method=cluster_config['eval_aggregation_method'],
    electrode_subset=eval_electrode_subset
)

# Run evaluation
log(f"Running evaluation on loaded model...", priority=0)
evaluation_results = evaluation.evaluate_on_all_metrics(
    model, 
    bin_transformer, 
    electrode_embeddings, 
    log_priority=1, 
    quick_eval=cluster_config['quick_eval'], 
    only_keys_containing='auroc/average'
)
print(evaluation_results)

[21:22:25 gpu 16.4G ram 2.7G] Running evaluation on loaded model...
[21:22:25 gpu 16.4G ram 2.7G]     evaluating on all metrics




[21:23:31 gpu 16.4G ram 2.8G]     done evaluating on all metrics
{'eval_auroc/average_gpt2_surprisal': 0.5773692810457516, 'eval_auroc/average_volume': 0.6153839782666313, 'eval_auroc/average_speech': 0.9285877551020408}
