In [1]:
import torch
from time import time
import numpy as np
from tqdm import tqdm


from models import ConditionOptimizer
from emg_processor import EMGDataset, EMGProcessor
from biomime_generator import *

In [2]:
device = 'cuda:0'
config_path = './config/config.yaml'
model_checkpoint = './ckp/model_linear.pth'
biomime_gen = BiomimeMuapGenerator(config_path=config_path, model_checkpoint=model_checkpoint, device=device)


In [3]:
emg_data = EMGDataset('./data/emg_data_10s_100MUs.npz')
emg_data.to_torch()

In [4]:
from time import time


# Hyperparameters / settings
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
total_reps = 5                   # Number of times to run the entire pipeline
probe_points = 1002             # LHS samples
num_iterations = 20             # Gradient descent epochs
learning_rate = 1e-2
batch_size = 40
extension_factor = 10

# Initialize ConditionOptimizer ONCE
condition_optimizer = ConditionOptimizer(
    biomime=biomime_gen,
    device=device,
    margin=1.0,
    indices=[0, 1, 2, 3, 4, 5],
    emg_data=emg_data,
    batch_size=batch_size,
    extension_factor=extension_factor
)

# Prepare to store all runs
all_muaps = []
all_conditions = []
all_kurtosis_histories = []

for k in range(total_reps):
    print(f"\n=== Starting Rep {k+1}/{total_reps} ===")
    start_time = time()

    # Use full_optimize() to run LHS + gradient descent in one go
    final_muaps, final_conditions, kurtosis_history = condition_optimizer.full_optimize(
        probe_points=probe_points,
        num_iterations=num_iterations,
        lr=learning_rate
    )

    # Store results for this rep
    all_muaps.append(final_muaps)                 # shape: [batch_size, ...]
    all_conditions.append(final_conditions)       # shape: [batch_size, num_params]
    all_kurtosis_histories.append(kurtosis_history)  # list/array of length num_iterations

    elapsed = time() - start_time
    print(f"--- Rep {k+1} completed in {elapsed:.2f} sec ---")

print("\n=== All Reps Completed! ===")

# Now you can analyze or combine the results from all_muaps, all_conditions, etc.
print(f"all_muaps length: {len(all_muaps)}")  # Should be 'total_reps'
print(f"all_conditions length: {len(all_conditions)}")
print(f"all_kurtosis_histories length: {len(all_kurtosis_histories)}")



=== Starting Rep 1/5 ===
>>> Starting Monte Carlo (LHS) sampling...


Evaluating LHS batches: 100%|██████████| 11/11 [00:01<00:00,  6.32it/s]


<<< LHS sampling done. Elapsed: 2.17 sec.
>>> Starting gradient-based optimization...
Epoch [10/20] | Loss: -0.100624 | Sum Kurt: -4.024940
Epoch [20/20] | Loss: -0.225692 | Sum Kurt: -9.027683
<<< Optimization done. Elapsed: 5.53 sec.
--- Rep 1 completed in 7.70 sec ---

=== Starting Rep 2/5 ===
>>> Starting Monte Carlo (LHS) sampling...


Evaluating LHS batches: 100%|██████████| 11/11 [00:01<00:00,  7.61it/s]


<<< LHS sampling done. Elapsed: 1.85 sec.
>>> Starting gradient-based optimization...
Epoch [10/20] | Loss: -0.051038 | Sum Kurt: -2.041532
Epoch [20/20] | Loss: -0.164774 | Sum Kurt: -6.590948
<<< Optimization done. Elapsed: 5.16 sec.
--- Rep 2 completed in 7.01 sec ---

=== Starting Rep 3/5 ===
>>> Starting Monte Carlo (LHS) sampling...


Evaluating LHS batches: 100%|██████████| 11/11 [00:01<00:00,  6.59it/s]


