In [None]:
# Cell ID: ZAmkt7vA2Ayo
# Cell 1: This cell is essential for initial setup, connecting to Google Drive, installing libraries, and checking GPU availability.
# Step 1: Conectar con Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Step 2: Instalar y actualizar las librerías
print("\nInstalando y actualizando librerías...")
!pip install --upgrade -q mne pytorch-lightning timm
print("✅ Librerías listas.")

# Step 3: Prueba explícita de control de la GPU
import torch
print("\n--- INICIANDO PRUEBA DE CONTROL DE GPU ---")
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"✅ GPU detectada: {torch.cuda.get_device_name(0)}")
    try:
        tensor_grande = torch.randn(1024, 1024, 512, device=device) # Asignar 2GB
        memoria_asignada = torch.cuda.memory_allocated(0) / 1024**3
        print(f"✅ ¡Éxito! Memoria asignada activamente: {memoria_asignada:.2f} GB")
        del tensor_grande
        torch.cuda.empty_cache()
        print("✅ Memoria liberada correctamente.")
        print("--- PRUEBA DE CONTROL DE GPU COMPLETADA EXITOSAMENTE ---")
    except Exception as e:
        print(f"❌ ¡ERROR DURANTE LA PRUEBA! No se pudo asignar memoria a la GPU: {e}")
else:
    print("❌ ¡ERROR! No se detectó ninguna GPU en este entorno de ejecución.")

Mounted at /content/drive

