In [3]:
from tools.functions import *
from tools.classes import *
from tools.utils import *
from tools.config import CONFIG, device
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.utils.tensorboard import SummaryWriter as TorchSummaryWriter
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

set_seed()

In [4]:
# Ensure the model save directory exists
ensure_dir(CONFIG['new_model_path'])

# Load Data
try:
    x, y = load_data(CONFIG['data_path'])
    logging.info(f"Loaded data shape: {x.shape}, Labels shape: {y.shape}")
except Exception as e:
    logging.error(f"Error loading data: {str(e)}")
    raise

# Prepare the data (includes SMOTE)
X_train, X_train_spectral, y_train, X_val, X_val_spectral, y_val, X_test, X_test_spectral, y_test = prepare_data(x, y)

# Apply preprocessing
X_train, X_train_spectral = preprocess_data(X_train, X_train_spectral)
X_val, X_val_spectral = preprocess_data(X_val, X_val_spectral)
X_test, X_test_spectral = preprocess_data(X_test, X_test_spectral)

# Identify minority classes for augmentation
class_counts = Counter(y_train.numpy())
minority_classes = [cls for cls, count in class_counts.items() if count < len(y_train) / len(class_counts) * 0.5]

# Apply augmentation
X_train, X_train_spectral, y_train = augment_minority_classes(X_train, X_train_spectral, y_train, minority_classes)

2024-10-21 21:27:34,284 - INFO - Loaded data shape: torch.Size([1066, 4, 3000]), Labels shape: torch.Size([1066])


Loaded data shape: torch.Size([1066, 4, 3000]), Labels shape: torch.Size([1066])
Original train set class distribution:
Counter({1: 467, 3: 134, 4: 100, 2: 39, 0: 5})
Not enough samples in minority class for SMOTE. Using simple oversampling.
After simple oversampling train set class distribution:
Counter({2: 467, 1: 467, 4: 467, 3: 467, 0: 467})


In [5]:
run_tuning = True
start_with_config = True  # Set this to True to start with CONFIG parameters
fine_tune_lr = True  # Set this to True if you want to fine-tune the learning rate after hyperparameter tuning

if run_tuning:
    logging.info("Starting hyperparameter tuning...")
    best_params = run_hyperparameter_tuning(X_train, X_train_spectral, y_train, device, start_with_config=start_with_config)
    
    # Initialize model with best parameters
    model_params = {k: v for k, v in best_params.items() if k in ['n_filters', 'lstm_hidden', 'lstm_layers', 'dropout']}
    ensemble_model = EnsembleModel(model_params).to(device)
    
    # Create data loaders
    train_loader = create_data_loaders(X_train, X_train_spectral, y_train, batch_size=best_params['batch_size'], is_train=True)
    val_loader = create_data_loaders(X_val, X_val_spectral, y_val, batch_size=best_params['batch_size'], is_train=False)
    
    # Set up loss function
    class_weights = get_class_weights(y_train).to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights + 1e-6, label_smoothing=0.1)
    
    # Optionally, find best learning rate
    if fine_tune_lr:
        temp_optimizer = optim.AdamW(ensemble_model.parameters(), lr=best_params['lr'], weight_decay=1e-5)
        best_lr = find_lr(ensemble_model, train_loader, val_loader, temp_optimizer, criterion, device, start_lr=best_params['lr'])
        logging.info(f"Fine-tuned learning rate: {best_lr}")
    else:
        best_lr = best_params['lr']
    
    params = {
        'model_params': model_params,
        'train_params': {'lr': best_lr, 'batch_size': best_params['batch_size'], 'num_epochs': CONFIG['initial_params']['train_params']['num_epochs'], 'patience': CONFIG['initial_params']['train_params']['patience']}
    }
else:
    params = CONFIG['initial_params']
    ensemble_model, _ = initialize_model(device)

if CONFIG['use_pretrained_weights']:
    pretrained_path = os.path.join(CONFIG['old_model_path'], CONFIG['model_names']['ensemble'])
    ensemble_model.load_state_dict(torch.load(pretrained_path))
    logging.info(f"Loaded pretrained weights from {pretrained_path}")

# Save parameters
save_params(params, os.path.join(CONFIG['new_model_path'], 'tuned_params.json'))

# Set up training parameters
train_params = params['train_params']
train_loader = create_data_loaders(X_train, X_train_spectral, y_train, batch_size=train_params['batch_size'], is_train=True)
val_loader = create_data_loaders(X_val, X_val_spectral, y_val, batch_size=train_params['batch_size'], is_train=False)

