<a href="https://colab.research.google.com/github/dastias/Projeto-doutorado/blob/main/Codigo_Victor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install PyWavelets

Collecting PyWavelets
  Downloading pywavelets-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.0 kB)
Downloading pywavelets-1.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.5/4.5 MB[0m [31m40.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: PyWavelets
Successfully installed PyWavelets-1.8.0


In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import pywt
import scipy.signal as signal
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import matplotlib.pyplot as plt
import seaborn as sns
import os
import math
from typing import Dict, List, Tuple, Optional
import logging
import requests
import zipfile
import io

In [None]:
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class SignalPreprocessor:
    """Signal preprocessing pipeline including noise removal and transformation."""

    def __init__(self, window_size: int = 1024, overlap: float = 0.5, sampling_rate: float = 20000):
        self.window_size = window_size
        self.overlap = overlap
        self.sampling_rate = sampling_rate

    def remove_noise(self, signal: np.ndarray, method: str = 'wavelet') -> np.ndarray:
        if method == 'wavelet':
            wavelet = 'db4'
            level = 4
            coeffs = pywt.wavedec(signal, wavelet, level=level)
            thresh = np.median(np.abs(coeffs[-1]))/0.6745
            for i in range(1, len(coeffs)):
                coeffs[i] = pywt.threshold(coeffs[i], thresh, mode='soft')
            return pywt.waverec(coeffs, wavelet)

        elif method == 'butterworth':
            nyquist = 0.5 * self.sampling_rate
            low = 2.0 / nyquist
            high = 450.0 / nyquist
            b, a = signal.butter(4, [low, high], btype='band')
            return signal.filtfilt(b, a, signal)

    def generate_spectrogram(self, signal_data: np.ndarray) -> np.ndarray:
        """Generate spectrogram from signal data."""
        f, t, Sxx = signal.spectrogram(signal_data,
                                    fs=self.sampling_rate,
                                    window='hann',
                                    nperseg=self.window_size,
                                    noverlap=int(self.window_size * self.overlap))

        # Convert to dB scale and normalize
        Sxx = 10 * np.log10(Sxx + 1e-10)
        Sxx = (Sxx - Sxx.min()) / (Sxx.max() - Sxx.min() + 1e-10)
        return Sxx

In [None]:
class VBLVA001Dataset(Dataset):
    """Custom dataset for VBL-VA001 vibration data."""

    def __init__(self,
                data_path: str,
                transform: Optional[transforms.Compose] = None):
        self.data_path = data_path
        self.transform = transform
        self.preprocessor = SignalPreprocessor(sampling_rate=20000)  # VBL-VA001 sampling rate is 25.6 kHz

        # Load VBL-VA001 dataset
        self.data, self.labels, self.label_mapping = self._load_vbl_va001_data()

        if len(self.data) == 0:
            raise ValueError(f"No data was loaded from {data_path}. Please check the dataset path and structure.")

        # Get the actual number of unique classes in the dataset
        self.num_classes = len(np.unique(self.labels))
        logger.info(f"Successfully loaded dataset with {len(self.data)} samples and {self.num_classes} unique classes")

    def _load_vbl_va001_data(self) -> Tuple[List, List, Dict]:
        """
        Load the VBL-VA001 dataset and create dynamic mapping of available condition folders.
        """
        data = []
        labels = []

        try:
            # Check if data path exists
            if not os.path.exists(self.data_path):
                logger.error(f"Data path does not exist: {self.data_path}")
                return [], [], {}

            # Get all subdirectories that could contain fault data
            available_dirs = [d for d in os.listdir(self.data_path)
                            if os.path.isdir(os.path.join(self.data_path, d))]
            logger.info(f"Available directories: {available_dirs}")

            # Create a new mapping based on actually available directories
            # Sort directories to ensure consistent label assignments
            available_dirs.sort()

            conditions = {}
            label_mapping = {}  # Store mapping for later reference
            for i, directory in enumerate(available_dirs):
                conditions[directory] = i
                label_mapping[i] = directory
                logger.info(f"Assigned label {i} to condition: {directory}")

            if not conditions:
                logger.error("No valid condition directories found!")
                return [], [], {}

            # Process each condition directory
            for condition, label in conditions.items():
                condition_path = os.path.join(self.data_path, condition)
                logger.info(f"Processing condition: {condition}")

                # Get all CSV files for this condition
                csv_files = [f for f in os.listdir(condition_path) if f.endswith('.csv')]
                logger.info(f"Found {len(csv_files)} CSV files in {condition_path}")

                # Process each CSV file (limit to 50 per condition)
                for file in csv_files[:50]:
                    try:
                        file_path = os.path.join(condition_path, file)

                        # Read CSV file - VBL-VA001 data format
                        df = pd.read_csv(file_path)

                        # Extract acceleration data (assuming column structure)
                        if len(df.columns) >= 2:  # At least time and one acceleration channel
                            # Use first acceleration column by default (column index 1)
                            signal_data = df.iloc[:, 1].values

                            # Trim signal to control memory usage
                            max_length = 50000  # Adjust as needed
                            if len(signal_data) > max_length:
                                signal_data = signal_data[:max_length]

                            data.append(signal_data)
                            labels.append(label)
                        else:
                            logger.warning(f"CSV file {file} does not have expected columns")

                    except Exception as e:
                        logger.error(f"Error loading file {file}: {str(e)}")

            if not data:
                logger.error("No data was loaded!")
                return [], [], {}

            # Convert labels to numpy array for easier processing
            labels_array = np.array(labels)

            # Create a mapping of original labels to consecutive integers starting from 0
            unique_labels = np.unique(labels_array)
            remap = {old_label: new_label for new_label, old_label in enumerate(unique_labels)}

            # Update label mapping
            new_label_mapping = {}
            for new_label, old_label in enumerate(unique_labels):
                new_label_mapping[new_label] = label_mapping[old_label]

            # Remap labels to be consecutive integers
            remapped_labels = [remap[label] for label in labels]

            logger.info(f"Successfully loaded {len(data)} samples with remapped label distribution: {pd.Series(remapped_labels).value_counts().to_dict()}")
            logger.info(f"Label mapping: {new_label_mapping}")

            return data, remapped_labels, new_label_mapping

        except Exception as e:
            logger.error(f"Error during data loading: {str(e)}")
            return [], [], {}

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        signal = self.data[idx]

        # Preprocess signal
        try:
            # Clear separation between signal denoising and spectrogram generation
            denoised_signal = self.preprocessor.remove_noise(signal)
            spectrogram = self.preprocessor.generate_spectrogram(denoised_signal)

            # Add channel dimension for CNN
            spectrogram = np.expand_dims(spectrogram, axis=0)

            # Convert to tensor
            spectrogram = torch.from_numpy(spectrogram).float()

            if self.transform:
                spectrogram = self.transform(spectrogram)

            return spectrogram, self.labels[idx]
        except Exception as e:
            logger.error(f"Error processing sample {idx}: {str(e)}")
            # Return a dummy sample for robustness
            dummy = torch.zeros((1, 129, 129))
            if self.transform:
                dummy = self.transform(dummy)
            return dummy, self.labels[idx]

In [None]:
class VibrationAnalysisCNN(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()

        # Safety check - ensure at least 2 classes for classification
        self.num_classes = max(2, num_classes)
        logger.info(f"Initializing model with {self.num_classes} output classes")

        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.classifier = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, self.num_classes)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [None]:
class SimCLRModel(nn.Module):
    """Self-supervised learning model based on SimCLR architecture."""

    def __init__(self, input_shape: Tuple[int, int], projection_dim: int = 128, num_classes: int = 2):
        super().__init__()

        self.num_classes = num_classes  # Adicionando atributo num_classes

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )

        self.projection = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, projection_dim)
        )

        # Adicionando classificador para fine-tuning
        self.classifier = nn.Sequential(
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, self.num_classes)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = self.encoder(x)
        h = h.view(h.size(0), -1)
        z = self.projection(h)
        return z

    def forward_classifier(self, x: torch.Tensor) -> torch.Tensor:
        # Método para classificação depois do pré-treinamento
        h = self.encoder(x)
        h = h.view(h.size(0), -1)
        logits = self.classifier(h)
        return logits

