In [None]:
# ==============================================================================
# 1. SETUP AND DEPENDENCY INSTALLATION
# ==============================================================================
print("Ensuring PyTorch Lightning and other libraries are installed...")
# Install the necessary libraries, including gcsfs for Google Cloud Storage access
!pip install --upgrade -q pytorch-lightning timm "pandas==2.2.2" "pyarrow==19.0.0" gcsfs
print("âœ… Installation check complete.")

# ==============================================================================
# 2. IMPORTS AND INITIAL CONFIGURATION
# ==============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from torch.utils.data import Dataset, DataLoader, random_split
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import CSVLogger
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score
import numpy as np
import pandas as pd
from pathlib import Path
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import os

# Set matrix multiplication precision for A100/H100 GPUs for better performance
torch.set_float32_matmul_precision('medium')
print("âœ… Libraries imported and configuration set.")

# ==============================================================================
# 3. AUTHENTICATE FOR GOOGLE CLOUD STORAGE (GCS)
# ==============================================================================
from google.colab import auth
import gcsfs

print("Authenticating to Google Cloud...")
# This command will trigger a pop-up to authenticate your user account.
auth.authenticate_user()
print("âœ… Authentication complete.")


# ==============================================================================
# 4. MODEL ARCHITECTURE DEFINITION (MULTI-MODEL SUPPORT)
# ==============================================================================
def get_model(model_name='convnext_base', num_classes=5, pretrained=True):
    """
    Creates a model adapted for sleep stage classification.
    Supports multiple architectures like ConvNeXT and Vision Transformer (ViT).

    Args:
        model_name (str): The name of the model architecture to use.
                          Options: 'convnext_base', 'vit_base'.
        num_classes (int): The number of output classes.
        pretrained (bool): Whether to load pre-trained weights.

    Returns:
        A PyTorch nn.Module representing the adapted model.
    """
    if model_name == 'convnext_base':
        # --- Create ConvNeXT Base Model ---
        model = timm.create_model('convnextv2_base.fcmae_ft_in22k_in1k', pretrained=pretrained)

        # Adapt the first layer for 1-channel input
        original_conv = model.stem[0]
        new_first_conv = nn.Conv2d(1, original_conv.out_channels, kernel_size=original_conv.kernel_size, stride=original_conv.stride, padding=original_conv.padding, bias=(original_conv.bias is not None))
        with torch.no_grad():
            if original_conv.weight.shape[1] == 3:
                new_first_conv.weight[:, :] = original_conv.weight.clone().mean(dim=1, keepdim=True)
        model.stem[0] = new_first_conv

        # Adapt the final classification layer
        num_ftrs = model.head.fc.in_features
        model.head.fc = nn.Linear(num_ftrs, num_classes)
        print(f"âœ… ConvNeXT Base model created.")

    elif model_name == 'vit_base':
        # --- Create Vision Transformer Base Model ---
        # Note: ViT is sensitive to image size. Timm handles this well, but performance
        # is best when input size is close to the pre-training size (e.g., 224x224).
        # We are using our native 76x60 size.
        model = timm.create_model('vit_base_patch16_224.augreg_in21k', pretrained=pretrained, img_size=(76, 60))

        # Adapt the first layer (patch embedding) for 1-channel input
        original_conv = model.patch_embed.proj
        new_patch_embed = nn.Conv2d(1, original_conv.out_channels, kernel_size=original_conv.kernel_size, stride=original_conv.stride, padding=original_conv.padding, bias=(original_conv.bias is not None))
        with torch.no_grad():
            if original_conv.weight.shape[1] == 3:
                new_patch_embed.weight[:, :] = original_conv.weight.clone().mean(dim=1, keepdim=True)
        model.patch_embed.proj = new_patch_embed

        # Adapt the final classification layer
        num_ftrs = model.head.in_features
        model.head = nn.Linear(num_ftrs, num_classes)
        print(f"âœ… Vision Transformer (ViT) Base model created.")

    else:
        raise ValueError(f"Model '{model_name}' not supported. Choose 'convnext_base' or 'vit_base'.")

    return model

print("âœ… `get_model` function defined with multi-architecture support.")

