In [None]:
# ==============================================================================
# SCRIPT TO GENERATE PERFORMANCE REPORT FROM A SAVED CHECKPOINT
# ==============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from torch.utils.data import Dataset, DataLoader, random_split
import pytorch_lightning as pl
from torchmetrics.classification import MulticlassAccuracy, MulticlassF1Score, MulticlassPrecision, MulticlassRecall
import numpy as np
import pandas as pd
from pathlib import Path
import os

# --- Ensure dependencies are installed ---
!pip install --upgrade -q pytorch-lightning timm "pandas==2.2.2" "pyarrow==1.9.0"

# ==============================================================================
# 1. DEFINE THE MODEL AND DATASET CLASSES
#    (These must match the training script exactly)
# ==============================================================================

def get_model(model_name='swin_base', num_classes=5, pretrained=True):
    if model_name == 'swin_base':
        model = timm.create_model(
            'swin_base_patch4_window7_224.ms_in22k',
            pretrained=pretrained,
            num_classes=num_classes,
            in_chans=1,
            img_size=(76, 60)
        )
    else:
        raise ValueError(f"Model '{model_name}' not supported.")
    return model

class SleepStageClassifierLightning(pl.LightningModule):
    def __init__(self, model_name, learning_rate=1e-5, class_weights=None):
        super().__init__()
        self.save_hyperparameters()
        self.model = get_model(model_name=self.hparams.model_name, num_classes=5, pretrained=False) # Pretrained=False for loading local weights
        self.train_accuracy = MulticlassAccuracy(num_classes=5)
        self.val_accuracy = MulticlassAccuracy(num_classes=5)
        self.weights = torch.tensor(class_weights, dtype=torch.float) if class_weights is not None else None
        self.loss_fn = F.cross_entropy
    def forward(self, x):
        return self.model(x)

class CombinedDataset(Dataset):
    def __init__(self, file_paths_chunk):
        self.file_paths = file_paths_chunk
        self.epochs_per_file = []
        self._cache = {}
        for f_path in self.file_paths:
            try:
                df_labels = pd.read_parquet(f_path, columns=['label'])
                num_valid = df_labels['label'].isin([0, 1, 2, 3, 4]).sum()
                self.epochs_per_file.append(num_valid)
            except Exception as e:
                self.epochs_per_file.append(0)
        self.cumulative_epochs = np.cumsum(self.epochs_per_file)
        self.total_epochs = self.cumulative_epochs[-1] if len(self.cumulative_epochs) > 0 else 0
    def __len__(self):
        return self.total_epochs
    def __getitem__(self, idx):
        file_idx = np.searchsorted(self.cumulative_epochs, idx, side='right')
        local_idx = idx - (self.cumulative_epochs[file_idx - 1] if file_idx > 0 else 0)
        file_path = self.file_paths[file_idx]
        if file_path not in self._cache:
            df = pd.read_parquet(file_path)
            self._cache[file_path] = df[df['label'].isin([0, 1, 2, 3, 4])].reset_index(drop=True)
        row = self._cache[file_path].iloc[local_idx]
        label = np.int64(row['label'])
        spectrogram_flat = row.drop('label').values.astype(np.float32)
        mean, std = spectrogram_flat.mean(), spectrogram_flat.std()
        spectrogram_normalized = (spectrogram_flat - mean) / (std + 1e-6)
        spectrogram_2d = spectrogram_normalized.reshape(1, 76, 60)
        return torch.from_numpy(spectrogram_2d), torch.tensor(label)

# ==============================================================================
# 2. LOAD MODEL AND DATA, THEN GENERATE REPORT
# ==============================================================================

# --- CONFIGURATION ---
# IMPORTANT: Update this path to point to the BEST checkpoint file from the Swin Base run
CHECKPOINT_PATH = "/content/drive/MyDrive/final_model_checkpoint/swin_base_1000_files_resumable_lr_2e-5/last.ckpt"
# -------------------

print(f"üß† Loading model from: {CHECKPOINT_PATH}")
if not os.path.exists(CHECKPOINT_PATH):
    print(f"‚ùå ERROR: Checkpoint file not found at the specified path.")
else:
    model = SleepStageClassifierLightning.load_from_checkpoint(CHECKPOINT_PATH)
    model.eval()
    model.cuda() # Move model to GPU
    print("‚úÖ Model loaded successfully.")

    # --- Load the dataset (needed for the validation set) ---
    shhs1_processed_dir_base = Path('/content/drive/MyDrive/shhs1_processed')
    shhs2_processed_dir_base = Path('/content/drive/MyDrive/shhs2_processed')
    shhs1_files = list(shhs1_processed_dir_base.glob('*.parquet'))[:500]
    shhs2_files = list(shhs2_processed_dir_base.glob('*.parquet'))[:500]
    specific_shhs_file_paths = shhs1_files + shhs2_files

    full_dataset = CombinedDataset(specific_shhs_file_paths)
    train_size = int(0.8 * len(full_dataset))
    val_size = len(full_dataset) - train_size
    _, val_dataset = random_split(full_dataset, [train_size, val_size])
    val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=0)
    print("‚úÖ Validation data loaded.")

    # --- Generate the report ---
    print("\n" + "="*80)
    print("Generating Final Performance Metrics on the Validation Set...")
    all_preds, all_labels = [], []
    with torch.no_grad():
        for batch in val_loader:
            x, y = batch
            x = x.to(model.device)
            logits = model(x)
            preds = torch.argmax(logits, dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(y.cpu())
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    num_classes = 5
    precision_metric = MulticlassPrecision(num_classes=num_classes, average=None)
    recall_metric = MulticlassRecall(num_classes=num_classes, average=None)
    f1_metric = MulticlassF1Score(num_classes=num_classes, average=None)
    accuracy_metric = MulticlassAccuracy(num_classes=num_classes, average='micro')

    precisions = precision_metric(all_preds, all_labels)
    recalls = recall_metric(all_preds, all_labels)
    f1_scores = f1_metric(all_preds, all_labels)
    accuracy = accuracy_metric(all_preds, all_labels)

    stage_map = {0: "Wake", 1: "N1", 2: "N2", 3: "N3", 4: "REM"}
    print("\n--- Sleep Stage Classification Report ---")
    print(f"{'Stage':<10} | {'Precision':<10} | {'Recall':<10} | {'F1-Score':<10}")
    print("-" * 50)
    for i in range(num_classes):
        stage_name = stage_map[i]
        precision, recall, f1 = precisions[i].item(), recalls[i].item(), f1_scores[i].item()
        print(f"{stage_name:<10} | {precision:<10.4f} | {recall:<10.4f} | {f1:<10.4f}")
    print("-" * 50)
    print(f"\nOverall Accuracy: {accuracy.item():.4f}")
    print("="*80 + "\n")

ModuleNotFoundError: No module named 'pytorch_lightning'