In [None]:
class LoRALayer(nn.Module):
    """Low-Rank Adaptation layer for efficient fine-tuning."""

    def __init__(self, in_features: int, out_features: int, rank: int = 4):
        super().__init__()
        self.lora_A = nn.Parameter(torch.zeros(in_features, rank))
        self.lora_B = nn.Parameter(torch.zeros(rank, out_features))
        self.scale = 0.01

        # Initialize weights
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.scale * (x @ self.lora_A @ self.lora_B)

In [None]:
class ModelAgent:
    """Agent for model selection and monitoring."""

    def __init__(self, models: Dict[str, nn.Module]):
        self.models = models
        self.performance_scores = {name: 1.0 for name in models.keys()}

    def select_model(self, signal_features: Dict) -> str:
        """Select best model based on signal characteristics and past performance."""
        scores = {}
        for name, model in self.models.items():
            base_score = self.performance_scores[name]
            # Add feature-based scoring here
            scores[name] = base_score

        return max(scores.items(), key=lambda x: x[1])[0]

    def update_performance(self, model_name: str, metric_value: float):
        """Update model performance scores."""
        alpha = 0.9  # Smoothing factor
        self.performance_scores[model_name] = (alpha * self.performance_scores[model_name] +
                                          (1 - alpha) * metric_value)

