In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR, LinearLR
from torch.optim.lr_scheduler import SequentialLR
from torch.amp import autocast, GradScaler
import numpy as np
from sklearn.metrics import accuracy_score, cohen_kappa_score
import logging
from tqdm import tqdm
import json
import os
from sleepdetector_new import ImprovedSleepdetector
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
from tqdm import tqdm

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')



# Assuming ImprovedSleepdetector is imported from your sleepdetector_new module

class EnsembleModel(nn.Module):
    def __init__(self, model_params, n_models=3):
        super().__init__()
        self.models = nn.ModuleList([ImprovedSleepdetector(**model_params) for _ in range(n_models)])
    
    def forward(self, x, spectral_features):
        outputs = [model(x.clone(), spectral_features.clone()) for model in self.models]
        return torch.mean(torch.stack(outputs), dim=0)

class DiverseEnsembleModel(nn.Module):
    def __init__(self, model_params, n_models=3):
        super().__init__()
        self.models = nn.ModuleList([
            ImprovedSleepdetector(**{**model_params, 'dropout': model_params['dropout'] * (i+1)/n_models})
            for i in range(n_models)
        ])
    
    def forward(self, x, spectral_features):
        outputs = [model(x.clone(), spectral_features.clone()) for model in self.models]
        return torch.mean(torch.stack(outputs), dim=0)

def load_best_params(file_path):
    with open(file_path, 'r') as f:
        params = json.load(f)
    return params['best_model_params']

def print_model_structure(model):
    for name, param in model.named_parameters():
        print(f"{name}: {param.shape}")

def load_data_and_params(config):
    data_dict = torch.load(config['preprocessed_data_path'])
    best_params_path = os.path.join(config['previous_model_path'], config['best_params_name'])
    best_params = load_best_params(best_params_path)
    return data_dict, best_params

# def get_scheduler(optimizer, num_warmup_steps, num_training_steps):
#     return SequentialLR(
#         optimizer,
#         schedulers=[
#             LinearLR(optimizer, start_factor=0.1, total_iters=num_warmup_steps),
#             CosineAnnealingLR(optimizer, T_max=num_training_steps - num_warmup_steps)
#         ],
#         milestones=[num_warmup_steps]
#     )

def get_scheduler(optimizer, num_warmup_steps, num_training_steps, min_lr=1e-6):
    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            min_lr,
            float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
        )

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

def find_lr(model, train_loader, optimizer, criterion, device, num_iter=100, start_lr=1e-8, end_lr=1):
    logging.info("Starting learning rate finder...")
    model.train()
    num_samples = len(train_loader.dataset)
    update_step = (end_lr / start_lr) ** (1 / num_iter)
    lr = start_lr
    optimizer.param_groups[0]["lr"] = lr
    running_loss = 0
    best_loss = float('inf')
    batch_num = 0
    losses = []
    log_lrs = []
    
    progress_bar = tqdm(range(num_iter), desc="Finding best LR")
    for i in progress_bar:
        try:
            inputs, spectral_features, targets = next(iter(train_loader))
        except StopIteration:
            train_loader = iter(train_loader)
            inputs, spectral_features, targets = next(train_loader)
        
        inputs, spectral_features, targets = inputs.to(device), spectral_features.to(device), targets.to(device)
        batch_size = inputs.size(0)
        batch_num += 1
        
        optimizer.zero_grad()
        outputs = model(inputs, spectral_features)
        loss = criterion(outputs, targets)
        
        # Compute the smoothed loss
        running_loss = 0.98 * running_loss + 0.02 * loss.item()
        smoothed_loss = running_loss / (1 - 0.98**batch_num)
        
        # Record the best loss
        if smoothed_loss < best_loss:
            best_loss = smoothed_loss
        
        # Stop if the loss is exploding
        if batch_num > 1 and smoothed_loss > 4 * best_loss:
            logging.info(f"Loss is exploding, stopping early at lr={lr:.2e}")
            break
        
        # Store the values
        losses.append(smoothed_loss)
        log_lrs.append(math.log10(lr))
        
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        lr *= update_step
        optimizer.param_groups[0]["lr"] = lr
        
        progress_bar.set_postfix({'loss': f'{smoothed_loss:.4f}', 'lr': f'{lr:.2e}'})
    
    plt.figure(figsize=(10, 6))
    plt.plot(log_lrs[10:-5], losses[10:-5])
    plt.xlabel("Log Learning Rate")
    plt.ylabel("Loss")
    plt.title("Learning Rate vs. Loss")
    plt.savefig('lr_finder_plot.png')
    plt.close()
    
    # Find the learning rate with the steepest negative gradient
    smoothed_losses = np.array(losses[10:-5])
    smoothed_lrs = np.array(log_lrs[10:-5])
    gradients = (smoothed_losses[1:] - smoothed_losses[:-1]) / (smoothed_lrs[1:] - smoothed_lrs[:-1])
    best_lr = 10 ** smoothed_lrs[np.argmin(gradients)]
    
    # Adjust the learning rate to be slightly lower than the one with steepest gradient
    best_lr *= 0.1
    
    logging.info(f"Learning rate finder completed. Suggested Learning Rate: {best_lr:.2e}")
    logging.info("Learning rate vs. loss plot saved as 'lr_finder_plot.png'")
    return best_lr