# ==============================================================================
# 5. PYTORCH LIGHTNING MODULE
# ==============================================================================
class SleepStageClassifierLightning(pl.LightningModule):
    """
    PyTorch Lightning module for sleep stage classification.
    """
    def __init__(self, model_name, learning_rate=1e-5, class_weights=None):
        super().__init__()
        self.save_hyperparameters()
        self.model = get_model(model_name=self.hparams.model_name, num_classes=5, pretrained=True)
        self.train_accuracy = MulticlassAccuracy(num_classes=5)
        self.val_accuracy = MulticlassAccuracy(num_classes=5)
        self.train_f1 = MulticlassF1Score(num_classes=5, average='macro')
        self.val_f1 = MulticlassF1Score(num_classes=5, average='macro')
        self.weights = torch.tensor(class_weights, dtype=torch.float) if class_weights is not None else None
        self.loss_fn = nn.CrossEntropyLoss(weight=self.weights)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y_true = batch
        y_pred_logits = self(x)
        loss = self.loss_fn(y_pred_logits, y_true)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_acc', self.train_accuracy(y_pred_logits, y_true), on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_f1', self.train_f1(y_pred_logits, y_true), on_step=False, on_epoch=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y_true = batch
        y_pred_logits = self(x)
        loss = self.loss_fn(y_pred_logits, y_true)
        self.log('val_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_acc', self.val_accuracy(y_pred_logits, y_true), on_epoch=True, prog_bar=True, logger=True)
        self.log('val_f1', self.val_f1(y_pred_logits, y_true), on_epoch=True, logger=True)
        return loss

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = {
            'scheduler': ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3),
            'monitor': 'val_loss',
            'interval': 'epoch',
            'frequency': 1,
        }
        return [optimizer], [scheduler]

print("âœ… `SleepStageClassifierLightning` module defined.")

# ==============================================================================
# 6. OPTIMIZED CUSTOM DATASET DEFINITION WITH METADATA CACHING
# ==============================================================================
class OptimizedCombinedDataset(Dataset):
    """
    An optimized dataset class that scans for epoch counts only once and saves
    the results to a metadata file for near-instantaneous loading on future runs.
    """
    def __init__(self, file_paths, metadata_path):
        self.file_paths = file_paths
        self.metadata_path = metadata_path
        self._cache = {}
        self.fs = gcsfs.GCSFileSystem()

        if self.fs.exists(self.metadata_path):
            print(f"Found metadata file at {self.metadata_path}. Loading epoch counts...")
            with self.fs.open(self.metadata_path, 'r') as f:
                metadata_df = pd.read_csv(f)
            epoch_counts_map = dict(zip(metadata_df['filepath'], metadata_df['epoch_count']))
            self.epochs_per_file = [epoch_counts_map.get(fp, 0) for fp in self.file_paths]
            print("âœ… Epoch counts loaded from metadata file.")
        else:
            print(f"Metadata file not found. Performing one-time scan of {len(self.file_paths)} files...")
            self.epochs_per_file = []
            epoch_data = []
            for f_path in self.file_paths:
                try:
                    df_labels = pd.read_parquet(f_path, columns=['label'])
                    num_valid = df_labels['label'].isin([0, 1, 2, 3, 4]).sum()
                    self.epochs_per_file.append(num_valid)
                    epoch_data.append({'filepath': f_path, 'epoch_count': num_valid})
                except Exception as e:
                    file_name = Path(f_path).name
                    print(f"Warning: Could not process {file_name}. Skipping. Error: {e}")
                    self.epochs_per_file.append(0)
            print("âœ… One-time scan complete. Saving metadata file for future runs...")
            metadata_df = pd.DataFrame(epoch_data)
            with self.fs.open(self.metadata_path, 'w') as f:
                metadata_df.to_csv(f, index=False)
            print(f"âœ… Metadata saved to {self.metadata_path}.")

        self.cumulative_epochs = np.cumsum(self.epochs_per_file)
        self.total_epochs = self.cumulative_epochs[-1] if len(self.cumulative_epochs) > 0 else 0
        print(f"âœ… Dataset initialized. Total valid epochs: {self.total_epochs}")

    def __len__(self):
        return self.total_epochs

    def __getitem__(self, idx):
        file_idx = np.searchsorted(self.cumulative_epochs, idx, side='right')
        local_idx = idx - (self.cumulative_epochs[file_idx - 1] if file_idx > 0 else 0)
        file_path = self.file_paths[file_idx]
        if file_path not in self._cache:
            try:
                df = pd.read_parquet(file_path)
                self._cache[file_path] = df[df['label'].isin([0, 1, 2, 3, 4])].reset_index(drop=True)
            except Exception as e:
                raise IOError(f"Error reading file {Path(file_path).name} in __getitem__: {e}")
        row = self._cache[file_path].iloc[local_idx]
        label = np.int64(row['label'])
        spectrogram_flat = row.drop('label').values.astype(np.float32)
        mean, std = spectrogram_flat.mean(), spectrogram_flat.std()
        spectrogram_normalized = (spectrogram_flat - mean) / (std + 1e-6)
        spectrogram_2d = spectrogram_normalized.reshape(1, 76, 60)
        return torch.from_numpy(spectrogram_2d), torch.tensor(label)