<<< LHS sampling done. Elapsed: 2.12 sec.
>>> Starting gradient-based optimization...
Epoch [10/20] | Loss: -0.073476 | Sum Kurt: -2.939049
Epoch [20/20] | Loss: -0.199691 | Sum Kurt: -7.987658
<<< Optimization done. Elapsed: 5.77 sec.
--- Rep 3 completed in 7.89 sec ---

=== Starting Rep 4/5 ===
>>> Starting Monte Carlo (LHS) sampling...


Evaluating LHS batches: 100%|██████████| 11/11 [00:01<00:00,  6.03it/s]


<<< LHS sampling done. Elapsed: 2.34 sec.
>>> Starting gradient-based optimization...
Epoch [10/20] | Loss: -0.055745 | Sum Kurt: -2.229819
Epoch [20/20] | Loss: -0.176116 | Sum Kurt: -7.044639
<<< Optimization done. Elapsed: 5.84 sec.
--- Rep 4 completed in 8.19 sec ---

=== Starting Rep 5/5 ===
>>> Starting Monte Carlo (LHS) sampling...


Evaluating LHS batches: 100%|██████████| 11/11 [00:01<00:00,  5.70it/s]


<<< LHS sampling done. Elapsed: 2.35 sec.
>>> Starting gradient-based optimization...
Epoch [10/20] | Loss: -0.094002 | Sum Kurt: -3.760073
Epoch [20/20] | Loss: -0.225015 | Sum Kurt: -9.000591
<<< Optimization done. Elapsed: 5.83 sec.
--- Rep 5 completed in 8.18 sec ---

=== All Reps Completed! ===
all_muaps length: 5
all_conditions length: 5
all_kurtosis_histories length: 5


In [5]:
all_conditions