def train_model(model, train_loader, val_data, optimizer, scheduler, criterion, device, epochs=100, accumulation_steps=4, log_interval=5):
    scaler = GradScaler()
    best_accuracy = 0
    best_model_state = None
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        
        epoch_progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
        for i, (batch_x, batch_x_spectral, batch_y) in enumerate(epoch_progress):
            batch_x, batch_x_spectral, batch_y = batch_x.to(device), batch_x_spectral.to(device), batch_y.to(device)
            
            with autocast(device_type='cuda' if device.type == 'cuda' else 'cpu'):
                outputs = model(batch_x, batch_x_spectral)
                loss = criterion(outputs, batch_y)
                loss = loss / accumulation_steps
            
            # Check for NaN loss
            if torch.isnan(loss).any():
                logging.error(f"NaN loss detected at epoch {epoch+1}, batch {i+1}")
                return None, 0
            
            scaler.scale(loss).backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            if (i + 1) % accumulation_steps == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            
            running_loss += loss.item() * accumulation_steps
            
            epoch_progress.set_postfix({'loss': f'{running_loss/(i+1):.4f}', 'lr': f'{optimizer.param_groups[0]["lr"]:.2e}'})
        
        scheduler.step()
        
        if (epoch + 1) % log_interval == 0 or epoch == epochs - 1:
            accuracy = evaluate_model(model, val_data, device)
            logging.info(f"Epoch {epoch+1}/{epochs} - Loss: {running_loss/len(train_loader):.4f}, Accuracy: {accuracy:.4f}, LR: {optimizer.param_groups[0]['lr']:.2e}")
            
            if accuracy > best_accuracy:
                best_accuracy = accuracy
                best_model_state = model.state_dict()
                logging.info(f"New best model saved with accuracy: {best_accuracy:.4f}")
    
    return best_model_state, best_accuracy

def evaluate_model(model, data, device):
    model.eval()
    X, X_spectral, y = data
    with torch.no_grad():
        outputs = model(X.to(device), X_spectral.to(device))
        _, predicted = torch.max(outputs, 1)
        accuracy = accuracy_score(y.cpu().numpy(), predicted.cpu().numpy())
    return accuracy