print("âœ… `OptimizedCombinedDataset` class defined.")

# ==============================================================================
# 7. TRAINING EXECUTION FOR FULL DATASET ON COLAB ENTERPRISE
# ==============================================================================
print("\n--- Starting Full Dataset Training Run with Optimal Settings on GCS ---")

# --- General Parameters ---
# --- CHOOSE YOUR MODEL HERE ---
MODEL_TO_TRAIN = 'convnext_base' # Options: 'convnext_base', 'vit_base'

EPOCHS = 40
BATCH_SIZE = 256
NUM_WORKERS = 8
CLASS_WEIGHTS = [0.7, 3.5, 0.5, 1.5, 1.2]
LEARNING_RATE = 5e-5

# --- Paths and File Identification (using GCS) ---
GCS_BUCKET_PATH = "gs://shhs-sleepedfx-data-bucket"
METADATA_FILE_PATH = os.path.join(GCS_BUCKET_PATH, "dataset_metadata.csv")

shhs1_processed_dir = f"{GCS_BUCKET_PATH}/shhs1_processed"
shhs2_processed_dir = f"{GCS_BUCKET_PATH}/shhs2_processed"

# --- Load ALL files from GCS ---
print("Listing all files in GCS bucket...")
fs = gcsfs.GCSFileSystem()
shhs1_files = fs.glob(f"{shhs1_processed_dir}/*.parquet")
shhs2_files = fs.glob(f"{shhs2_processed_dir}/*.parquet")
specific_shhs_file_paths = shhs1_files + shhs2_files

# --- Main Training Logic ---
if not specific_shhs_file_paths:
     print("\nERROR: No valid .parquet files were found in GCS. Aborting training.")
else:
    print(f"\nFound {len(specific_shhs_file_paths)} total files for training.")

    full_dataset = OptimizedCombinedDataset(
        file_paths=specific_shhs_file_paths,
        metadata_path=METADATA_FILE_PATH
    )

    if len(full_dataset) > 1:
        train_size = int(0.8 * len(full_dataset))
        val_size = len(full_dataset) - train_size
        train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

        print(f"Dataset split: {len(train_dataset)} training samples, {len(val_dataset)} validation samples.")

        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, persistent_workers=True)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, persistent_workers=True)

        model = SleepStageClassifierLightning(
            model_name=MODEL_TO_TRAIN,
            learning_rate=LEARNING_RATE,
            class_weights=CLASS_WEIGHTS
        )

        experiment_name = f"{MODEL_TO_TRAIN}_full_dataset_gcs"
        csv_logger = CSVLogger(f"{GCS_BUCKET_PATH}/training_logs/", name=experiment_name)

        checkpoint_callback = ModelCheckpoint(
            monitor='val_loss',
            dirpath=f"{GCS_BUCKET_PATH}/model_checkpoints/",
            filename=f"sleep-stage-model-{experiment_name}-{{epoch:02d}}-{{val_loss:.4f}}",
            save_top_k=1,
            mode='min'
        )

        early_stop_callback = EarlyStopping(
           monitor='val_loss',
           patience=7,
           verbose=True,
           mode='min'
        )

        trainer = pl.Trainer(
            max_epochs=EPOCHS,
            accelerator="gpu",
            devices=1,
            logger=csv_logger,
            callbacks=[checkpoint_callback, early_stop_callback],
            precision="bf16-mixed",
            gradient_clip_val=1.0
        )

        print(f"\nðŸš€ Starting model training for {MODEL_TO_TRAIN}...")
        trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
        print(f"\nâœ… Training complete!")
        print(f"Best model saved at: {checkpoint_callback.best_model_path}")

    else:
        print("Dataset is too small to split. Aborting training.")