In [None]:
# Imports
import scipy.io as sio
import matplotlib.pyplot as plt
from sklearn.metrics import cohen_kappa_score, accuracy_score, confusion_matrix
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sleepdetector_new import ImprovedSleepdetector
from tqdm import tqdm
import seaborn as sns
from sklearn.model_selection import train_test_split, KFold
from scipy.signal import welch
from imblearn.over_sampling import SMOTE
from torch.optim.lr_scheduler import ReduceLROnPlateau
import optuna
import os
import logging
import json
from torch.optim.lr_scheduler import SequentialLR
from torch.amp import autocast, GradScaler
from scipy.interpolate import CubicSpline
from torch_lr_finder import LRFinder
import torch.nn.functional as F
# from sleepdetector_newest import ImprovedSleepdetector
from sleepdetector_new import ImprovedSleepdetector
# from sleepdetector_old import ImprovedSleepdetector
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR, LinearLR
from torch.optim.lr_scheduler import SequentialLR
from sklearn.metrics import accuracy_score, cohen_kappa_score
import torch.nn.functional as F
import math
from collections import Counter

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Set random seed for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()


In [None]:
class EnsembleModel(nn.Module):
    def __init__(self, model_params, n_models=3):
        super().__init__()
        self.models = nn.ModuleList([ImprovedSleepdetector(**model_params) for _ in range(n_models)])
    
    def forward(self, x, spectral_features):
        outputs = [model(x.clone(), spectral_features.clone()) for model in self.models]
        return torch.mean(torch.stack(outputs), dim=0)

class DiverseEnsembleModel(nn.Module):
    def __init__(self, model_params, n_models=3):
        super().__init__()
        self.models = nn.ModuleList([
            ImprovedSleepdetector(**{**model_params, 'dropout': model_params['dropout'] * (i+1)/n_models})
            for i in range(n_models)
        ])
    
    def forward(self, x, spectral_features):
        outputs = [model(x.clone(), spectral_features.clone()) for model in self.models]
        return torch.mean(torch.stack(outputs), dim=0)

In [None]:
def load_best_params(file_path):
    with open(file_path, 'r') as f:
        params = json.load(f)
    return params['best_model_params']



def load_data_and_params(config):
    data_dict = torch.load(config['preprocessed_data_path'])
    best_params_path = os.path.join(config['previous_model_path'], config['best_params_name'])
    best_params = load_best_params(best_params_path)
    return data_dict, best_params


def print_model_structure(model):
    for name, param in model.named_parameters():
        print(f"{name}: {param.shape}")

def load_data(filepath, add_dim=False):
    try:
        # Load the data from the .mat file
        mat_file = sio.loadmat(filepath)
        
        # Stack the signals into x
        x = np.stack((mat_file['sig1'], mat_file['sig2'], mat_file['sig3'], mat_file['sig4']), axis=1)
        x = torch.from_numpy(x).float()  # Convert to PyTorch tensor
        
        # Load the labels
        y = torch.from_numpy(mat_file['labels'].flatten()).long()
        
        # Remove epochs where y is -1 (if any)
        valid_indices = y != -1
        x = x[valid_indices]
        y = y[valid_indices]
        
        # Ensure x is in shape [number of epochs, 4, 3000]
        if x.dim() == 2:
            x = x.unsqueeze(1)
        
        if add_dim:
            x = x.unsqueeze(1)  # Add an extra dimension if required
        
        print(f"Loaded data shape: {x.shape}, Labels shape: {y.shape}")
        
        return x, y

    except Exception as e:
        logging.error(f"Error loading data: {e}")
        raise


# def extract_spectral_features(x):
#     features = []
#     for epoch in x:
#         epoch_features = []
#         for channel in epoch:
#             # Check if channel is a PyTorch tensor, if so convert to numpy array
#             if isinstance(channel, torch.Tensor):
#                 channel = channel.numpy()
#             f, psd = welch(channel, fs=100, nperseg=1000)
#             delta = np.sum(psd[(f >= 0.5) & (f <= 4)])
#             theta = np.sum(psd[(f > 4) & (f <= 8)])
#             alpha = np.sum(psd[(f > 8) & (f <= 13)])
#             beta = np.sum(psd[(f > 13) & (f <= 30)])
#             epoch_features.extend([delta, theta, alpha, beta])
#         features.append(epoch_features)
#     return np.array(features)

