In [1]:
# Cell 1: Initial setup, connecting to Google Drive, installing libraries, and checking GPU availability.
from google.colab import drive
drive.mount('/content/drive')

# Step 2: Instalar y actualizar las librer√≠as
print("\nInstalando y actualizando librer√≠as...")
!pip install --upgrade -q mne pytorch-lightning timm
print("‚úÖ Librer√≠as listas.")

# Step 3: Prueba expl√≠cita de control de la GPU
import torch
print("\n--- INICIANDO PRUEBA DE CONTROL DE GPU ---")
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"‚úÖ GPU detectada: {torch.cuda.get_device_name(0)}")
    try:
        tensor_grande = torch.randn(1024, 1024, 512, device=device) # Asignar 2GB
        memoria_asignada = torch.cuda.memory_allocated(0) / 1024**3
        print(f"‚úÖ ¬°√âxito! Memoria asignada activamente: {memoria_asignada:.2f} GB")
        del tensor_grande
        torch.cuda.empty_cache()
        print("‚úÖ Memoria liberada correctamente.")
        print("--- PRUEBA DE CONTROL DE GPU COMPLETADA EXITOSAMENTE ---")
    except Exception as e:
        print(f"‚ùå ¬°ERROR DURANTE LA PRUEBA! No se pudo asignar memoria a la GPU: {e}")
else:
    print("‚ùå ¬°ERROR! No se detect√≥ ninguna GPU en este entorno de ejecuci√≥n.")

Mounted at /content/drive