In [None]:
class VibrationAnalyzer:
    def __init__(self, config: Dict):
        self.config = config
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        logger.info(f"Using device: {self.device}")
        self.metrics_history = []
        self.model = None
        self.label_mapping = None

        # Initialize model agent if multiple models are specified
        self.model_agent = None
        self.models = {}
        if 'use_model_agent' in config and config['use_model_agent']:
            logger.info("Initializing ModelAgent for adaptive model selection")
            self._initialize_models()

    def _initialize_models(self):
        """Initialize multiple model architectures for ModelAgent."""
        # Initialize base CNN model
        base_model = VibrationAnalysisCNN(num_classes=2)  # Will be updated to correct class count later
        self.models['base_cnn'] = base_model

        # Initialize SimCLR model if enabled in config
        if 'use_simclr' in self.config and self.config['use_simclr']:
            # Default input shape for spectrogram (can be updated later)
            input_shape = (129, 129)
            simclr_model = SimCLRModel(input_shape=input_shape, num_classes=2)  # Inicializar com num_classes
            self.models['simclr'] = simclr_model
            logger.info("SimCLR model initialized")

        # Create ModelAgent
        self.model_agent = ModelAgent(self.models)
        logger.info(f"ModelAgent initialized with {len(self.models)} models")

    def train(self, train_loader: DataLoader, val_loader: DataLoader, num_classes: int, label_mapping: Dict = None):
        """Train the model with continuous monitoring."""
        self.label_mapping = label_mapping

        # Safety check - use at least 2 classes
        num_classes = max(2, num_classes)
        logger.info(f"Training model with {num_classes} classes")

        # Inspect data to determine actual classes present in the dataset
        actual_classes = set()
        for _, batch_targets in train_loader:
            actual_classes.update(batch_targets.numpy())

        # Adjust num_classes based on actual data
        actual_num_classes = max(actual_classes) + 1 if actual_classes else num_classes
        logger.info(f"Detected {actual_num_classes} unique classes in training data, labels: {actual_classes}")

        # If model agent is enabled, update all models with correct number of classes
        if self.model_agent:
            for name, model in self.models.items():
                if hasattr(model, 'num_classes'):
                    # Update model's num_classes property
                    model.num_classes = actual_num_classes

                    # Rebuild classifier for SimCLR model with the correct number of classes
                    if isinstance(model, SimCLRModel):
                        model.classifier = nn.Sequential(
                            nn.Linear(128, 256),
                            nn.ReLU(),
                            nn.Dropout(0.5),
                            nn.Linear(256, actual_num_classes)
                        ).to(self.device)
                        logger.info(f"Rebuilt SimCLR classifier with {actual_num_classes} output classes")

                    # For VibrationAnalysisCNN, we need to rebuild the classifier
                    elif isinstance(model, VibrationAnalysisCNN):
                        model.classifier = nn.Sequential(
                            nn.Linear(128, 256),
                            nn.ReLU(),
                            nn.Dropout(0.5),
                            nn.Linear(256, actual_num_classes)
                        ).to(self.device)
                        logger.info(f"Rebuilt CNN classifier with {actual_num_classes} output classes")

            # Let the agent select the best model to start with
            signal_features = {'num_classes': actual_num_classes}  # Could add more features here
            selected_model = self.model_agent.select_model(signal_features)
            self.model = self.models[selected_model].to(self.device)
            logger.info(f"ModelAgent selected '{selected_model}' for training")
        else:
            # Use standard VibrationAnalysisCNN if no agent is specified
            self.model = VibrationAnalysisCNN(num_classes=actual_num_classes).to(self.device)

        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.model.parameters(), lr=self.config['learning_rate'])
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)

        best_val_loss = float('inf')
        patience_counter = 0
        early_stop_patience = 10

        for epoch in range(self.config['epochs']):
            try:
                # Validate data loaders before training
                if len(train_loader) == 0:
                    logger.error("Training data loader is empty!")
                    break
                if len(val_loader) == 0:
                    logger.error("Validation data loader is empty!")
                    break

                train_loss = self._train_epoch(self.model, train_loader, criterion, optimizer)
                val_metrics = self._validate(self.model, val_loader, criterion)

                # Update learning rate
                scheduler.step(val_metrics['val_loss'])

                # Update model agent scores if enabled
                if self.model_agent:
                    current_model_name = [name for name, model in self.models.items()
                                        if model == self.model][0]
                    self.model_agent.update_performance(current_model_name, val_metrics['accuracy'])

                    # Potentially switch models if another is predicted to perform better
                    if epoch > 0 and epoch % 5 == 0:  # Check every 5 epochs
                        signal_features = {'epoch': epoch}  # Add more features as needed
                        best_model_name = self.model_agent.select_model(signal_features)
                        if best_model_name != current_model_name:
                            logger.info(f"Switching from {current_model_name} to {best_model_name} at epoch {epoch}")
                            self.model = self.models[best_model_name].to(self.device)
                            # Reset optimizer and scheduler for new model
                            optimizer = optim.Adam(self.model.parameters(),
                                                  lr=self.config['learning_rate'])
                            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)

                # Early stopping
                if val_metrics['val_loss'] < best_val_loss:
                    best_val_loss = val_metrics['val_loss']
                    patience_counter = 0
                    # Save best model
                    torch.save(self.model.state_dict(), 'best_model.pth')
                else:
                    patience_counter += 1

                if patience_counter >= early_stop_patience:
                    logger.info(f"Early stopping triggered after {epoch+1} epochs")
                    break

                # Log metrics
                metrics = {
                    'epoch': epoch,
                    'train_loss': train_loss,
                    **val_metrics
                }
                self.metrics_history.append(metrics)

                logger.info(f"Epoch {epoch+1}/{self.config['epochs']}: "
                          f"train_loss={train_loss:.4f}, "
                          f"val_loss={val_metrics['val_loss']:.4f}, "
                          f"val_accuracy={val_metrics['accuracy']:.4f}, "
                          f"lr={optimizer.param_groups[0]['lr']:.6f}")

            except Exception as e:
                logger.error(f"Error during epoch {epoch}: {str(e)}")
                import traceback
                traceback.print_exc()
                continue

        # Load best model for evaluation with error handling
        if os.path.exists('best_model.pth'):
            try:
                self.model.load_state_dict(torch.load('best_model.pth'))
                logger.info("Loaded best model for final evaluation")
            except Exception as e:
                logger.error(f"Error loading best model: {str(e)}")
                logger.info("Continuing with the current model state")

    # Fix 2: Enhanced train_epoch function with better error handling and target filtering
    def _train_epoch(self, model: nn.Module,
                  train_loader: DataLoader,
                  criterion: nn.Module,
                  optimizer: optim.Optimizer) -> float:

        model.train()
        total_loss = 0
        num_batches = 0
        valid_batches = 0

        for batch_idx, (data, target) in enumerate(train_loader):
            try:
                # Filter out samples with out-of-range target values
                valid_indices = target < model.num_classes
                if not torch.all(valid_indices):
                    invalid_targets = torch.unique(target[~valid_indices]).tolist()
                    logger.warning(f"Batch {batch_idx}: Filtering out {(~valid_indices).sum().item()} samples with invalid targets: {invalid_targets}")

                    # Skip batch if no valid samples remain
                    if not torch.any(valid_indices):
                        logger.warning(f"Skipping batch {batch_idx}: No valid targets")
                        continue

                    # Keep only valid samples
                    data = data[valid_indices]
                    target = target[valid_indices]

                # Move data to device
                data, target = data.to(self.device), target.to(self.device)

                # Double-check target values (defensive)
                if torch.max(target).item() >= model.num_classes:
                    logger.error(f"Target values still out of range after filtering: {torch.unique(target)}")
                    continue

                optimizer.zero_grad()

                # Forward pass based on model type
                if isinstance(model, SimCLRModel):
                    output = model.forward_classifier(data)
                else:
                    output = model(data)

                loss = criterion(output, target)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                num_batches += 1
                valid_batches += 1

                # Print progress every 10 batches
                if batch_idx % 10 == 0:
                    logger.info(f"Train Batch {batch_idx}/{len(train_loader)}: Loss {loss.item():.4f}")

            except Exception as e:
                logger.error(f"Error in training batch {batch_idx}: {str(e)}")
                continue

        logger.info(f"Completed training epoch: {valid_batches}/{len(train_loader)} valid batches processed")
        return total_loss / max(1, num_batches)  # Avoid division by zero

        def _validate(self, model: nn.Module, val_loader: DataLoader, criterion: nn.Module) -> Dict:
          """Validate model and compute metrics."""
        model.eval()
        val_loss = 0
        predictions = []
        targets = []
        valid_batches = 0

        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(val_loader):
                try:
                    # Filter out samples with out-of-range target values
                    valid_indices = target < model.num_classes
                    if not torch.all(valid_indices):
                        invalid_targets = torch.unique(target[~valid_indices]).tolist()
                        logger.warning(f"Validation batch {batch_idx}: Filtering out {(~valid_indices).sum().item()} samples with invalid targets: {invalid_targets}")

                        # Skip batch if no valid samples remain
                        if not torch.any(valid_indices):
                            logger.warning(f"Skipping validation batch {batch_idx}: No valid targets")
                            continue

                        # Keep only valid samples
                        data = data[valid_indices]
                        target = target[valid_indices]

                    data, target = data.to(self.device), target.to(self.device)

                    # Double-check target values (defensive)
                    if torch.max(target).item() >= model.num_classes:
                        logger.error(f"Validation target values still out of range after filtering: {torch.unique(target)}")
                        continue

                    # Forward pass based on model type
                    if isinstance(model, SimCLRModel):
                        output = model.forward_classifier(data)
                    else:
                        output = model(data)

                    val_loss += criterion(output, target).item()

                    pred = output.argmax(dim=1)
                    predictions.extend(pred.cpu().numpy())
                    targets.extend(target.cpu().numpy())
                    valid_batches += 1

                except Exception as e:
                    logger.error(f"Error in validation batch {batch_idx}: {str(e)}")
                    continue

        if valid_batches > 0:  # Avoid division by zero
            val_loss /= valid_batches

        logger.info(f"Completed validation: {valid_batches}/{len(val_loader)} valid batches processed")

        # Compute metrics
        if len(targets) > 0:  # Ensure we have predictions to evaluate
            accuracy = accuracy_score(targets, predictions)
            precision, recall, f1, _ = precision_recall_fscore_support(
                targets, predictions, average='weighted', zero_division=0
            )

            return {
                'val_loss': val_loss,
                'accuracy': accuracy,
                'precision': precision,
                'recall': recall,
                'f1': f1
            }
        else:
            logger.warning("No validation predictions generated")
            return {
                'val_loss': val_loss,
                'accuracy': 0,
                'precision': 0,
                'recall': 0,
                'f1': 0
            }

    # Fix 5: Enhanced model loading functionality to handle architecture mismatches
    def evaluate(self, test_loader: DataLoader, num_classes: int):
        """Evaluate model on test set and generate confusion matrix."""
        if self.model is None:
            logger.error("Model not trained yet")
            return

        # Ensure model is in evaluation mode
        self.model.eval()
        predictions = []
        targets = []

        # First check that we're using the correct model with the right number of classes
        actual_classes = set()
        for _, batch_targets in test_loader:
            actual_classes.update(batch_targets.numpy())

        actual_num_classes = max(actual_classes) + 1 if actual_classes else num_classes

        # If there's a mismatch, update the model's classifier
        if hasattr(self.model, 'num_classes') and self.model.num_classes != actual_num_classes:
            logger.warning(f"Model has {self.model.num_classes} output classes but test data has {actual_num_classes} classes")
            logger.info("Attempting to reconstruct classifier layer...")

            if isinstance(self.model, SimCLRModel):
                self.model.num_classes = actual_num_classes
                self.model.classifier = nn.Sequential(
                    nn.Linear(128, 256),
                    nn.ReLU(),
                    nn.Dropout(0.5),
                    nn.Linear(256, actual_num_classes)
                ).to(self.device)
            elif isinstance(self.model, VibrationAnalysisCNN):
                self.model.num_classes = actual_num_classes
                self.model.classifier = nn.Sequential(
                    nn.Linear(128, 256),
                    nn.ReLU(),
                    nn.Dropout(0.5),
                    nn.Linear(256, actual_num_classes)
                ).to(self.device)

        with torch.no_grad():
            for data, target in test_loader:
                try:
                    # Filter out samples with out-of-range target values
                    valid_indices = target < self.model.num_classes
                    if not torch.all(valid_indices):
                        # Keep only valid samples
                        data = data[valid_indices]
                        target = target[valid_indices]

                        # Skip if no valid samples remain
                        if not torch.any(valid_indices):
                            continue

                    data, target = data.to(self.device), target.to(self.device)

                    # Forward pass based on model type
                    if isinstance(self.model, SimCLRModel):
                        output = self.model.forward_classifier(data)
                    else:
                        output = self.model(data)

                    pred = output.argmax(dim=1)
                    predictions.extend(pred.cpu().numpy())
                    targets.extend(target.cpu().numpy())
                except Exception as e:
                    logger.error(f"Error during evaluation: {str(e)}")
                    continue

        # Compute and display evaluation metrics
        if len(targets) > 0:
            from sklearn.metrics import confusion_matrix, classification_report

            # Get unique classes actually present in predictions and targets
            unique_classes = sorted(set(np.concatenate([targets, predictions])))

            # Generate class names for reporting
            class_names = []
            for i in unique_classes:
                if self.label_mapping and i in self.label_mapping:
                    class_names.append(self.label_mapping[i])
                else:
                    class_names.append(f"Class {i}")

            # Compute confusion matrix
            cm = confusion_matrix(targets, predictions, labels=unique_classes)

            # Plot confusion matrix
            plt.figure(figsize=(10, 8))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                      xticklabels=class_names, yticklabels=class_names)
            plt.xlabel('Predicted')
            plt.ylabel('True')
            plt.title('Confusion Matrix - VBL-VA001 Dataset')
            plt.tight_layout()
            plt.savefig('confusion_matrix_vbl_va001.png')
            plt.close()  # Close to avoid display issues

            # Print classification report
            report = classification_report(targets, predictions,
                                  labels=unique_classes,
                                  target_names=class_names,
                                  zero_division=0)
            print("\nClassification Report:")
            print(report)

            logger.info(f"Evaluation completed on {len(targets)} samples")
        else:
            logger.error("No test predictions generated - cannot create evaluation metrics")

    def plot_metrics(self):
        """Plot training metrics history."""
        if not self.metrics_history:
            logger.error("No metrics to plot. Train the model first.")
            return

        metrics_df = pd.DataFrame(self.metrics_history)

        fig, axes = plt.subplots(2, 2, figsize=(15, 10))

        # Plot training and validation loss
        axes[0, 0].plot(metrics_df['epoch'], metrics_df['train_loss'], label='Train')
        axes[0, 0].plot(metrics_df['epoch'], metrics_df['val_loss'], label='Validation')
        axes[0, 0].set_title('Loss')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].legend()

        # Plot accuracy
        axes[0, 1].plot(metrics_df['epoch'], metrics_df['accuracy'])
        axes[0, 1].set_title('Validation Accuracy')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Accuracy')

        # Plot precision/recall
        axes[1, 0].plot(metrics_df['epoch'], metrics_df['precision'], label='Precision')
        axes[1, 0].plot(metrics_df['epoch'], metrics_df['recall'], label='Recall')
        axes[1, 0].set_title('Precision and Recall')
        axes[1, 0].set_xlabel('Epoch')
        axes[1, 0].legend()

        # Plot F1 score
        axes[1, 1].plot(metrics_df['epoch'], metrics_df['f1'])
        axes[1, 1].set_title('F1 Score')
        axes[1, 1].set_xlabel('Epoch')

        plt.tight_layout()
        plt.savefig('training_metrics_vbl_va001.png')
        plt.show()

    def train_simclr(self, train_loader: DataLoader, temperature: float = 0.5, epochs: int = 20):
        """Train SimCLR model with contrastive loss."""
        if 'simclr' not in self.models:
            logger.error("SimCLR model not initialized")
            return

        model = self.models['simclr'].to(self.device)
        optimizer = torch.optim.Adam(model.parameters(), lr=self.config.get('simclr_lr', 0.001))

        logger.info("Starting SimCLR pre-training...")

        for epoch in range(epochs):
            model.train()
            total_loss = 0

            for batch_idx, (data, _) in enumerate(train_loader):
                # Generate two augmented views for each sample
                batch_size = data.size(0)

                # Simple augmentations - can be replaced with more sophisticated transforms
                transform = transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomRotation(10),
                    transforms.Normalize(mean=[0.485], std=[0.229])
                ])

                augmented1 = transform(data)
                augmented2 = transform(data)

                augmented1, augmented2 = augmented1.to(self.device), augmented2.to(self.device)

                # Forward pass through encoder and projection head
                z1 = model(augmented1)
                z2 = model(augmented2)

                # Normalized feature vectors
                z1 = torch.nn.functional.normalize(z1, dim=1)
                z2 = torch.nn.functional.normalize(z2, dim=1)

                # Concatenate representations
                representations = torch.cat([z1, z2], dim=0)

                # Compute similarity matrix
                similarity_matrix = torch.matmul(representations, representations.T)

                # NT-Xent loss
                batch_size = data.size(0)
                mask = torch.zeros((2 * batch_size, 2 * batch_size), dtype=bool).to(self.device)

                # Set positive pairs
                for i in range(batch_size):
                    mask[i, batch_size + i] = 1
                    mask[batch_size + i, i] = 1

                # Remove diagonal (self-similarity)
                mask.fill_diagonal_(0)

                # Set positive pairs to extremenly high negative value (will be exp(0) = 1)
                similarity_matrix = similarity_matrix.masked_fill(~mask, -1e9)

                # Compute NT-Xent loss
                logits = similarity_matrix / temperature
                labels = torch.arange(2 * batch_size).to(self.device)
                loss = torch.nn.CrossEntropyLoss()(logits, labels)

                # Optimization step
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                total_loss += loss.item()

                if batch_idx % 10 == 0:
                    logger.info(f"SimCLR Training: Epoch {epoch+1}/{epochs}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")

            avg_loss = total_loss / len(train_loader)
            logger.info(f"SimCLR Epoch {epoch+1}/{epochs}: Average Loss: {avg_loss:.4f}")

        logger.info("SimCLR pre-training completed")