def extract_spectral_features(x):
    features = []
    for channel in range(x.shape[0]):  # Iterate over channels
        # Convert to NumPy array for scipy.signal.welch
        channel_data = x[channel].cpu().numpy()
        f, psd = welch(channel_data, fs=100, nperseg=min(1000, len(channel_data)))
        delta = np.sum(psd[(f >= 0.5) & (f <= 4)])
        theta = np.sum(psd[(f > 4) & (f <= 8)])
        alpha = np.sum(psd[(f > 8) & (f <= 13)])
        beta = np.sum(psd[(f > 13) & (f <= 30)])
        features.extend([delta, theta, alpha, beta])
    return np.array(features)





def time_warp(x, sigma=0.2, knot=4):
    orig_steps = np.arange(x.shape[1])
    random_warps = np.random.normal(loc=1.0, scale=sigma, size=(x.shape[0], knot+2, x.shape[2]))
    warp_steps = (np.ones((x.shape[2],1))*(np.linspace(0, x.shape[1]-1., num=knot+2))).T
    ret = np.zeros_like(x)
    for i, pat in enumerate(x):
        for dim in range(x.shape[2]):
            time_warp = CubicSpline(warp_steps[:, dim], warp_steps[:, dim] * random_warps[i, :, dim])(orig_steps)
            scale = (x.shape[1]-1)/time_warp[-1]
            ret[i, :, dim] = np.interp(orig_steps, np.clip(scale*time_warp, 0, x.shape[1]-1), pat[:, dim]).T
    return ret




def augment_minority_classes(x, x_spectral, y, minority_classes):
    augmented_x = []
    augmented_x_spectral = []
    augmented_y = []
    for i in range(len(y)):
        augmented_x.append(x[i])
        augmented_x_spectral.append(x_spectral[i])
        augmented_y.append(y[i])
        if y[i] in minority_classes:
            # Apply time_warp augmentation
            augmented = torch.from_numpy(time_warp(x[i].unsqueeze(0).numpy(), sigma=0.3, knot=5)).squeeze(0)
            augmented_x.append(augmented)
            augmented_x_spectral.append(x_spectral[i])  # Duplicate spectral features for augmented data
            augmented_y.append(y[i])
    return torch.stack(augmented_x), torch.stack(augmented_x_spectral), torch.tensor(augmented_y)


# def prepare_data(x, y, test_size=0.2, split=True):
#     """
#     Prepare data for training or testing.
    
#     :param x: Input data tensor
#     :param y: Labels tensor
#     :param test_size: Proportion of the dataset to include in the test split
#     :param split: If True, split the data into train and test sets. If False, process all data without splitting.
#     :return: Processed data tensors
#     """
#     if split:
#         X_train, X_test, y_train, y_test = train_test_split(x.numpy(), y.numpy(), test_size=test_size, stratify=y, random_state=42)
        
#         X_train_spectral = extract_spectral_features(torch.from_numpy(X_train))
#         X_test_spectral = extract_spectral_features(torch.from_numpy(X_test))
        
#         X_train_combined = np.concatenate([X_train.reshape(X_train.shape[0], -1), X_train_spectral], axis=1)
#         X_test_combined = np.concatenate([X_test.reshape(X_test.shape[0], -1), X_test_spectral], axis=1)
        
#         smote = SMOTE(random_state=42)
#         X_train_resampled, y_train_resampled = smote.fit_resample(X_train_combined, y_train)
        
#         original_shape = list(X_train.shape)
#         original_shape[0] = X_train_resampled.shape[0]
#         spectral_shape = (X_train_resampled.shape[0], X_train_spectral.shape[1])
        
#         X_train_final = X_train_resampled[:, :-X_train_spectral.shape[1]].reshape(original_shape)
#         X_train_spectral_final = X_train_resampled[:, -X_train_spectral.shape[1]:].reshape(spectral_shape)
        
#         return (torch.from_numpy(X_train_final).float(),
#                 torch.from_numpy(X_train_spectral_final).float(),
#                 torch.from_numpy(y_train_resampled).long(),
#                 torch.from_numpy(X_test).float(),
#                 torch.from_numpy(X_test_spectral).float(),
#                 torch.from_numpy(y_test).long())
#     else:
#         X_spectral = extract_spectral_features(x)
        