Instalando y actualizando librer√≠as...
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m7.4/7.4 MB[0m [31m68.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m828.2/828.2 kB[0m [31m58.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m983.0/983.0 kB[0m [31m61.2 MB/s[0m eta [36m0:00:00[0m
[?25h‚úÖ Librer√≠as listas.

--- INICIANDO PRUEBA DE CONTROL DE GPU ---
‚úÖ GPU detectada: NVIDIA A100-SXM4-40GB
‚úÖ ¬°√âxito! Memoria asignada activamente: 2.00 GB
‚úÖ Memoria liberada correctamente.
--- PRUEBA DE CONTROL DE GPU COMPLETADA EXITOSAMENTE ---


In [None]:
# ==============================================================================
# 1. SETUP AND DEPENDENCY INSTALLATION
# ==============================================================================
print("Ensuring PyTorch Lightning and other libraries are installed...")
# Install the necessary libraries with pinned versions to avoid conflicts
!pip install --upgrade -q pytorch-lightning timm "pandas==2.2.2" "pyarrow==19.0.0"
print("‚úÖ Installation check complete.")

# ==============================================================================
# 2. IMPORTS AND INITIAL CONFIGURATION
# ==============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from torch.utils.data import Dataset, DataLoader, random_split
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import CSVLogger
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MulticlassPrecision, MulticlassRecall
import numpy as np
import pandas as pd
from pathlib import Path
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os

# Set matrix multiplication precision for A100/H100 GPUs for better performance
torch.set_float32_matmul_precision('medium')
print("‚úÖ Libraries imported and configuration set.")

# ==============================================================================
# 3. MODEL ARCHITECTURE DEFINITION
# ==============================================================================
def get_model(model_name='swin_base', num_classes=5, pretrained=True):
    """
    Creates a model adapted for sleep stage classification using timm's built-in helpers.
    """
    if model_name == 'swin_base':
        model = timm.create_model(
            'swin_base_patch4_window7_224.ms_in22k',
            pretrained=pretrained,
            num_classes=num_classes,
            in_chans=1,
            img_size=(76, 60)
        )
        print(f"‚úÖ Swin Transformer Base model created.")
    else:
        raise ValueError(f"Model '{model_name}' not supported for this script.")

    return model

print("‚úÖ `get_model` function defined.")

# ==============================================================================
# 4. PYTORCH LIGHTNING MODULE
# ==============================================================================
class SleepStageClassifierLightning(pl.LightningModule):
    def __init__(self, model_name, learning_rate=1e-5, class_weights=None):
        super().__init__()
        self.save_hyperparameters()
        self.model = get_model(model_name=self.hparams.model_name, num_classes=5, pretrained=True)
        self.train_accuracy = MulticlassAccuracy(num_classes=5)
        self.val_accuracy = MulticlassAccuracy(num_classes=5)
        self.weights = torch.tensor(class_weights, dtype=torch.float) if class_weights is not None else None
        self.loss_fn = F.cross_entropy

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y_true = batch
        y_pred_logits = self(x)
        loss = self.loss_fn(y_pred_logits, y_true, weight=self.weights.to(self.device) if self.weights is not None else None)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_acc', self.train_accuracy(y_pred_logits, y_true), on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y_true = batch
        y_pred_logits = self(x)
        loss = self.loss_fn(y_pred_logits, y_true, weight=self.weights.to(self.device) if self.weights is not None else None)
        self.log('val_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_acc', self.val_accuracy(y_pred_logits, y_true), on_epoch=True, prog_bar=True, logger=True)
        return loss

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = {
            'scheduler': ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3),
            'monitor': 'val_loss',
            'interval': 'epoch',
            'frequency': 1,
        }
        return [optimizer], [scheduler]

    def on_fit_end(self):
        print("\n" + "="*80)
        print("Generating Final Performance Metrics on the Validation Set...")
        self.model.eval()
        all_preds, all_labels = [], []
        if not self.trainer.val_dataloaders:
            print("Validation dataloader not available. Skipping report generation.")
            return
        with torch.no_grad():
            for batch in self.trainer.val_dataloaders:
                x, y = batch
                x = x.to(self.device)
                logits = self.model(x)
                preds = torch.argmax(logits, dim=1)
                all_preds.append(preds.cpu())
                all_labels.append(y.cpu())
        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)
        num_classes = 5
        precision_metric = MulticlassPrecision(num_classes=num_classes, average=None).to(self.device)
        recall_metric = MulticlassRecall(num_classes=num_classes, average=None).to(self.device)
        f1_metric = MulticlassF1Score(num_classes=num_classes, average=None).to(self.device)
        accuracy_metric = MulticlassAccuracy(num_classes=num_classes, average='micro').to(self.device)
        precisions = precision_metric(all_preds, all_labels)
        recalls = recall_metric(all_preds, all_labels)
        f1_scores = f1_metric(all_preds, all_labels)
        accuracy = accuracy_metric(all_preds, all_labels)
        stage_map = {0: "Wake", 1: "N1", 2: "N2", 3: "N3", 4: "REM"}
        print("\n--- Sleep Stage Classification Report ---")
        print(f"{'Stage':<10} | {'Precision':<10} | {'Recall':<10} | {'F1-Score':<10}")
        print("-" * 50)
        for i in range(num_classes):
            stage_name = stage_map[i]
            precision, recall, f1 = precisions[i].item(), recalls[i].item(), f1_scores[i].item()
            print(f"{stage_name:<10} | {precision:<10.4f} | {recall:<10.4f} | {f1:<10.4f}")
        print("-" * 50)
        print(f"\nOverall Accuracy: {accuracy.item():.4f}")
        print("="*80 + "\n")

print("‚úÖ `SleepStageClassifierLightning` module defined.")

# ==============================================================================
# 5. OPTIMIZED CUSTOM DATASET DEFINITION WITH METADATA CACHING
# ==============================================================================
class OptimizedCombinedDataset(Dataset):
    def __init__(self, file_paths, metadata_path):
        self.file_paths = file_paths
        self.metadata_path = metadata_path
        self._cache = {}

        # --- MODIFICATION: Incremental and Resumable Scanning ---
        processed_files = set()
        if os.path.exists(self.metadata_path):
            print(f"Found existing metadata file at {self.metadata_path}. Checking for unscanned files...")
            metadata_df = pd.read_csv(self.metadata_path)
            processed_files = set(metadata_df['filepath'].apply(str))

        all_file_paths_str = {str(p) for p in self.file_paths}
        files_to_scan = [Path(p) for p in all_file_paths_str - processed_files]

        if files_to_scan:
            print(f"Found {len(files_to_scan)} new or unscanned files. Scanning in batches...")

            batch_size = 100 # Scan 100 files at a time
            for i in range(0, len(files_to_scan), batch_size):
                batch_paths = files_to_scan[i:i + batch_size]
                print(f"  -> Scanning batch {i//batch_size + 1}/{-(-len(files_to_scan)//batch_size)}...")

                epoch_data = []
                for f_path in batch_paths:
                    try:
                        df_labels = pd.read_parquet(f_path, columns=['label'])
                        num_valid = df_labels['label'].isin([0, 1, 2, 3, 4]).sum()
                        epoch_data.append({'filepath': str(f_path), 'epoch_count': num_valid})
                    except Exception as e:
                        print(f"    - Warning: Could not process {f_path.name}. Skipping. Error: {e}")

                # Append results of this batch to the metadata file
                if epoch_data:
                    batch_df = pd.DataFrame(epoch_data)
                    # Use append mode and don't write header if file already exists
                    batch_df.to_csv(self.metadata_path, mode='a', header=not os.path.exists(self.metadata_path), index=False)
                    print(f"     ‚úÖ Saved progress for {len(batch_paths)} files.")

        print("‚úÖ Scan complete. Loading final metadata...")
        final_metadata_df = pd.read_csv(self.metadata_path)
        final_metadata_df['filepath'] = final_metadata_df['filepath'].apply(Path)
        epoch_counts_map = dict(zip(final_metadata_df['filepath'], final_metadata_df['epoch_count']))
        self.epochs_per_file = [epoch_counts_map.get(fp, 0) for fp in self.file_paths]

        self.cumulative_epochs = np.cumsum(self.epochs_per_file)
        self.total_epochs = self.cumulative_epochs[-1] if len(self.cumulative_epochs) > 0 else 0
        print(f"‚úÖ Dataset initialized. Total valid epochs: {self.total_epochs}")

    def __len__(self):
        return self.total_epochs

    def __getitem__(self, idx):
        file_idx = np.searchsorted(self.cumulative_epochs, idx, side='right')
        local_idx = idx - (self.cumulative_epochs[file_idx - 1] if file_idx > 0 else 0)
        file_path = self.file_paths[file_idx]
        if file_path not in self._cache:
            try:
                df = pd.read_parquet(file_path)
                self._cache[file_path] = df[df['label'].isin([0, 1, 2, 3, 4])].reset_index(drop=True)
            except Exception as e:
                raise IOError(f"Error reading file {file_path.name} in __getitem__: {e}")
        row = self._cache[file_path].iloc[local_idx]
        label = np.int64(row['label'])
        spectrogram_flat = row.drop('label').values.astype(np.float32)
        mean, std = spectrogram_flat.mean(), spectrogram_flat.std()
        spectrogram_normalized = (spectrogram_flat - mean) / (std + 1e-6)
        spectrogram_2d = spectrogram_normalized.reshape(1, 76, 60)
        return torch.from_numpy(spectrogram_2d), torch.tensor(label)

print("‚úÖ `OptimizedCombinedDataset` class defined.")

# ==============================================================================
# 6. TRAINING EXECUTION
# ==============================================================================
print("\n--- Starting Swin Transformer Experiment (1000 Files, Resumable) ---")

# --- General Parameters ---
MODEL_TO_TRAIN = 'swin_base'
EPOCHS = 40
BATCH_SIZE = 256
NUM_WORKERS = 0
CLASS_WEIGHTS = [0.7, 3.5, 0.5, 1.5, 1.2]
LEARNING_RATE = 2e-5

# --- Paths and File Identification (using Google Drive) ---
shhs1_processed_dir_base = Path('/content/drive/MyDrive/shhs1_processed')
shhs2_processed_dir_base = Path('/content/drive/MyDrive/shhs2_processed')
METADATA_PATH = Path('/content/drive/MyDrive/dataset_metadata_1000_files.csv')

# Using 1000 files for a fair comparison
shhs1_files = list(shhs1_processed_dir_base.glob('*.parquet'))[:500]
shhs2_files = list(shhs2_processed_dir_base.glob('*.parquet'))[:500]
specific_shhs_file_paths = shhs1_files + shhs2_files

# --- Main Experiment Logic ---
if not specific_shhs_file_paths:
     print("\nERROR: No valid .parquet files were found. Aborting experiment.")
else:
    print(f"\nFound {len(specific_shhs_file_paths)} specific files for training.")

    full_dataset = OptimizedCombinedDataset(
        file_paths=specific_shhs_file_paths,
        metadata_path=METADATA_PATH
    )

    if len(full_dataset) > 1:
        train_size = int(0.8 * len(full_dataset))
        val_size = len(full_dataset) - train_size
        train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

        print(f"Dataset split: {len(train_dataset)} training samples, {len(val_dataset)} validation samples.")

        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, persistent_workers= (NUM_WORKERS > 0) )
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, persistent_workers= (NUM_WORKERS > 0) )

        model = SleepStageClassifierLightning(
            model_name=MODEL_TO_TRAIN,
            learning_rate=LEARNING_RATE,
            class_weights=CLASS_WEIGHTS
        )

        experiment_name = f"{MODEL_TO_TRAIN}_1000_files_resumable_lr_2e-5"
        checkpoint_dir = Path('/content/drive/MyDrive/final_model_checkpoint/') / experiment_name
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

        csv_logger = CSVLogger("/content/drive/MyDrive/sleep_logs/", name=experiment_name)

        checkpoint_callback = ModelCheckpoint(
            monitor='val_loss',
            dirpath=checkpoint_dir,
            filename=f"sleep-stage-model-{{epoch:02d}}-{{val_loss:.4f}}",
            save_top_k=1,
            mode='min',
            save_last=True
        )

        early_stop_callback = EarlyStopping(
           monitor='val_loss',
           patience=7,
           verbose=True,
           mode='min'
        )

        trainer = pl.Trainer(
            max_epochs=EPOCHS,
            accelerator="gpu",
            devices=1,
            logger=csv_logger,
            callbacks=[checkpoint_callback, early_stop_callback],
            precision="bf16-mixed",
            gradient_clip_val=1.0
        )

        resume_checkpoint_path = checkpoint_dir / "last.ckpt"
        if resume_checkpoint_path.exists():
            print(f"üöÄ Resuming training from checkpoint: {resume_checkpoint_path}")
            trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=resume_checkpoint_path)
        else:
            print(f"üöÄ Starting new training run for {MODEL_TO_TRAIN.upper()}...")
            trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

        print(f"‚úÖ Training complete for {MODEL_TO_TRAIN.upper()}!")
        print(f"Best model for this run saved at: {checkpoint_callback.best_model_path}")

    else:
        print("Dataset is too small to split. Aborting experiment.")

Ensuring PyTorch Lightning and other libraries are installed...
‚úÖ Installation check complete.
‚úÖ Libraries imported and configuration set.
‚úÖ `get_model` function defined.
‚úÖ `SleepStageClassifierLightning` module defined.
‚úÖ `OptimizedCombinedDataset` class defined.

--- Starting Swin Transformer Experiment (1000 Files, Resumable) ---

Found 1000 specific files for training.
Found 1000 new or unscanned files. Scanning in batches...
  -> Scanning batch 1/10...
     ‚úÖ Saved progress for 100 files.
  -> Scanning batch 2/10...
     ‚úÖ Saved progress for 100 files.
  -> Scanning batch 3/10...
     ‚úÖ Saved progress for 100 files.
  -> Scanning batch 4/10...
     ‚úÖ Saved progress for 100 files.
  -> Scanning batch 5/10...
     ‚úÖ Saved progress for 100 files.
  -> Scanning batch 6/10...
     ‚úÖ Saved progress for 100 files.
  -> Scanning batch 7/10...
     ‚úÖ Saved progress for 100 files.
  -> Scanning batch 8/10...
     ‚úÖ Saved progress for 100 files.
  -> Scanning batch 

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

  idx_right = torch.bucketize(x, p)
  numerator += self.values[as_s] * \
INFO:pytorch_lightning.utilities.rank_zero:Using bfloat16 Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


‚úÖ Swin Transformer Base model created.
üöÄ Resuming training from checkpoint: /content/drive/MyDrive/final_model_checkpoint/swin_base_1000_files_resumable_lr_2e-5/last.ckpt


/usr/local/lib/python3.12/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:701: Checkpoint directory /content/drive/MyDrive/final_model_checkpoint/swin_base_1000_files_resumable_lr_2e-5 exists and is not empty.
INFO:pytorch_lightning.utilities.rank_zero:Restoring states from the checkpoint path at /content/drive/MyDrive/final_model_checkpoint/swin_base_1000_files_resumable_lr_2e-5/last.ckpt
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/usr/local/lib/python3.12/dist-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision bf16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.
INFO:pytorch_lightning.callbacks.model_summary:
  | Name           | Type               | Params | Mode 
--------------------------------------------------------------
0 | model          | SwinTransformer    | 86.7 M | train
1 | train_accuracy | MulticlassAccuracy | 0     

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.12/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/usr/local/lib/python3.12/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Monitored metric val_loss did not improve in the last 8 records. Best score: 0.595. Signaling Trainer to stop.



Generating Final Performance Metrics on the Validation Set...


In [6]:
# ==============================================================================
# SCRIPT TO GENERATE PERFORMANCE REPORT FROM A SAVED CHECKPOINT
# ==============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from torch.utils.data import Dataset, DataLoader, random_split
import pytorch_lightning as pl
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MulticlassPrecision, MulticlassRecall
import numpy as np
import pandas as pd
from pathlib import Path
import os

# --- Ensure dependencies are installed ---
# MODIFICATION: Pinned pyarrow to a compatible version
!pip install --upgrade -q pytorch-lightning timm "pandas==2.2.2" "pyarrow==19.0.0"

# ==============================================================================
# 1. DEFINE THE MODEL AND DATASET CLASSES
#    (These must match the training script exactly)
# ==============================================================================

def get_model(model_name='swin_base', num_classes=5, pretrained=True):
    if model_name == 'swin_base':
        model = timm.create_model(
            'swin_base_patch4_window7_224.ms_in22k',
            pretrained=pretrained,
            num_classes=num_classes,
            in_chans=1,
            img_size=(76, 60)
        )
    else:
        raise ValueError(f"Model '{model_name}' not supported.")
    return model

class SleepStageClassifierLightning(pl.LightningModule):
    def __init__(self, model_name, learning_rate=1e-5, class_weights=None):
        super().__init__()
        self.save_hyperparameters()
        self.model = get_model(model_name=self.hparams.model_name, num_classes=5, pretrained=False) # Pretrained=False for loading local weights
        self.train_accuracy = MulticlassAccuracy(num_classes=5)
        self.val_accuracy = MulticlassAccuracy(num_classes=5)
        self.weights = torch.tensor(class_weights, dtype=torch.float) if class_weights is not None else None
        self.loss_fn = F.cross_entropy
    def forward(self, x):
        return self.model(x)

# --- MODIFICATION: Use the fast, metadata-caching dataset ---
class OptimizedCombinedDataset(Dataset):
    def __init__(self, file_paths, metadata_path):
        self.file_paths = file_paths
        self.metadata_path = metadata_path
        self._cache = {}

        if os.path.exists(self.metadata_path):
            print(f"Found metadata file at {self.metadata_path}. Loading epoch counts...")
            metadata_df = pd.read_csv(self.metadata_path)
            metadata_df['filepath'] = metadata_df['filepath'].apply(Path)
            epoch_counts_map = dict(zip(metadata_df['filepath'], metadata_df['epoch_count']))
            self.epochs_per_file = [epoch_counts_map.get(fp, 0) for fp in self.file_paths]
            print("‚úÖ Epoch counts loaded from metadata file.")
        else:
            raise FileNotFoundError(f"Metadata file not found at {self.metadata_path}. Please run the training script first to generate it.")

        self.cumulative_epochs = np.cumsum(self.epochs_per_file)
        self.total_epochs = self.cumulative_epochs[-1] if len(self.cumulative_epochs) > 0 else 0
        print(f"‚úÖ Dataset initialized. Total valid epochs: {self.total_epochs}")

    def __len__(self):
        return self.total_epochs

    def __getitem__(self, idx):
        file_idx = np.searchsorted(self.cumulative_epochs, idx, side='right')
        local_idx = idx - (self.cumulative_epochs[file_idx - 1] if file_idx > 0 else 0)
        file_path = self.file_paths[file_idx]
        if file_path not in self._cache:
            try:
                df = pd.read_parquet(file_path)
                self._cache[file_path] = df[df['label'].isin([0, 1, 2, 3, 4])].reset_index(drop=True)
            except Exception as e:
                raise IOError(f"Error reading file {file_path.name} in __getitem__: {e}")
        row = self._cache[file_path].iloc[local_idx]
        label = np.int64(row['label'])
        spectrogram_flat = row.drop('label').values.astype(np.float32)
        mean, std = spectrogram_flat.mean(), spectrogram_flat.std()
        spectrogram_normalized = (spectrogram_flat - mean) / (std + 1e-6)
        spectrogram_2d = spectrogram_normalized.reshape(1, 76, 60)
        return torch.from_numpy(spectrogram_2d), torch.tensor(label)

# ==============================================================================
# 2. LOAD MODEL AND DATA, THEN GENERATE REPORT
# ==============================================================================

# --- CONFIGURATION ---
# IMPORTANT: Update this path to point to the BEST checkpoint file from the Swin Base run
CHECKPOINT_PATH = "/content/drive/MyDrive/final_model_checkpoint/swin_base_1000_files_resumable_lr_2e-5/sleep-stage-model-epoch=03-val_loss=0.5951.ckpt"
METADATA_PATH = Path('/content/drive/MyDrive/dataset_metadata_1000_files.csv')
# -------------------

print(f"üß† Loading model from: {CHECKPOINT_PATH}")
if not os.path.exists(CHECKPOINT_PATH):
    print(f"‚ùå ERROR: Checkpoint file not found at the specified path.")
else:
    model = SleepStageClassifierLightning.load_from_checkpoint(CHECKPOINT_PATH)
    model.eval()
    model.cuda() # Move model to GPU
    print("‚úÖ Model loaded successfully.")

    # --- Load the dataset (needed for the validation set) ---
    shhs1_processed_dir_base = Path('/content/drive/MyDrive/shhs1_processed')
    shhs2_processed_dir_base = Path('/content/drive/MyDrive/shhs2_processed')
    shhs1_files = list(shhs1_processed_dir_base.glob('*.parquet'))[:500]
    shhs2_files = list(shhs2_processed_dir_base.glob('*.parquet'))[:500]
    specific_shhs_file_paths = shhs1_files + shhs2_files

    full_dataset = OptimizedCombinedDataset(
        file_paths=specific_shhs_file_paths,
        metadata_path=METADATA_PATH
    )
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    _, val_dataset = random_split(full_dataset, [train_size, val_size])
    val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=0)
    print("‚úÖ Validation data loaded.")

    # --- Generate the report ---
    print("\n" + "="*80)
    print("Generating Final Performance Metrics on the Validation Set...")
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            x, y = batch
            x = x.to(model.device)
            logits = model(x)
            preds = torch.argmax(logits, dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(y.cpu())
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    num_classes = 5
    precision_metric = MulticlassPrecision(num_classes=num_classes, average=None)
    recall_metric = MulticlassRecall(num_classes=num_classes, average=None)
    f1_metric = MulticlassF1Score(num_classes=num_classes, average=None)
    accuracy_metric = MulticlassAccuracy(num_classes=num_classes, average='micro')

    precisions = precision_metric(all_preds, all_labels)
    recalls = recall_metric(all_preds, all_labels)
    f1_scores = f1_metric(all_preds, all_labels)
    accuracy = accuracy_metric(all_preds, all_labels)

    stage_map = {0: "Wake", 1: "N1", 2: "N2", 3: "N3", 4: "REM"}
    print("\n--- Sleep Stage Classification Report ---")
    print(f"{'Stage':<10} | {'Precision':<10} | {'Recall':<10} | {'F1-Score':<10}")
    print("-" * 50)
    for i in range(num_classes):
        stage_name = stage_map[i]
        precision, recall, f1 = precisions[i].item(), recalls[i].item(), f1_scores[i].item()
        print(f"{stage_name:<10} | {precision:<10.4f} | {recall:<10.4f} | {f1:<10.4f}")
    print("-" * 50)
    print(f"\nOverall Accuracy: {accuracy.item():.4f}")
    print("="*80 + "\n")

üß† Loading model from: /content/drive/MyDrive/final_model_checkpoint/swin_base_1000_files_resumable_lr_2e-5/sleep-stage-model-epoch=03-val_loss=0.5951.ckpt
‚úÖ Model loaded successfully.
Found metadata file at /content/drive/MyDrive/dataset_metadata_1000_files.csv. Loading epoch counts...
‚úÖ Epoch counts loaded from metadata file.
‚úÖ Dataset initialized. Total valid epochs: 1093021
‚úÖ Validation data loaded.

Generating Final Performance Metrics on the Validation Set...


OSError: Error reading file shhs2-204275.parquet in __getitem__: module 'pyarrow.lib' has no attribute 'Decimal32Type'

# --- CONFIGURATION ---
# IMPORTANT: Update this path to point to the BEST checkpoint file from the Swin Base run
CHECKPOINT_PATH = "/content/drive/MyDrive/final_model_checkpoint/swin_base_1000_files_resumable_lr_2e-5/sleep-stage-model-epoch=03-val_loss=0.5951.ckpt"
# -------------------