Instalando y actualizando librerías...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m46.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m828.2/828.2 kB[0m [31m41.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m128.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m100.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m64.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m11.4 MB/s

In [None]:
# Cell ID: qVOWMBr42MvW
# Cell 2: This cell defines the Model Architecture (get_convnext_model).
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
import pytorch_lightning as pl
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MulticlassCohenKappa
import numpy as np

# ==============================================================================
# DEFINICIÓN DE LA ARQUITECTURA DEL MODELO
# ==============================================================================
def get_convnext_model(num_classes=5, pretrained=True):
    """
    Crea un modelo ConvNeXT V2 adaptado para la clasificación de etapas del sueño.
    """
    model = timm.create_model('convnextv2_tiny.fcmae_ft_in22k_in1k', pretrained=pretrained)

    original_conv = model.stem[0]
    new_first_conv = nn.Conv2d(1, original_conv.out_channels,
                               kernel_size=original_conv.kernel_size,
                               stride=original_conv.stride,
                               padding=original_conv.padding,
                               bias=(original_conv.bias is not None))

    with torch.no_grad():
        new_first_conv.weight[:, :] = original_conv.weight.clone().mean(dim=1, keepdim=True)

    model.stem[0] = new_first_conv

    num_ftrs = model.head.fc.in_features
    model.head.fc = nn.Linear(num_ftrs, num_classes)

    return model

print("✅ Funciones del modelo definidas exitosamente.")

✅ Funciones del modelo definidas exitosamente.


In [None]:
# Cell ID: MAvU9zdM2SNw OK
# Cell 3: This cell defines the PyTorch Lightning Module (SleepStageClassifierLightning), which is crucial for the training process.
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
import numpy as np
import torch.optim as optim
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping # Removed GradientClipByNorm
from pytorch_lightning.loggers import CSVLogger
from torch.optim.lr_scheduler import ReduceLROnPlateau # Import the scheduler
from torch.nn.utils import clip_grad_norm_ # Import clip_grad_norm_

# Reuse the get_convnext_model function from the previous cell
# from your_module import get_convnext_model # If get_convnext_model is in a separate file

# ==============================================================================
# DEFINICIÓN DEL CLASIFICADOR USANDO PYTORCH LIGHTNING TRAINER
# ==============================================================================
class SleepStageClassifierLightning(pl.LightningModule):
    """
    Módulo de PyTorch Lightning que encapsula nuestro modelo y la lógica de entrenamiento,
    diseñado para ser usado con el PyTorch Lightning Trainer.
    """
    def __init__(self, model_name='convnextv2_tiny', num_classes=5, learning_rate=1e-4, class_weights=None):
        super().__init__()
        self.save_hyperparameters()
        self.model = get_convnext_model(num_classes=num_classes, pretrained=True)

        # Métricas
        self.train_accuracy = MulticlassAccuracy(num_classes=num_classes)
        self.train_f1 = MulticlassF1Score(num_classes=num_classes) # Use default 'macro' average or specify
        self.val_accuracy = MulticlassAccuracy(num_classes=num_classes)
        self.val_f1 = MulticlassF1Score(num_classes=num_classes) # Use default 'macro' average or specify

        self.weights = torch.tensor(class_weights, dtype=torch.float) if class_weights is not None else None
        self.loss_fn = F.cross_entropy

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y_true = batch
        # Data is automatically moved to the device by the Trainer
        y_pred_logits = self(x)

        if self.weights is not None:
            loss = self.loss_fn(y_pred_logits, y_true, weight=self.weights.to(self.device))
        else:
            loss = self.loss_fn(y_pred_logits, y_true)

        preds = torch.argmax(y_pred_logits, dim=1)

        # Log metrics
        self.train_accuracy(preds, y_true)
        self.train_f1(y_pred_logits, y_true) # Pass logits and true labels to F1
        self.log('train_loss', loss)
        self.log('train_acc', self.train_accuracy, on_step=True, on_epoch=False, prog_bar=True)
        self.log('train_f1', self.train_f1, on_step=True, on_epoch=False)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y_true = batch
        # Data is automatically moved to the device by the Trainer

        # Add debugging prints and checks for NaN in input
        print(f"Validation Step {batch_idx}:")
        print(f"  Input shape: {x.shape}, dtype: {x.dtype}")
        print(f"  True labels shape: {y_true.shape}, dtype: {y_true.dtype}")

        if torch.isnan(x).any():
            print(f"  !!! WARNING: NaN values detected in input batch x in validation step {batch_idx} !!!")
            # Optional: Print indices of samples with NaN inputs
            # nan_sample_indices = torch.where(torch.isnan(x).any(dim=[1, 2, 3]))[0]
            # print(f"  Indices of samples with NaN input: {nan_sample_indices}")

        y_pred_logits = self(x)

        # Add debugging print for predicted logits immediately after forward pass
        print(f"  Predicted logits shape: {y_pred_logits.shape}, dtype: {y_pred_logits.dtype}")
        print(f"  Predicted logits (first 5): {y_pred_logits[:5]}")

        if torch.isnan(y_pred_logits).any():
             print(f"  !!! WARNING: NaN values detected in predicted logits in validation step {batch_idx} !!!")
             # Optional: Print indices of samples with NaN logits
             # nan_logit_indices = torch.where(torch.isnan(y_pred_logits).any(dim=1))[0]
             # print(f"  Indices of samples with NaN logits: {nan_logit_indices}")


        if self.weights is not None:
            # Ensure weights are on the correct device
            weights = self.weights.to(self.device)
            print(f"  Using class weights: {weights}")
            loss = self.loss_fn(y_pred_logits, y_true, weight=weights)
        else:
            loss = self.loss_fn(y_pred_logits, y_true)

        # Add debugging print for loss
        print(f"  Calculated loss: {loss.item()}")

        # Check for NaN loss
        if torch.isnan(loss):
             print(f"  !!! WARNING: NaN loss detected in validation step {batch_idx} !!!")
             # Optionally, you could try to identify which sample caused the NaN
             # This would require iterating through samples in the batch, which can be slow
             # For now, just logging the batch index and loss is a good start.


        preds = torch.argmax(y_pred_logits, dim=1)

        # Log metrics
        self.val_accuracy(preds, y_true)
        self.val_f1(y_pred_logits, y_true) # Pass logits and true labels to F1
        self.log('val_loss', loss)
        self.log('val_acc', self.val_accuracy, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_f1', self.val_f1, on_step=False, on_epoch=True)


        return loss

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)

        # Define the learning rate scheduler
        scheduler = {
            'scheduler': ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True),
            'interval': 'epoch', # Step the scheduler after each epoch
            'frequency': 1,
            'monitor': 'val_loss' # Metric to monitor for reducing the learning rate
        }

        return [optimizer], [scheduler] # Return optimizer and scheduler

    # Add training_step_end to apply clipping
    def training_step_end(self, outputs):
         # Apply gradient clipping after gradients are computed but before optimizer step
         # This is suitable for automatic optimization
         clip_grad_norm_(self.parameters(), 1.0) # Corrected: Pass a numerical value for clipping
         return outputs


print("✅ SleepStageClassifierLightning defined with ReduceLROnPlateau scheduler and direct gradient clipping in training_step_end.")

✅ SleepStageClassifierLightning defined with ReduceLROnPlateau scheduler and direct gradient clipping in training_step_end.


In [None]:
# Cell ID: 3Kaxttzn56u8
# Cell 4: Re-runs the library installation to ensure all necessary packages are available, especially if the runtime environment changes. It's good practice to keep this before the main training loop
# Re-run installation just in case
print("\nEnsuring PyTorch Lightning and other libs are installed...")
!pip install --upgrade -q pytorch-lightning timm
print("✅ Installation check complete.")


Ensuring PyTorch Lightning and other libs are installed...
✅ Installation check complete.


In [None]:
# Cell ID: sqz0XnS36jjF OK
# Cell 5: CombinedDataset Definition (Used for Chunked Training). This cell contains the CombinedDataset class, which is used by the training loop.

import torch
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
from pathlib import Path
import logging
import numpy as np
import time # Import time for optional profiling within getitem

print("Cell 4: Defining CombinedDataset class...")

# --- Dataset that loads data from a list of file paths (for chunking) ---
class CombinedDataset(Dataset):
    def __init__(self, file_paths_chunk): # Modified to accept a list of file paths
        """
        Initializes the dataset with a list of file paths for a specific chunk.
        """
        print(f"CombinedDataset: Initializing with {len(file_paths_chunk)} files.")
        self.original_file_paths = file_paths_chunk
        self.file_paths = [] # This will store only the successfully processed file paths

        logging.info(f"Dataset chunk initialized with {len(self.original_file_paths)} subjects.")
        print("CombinedDataset: Pre-calculating number of epochs per subject in chunk (this may take a moment)...")
        logging.info("Pre-calculating the number of epochs per subject in chunk (this may take a moment)...")

        epochs_per_file_list = []
        successfully_processed_files = [] # To store file paths that were successfully processed

        # Add the specific problematic file skip, use general error handling instead
        problematic_file = "shhs2-200820.parquet"

        for i, f in enumerate(self.original_file_paths):
            # Adding more granular print with flush=True
            print(f"CombinedDataset: Start processing file {i+1}/{len(self.original_file_paths)}: {f.name}", flush=True)

            if f.name == problematic_file:
                print(f"CombinedDataset: Skipping problematic file: {f.name}", flush=True)
                logging.warning(f"Skipping problematic file: {f.name}")
                epochs_per_file_list.append(0) # Append 0 epochs for the skipped file
                continue # Skip to the next file

            try:
                df_labels = pd.read_parquet(f, columns=['label'])
                print(f"CombinedDataset: Successfully read file {f.name}", flush=True) # Added success print
                valid_labels = [0, 1, 2, 3, 4]
                num_valid_epochs = df_labels['label'].isin(valid_labels).sum()

                # Only add files with valid epochs and successful reads to the processed list
                if num_valid_epochs > 0:
                    epochs_per_file_list.append(num_valid_epochs)
                    successfully_processed_files.append(f) # Add the file path here
                    print(f"CombinedDataset: Found {num_valid_epochs} valid epochs in {f.name}", flush=True) # Added valid epochs print
                else:
                    print(f"CombinedDataset: No valid epochs found in {f.name}. Skipping file.", flush=True)
                    logging.info(f"No valid epochs found in {f.name}, skipping.")
                    epochs_per_file_list.append(0)


            except Exception as e:
                print(f"CombinedDataset: ERROR processing file {f.name}. Reason: {e}. Skipping file.", flush=True) # More detailed error print
                logging.warning(f"No se pudo leer o procesar el archivo {f.name}, se omitirá. Razón: {e}")
                epochs_per_file_list.append(0) # Append 0 epochs for the errored file


        # Update self.file_paths with only the successfully processed files that had valid epochs
        self.file_paths = successfully_processed_files

        # Recalculate cumulative epochs based on the successful files
        # Need to ensure epochs_per_file_list is aligned with successfully_processed_files
        # Let's rebuild epochs_per_file_list based on the successful files to be safe
        epochs_per_file_for_successful = []
        for f in self.file_paths:
            # Find the original index of this file to get its epoch count from the initial list
            original_index = self.original_file_paths.index(f)
            epochs_per_file_for_successful.append(epochs_per_file_list[original_index])

        self.epochs_per_file = np.array(epochs_per_file_for_successful)


        self.cumulative_epochs = np.cumsum(self.epochs_per_file)
        self.total_epochs = self.cumulative_epochs[-1] if len(self.cumulative_epochs) > 0 else 0


        print(f"CombinedDataset: Finished pre-calculation for chunk. Processed {len(self.file_paths)} files. Total valid epochs: {self.total_epochs}")
        logging.info(f"Número final de épocas válidas para el chunk: {self.total_epochs}")

        # --- Caching mechanism ---
        self._cache = {} # Dictionary to store cached dataframes
        print("CombinedDataset: Caching mechanism initialized.")

    def __len__(self):
        return self.total_epochs

    def __getitem__(self, idx):
        # start_time_getitem = time.time() # Optional: start profiling

        file_idx = np.searchsorted(self.cumulative_epochs, idx, side='right')
        file_path = self.file_paths[file_idx] # Use the list of successfully processed file paths

        if file_idx == 0:
            local_idx_global = idx
        else:
            local_idx_global = idx - self.cumulative_epochs[file_idx - 1]

        # --- Check cache first ---
        if file_path not in self._cache:
            # print(f"__getitem__: Cache miss for {file_path.name}. Reading from disk.", flush=True) # Optional print
            # Read the entire parquet file if not in cache
            try:
                df = pd.read_parquet(file_path)
                # Store in cache
                self._cache[file_path] = df
                # print(f"__getitem__: Stored {file_path.name} in cache.", flush=True) # Optional print
            except Exception as e:
                 print(f"__getitem__: ERROR reading file {file_path.name}: {e}. Skipping sample.", flush=True)
                 # Handle corrupted/unreadable files gracefully by returning None or raising an error
                 # Returning None might require collate_fn to handle None values. Raising error is simpler but stops batch.
                 # For now, let's re-raise the exception after logging
                 raise e # Re-raise the exception after logging


        # --- Get data from cache ---
        df_cached = self._cache[file_path]

        # --- Process data from the cached dataframe ---
        valid_labels = [0, 1, 2, 3, 4]
        df_valid = df_cached[df_cached['label'].isin(valid_labels)]

        # Ensure local_idx_global is within the bounds of the valid dataframe for this file
        # This check is crucial now as we only added files with valid epochs to self.file_paths
        if local_idx_global >= len(df_valid):
             print(f"__getitem__: WARNING: Calculated local index {local_idx_global} is out of bounds for valid data in {file_path.name} (length {len(df_valid)}). This should not happen if pre-calculation is correct. Skipping sample.", flush=True)
             # This indicates a potential issue in the epoch pre-calculation logic or data inconsistency
             raise IndexError(f"Local index out of bounds for file {file_path.name}")


        row = df_valid.iloc[local_idx_global]
        spectrogram_flat = row[:-1].values.astype(np.float32)
        label = np.int64(row['label'])

        spectrogram_2d = spectrogram_flat.reshape(1, 76, 60)

        # end_time_getitem = time.time() # Optional: end profiling
        # print(f"__getitem__ for index {idx} from {file_path.name} took {end_time_getitem - start_time_getitem:.4f} seconds.", flush=True) # Optional print

        return torch.from_numpy(spectrogram_2d), torch.tensor(label)

print("✅ CombinedDataset class defined with caching and improved error handling/logging.")

# NOTE: Data loading and Trainer setup logic is moved to the chunked training loop (Cell 6)

Cell 4: Defining CombinedDataset class...
✅ CombinedDataset class defined with caching and improved error handling/logging.


In [None]:
# Cell ID: 1ruaWyKQ7e6t
# Cell 6: This cell contains the complete logic for execution of the chunked training loop, including data discovery,
# chunking, model loading/initialization, trainer setup, and the training process with checkpointing.
# Start the loop from the chunk where the previous session dropped off (chunk_idx = 124)

import torch
from torch.utils.data import Dataset, DataLoader, random_split
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger
import numpy as np
import os
import glob
import pandas as pd # Ensure pandas is imported
from pathlib import Path # Ensure Path is imported

# Set matrix multiplication precision for Tensor Cores
torch.set_float32_matmul_precision('medium')

print("Cell 6: Starting chunked training loop execution...")

# --- Define training parameters ---
epochs_per_chunk = 5 # Train for 5 epochs on each chunk (adjust as needed)
batch_size = 256
num_workers = 4 # Use more workers for potentially faster loading from local disk
class_weights = [0.7, 3.5, 0.5, 1.5, 1.2] # Reuse weights
learning_rate = 1e-4

# --- Rutas y Configuración ---
print("Cell 6: Defining paths and checking for staged data...")
sleep_edfx_processed_dir_local = Path('/content/processed_data/sleep_edfx_processed/')
shhs1_processed_dir_local = Path('/content/processed_data/shhs1_processed/')
shhs2_processed_dir_local = Path('/content/processed_data/shhs2_processed/')
final_checkpoint_dir = Path('/content/drive/MyDrive/final_model_checkpoint/') # Checkpoints still saved to Drive
final_checkpoint_dir.mkdir(parents=True, exist_ok=True)
print(f"Cell 6: Checkpoint directory: {final_checkpoint_dir}")

# --- Discover and chunk files (from local staged data) ---
# Assuming data was staged to /content/processed_data/ in Cell 4
print("Cell 6: Discovering and chunking files from local staged data...")
# Check if the local staged directory exists and has data
if not sleep_edfx_processed_dir_local.exists() or not any(sleep_edfx_processed_dir_local.iterdir()):
    print("\nWARNING: Local staged data not found. Please run Cell 4 (Data Staging).")
    print("Attempting to discover files directly from Drive instead (will be slower)...")
    sleep_edfx_processed_dir_source = Path('/content/drive/MyDrive/sleep_edfx_processed/')
    shhs1_processed_dir_source = Path('/content/drive/MyDrive/shhs1_processed/')
    shhs2_processed_dir_source = Path('/content/drive/MyDrive/shhs2_processed/')
else:
    print("Cell 6: Using local staged data.")
    sleep_edfx_processed_dir_source = sleep_edfx_processed_dir_local
    shhs1_processed_dir_source = shhs1_processed_dir_local
    shhs2_processed_dir_source = shhs2_processed_dir_local


edfx_files = sorted(list(sleep_edfx_processed_dir_source.glob('*.parquet')))
shhs1_files = sorted(list(shhs1_processed_dir_source.glob('**/*.parquet')))
shhs2_files = sorted(list(shhs2_processed_dir_source.glob('**/*.parquet')))
all_files = edfx_files + shhs1_files + shhs2_files

if not all_files:
    print("\nERROR: No data files found in the specified directories (local staged or Drive). Cannot proceed with training.")
else:
    print(f"Cell 6: Found {len(all_files)} total files.")

    # Define chunk size (e.g., number of files per chunk) - Reuse from Cell 5 or define here
    # Assuming chunk_size is defined in Cell 5 and available, or define a default:
    try:
        if 'chunk_size' not in locals() and 'chunk_size' not in globals():
            chunk_size = 50 # Default chunk size if not defined in Cell 5
            print(f"chunk_size not found from Cell 5, using default: {chunk_size}")
    except NameError:
         chunk_size = 50 # Default chunk size
         print(f"chunk_size not found, using default: {chunk_size}")


    print(f"Cell 6: Dividing {len(all_files)} files into chunks of size {chunk_size}...")
    file_chunks = [all_files[i:i + chunk_size] for i in range(0, len(all_files), chunk_size)]
    print(f"Cell 6: Divided into {len(file_chunks)} chunks.")


    # Define logger (Instantiate once for the entire run)
    logger = CSVLogger("logs", name="sleep_stage_training")

    # --- Outer loop through data chunks ---
    print(f"Cell 6: Starting training loop over {len(file_chunks)} chunks.")

    # Variable to keep track of the model for loading checkpoints
    model_lightning = None
    latest_checkpoint_path = None # Track the path of the latest checkpoint to load

    # Start the loop from the chunk where the previous session dropped off (chunk_idx = 0)
    for chunk_idx in range(0, len(file_chunks)):
        file_paths_chunk = file_chunks[chunk_idx]
        print(f"\n--- Processing Chunk {chunk_idx + 1}/{len(file_chunks)} ---")
        print(f"Number of files in current chunk: {len(file_paths_chunk)}")

        if not file_paths_chunk:
            print(f"Chunk {chunk_idx + 1} is empty. Skipping.")
            continue

        # a. Create a CombinedDataset instance for the current chunk
        print(f"Initializing CombinedDataset for chunk {chunk_idx + 1}...")
        # Assuming CombinedDataset is defined in Cell 4
        chunk_dataset = CombinedDataset(file_paths_chunk) # Pass the list of files for this chunk
        print(f"Total valid epochs in chunk {chunk_idx + 1}: {len(chunk_dataset)}")


        if len(chunk_dataset) == 0:
            print(f"Chunk {chunk_idx + 1} contains no valid epochs. Skipping.")
            continue

        # b. Split the current chunk's dataset into train and validation
        print(f"Splitting dataset for chunk {chunk_idx + 1}...")
        # Ensure sufficient data for split
        if len(chunk_dataset) < 2: # Need at least 2 samples to split
             print(f"Chunk {chunk_idx + 1} has only {len(chunk_dataset)} valid epochs. Cannot split. Skipping.")
             continue

        train_size = int(0.8 * len(chunk_dataset))
        val_size = len(chunk_dataset) - train_size

        # Adjust sizes if train_size or val_size is zero after calculation
        if train_size == 0:
            print(f"Train size is 0 for chunk {chunk_idx + 1}. Cannot split. Skipping.")
            continue
        # If val_size is 0 but total dataset is > 0, proceed with training only
        if val_size == 0:
             print(f"Validation size is 0 for chunk {chunk_idx + 1}. Using entire chunk for training.")
             train_size = len(chunk_dataset)
             val_dataset = None
             train_dataset = chunk_dataset
        else:
            train_dataset, val_dataset = random_split(chunk_dataset, [train_size, val_size])


        print(f"Chunk {chunk_idx + 1}: Train size: {len(train_dataset)}, Val size: {len(val_dataset) if val_dataset else 0}")


        # c. Create DataLoaders for the current chunk
        print(f"Creating DataLoaders for chunk {chunk_idx + 1}...")
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, persistent_workers=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, persistent_workers=True) if val_dataset else None
        print(f"DataLoaders created for chunk {chunk_idx + 1}.")

        # d. Initialize or load the SleepStageClassifierLightning model
        # Always attempt to load the latest checkpoint from the previous chunk if it exists
        print(f"Attempting to find and load latest checkpoint from {final_checkpoint_dir}...")

        # Prioritize the explicitly saved 'latest_model_chunk_{prev_chunk_idx+1}.ckpt'
        # The previous chunk index for explicit save filename is chunk_idx (since loop is 0-indexed)
        prev_chunk_checkpoint_explicit = final_checkpoint_dir / f'latest_model_chunk_{chunk_idx}.ckpt'

        if prev_chunk_checkpoint_explicit.exists():
            latest_checkpoint_path = str(prev_chunk_checkpoint_explicit)
            print(f"Found latest explicit checkpoint from previous chunk: {latest_checkpoint_path}")
        else:
             # Fallback to finding the latest ModelCheckpoint saved during the previous chunk's training
             # This looks for files matching the pattern 'lightning-chunk-{prev_chunk_idx}-*.ckpt'
             # The previous chunk index for the filename pattern is chunk_idx (0-indexed) + 1 for the chunk number in filename
             prev_chunk_filename_pattern = f'lightning-chunk-{chunk_idx}-*.ckpt'
             checkpoint_files_prev_chunk = sorted(glob.glob(str(final_checkpoint_dir / prev_chunk_filename_pattern)))


             if checkpoint_files_prev_chunk:
                  latest_checkpoint_path = checkpoint_files_prev_chunk[-1] # Take the last one alphabetically/chronologically
                  print(f"Found latest ModelCheckpoint from previous chunk: {latest_checkpoint_path}")
             else:
                  latest_checkpoint_path = None
                  print(f"No specific or ModelCheckpoint found for chunk {chunk_idx}.")


        if latest_checkpoint_path:
            print(f"Loading model from checkpoint: {latest_checkpoint_path}")
            # Load the checkpoint
            try:
                # Use load_from_checkpoint for LightningModule, passing required hyperparameters
                # Assuming SleepStageClassifierLightning is defined in Cell 3
                 model_lightning = SleepStageClassifierLightning.load_from_checkpoint(
                     latest_checkpoint_path,
                     learning_rate=learning_rate, # Pass hyperparameters if needed by __init__
                     class_weights=class_weights
                 )
                 print("Model loaded successfully from checkpoint.")
            except Exception as e:
                 print(f"Error loading model from checkpoint {latest_checkpoint_path}: {e}")
                 print("Initializing a new model instead for this chunk.")
                 # Assuming SleepStageClassifierLightning is defined in Cell 3
                 model_lightning = SleepStageClassifierLightning(learning_rate=learning_rate, class_weights=class_weights)
                 latest_checkpoint_path = None # Reset checkpoint path if loading failed
        else:
            # If no checkpoint found (e.g., first chunk or loading failed), initialize a new model
            print("No checkpoint found from previous chunk. Initializing a new model.")
            # Assuming SleepStageClassifierLightning is defined in Cell 3
            model_lightning = SleepStageClassifierLightning(learning_rate=learning_rate, class_weights=class_weights)
            # latest_checkpoint_path remains None


        # f. Define ModelCheckpoint callback for the current chunk
        # Filename now includes chunk index for distinct checkpoints within each chunk's training run
        # Use chunk_idx + 1 for the filename to match human-readable chunk numbering
        checkpoint_callback = ModelCheckpoint(
            dirpath=final_checkpoint_dir,
            filename=f'lightning-chunk-{chunk_idx + 1}-{{epoch:02d}}-{{val_loss:.4f}}',
            monitor='val_loss', # Monitor validation loss
            mode='min',
            save_top_k=1, # Save the best model based on the monitor for this chunk's training
            save_last=True # Also save the model from the last epoch of this chunk's training
        )
        print(f"ModelCheckpoint callback defined for chunk {chunk_idx + 1}.")

        # g. Initialize the PyTorch Lightning Trainer for the current chunk
        # The Trainer needs to be re-initialized for each chunk to handle the new dataloaders and potentially loaded model state
        print(f"Initializing Trainer for chunk {chunk_idx + 1}...")
        trainer = pl.Trainer(
            max_epochs=epochs_per_chunk,
            accelerator="gpu",
            devices=1, # Assuming you want to use one GPU
            callbacks=[checkpoint_callback], # Add early_stop_callback here if used
            logger=logger, # Use the same logger instance for all chunks
            precision="32", # Enable mixed precision
            accumulate_grad_batches=4,
            # Add other trainer arguments as needed, e.g., limit_train_batches, limit_val_batches
            # limit_train_batches=0.1, # Example: use only 10% of training data per chunk
            # limit_val_batches=0.1 # Example: use only 10% of validation data per chunk
        )
        print(f"Trainer initialized for chunk {chunk_idx + 1}.")

        # h. Call trainer.fit() to train on the current chunk
        print(f"Starting training for chunk {chunk_idx + 1}...")
        # Pass the model and data loaders to the trainer
        if val_loader is not None:
            trainer.fit(model_lightning, train_loader, val_loader)
        else:
            # If no validation data, train without validation loader
            trainer.fit(model_lightning, train_loader)

        print(f"Training completed for chunk {chunk_idx + 1}.")

        # i. Explicitly save the model state after training on the chunk
        # Save to a fixed name with chunk index so it's easy to find the latest for the next iteration
        final_chunk_checkpoint_path = final_checkpoint_dir / f'latest_model_chunk_{chunk_idx + 1}.ckpt'
        print(f"Saving final model state for chunk {chunk_idx + 1} to {final_chunk_checkpoint_path}...")
        trainer.save_checkpoint(final_chunk_checkpoint_path)
        print(f"Model state saved for chunk {chunk_idx + 1}.")

        # latest_checkpoint_path is already updated by the explicit save


    print("\n========== CHUNKED TRAINING LOOP COMPLETED ==========")