# Set up optimizer and scheduler with the selected learning rate
optimizer = optim.AdamW(ensemble_model.parameters(), lr=train_params['lr'], weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)

# Train model
best_model_state, best_accuracy = train_model(
    ensemble_model, train_loader, (X_val, X_val_spectral, y_val),
    optimizer, scheduler, criterion, device, epochs=train_params['num_epochs'], patience=train_params['patience']
)

# Save best model
if best_model_state is not None:
    save_model(ensemble_model, os.path.join(CONFIG['new_model_path'], CONFIG['model_names']['ensemble']))
    logging.info(f"Best ensemble model saved. Final validation accuracy: {best_accuracy:.4f}")

    # Evaluate on test set
    ensemble_model.load_state_dict(best_model_state)
    test_loss, test_accuracy, test_predictions = evaluate_model(ensemble_model, (X_test, X_test_spectral, y_test), criterion, device)
    logging.info(f"Ensemble Model - Final Test Accuracy: {test_accuracy:.4f}")

    # Generate and save confusion matrix
    cm = confusion_matrix(y_test.cpu().numpy(), test_predictions)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.savefig(os.path.join(CONFIG['new_model_path'], 'confusion_matrix.png'))

    # Generate classification report
    report = classification_report(y_test.cpu().numpy(), test_predictions)
    logging.info(f"Classification Report:\n{report}")

2024-10-21 21:27:42,538 - INFO - Starting hyperparameter tuning...
[I 2024-10-21 21:27:42,540] A new study created in memory with name: no-name-95b6821d-aa04-4c89-900d-d383110fd2e5
Training Progress:   0%|          | 0/1 [01:39<?, ?it/s]
[W 2024-10-21 21:29:29,382] Trial 0 failed with parameters: {'n_filters': [32, 64, 128], 'lstm_hidden': 264, 'lstm_layers': 2, 'dropout': 0.22931168779815797, 'batch_size': 32, 'lr': 0.0031208472635093423} because of the following error: KeyboardInterrupt().
Traceback (most recent call last):
  File "/userdata/jkrolik/miniconda3/envs/myenv/lib/python3.10/site-packages/optuna/study/_optimize.py", line 197, in _run_trial
    value_or_values = func(trial)
  File "/userdata/jkrolik/eeg-sleepstage-classifier/tools/functions.py", line 450, in <lambda>
    study.optimize(lambda trial: objective(trial, X, X_spectral, y, device, start_with_config=start_with_config), n_trials=n_trials)
  File "/userdata/jkrolik/eeg-sleepstage-classifier/tools/functions.py", line

KeyboardInterrupt: 

In [5]:
# Modify the diverse ensemble training section
diverse_ensemble = DiverseEnsembleModel(CONFIG['initial_params']['model_params']).to(device)
diverse_optimizer = optim.AdamW(diverse_ensemble.parameters(), lr=1e-3, weight_decay=1e-4)
diverse_scheduler = CosineAnnealingWarmRestarts(diverse_optimizer, T_0=10, T_mult=2, eta_min=1e-6)

# Set up TensorBoard
writer = TorchSummaryWriter(log_dir=os.path.join(CONFIG['new_model_path'], 'tensorboard_logs2'))

logging.info("Training diverse ensemble model...")
for epoch in range(train_params['num_epochs']):
    diverse_ensemble.train()
    epoch_loss = 0
    for batch_idx, (data, spectral_features, target) in enumerate(train_loader):
        data, spectral_features, target = data.to(device), spectral_features.to(device), target.to(device)
        diverse_optimizer.zero_grad()
        output = diverse_ensemble(data, spectral_features)
        loss = criterion(output, target)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(diverse_ensemble.parameters(), max_norm=1.0)
        
        diverse_optimizer.step()
        epoch_loss += loss.item()
        
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    
    # Calculate average training loss for the epoch
    avg_train_loss = epoch_loss / len(train_loader)
    
    # Validation
    diverse_ensemble.eval()
    val_loss = 0
    correct = 0
    with torch.no_grad():
        for data, spectral_features, target in val_loader:  # Use val_loader instead of zipping X_val, X_val_spectral, y_val
            data, spectral_features, target = data.to(device), spectral_features.to(device), target.to(device)
            output = diverse_ensemble(data, spectral_features)
            val_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    val_loss /= len(val_loader)
    accuracy = correct / len(val_loader.dataset)

    print(f'Validation set: Average loss: {val_loss:.4f}, Accuracy: {correct}/{len(val_loader.dataset)} ({accuracy:.2f}%)')

    # Log to TensorBoard
    writer.add_scalar('Loss/train', avg_train_loss, epoch)
    writer.add_scalar('Loss/validation', val_loss, epoch)
    writer.add_scalar('Accuracy/validation', accuracy, epoch)
    writer.add_scalar('Learning Rate', diverse_optimizer.param_groups[0]['lr'], epoch)

    diverse_scheduler.step()

    # Check for improvement and save the best model
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        diverse_best_state = diverse_ensemble.state_dict()
        save_model(diverse_ensemble, os.path.join(CONFIG['new_model_path'], CONFIG['model_names']['diverse2']))
        logging.info(f"New best diverse ensemble model saved. Accuracy: {accuracy:.4f}")