#         return (x.float(),
#                 torch.from_numpy(X_spectral).float(),
#                 y.long())
    
def prepare_data(x, y, test_size=0.2):
    # Convert PyTorch tensors to NumPy arrays for scikit-learn and SMOTE
    x_np = x.cpu().numpy()
    y_np = y.cpu().numpy()

    # Split the data
    X_train, X_test, y_train, y_test = train_test_split(x_np, y_np, test_size=test_size, stratify=y_np, random_state=42)
    
    # Convert back to PyTorch tensors for spectral feature extraction
    X_train_torch = torch.from_numpy(X_train).float()
    X_test_torch = torch.from_numpy(X_test).float()

    # Extract spectral features
    X_train_spectral = np.array([extract_spectral_features(x) for x in X_train_torch])
    X_test_spectral = np.array([extract_spectral_features(x) for x in X_test_torch])
    
    print("Original train set class distribution:")
    print(Counter(y_train))
    
    # Reshape the data for SMOTE
    X_train_reshaped = X_train.reshape(X_train.shape[0], -1)
    X_train_spectral_reshaped = X_train_spectral.reshape(X_train_spectral.shape[0], -1)
    X_combined = np.hstack((X_train_reshaped, X_train_spectral_reshaped))
    
    # Apply SMOTE
    smote = SMOTE(sampling_strategy='auto', random_state=42)
    X_resampled, y_resampled = smote.fit_resample(X_combined, y_train)
    
    print("After SMOTE train set class distribution:")
    print(Counter(y_resampled))
    
    # Reshape the resampled data back to the original shape
    X_train_resampled = X_resampled[:, :X_train_reshaped.shape[1]].reshape(-1, X_train.shape[1], X_train.shape[2])
    X_train_spectral_resampled = X_resampled[:, X_train_reshaped.shape[1]:].reshape(-1, X_train_spectral.shape[1])
    
    # Convert to PyTorch tensors
    X_train = torch.from_numpy(X_train_resampled).float()
    X_train_spectral = torch.from_numpy(X_train_spectral_resampled).float()
    y_train = torch.from_numpy(y_resampled).long()
    X_test = torch.from_numpy(X_test).float()
    X_test_spectral = torch.from_numpy(X_test_spectral).float()
    y_test = torch.from_numpy(y_test).long()
    
    return X_train, X_train_spectral, y_train, X_test, X_test_spectral, y_test






def get_scheduler(optimizer, num_warmup_steps, num_training_steps, min_lr=1e-6):
    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            min_lr,
            float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
        )

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


def initialize_model(config, best_model_params, device):
    ensemble_model = EnsembleModel(best_model_params, n_models=3).to(device)
    
    if config['use_pretrained']:
        if os.path.exists(config['pretrained_weights_path']):
            # Load the state dict
            state_dict = torch.load(config['pretrained_weights_path'], map_location=device)
            
            # Filter out unnecessary keys
            model_dict = ensemble_model.state_dict()
            pretrained_dict = {k: v for k, v in state_dict.items() if k in model_dict and v.shape == model_dict[k].shape}
            
            # Update the model dict
            model_dict.update(pretrained_dict)
            
            # Load the filtered state dict
            ensemble_model.load_state_dict(model_dict, strict=False)
            
            # Log the loaded and missing keys
            loaded_keys = set(pretrained_dict.keys())
            all_keys = set(model_dict.keys())
            missing_keys = all_keys - loaded_keys
            
            logging.info(f"Loaded pre-trained weights from {config['pretrained_weights_path']}")
            logging.info(f"Number of loaded parameters: {len(loaded_keys)}")
            logging.info(f"Number of missing parameters: {len(missing_keys)}")
            if missing_keys:
                logging.warning(f"Missing keys: {missing_keys}")
        else:
            logging.warning(f"Pre-trained weights file not found at {config['pretrained_weights_path']}. Initializing with random weights.")
    else:
        logging.info("Initializing with random weights.")
    
    return ensemble_model