In [None]:
def download_vbl_va001_dataset(download_path: str = './data'):
    """Download VBL-VA001 dataset from Zenodo."""
    try:
        # Create download directory if it doesn't exist
        os.makedirs(download_path, exist_ok=True)

        # Zenodo URL for the VBL-VA001 dataset
        zenodo_url = "https://zenodo.org/records/7006575/files/VBL-VA001.zip"

        logger.info(f"Downloading VBL-VA001 dataset from {zenodo_url}")
        logger.info("This may take some time depending on your internet connection...")

        # Download the file
        response = requests.get(zenodo_url, stream=True)
        response.raise_for_status()  # Raise exception for HTTP errors

        # Check if the zip file already exists
        zip_path = os.path.join(download_path, "VBL-VA001.zip")
        extract_path = os.path.join(download_path, "VBL-VA001")

        # Write the content to a zip file
        with open(zip_path, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)

        logger.info(f"Download completed to {zip_path}")

        # Extract the zip file
        logger.info(f"Extracting dataset to {extract_path}")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(download_path)

        logger.info(f"Dataset extracted successfully to {extract_path}")
        return extract_path

    except requests.exceptions.RequestException as e:
        logger.error(f"Error downloading dataset: {str(e)}")
        return None
    except zipfile.BadZipFile as e:
        logger.error(f"Error extracting dataset: {str(e)}")
        return None
    except Exception as e:
        logger.error(f"Unexpected error: {str(e)}")
        return None