[array([[ 0.17238152,  0.8784257 ,  0.15869375,  0.47223657,  0.9700714 ,
          1.0048656 ],
        [ 0.0674331 ,  0.67993236,  0.07561096,  0.82840586,  0.07996786,
          0.7172803 ],
        [ 0.03803804,  0.6183676 ,  0.30035716,  0.31508762,  0.57329684,
          0.35131457],
        [ 0.02723485,  0.91422206,  0.26567563,  0.5349531 ,  0.14440772,
          0.79823864],
        [ 0.03460067,  0.24588466,  0.1460819 ,  0.54492366,  0.63938564,
          0.13898897],
        [ 0.05285158,  0.30830514,  0.23775055,  0.43394637,  0.06994335,
          0.853135  ],
        [ 0.0606031 ,  0.72056264,  0.14379044,  0.66508937,  0.48143643,
          0.39333108],
        [ 0.06752861,  0.73904836,  0.14923917,  0.6647987 ,  0.5609234 ,
          0.7873349 ],
        [ 0.0192902 ,  0.5845917 ,  0.2980065 ,  0.30644813,  0.44942945,
          0.09597326],
        [ 0.04887617,  0.82458746,  0.35783178,  0.7194818 ,  0.92654574,
          0.33966675],
        [ 0.02690525,  0.84589

In [None]:
all_conds = np.stack(all_conditions).reshape(-1, 6)
with torch.no_grad():
    muaps_from_conds = biomime_gen(torch.tensor(all_conds).cuda())
training_data = emg_data.tensors['extended_emg'].to(device)
cov = emg_data.tensors['covariance_matrix'].to(device)


In [26]:
R = 10
muaps_tmp = torch.tensor(muaps_from_conds.reshape(-1, 96, 320)).permute(0, 2, 1).to(torch.float32)
filters = EMGProcessor.get_separation_vectors_torch(muaps_tmp, R=R).cuda()
sources_estimated = filters.T @ cov @ training_data
sources_estimated = sources_estimated.cpu().detach().numpy()

  muaps_tmp = torch.tensor(muaps_from_conds.reshape(-1, 96, 320)).permute(0, 2, 1).to(torch.float32)


In [30]:
sources_estimated.shape

(200, 20009)

In [27]:
from metrics import *

def calculate_silhouette(sources_estimated, n_components=2):
    """
    Calculate silhouette scores for the estimated sources.

    Parameters:
    - sources_estimated (numpy.ndarray): The estimated sources.
    - n_components (int): Number of components to calculate silhouette scores for.

    Returns:
    - avg_silhouette (float): Average silhouette score.
    - sil_scores (list): List of silhouette scores for each component.
    """
    torch.cuda.empty_cache()

    avg = 0.0
    sil_scores = []
    for i in range(n_components):
        sil = calculate_kmeans_and_silhouette(sources_estimated[i][:10000], init_method='maxmin')['silhouette_score']
        avg += sil
        sil_scores.append(sil)
    
    return avg / n_components, sil_scores

calculate_silhouette(sources_estimated, n_components=2)


(0.9271552264690399, [0.9313795, 0.92293096])

In [32]:
# Example usage:
np.random.seed(42)  # For reproducibility
S_true = emg_data.data_dict['spike_trains']

threshold_est = 0.3
threshold_true = 0.9
S_est = normalize_sources(sources_estimated)



matches, recall_matrix,thresholds_est, thresholds_true = \
    match_estimates_to_true(S_est, S_true, threshold_est=threshold_est, threshold_true=0.9)

precision_list, recall_list, mean_precision, mean_recall = evaluate_matches(
    S_est, 
    S_true, 
    matches, 
    thresholds_est, 
    thresholds_true
)


In [44]:
# Print matches and unique counts
print("Matches:")
print(matches)
print(f"Number of unique matches: {np.unique(matches).shape[0]}")

unique_matches = np.unique(matches.reshape(batch_size, -1), axis=0)


unique_counts = [len(np.unique(matches[:i+batch_size])) for i in range(0, len(matches), batch_size)]
print("Unique counts per batch:")
print(unique_counts)

# Calculate and print RMSE
conds_true = emg_data.tensors['base_conditions'][matches]
correct_muaps = recall_list > 0.9
conds_true_np = conds_true.cpu().numpy()[correct_muaps]
all_conds_np = all_conds[correct_muaps]

rmse = np.sqrt(np.mean((conds_true_np - all_conds_np) ** 2))
print(f"RMSE: {rmse}")

# Calculate and print MSE per MUAP
mse_per_muap = np.sqrt(np.mean((conds_true_np - all_conds_np) ** 2, axis=1))
print(f"Number of correct MUAPs: {correct_muaps.sum()}")
print(f"Number of MUAPs with MSE < 0.01: {np.sum(mse_per_muap < 0.01)}")

Matches:
[ 5  9 85 74 82 48 12 12 85 38 95 53 85  5 32 33 91 14 33 33 20 38  9 38
 89 31 60 98 19 54 38 31 80 43 44 35 79 16 13 14 77 95  8 24 95 97 54 15
 38  5 53  9  5 81 55 38 41 90 50 13 61 34 52 70 82 69 48 73 97 87 17 15
  0 94 35 50 58  7 45 47  9 37  9 82 13 12 70 35 71 34 31  6 60 16 31 35
  8  6 42 61  6 41 12  5 19 31 63 34 70 88 81 79 31 90 66 38  9  6  3 51
 78 34 44 50 91 60 59 88 21 95 70  4 90 22  9 77 13 23 80 81 33 22 81 38
  5 60  6 70 88 33 11 80 33 19 31 54 34 83 33 99 38 22 31 70 11 95 78 31
 16 33 92 29 31 31 70 34 34 97  9 97 12 85 38 63 55 78 34 73 54 12  5 51
 56 64 58 21 71 53  6 46]
Number of unique matches: 75
Unique counts per batch:
[28, 52, 61, 70, 75]
RMSE: 0.09797072410583496
Number of correct MUAPs: 157
Number of MUAPs with MSE < 0.01: 3