def load_params_and_initialize_model(config, device):
    params_path = os.path.join(config['new_model_path'], config['best_params_name'])
    
    try:
        with open(params_path, 'r') as f:
            loaded_params = json.load(f)
        
        best_model_params = loaded_params['best_model_params']
        best_train_params = loaded_params['best_train_params']
        
        print("Parameters loaded successfully.")
        print(f"Best model parameters: {best_model_params}")
        print(f"Best training parameters: {best_train_params}")
        
        ensemble_model = initialize_model(config, best_model_params, device)
        
        return ensemble_model, best_model_params, best_train_params
    
    except FileNotFoundError:
        print(f"Error: The file {params_path} was not found.")
        raise
    except json.JSONDecodeError:
        print(f"Error: The file {params_path} is not a valid JSON file.")
        raise
    except KeyError as e:
        print(f"Error: The key {e} was not found in the loaded parameters.")
        raise



def find_lr(model, train_loader, optimizer, criterion, device, num_iter=100, start_lr=1e-8, end_lr=1):
    logging.info("Starting learning rate finder...")
    model.train()
    num_samples = len(train_loader.dataset)
    update_step = (end_lr / start_lr) ** (1 / num_iter)
    lr = start_lr
    optimizer.param_groups[0]["lr"] = lr
    running_loss = 0
    best_loss = float('inf')
    batch_num = 0
    losses = []
    log_lrs = []
    
    progress_bar = tqdm(range(num_iter), desc="Finding best LR")
    for i in progress_bar:
        try:
            inputs, spectral_features, targets = next(iter(train_loader))
        except StopIteration:
            train_loader = iter(train_loader)
            inputs, spectral_features, targets = next(train_loader)
        
        inputs, spectral_features, targets = inputs.to(device), spectral_features.to(device), targets.to(device)
        batch_size = inputs.size(0)
        batch_num += 1
        
        optimizer.zero_grad()
        outputs = model(inputs, spectral_features)
        loss = criterion(outputs, targets)
        
        # Compute the smoothed loss
        running_loss = 0.98 * running_loss + 0.02 * loss.item()
        smoothed_loss = running_loss / (1 - 0.98**batch_num)
        
        # Record the best loss
        if smoothed_loss < best_loss:
            best_loss = smoothed_loss
        
        # Stop if the loss is exploding
        if batch_num > 1 and smoothed_loss > 4 * best_loss:
            logging.info(f"Loss is exploding, stopping early at lr={lr:.2e}")
            break
        
        # Store the values
        losses.append(smoothed_loss)
        log_lrs.append(math.log10(lr))
        
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        lr *= update_step
        optimizer.param_groups[0]["lr"] = lr
        
        progress_bar.set_postfix({'loss': f'{smoothed_loss:.4f}', 'lr': f'{lr:.2e}'})
    
    plt.figure(figsize=(10, 6))
    plt.plot(log_lrs[10:-5], losses[10:-5])
    plt.xlabel("Log Learning Rate")
    plt.ylabel("Loss")
    plt.title("Learning Rate vs. Loss")
    plt.savefig('lr_finder_plot.png')
    plt.close()
    
    # Find the learning rate with the steepest negative gradient
    smoothed_losses = np.array(losses[10:-5])
    smoothed_lrs = np.array(log_lrs[10:-5])
    gradients = (smoothed_losses[1:] - smoothed_losses[:-1]) / (smoothed_lrs[1:] - smoothed_lrs[:-1])
    best_lr = 10 ** smoothed_lrs[np.argmin(gradients)]
    
    # Adjust the learning rate to be slightly lower than the one with steepest gradient
    best_lr *= 0.1
    
    logging.info(f"Learning rate finder completed. Suggested Learning Rate: {best_lr:.2e}")
    logging.info("Learning rate vs. loss plot saved as 'lr_finder_plot.png'")
    return best_lr


def get_class_weights(y):
    class_counts = torch.bincount(y)
    class_weights = 1. / class_counts.float()
    class_weights = class_weights / class_weights.sum()
    return class_weights



# def train_model(model, train_loader, val_data, optimizer, scheduler, criterion, device, epochs=100):
#     best_accuracy = 0
#     best_model_state = None
    
#     for epoch in tqdm(range(epochs), desc="Training Progress"):
#         model.train()
#         running_loss = 0.0
#         for batch_x, batch_x_spectral, batch_y in train_loader:
#             batch_x, batch_x_spectral, batch_y = batch_x.to(device), batch_x_spectral.to(device), batch_y.to(device)
#             optimizer.zero_grad()
#             outputs = model(batch_x, batch_x_spectral)
#             loss = criterion(outputs, batch_y)
#             loss.backward()
#             torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
#             optimizer.step()
#             running_loss += loss.item()
        