In [None]:
def main():
    """Main execution function."""
    try:
        # Set configuration parameters
        config = {
            'learning_rate': 0.001,
            'batch_size': 32,
            'epochs': 50,
            'use_model_agent': True,
            'use_simclr': True,
            'test_split': 0.2,
            'val_split': 0.1
        }

        # Download and extract dataset if needed
        data_path = './data/VBL-VA001'
        if not os.path.exists(data_path):
            logger.info("Dataset not found locally. Downloading...")
            data_path = download_vbl_va001_dataset()
            if data_path is None:
                logger.error("Failed to download dataset. Exiting.")
                return

        # Define transformations
        transform = transforms.Compose([
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])

        # Load dataset
        logger.info(f"Loading dataset from {data_path}")
        try:
            full_dataset = VBLVA001Dataset(data_path=data_path, transform=transform)
            logger.info(f"Dataset loaded with {len(full_dataset)} samples and {full_dataset.num_classes} classes")

            # Split dataset into train, validation, and test
            dataset_size = len(full_dataset)
            test_size = int(config['test_split'] * dataset_size)
            val_size = int(config['val_split'] * dataset_size)
            train_size = dataset_size - test_size - val_size

            train_dataset, val_dataset, test_dataset = random_split(
                full_dataset, [train_size, val_size, test_size],
                generator=torch.Generator().manual_seed(42)
            )

            # Create data loaders
            train_loader = DataLoader(
                train_dataset,
                batch_size=config['batch_size'],
                shuffle=True,
                num_workers=4,
                pin_memory=True
            )

            val_loader = DataLoader(
                val_dataset,
                batch_size=config['batch_size'],
                shuffle=False,
                num_workers=4,
                pin_memory=True
            )

            test_loader = DataLoader(
                test_dataset,
                batch_size=config['batch_size'],
                shuffle=False,
                num_workers=4,
                pin_memory=True
            )

            logger.info(f"Data split: Train={len(train_dataset)}, Val={len(val_dataset)}, Test={len(test_dataset)}")

            # Initialize and train the model
            analyzer = VibrationAnalyzer(config)

            # Train SimCLR model first if enabled
            if config['use_simclr'] and hasattr(analyzer, 'train_simclr'):
                logger.info("Starting self-supervised pre-training with SimCLR")
                analyzer.train_simclr(train_loader)

            # Train the main model
            logger.info("Starting supervised training")
            analyzer.train(
                train_loader=train_loader,
                val_loader=val_loader,
                num_classes=full_dataset.num_classes,
                label_mapping=full_dataset.label_mapping
            )

            # Evaluate model
            logger.info("Evaluating model on test set")
            analyzer.evaluate(test_loader, full_dataset.num_classes)

            # Plot metrics
            analyzer.plot_metrics()

            logger.info("Training and evaluation completed successfully!")

        except Exception as e:
            logger.error(f"Error in main execution: {str(e)}")
            import traceback
            traceback.print_exc()

    except Exception as e:
        logger.error(f"Unhandled exception in main: {str(e)}")
        import traceback
        traceback.print_exc()