# Close TensorBoard writer
writer.close()

logging.info(f"Diverse ensemble training completed. Best accuracy: {best_accuracy:.4f}")

# Evaluate the best diverse ensemble model on the test set
diverse_ensemble.load_state_dict(diverse_best_state)
test_loss, test_accuracy, test_predictions = evaluate_model(diverse_ensemble, (X_test, X_test_spectral, y_test), criterion, device)
logging.info(f"Diverse Ensemble Model - Final Test Accuracy: {test_accuracy:.4f}")

# Generate and save confusion matrix for diverse ensemble
cm = confusion_matrix(y_test.cpu().numpy(), test_predictions)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix - Diverse Ensemble')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.savefig(os.path.join(CONFIG['new_model_path'], 'diverse_ensemble_confusion_matrix2.png'))
plt.close()

# Generate classification report for diverse ensemble
report = classification_report(y_test.cpu().numpy(), test_predictions)
logging.info(f"Diverse Ensemble Classification Report:\n{report}")

2024-10-10 13:33:59,552 - INFO - Training diverse ensemble model...


Validation set: Average loss: 1.5474, Accuracy: 34/107 (0.32%)
Validation set: Average loss: 0.9783, Accuracy: 76/107 (0.71%)
Validation set: Average loss: 1.0700, Accuracy: 75/107 (0.70%)
Validation set: Average loss: 1.5015, Accuracy: 38/107 (0.36%)
Validation set: Average loss: 2.0008, Accuracy: 22/107 (0.21%)
Validation set: Average loss: 1.1069, Accuracy: 74/107 (0.69%)
Validation set: Average loss: 0.9498, Accuracy: 83/107 (0.78%)
Validation set: Average loss: 0.9906, Accuracy: 81/107 (0.76%)
Validation set: Average loss: 1.2587, Accuracy: 67/107 (0.63%)
Validation set: Average loss: 1.6496, Accuracy: 35/107 (0.33%)
Validation set: Average loss: 1.2008, Accuracy: 76/107 (0.71%)
Validation set: Average loss: 1.3121, Accuracy: 69/107 (0.64%)
Validation set: Average loss: 1.4315, Accuracy: 52/107 (0.49%)
Validation set: Average loss: 1.2661, Accuracy: 64/107 (0.60%)
Validation set: Average loss: 1.0636, Accuracy: 76/107 (0.71%)
Validation set: Average loss: 1.6936, Accuracy: 43/107 

2024-10-10 13:58:03,068 - INFO - New best diverse ensemble model saved. Accuracy: 0.9065


Validation set: Average loss: 0.7205, Accuracy: 89/107 (0.83%)
Validation set: Average loss: 0.7797, Accuracy: 88/107 (0.82%)
Validation set: Average loss: 0.6538, Accuracy: 94/107 (0.88%)
Validation set: Average loss: 0.6700, Accuracy: 91/107 (0.85%)
Validation set: Average loss: 0.6892, Accuracy: 93/107 (0.87%)
Validation set: Average loss: 0.6523, Accuracy: 95/107 (0.89%)
Validation set: Average loss: 0.9589, Accuracy: 78/107 (0.73%)
Validation set: Average loss: 0.8330, Accuracy: 85/107 (0.79%)
Validation set: Average loss: 0.8865, Accuracy: 82/107 (0.77%)
Validation set: Average loss: 0.7513, Accuracy: 91/107 (0.85%)
Validation set: Average loss: 0.7986, Accuracy: 87/107 (0.81%)
Validation set: Average loss: 0.9080, Accuracy: 80/107 (0.75%)
Validation set: Average loss: 0.7492, Accuracy: 94/107 (0.88%)
Validation set: Average loss: 0.8238, Accuracy: 87/107 (0.81%)
Validation set: Average loss: 0.8047, Accuracy: 90/107 (0.84%)
Validation set: Average loss: 0.6593, Accuracy: 95/107 