#         accuracy = evaluate_model(model, val_data, device)
        
#         if accuracy > best_accuracy:
#             best_accuracy = accuracy
#             best_model_state = model.state_dict()
        
#         scheduler.step(accuracy)
        
#         if (epoch + 1) % 10 == 0:
#             logging.info(f"Epoch {epoch+1}/{epochs} - Loss: {running_loss/len(train_loader):.4f}, Accuracy: {accuracy:.4f}")
    
#     return best_model_state, best_accuracy

# def train_model(model, train_loader, val_data, optimizer, scheduler, criterion, device, epochs=100):
#     best_accuracy = 0
#     best_model_state = None
    
#     for epoch in tqdm(range(epochs), desc="Training Progress"):
#         model.train()
#         running_loss = 0.0
#         for batch_x, batch_x_spectral, batch_y in train_loader:
#             batch_x, batch_x_spectral, batch_y = batch_x.to(device), batch_x_spectral.to(device), batch_y.to(device)
#             optimizer.zero_grad()
#             outputs = model(batch_x, batch_x_spectral)
#             loss = criterion(outputs, batch_y)
#             loss.backward()
#             torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
#             optimizer.step()
#             scheduler.step()  # Step the scheduler after each batch for OneCycleLR
#             running_loss += loss.item()
        
#         # Evaluate every epoch
#         accuracy = evaluate_model(model, val_data, device)
        
#         if accuracy > best_accuracy:
#             best_accuracy = accuracy
#             best_model_state = model.state_dict()
#             torch.save(best_model_state, f"checkpoint_epoch_{epoch}.pth")
        
#         logging.info(f"Epoch {epoch+1}/{epochs} - Loss: {running_loss/len(train_loader):.4f}, Accuracy: {accuracy:.4f}")
    
#     return best_model_state, best_accuracy



def train_model(model, train_loader, val_data, optimizer, scheduler, criterion, device, epochs=100):
    scaler = GradScaler()
    best_accuracy = 0
    best_model_state = None
    
    for epoch in tqdm(range(epochs), desc="Training Progress"):
        model.train()
        running_loss = 0.0
        for batch_x, batch_x_spectral, batch_y in train_loader:
            batch_x, batch_x_spectral, batch_y = batch_x.to(device), batch_x_spectral.to(device), batch_y.to(device)
            
            optimizer.zero_grad()
            
            with autocast(device_type=device.type):
                outputs = model(batch_x, batch_x_spectral)
                
                if torch.isnan(outputs).any():
                    print("NaNs detected in model outputs!")
                    print(f"batch_x range: {batch_x.min().item()} to {batch_x.max().item()}")
                    print(f"batch_x_spectral range: {batch_x_spectral.min().item()} to {batch_x_spectral.max().item()}")
                    print(f"Model parameters:")
                    for name, param in model.named_parameters():
                        print(f"{name}: {param.data.min().item()} to {param.data.max().item()}")
                    return None, 0
                
                loss = criterion(outputs, batch_y)
            
            scaler.scale(loss).backward()
            
            scaler.unscale_(optimizer)
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
            
            if torch.isnan(grad_norm) or torch.isinf(grad_norm):
                print(f"Gradient norm is NaN or Inf: {grad_norm}")
                return None, 0
            
            scaler.step(optimizer)
            scaler.update()
            
            scheduler.step()

            running_loss += loss.item()
        
        accuracy = evaluate_model(model, val_data, device)
        
        print(f"Epoch {epoch+1}/{epochs} - Loss: {running_loss/len(train_loader):.4f}, Accuracy: {accuracy:.4f}, Grad norm: {grad_norm:.4f}")
        
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_model_state = model.state_dict()
    
    return best_model_state, best_accuracy

def evaluate_model(model, data, device):
    model.eval()
    X, X_spectral, y = data
    with torch.no_grad():
        outputs = model(X.to(device), X_spectral.to(device))
        _, predicted = torch.max(outputs, 1)
        accuracy = accuracy_score(y.cpu().numpy(), predicted.cpu().numpy())
    return accuracy