In [None]:
def visualize_sample(dataset_path, sample_idx=0):
    """Visualize a sample from the dataset."""
    try:
        # Load dataset
        dataset = VBLVA001Dataset(data_path=dataset_path)

        if sample_idx >= len(dataset):
            logger.error(f"Sample index {sample_idx} out of range (0-{len(dataset)-1})")
            return

        # Get sample data
        spectrogram, label = dataset[sample_idx]
        raw_signal = dataset.data[sample_idx]

        # Convert label to class name
        class_name = dataset.label_mapping.get(label, f"Unknown ({label})")

        # Create figure with multiple plots
        fig, axes = plt.subplots(3, 1, figsize=(10, 15))

        # Plot raw signal
        axes[0].plot(raw_signal)
        axes[0].set_title(f"Raw Vibration Signal - Class: {class_name}")
        axes[0].set_xlabel("Sample")
        axes[0].set_ylabel("Amplitude")

        # Plot denoised signal
        denoised = dataset.preprocessor.remove_noise(raw_signal)
        axes[1].plot(denoised)
        axes[1].set_title("Denoised Signal")
        axes[1].set_xlabel("Sample")
        axes[1].set_ylabel("Amplitude")

        # Plot spectrogram
        spectrogram = spectrogram.squeeze().numpy()
        im = axes[2].imshow(spectrogram, aspect='auto', origin='lower', cmap='viridis')
        axes[2].set_title("Spectrogram")
        axes[2].set_xlabel("Time")
        axes[2].set_ylabel("Frequency")
        plt.colorbar(im, ax=axes[2])

        plt.tight_layout()
        plt.savefig(f"sample_{sample_idx}_visualization.png")
        plt.show()

        logger.info(f"Sample {sample_idx} visualization saved to sample_{sample_idx}_visualization.png")

    except Exception as e:
        logger.error(f"Error visualizing sample: {str(e)}")