def distill_knowledge(teacher_model, student_model, train_loader, val_data, device, num_epochs=50, log_interval=5):
    optimizer = optim.AdamW(student_model.parameters(), lr=1e-3, weight_decay=1e-2)
    scheduler = get_scheduler(optimizer, num_warmup_steps=len(train_loader) * 5, num_training_steps=len(train_loader) * num_epochs)
    criterion = nn.KLDivLoss(reduction='batchmean')
    
    teacher_model.eval()
    overall_progress = tqdm(total=num_epochs, desc="Overall Distillation Progress", position=0)
    
    for epoch in range(num_epochs):
        student_model.train()
        running_loss = 0.0
        
        epoch_progress = tqdm(train_loader, desc=f"Distillation Epoch {epoch+1}/{num_epochs}", position=1, leave=False)
        for batch_x, batch_x_spectral, batch_y in epoch_progress:
            batch_x, batch_x_spectral, batch_y = batch_x.to(device), batch_x_spectral.to(device), batch_y.to(device)
            
            with torch.no_grad():
                teacher_outputs = teacher_model(batch_x, batch_x_spectral)
            
            student_outputs = student_model(batch_x, batch_x_spectral)
            
            loss = criterion(F.log_softmax(student_outputs / 2, dim=1),
                             F.softmax(teacher_outputs / 2, dim=1))
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            running_loss += loss.item()
            
            epoch_progress.set_postfix({'loss': f'{running_loss/(epoch_progress.n+1):.4f}', 'lr': f'{optimizer.param_groups[0]["lr"]:.2e}'})
        
        # Evaluate and log every log_interval epochs
        if (epoch + 1) % log_interval == 0 or epoch == num_epochs - 1:
            accuracy = evaluate_model(student_model, val_data, device)
            logging.info(f"Distillation Epoch {epoch+1}/{num_epochs} - Loss: {running_loss/len(train_loader):.4f}, Accuracy: {accuracy:.4f}, LR: {optimizer.param_groups[0]['lr']:.2e}")
        
        overall_progress.update(1)
    
    overall_progress.close()
    return student_model


In [None]:
config = {
        'previous_model_path': './models/original/',
        'new_model_path': './models/new/',
        'best_params_name': 'best_params_ensemble.json',
        'best_model_name': 'best_ensemble_model.pth',
        'use_pretrained': True,
        'pretrained_weights_path': './models/original/best_ensemble_model.pth',
        'preprocessed_data_path': './preprocessed_data.pt'
    }

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")

# Load data and parameters
data_dict, best_params = load_data_and_params(config)

# Extract data
X_train, X_train_spectral, y_train = data_dict['X_train'], data_dict['X_train_spectral'], data_dict['y_train']
X_test, X_test_spectral, y_test = data_dict['X_test'], data_dict['X_test_spectral'], data_dict['y_test']

# Print shapes for verification
# print(f"X_train shape: {X_train.shape}")
# print(f"X_train_spectral shape: {X_train_spectral.shape}")
# print(f"y_train shape: {y_train.shape}")
# print(f"X_test shape: {X_test.shape}")
# print(f"X_test_spectral shape: {X_test_spectral.shape}")
# print(f"y_test shape: {y_test.shape}")

# Initialize model with best parameters
ensemble_model = EnsembleModel(best_params).to(device)

# Print the structure of the new model
# print("\nModel Structure:")
# print_model_structure(ensemble_model)

# Load pre-trained weights if available
if config['use_pretrained'] and os.path.exists(config['pretrained_weights_path']):
    try:
        state_dict = torch.load(config['pretrained_weights_path'])
        ensemble_model.load_state_dict(state_dict)
        logging.info("Successfully loaded pre-trained weights")
    except Exception as e:
        logging.error(f"Error loading pre-trained weights: {e}")
        logging.info("Attempting to load with strict=False")
        incompatible_keys = ensemble_model.load_state_dict(state_dict, strict=False)
        logging.info(f"Missing Keys: {incompatible_keys.missing_keys}")
        logging.info(f"Unexpected Keys: {incompatible_keys.unexpected_keys}")
else:
    logging.info("Starting with fresh weights")



In [None]:
logging.info("Starting training process...")
overall_steps = 4  # LR finding, Ensemble training, Diverse Ensemble training, Knowledge Distillation
overall_progress = tqdm(total=overall_steps, desc="Overall Training Progress", position=0)