def distill_knowledge(teacher_model, student_model, train_loader, val_data, device, num_epochs=50, log_interval=5):
    optimizer = optim.AdamW(student_model.parameters(), lr=1e-5, weight_decay=1e-2)
    scheduler = get_scheduler(optimizer, num_warmup_steps=len(train_loader) * 5, num_training_steps=len(train_loader) * num_epochs)
    criterion = nn.KLDivLoss(reduction='batchmean')
    temperature = 2.0  # Make sure this value is reasonable

    teacher_model.eval()
    overall_progress = tqdm(total=num_epochs, desc="Overall Distillation Progress", position=0)
    
    for epoch in range(num_epochs):
        student_model.train()
        running_loss = 0.0
        
        epoch_progress = tqdm(train_loader, desc=f"Distillation Epoch {epoch+1}/{num_epochs}", position=1, leave=False)
        for batch_x, batch_x_spectral, batch_y in epoch_progress:
            batch_x, batch_x_spectral, batch_y = batch_x.to(device), batch_x_spectral.to(device), batch_y.to(device)

            # Check for NaNs or Infs in input data
            if torch.isnan(batch_x).any() or torch.isinf(batch_x).any():
                print("NaNs or Infs detected in input data!")
            
            if torch.isnan(batch_x_spectral).any() or torch.isinf(batch_x_spectral).any():
                print("NaNs or Infs detected in spectral input data!")

            
            with torch.no_grad():
                teacher_outputs = teacher_model(batch_x, batch_x_spectral)
            
            student_outputs = student_model(batch_x, batch_x_spectral)
            

            epsilon = 1e-8  # Small constant to prevent log(0)
            loss = criterion(
                F.log_softmax(student_outputs / temperature + epsilon, dim=1),
                F.softmax(teacher_outputs / temperature + epsilon, dim=1)
            )

            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            running_loss += loss.item()
            
            epoch_progress.set_postfix({'loss': f'{running_loss/(epoch_progress.n+1):.4f}', 'lr': f'{optimizer.param_groups[0]["lr"]:.2e}'})
        
        # Evaluate and log every log_interval epochs
        if (epoch + 1) % log_interval == 0 or epoch == num_epochs - 1:
            accuracy = evaluate_model(student_model, val_data, device)
            logging.info(f"Distillation Epoch {epoch+1}/{num_epochs} - Loss: {running_loss/len(train_loader):.4f}, Accuracy: {accuracy:.4f}, LR: {optimizer.param_groups[0]['lr']:.2e}")
        
        overall_progress.update(1)
    
    overall_progress.close()
    return student_model


# Load Data

In [None]:
preprocessed_file_name = './preprocessing/preprocessed_data.mat'

In [None]:
# Load the data
x, y = load_data(preprocessed_file_name)
print(f"Loaded data shape: {x.shape}, Labels shape: {y.shape}")

# Prepare the data (includes SMOTE)
X_train, X_train_spectral, y_train, X_test, X_test_spectral, y_test = prepare_data(x, y)

print("After SMOTE:")
print(f"X_train shape: {X_train.shape}")
print(f"X_train_spectral shape: {X_train_spectral.shape}")
print(f"y_train shape: {y_train.shape}")

# Identify minority classes for augmentation
class_counts = Counter(y_train.numpy())
minority_classes = [cls for cls, count in class_counts.items() if count < len(y_train) / len(class_counts) * 0.5]

# Apply augmentation
X_train, X_train_spectral, y_train = augment_minority_classes(X_train, X_train_spectral, y_train, minority_classes)

print("After augmentation:")
print(f"X_train shape: {X_train.shape}")
print(f"X_train_spectral shape: {X_train_spectral.shape}")
print(f"y_train shape: {y_train.shape}")
print("Final class distribution:")
print(Counter(y_train.numpy()))

# Hyparameter Optimization

In [None]:
config = {
    'previous_model_path': './models/original/',
    'new_model_path': './models/new/',
    'best_model_name': 'best_ensemble_model.pth',
    'best_params_name': 'best_params_ensemble.json',
    'test_data_name': 'test_data.json',
    'confusion_matrix_norm': 'confusion_matrix_normalized.png',
    'confusion_matrix_non_norm': 'confusion_matrix_non_normalized.png',
    # 'initial_weights_name': 'best_ensemble_model.pth',
    'initial_weights_name': 'best_ensemble_model.pth',
    'use_pretrained': False,  # Set to True to use previous weights
}

# Ensure the model save directory exists