Cell 6: Starting chunked training loop execution...
Cell 6: Defining paths and checking for staged data...
Cell 6: Checkpoint directory: /content/drive/MyDrive/final_model_checkpoint
Cell 6: Discovering and chunking files from local staged data...

Attempting to discover files directly from Drive instead (will be slower)...
Cell 6: Found 8628 total files.
chunk_size not found from Cell 5, using default: 50
Cell 6: Dividing 8628 files into chunks of size 50...
Cell 6: Divided into 173 chunks.
Cell 6: Starting training loop over 173 chunks.

--- Processing Chunk 126/173 ---
Number of files in current chunk: 50
Initializing CombinedDataset for chunk 126...
CombinedDataset: Initializing with 50 files.
CombinedDataset: Pre-calculating number of epochs per subject in chunk (this may take a moment)...
CombinedDataset: Start processing file 1/50: shhs2-200771.parquet
CombinedDataset: Successfully read file shhs2-200771.parquet
CombinedDataset: Found 1181 valid epochs in shhs2-200771.parquet
Co

KeyboardInterrupt: 

In [None]:
# Cell ID: 16c810ee
# Cell 7: Finding/display the latest log file and create a backup copy of it.
import pandas as pd
import glob
import os
import shutil # Import shutil for file copying
from pathlib import Path # Import Path