In [None]:

# Create data loader
train_loader = DataLoader(
    TensorDataset(X_train, X_train_spectral, y_train), 
    batch_size=32,  # You might want to adjust this based on your best_params
    shuffle=True
)

# Find best learning rate
optimizer = optim.AdamW(ensemble_model.parameters(), lr=1e-8, weight_decay=1e-2)
criterion = nn.CrossEntropyLoss()

logging.info("Finding best learning rate...")
best_lr = find_lr(ensemble_model, train_loader, optimizer, criterion, device, num_iter=100, start_lr=1e-8, end_lr=1)
overall_progress.update(1)



In [None]:
# Set up optimizer and scheduler with best learning rate
optimizer = optim.AdamW(ensemble_model.parameters(), lr=best_lr, weight_decay=1e-2)
num_epochs = 100  # Adjust as needed
num_warmup_steps = len(train_loader) * 5  # 5 epochs of warmup
num_training_steps = len(train_loader) * num_epochs
scheduler = get_scheduler(optimizer, num_warmup_steps, num_training_steps)

# Train model
logging.info("Training ensemble model...")
best_model_state, best_accuracy = train_model(
    ensemble_model, train_loader, (X_test, X_test_spectral, y_test),
    optimizer, scheduler, criterion, device, epochs=num_epochs
)
overall_progress.update(1)

# Save best model
if best_model_state is not None:
    torch.save(best_model_state, os.path.join(config['new_model_path'], config['best_model_name']))
    logging.info(f"Best ensemble model saved. Final accuracy: {best_accuracy:.4f}")
else:
    logging.error("Training failed due to NaN loss.")


In [None]:
# Train diverse ensemble
diverse_ensemble = DiverseEnsembleModel(best_params).to(device)
diverse_optimizer = optim.AdamW(diverse_ensemble.parameters(), lr=best_lr, weight_decay=1e-2)
diverse_scheduler = get_scheduler(diverse_optimizer, num_warmup_steps, num_training_steps)


logging.info("Training diverse ensemble model...")
diverse_best_state, diverse_accuracy = train_model(
    diverse_ensemble, train_loader, (X_test, X_test_spectral, y_test),
    diverse_optimizer, diverse_scheduler, criterion, device, epochs=num_epochs
)
overall_progress.update(1)


torch.save(diverse_best_state, os.path.join(config['new_model_path'], 'best_diverse_ensemble_model.pth'))
logging.info(f"Best diverse ensemble model saved. Final accuracy: {diverse_accuracy:.4f}")

# Distill knowledge
single_model = ImprovedSleepdetector(**best_params).to(device)

logging.info("Performing knowledge distillation...")
distilled_model = distill_knowledge(ensemble_model, single_model, train_loader, (X_test, X_test_spectral, y_test), device)
overall_progress.update(1)


torch.save(distilled_model.state_dict(), os.path.join(config['new_model_path'], 'distilled_model.pth'))
overall_progress.close()

In [None]:
# Final evaluation
ensemble_model.load_state_dict(best_model_state)
final_accuracy = evaluate_model(ensemble_model, (X_test, X_test_spectral, y_test), device)

diverse_ensemble.load_state_dict(diverse_best_state)
diverse_final_accuracy = evaluate_model(diverse_ensemble, (X_test, X_test_spectral, y_test), device)

distilled_accuracy = evaluate_model(distilled_model, (X_test, X_test_spectral, y_test), device)


logging.info(f"Training completed. Best accuracy: {best_accuracy:.4f}")
logging.info(f"Ensemble Model - Final Test Accuracy: {final_accuracy:.4f}")
logging.info(f"Diverse Ensemble Model - Final Test Accuracy: {diverse_final_accuracy:.4f}")
logging.info(f"Distilled Model - Final Test Accuracy: {distilled_accuracy:.4f}")