os.makedirs(config['new_model_path'], exist_ok=True)
config['pretrained_weights_path'] = os.path.join(config['previous_model_path'], config['best_model_name'])

In [None]:
ensemble_model, best_model_params, best_train_params = load_params_and_initialize_model(config, device)

In [None]:
# def print_model_structure(model):
#     for name, param in model.named_parameters():
#         print(f"{name}: {param.shape}")

# # Before loading weights
# print("Current model structure:")
# print_model_structure(ensemble_model)

# # After loading weights
# print("\nLoaded model structure:")
# state_dict = torch.load(config['pretrained_weights_path'], map_location=device)
# for name, param in state_dict.items():
#     print(f"{name}: {param.shape}")

In [None]:
# # Ensure the new model save directory exists
# os.makedirs(config['new_model_path'], exist_ok=True)

# # Set the full path for the pretrained weights
# config['pretrained_weights_path'] = os.path.join(config['previous_model_path'], config['best_model_name'])

# # Save test data
# test_data = {
#     'X_test': X_test.tolist(),
#     'X_test_spectral': X_test_spectral.tolist(),
#     'y_test': y_test.tolist()
# }

# with open(os.path.join(config['new_model_path'], config['test_data_name']), 'w') as f:
#     json.dump(test_data, f)

# logging.info("Test data saved successfully")

# # Hyperparameter optimization
# study = optuna.create_study(direction='maximize')
# try:
#     study.optimize(lambda trial: objective(trial, X_train, X_train_spectral, y_train, X_test, X_test_spectral, y_test, device), 
#                n_trials=100, 
#                callbacks=[OptunaPruneCallback()],
#                show_progress_bar=True)
# except Exception as e:
#     logging.error(f"An error occurred during optimization: {e}")
#     raise

# logging.info(f"Best trial: {study.best_trial.number}")
# logging.info(f"Best value: {study.best_value:.4f}")


# best_params = study.best_params
# best_model_params = {k: v for k, v in best_params.items() if k in ['n_filters', 'lstm_hidden', 'lstm_layers', 'dropout']}
# best_train_params = {k: v for k, v in best_params.items() if k in ['lr', 'batch_size']}

# # Save the best parameters
# params_to_save = {
#     'best_params': best_params,
#     'best_model_params': best_model_params,
#     'best_train_params': best_train_params
# }

# with open(os.path.join(config['new_model_path'], config['best_params_name']), 'w') as f:
#     json.dump(params_to_save, f, indent=4)

# logging.info(f"Best parameters saved to {os.path.join(config['new_model_path'], config['best_params_name'])}")



# TRAINING ENSEMBLE

In [None]:
from torch.utils.data import Sampler

class BalancedBatchSampler(Sampler):
    def __init__(self, labels, batch_size):
        self.labels = labels
        self.batch_size = batch_size
        self.label_to_indices = {label: np.where(labels == label)[0] for label in set(labels)}
        self.used_label_indices_count = {label: 0 for label in set(labels)}
        self.count = 0
        self.n_classes = len(set(labels))
        self.n_samples = len(labels)
        
    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < self.n_samples:
            classes = list(self.label_to_indices.keys())
            indices = []
            for class_ in classes:
                indices.extend(self.label_to_indices[class_][
                    self.used_label_indices_count[class_]:self.used_label_indices_count[class_] + self.batch_size // self.n_classes
                ])
                self.used_label_indices_count[class_] += self.batch_size // self.n_classes
                if self.used_label_indices_count[class_] + self.batch_size // self.n_classes > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            yield indices
            self.count += self.batch_size

    def __len__(self):
        return self.n_samples // self.batch_size

In [None]:
logging.info("Starting training process...")
overall_steps = 4  # LR finding, Ensemble training, Diverse Ensemble training, Knowledge Distillation
overall_progress = tqdm(total=overall_steps, desc="Overall Training Progress", position=0)

In [None]:
balanced_sampler = BalancedBatchSampler(y_train.numpy(), batch_size=best_train_params['batch_size'])
train_loader = DataLoader(TensorDataset(X_train, X_train_spectral, y_train), batch_sampler=balanced_sampler)

In [None]:
# Set up loss function
class_weights = get_class_weights(y_train).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights + 1e-6)