# Define the base directory for logs on Google Drive (where CSVLogger is configured to save)
log_dir_drive = '/content/drive/MyDrive/sleep_logs/sleep_stage_training/' # Updated to the expected logger path

# --- Find the latest version directory and the metrics.csv file ---
version_dirs = sorted(glob.glob(os.path.join(log_dir_drive, 'version_*')))

log_file_path = None
latest_version_dir = None

if not version_dirs:
    print(f"No version directories found in {log_dir_drive}. Please ensure training has started and logs are being generated.")
else:
    latest_version_dir = version_dirs[-1]
    print(f"Found latest log version directory: {latest_version_dir}")

    # Look for the metrics.csv file within the latest version directory
    metrics_file_path_candidate = os.path.join(latest_version_dir, 'metrics.csv') # CSVLogger default filename

    if os.path.exists(metrics_file_path_candidate):
        log_file_path = metrics_file_path_candidate
        print(f"Found log file: {log_file_path}")
    else:
        print(f"Could not find 'metrics.csv' in {latest_version_dir}. Please check the exact filename in your log directory.")
        # Fallback or additional checks could be added here if the filename might vary

# --- Perform backup and display metrics if log file was found ---
if log_file_path:
    # --- Define the path for the backup copy ---
    # You can choose a different name or directory for the backup
    backup_log_file_path = "/content/drive/MyDrive/sleep_logs/metrics_backup_for_analysis.csv" # Example backup path - UPDATE THIS if needed

    # Create backup copy
    try:
        shutil.copy2(log_file_path, backup_log_file_path)
        print(f"\nSuccessfully created backup of metrics.csv at: {backup_log_file_path}")
    except Exception as e:
        print(f"\nError creating backup copy: {e}")


    # --- Load and display metrics ---
    try:
        log_df = pd.read_csv(log_file_path)

        # Display relevant columns
        print("\nValidation Metrics from Logs:")
        if 'val_acc' in log_df.columns and 'val_f1' in log_df.columns:
             # Only show columns relevant to epoch-end validation metrics
             # These typically have non-NaN values only at the end of an epoch
             validation_metrics_df = log_df[['epoch', 'val_loss', 'val_acc', 'val_f1']].dropna(subset=['val_loss', 'val_acc', 'val_f1'])
             display(validation_metrics_df)
        else:
             print("Could not find 'val_acc' or 'val_f1' columns in the log file.")
             print("Available columns:", log_df.columns.tolist())


    except Exception as e:
        print(f"Error loading or processing log file {log_file_path}: {e}")

else:
    print("\nLog file not found. Cannot perform backup or display metrics.")