In [None]:
if __name__ == "__main__":
    main()

ERROR:__main__:Error during epoch 0: 'VibrationAnalyzer' object has no attribute '_validate'
Traceback (most recent call last):
  File "<ipython-input-11-0b80e77870ff>", line 107, in train
    val_metrics = self._validate(self.model, val_loader, criterion)
                  ^^^^^^^^^^^^^^
AttributeError: 'VibrationAnalyzer' object has no attribute '_validate'
ERROR:__main__:Error during epoch 1: 'VibrationAnalyzer' object has no attribute '_validate'
Traceback (most recent call last):
  File "<ipython-input-11-0b80e77870ff>", line 107, in train
    val_metrics = self._validate(self.model, val_loader, criterion)
                  ^^^^^^^^^^^^^^
AttributeError: 'VibrationAnalyzer' object has no attribute '_validate'
ERROR:__main__:Error during epoch 2: 'VibrationAnalyzer' object has no attribute '_validate'
Traceback (most recent call last):
  File "<ipython-input-11-0b80e77870ff>", line 107, in train
    val_metrics = self._validate(self.model, val_loader, criterion)
                  ^


Classification Report:
              precision    recall  f1-score   support

     bearing       1.00      1.00      1.00        10
misalignment       0.60      1.00      0.75         9
      normal       1.00      1.00      1.00         7
   unbalance       1.00      0.57      0.73        14

    accuracy                           0.85        40
   macro avg       0.90      0.89      0.87        40
weighted avg       0.91      0.85      0.85        40