# Find best learning rate
initial_optimizer = optim.Adam(ensemble_model.parameters(), lr=best_train_params['lr'], weight_decay=1e-5)
logging.info("Finding best learning rate...")
# best_lr = find_lr(ensemble_model, train_loader, initial_optimizer, criterion, device, num_iter=100, start_lr=best_train_params['lr'], end_lr=1)
overall_progress.update(1)




In [None]:
assert not torch.isnan(X_train).any(), "NaN values found in X_train"
assert not torch.isinf(X_train).any(), "Inf values found in X_train"
assert not torch.isnan(X_train_spectral).any(), "NaN values found in X_train_spectral"
assert not torch.isinf(X_train_spectral).any(), "Inf values found in X_train_spectral"

In [None]:
num_epochs = 1000  # Adjust as needed
num_warmup_steps = len(train_loader) * 5  # 5 epochs of warmup
num_training_steps = len(train_loader) * num_epochs

# Set up optimizer and scheduler with best learning rate
# best_lr = min(best_lr, best_train_params['lr']) * 0.1  # Reduce the LR slightly
best_lr = best_train_params['lr'] * 0.1  # Reduce the LR slightly

optimizer = optim.Adam(ensemble_model.parameters(), lr=best_lr, weight_decay=1e-5)
scheduler = OneCycleLR(optimizer, max_lr=best_lr, steps_per_epoch=len(train_loader), epochs=num_epochs)

# Train model
logging.info("Training ensemble model...")
best_model_state, best_accuracy = train_model(
    ensemble_model, train_loader, (X_test, X_test_spectral, y_test),
    optimizer, scheduler, criterion, device, epochs=num_epochs
)
overall_progress.update(1)

# Save best model
if best_model_state is not None:
    torch.save(best_model_state, os.path.join(config['new_model_path'], config['best_model_name']))
    logging.info(f"Best ensemble model saved. Final accuracy: {best_accuracy:.4f}")
else:
    logging.error("Training failed due to NaN loss.")

In [None]:
# Evaluate the model
ensemble_model.load_state_dict(best_model_state)
final_accuracy = evaluate_model(ensemble_model, (X_test, X_test_spectral, y_test), device)
logging.info(f"Ensemble Model - Final Test Accuracy: {final_accuracy:.4f}")

overall_progress.update(1)
overall_progress.close()

# Diverse Ensemble Training

In [None]:
# # Train diverse ensemble
# diverse_ensemble = DiverseEnsembleModel(best_params).to(device)
# diverse_optimizer = optim.AdamW(diverse_ensemble.parameters(), lr=best_lr, weight_decay=1e-2)
# diverse_scheduler = get_scheduler(diverse_optimizer, num_warmup_steps, num_training_steps)


# logging.info("Training diverse ensemble model...")
# diverse_best_state, diverse_accuracy = train_model(
#     diverse_ensemble, train_loader, (X_test, X_test_spectral, y_test),
#     diverse_optimizer, diverse_scheduler, criterion, device, epochs=num_epochs
# )
# overall_progress.update(1)


# torch.save(diverse_best_state, os.path.join(config['new_model_path'], 'best_diverse_ensemble_model.pth'))
# logging.info(f"Best diverse ensemble model saved. Final accuracy: {diverse_accuracy:.4f}")

# # Distill knowledge
# single_model = ImprovedSleepdetector(**best_params).to(device)

# logging.info("Performing knowledge distillation...")
# distilled_model = distill_knowledge(ensemble_model, single_model, train_loader, (X_test, X_test_spectral, y_test), device)
# overall_progress.update(1)


# torch.save(distilled_model.state_dict(), os.path.join(config['new_model_path'], 'distilled_model.pth'))
# overall_progress.close()

In [None]:
# # Final evaluation
# diverse_ensemble.load_state_dict(diverse_best_state)
# diverse_final_accuracy = evaluate_model(diverse_ensemble, (X_test, X_test_spectral, y_test), device)

# distilled_accuracy = evaluate_model(distilled_model, (X_test, X_test_spectral, y_test), device)


# logging.info(f"Training completed. Best accuracy: {best_accuracy:.4f}")
# logging.info(f"Ensemble Model - Final Test Accuracy: {final_accuracy:.4f}")
# logging.info(f"Diverse Ensemble Model - Final Test Accuracy: {diverse_final_accuracy:.4f}")
# logging.info(f"Distilled Model - Final Test Accuracy: {distilled_accuracy:.4f}")