2024-10-10 16:58:35,878 - INFO - Diverse ensemble training completed. Best accuracy: 0.9065


Validation set: Average loss: 0.7012, Accuracy: 97/107 (0.91%)


2024-10-10 16:58:36,251 - INFO - Diverse Ensemble Model - Final Test Accuracy: 0.8692
2024-10-10 16:58:36,995 - INFO - Diverse Ensemble Classification Report:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         1
           1       0.94      0.94      0.94       134
           2       0.67      0.17      0.27        12
           3       0.71      0.89      0.79        38
           4       0.86      0.83      0.84        29

    accuracy                           0.87       214
   macro avg       0.63      0.57      0.57       214
weighted avg       0.87      0.87      0.86       214



In [6]:
# # Train diverse ensemble
# diverse_ensemble = DiverseEnsembleModel(CONFIG['initial_params']['model_params']).to(device)
# diverse_optimizer = optim.AdamW(diverse_ensemble.parameters(), lr=train_params['lr'], weight_decay=1e-2)
# diverse_scheduler = get_scheduler(diverse_optimizer, num_warmup_steps=len(train_loader)*5, num_training_steps=len(train_loader)*train_params['num_epochs'])

# logging.info("Training diverse ensemble model...")
# diverse_best_state, diverse_accuracy = train_model(
#     diverse_ensemble, train_loader, (X_val, X_val_spectral, y_val),
#     diverse_optimizer, diverse_scheduler, criterion, device, epochs=train_params['num_epochs']
# )

# # save_model(diverse_ensemble, os.path.join(CONFIG['new_model_path'], CONFIG['model_names']['diverse']))
# logging.info(f"Best diverse ensemble model saved. Final accuracy: {diverse_accuracy:.4f}")

# Distill knowledge
single_model = ImprovedSleepdetector(**CONFIG['initial_params']['model_params']).to(device)

logging.info("Performing knowledge distillation...")
distilled_model = distill_knowledge(ensemble_model, single_model, train_loader, (X_val, X_val_spectral, y_val), device)

save_model(distilled_model, os.path.join(CONFIG['new_model_path'], CONFIG['model_names']['distilled2']))

# Final evaluation
_, ensemble_accuracy, _ = evaluate_model(ensemble_model, (X_test, X_test_spectral, y_test), criterion, device)
_, diverse_accuracy, _ = evaluate_model(diverse_ensemble, (X_test, X_test_spectral, y_test), criterion, device)
_, distilled_accuracy, _ = evaluate_model(distilled_model, (X_test, X_test_spectral, y_test), criterion, device)

logging.info(f"Training completed.")
logging.info(f"Ensemble Model - Final Test Accuracy: {ensemble_accuracy:.4f}")
logging.info(f"Diverse Ensemble Model - Final Test Accuracy: {diverse_accuracy:.4f}")
logging.info(f"Distilled Model - Final Test Accuracy: {distilled_accuracy:.4f}")

2024-10-10 16:58:37,542 - INFO - Performing knowledge distillation...
Overall Distillation Progress:   8%|▊         | 4/50 [00:33<06:27,  8.42s/it]2024-10-10 16:59:19,898 - INFO - Distillation Epoch 5/50 - Loss: 1.3472, Accuracy: 0.1402, LR: 1.00e-05
Overall Distillation Progress:  18%|█▊        | 9/50 [01:15<05:39,  8.27s/it]2024-10-10 17:00:01,403 - INFO - Distillation Epoch 10/50 - Loss: 0.5650, Accuracy: 0.7570, LR: 8.89e-06
Overall Distillation Progress:  28%|██▊       | 14/50 [01:57<05:00,  8.35s/it]2024-10-10 17:00:43,133 - INFO - Distillation Epoch 15/50 - Loss: 0.1576, Accuracy: 0.8224, LR: 7.78e-06
Overall Distillation Progress:  38%|███▊      | 19/50 [02:37<04:11,  8.13s/it]2024-10-10 17:01:23,865 - INFO - Distillation Epoch 20/50 - Loss: 0.1195, Accuracy: 0.8505, LR: 6.67e-06
Overall Distillation Progress:  48%|████▊     | 24/50 [03:19<03:34,  8.26s/it]2024-10-10 17:02:05,368 - INFO - Distillation Epoch 25/50 - Loss: 0.1096, Accuracy: 0.6262, LR: 5.56e-06
Overall Distillati