<a href="https://colab.research.google.com/github/kartheek0107/demo2/blob/main/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Optional
import math

class SEBlock1D(nn.Module):
    """Squeeze-and-Excitation block for 1D signals"""
    def __init__(self, channels: int, reduction: int = 16):
        super().__init__()
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        reduced_channels = max(channels // reduction, 1)
        self.fc = nn.Sequential(
            nn.Conv1d(channels, reduced_channels, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv1d(reduced_channels, channels, 1, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x: Tensor) -> Tensor:
        # x: [B, C, L]
        scale = self.global_pool(x)  # [B, C, 1]
        scale = self.fc(scale)       # [B, C, 1]
        return x * scale

class AdaptiveNorm1d(nn.Module):
    """Adaptive normalization that switches between BatchNorm and GroupNorm based on sequence length"""
    def __init__(self, channels: int, min_length: int = 4):
        super().__init__()
        self.channels = channels
        self.min_length = min_length
        self.batch_norm = nn.BatchNorm1d(channels)
        # Use 8 groups for GroupNorm, ensuring each group has at least 1 channel
        num_groups = min(8, channels)
        self.group_norm = nn.GroupNorm(num_groups, channels)

    def forward(self, x: Tensor) -> Tensor:
        # Use GroupNorm if sequence is too short for BatchNorm
        if x.size(-1) < self.min_length or self.training is False:
            return self.group_norm(x)
        else:
            return self.batch_norm(x)

class ResidualBlock1D(nn.Module):
    """Residual block with SE attention for 1D signals"""
    def __init__(self, in_channels: int, out_channels: int, stride: int = 1, dropout: float = 0.1):
        super().__init__()

        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3,
                              stride=stride, padding=1, bias=False)
        self.bn1 = AdaptiveNorm1d(out_channels)  # Use adaptive normalization

        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3,
                              stride=1, padding=1, bias=False)
        self.bn2 = AdaptiveNorm1d(out_channels)  # Use adaptive normalization

        self.dropout = nn.Dropout1d(dropout) if dropout > 0 else nn.Identity()
        self.se = SEBlock1D(out_channels)
        self.relu = nn.ReLU(inplace=True)

        # Skip connection
        if stride != 1 or in_channels != out_channels:
            self.skip = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                AdaptiveNorm1d(out_channels)  # Use adaptive normalization
            )
        else:
            self.skip = nn.Identity()

    def forward(self, x: Tensor) -> Tensor:
        identity = self.skip(x)

        out = self.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.bn2(self.conv2(out))
        out = self.se(out)

        out += identity
        return self.relu(out)

class ASPP1D(nn.Module):
    """Atrous Spatial Pyramid Pooling for 1D signals"""
    def __init__(self, in_channels: int, out_channels: int = 256):
        super().__init__()

        # Different dilation rates
        self.conv1 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, 1, bias=False),
            AdaptiveNorm1d(out_channels),  # Use adaptive normalization
            nn.ReLU(inplace=True)
        )

        self.conv2 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, 3, padding=2, dilation=2, bias=False),
            AdaptiveNorm1d(out_channels),  # Use adaptive normalization
            nn.ReLU(inplace=True)
        )

        self.conv3 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, 3, padding=4, dilation=4, bias=False),
            AdaptiveNorm1d(out_channels),  # Use adaptive normalization
            nn.ReLU(inplace=True)
        )

        self.conv4 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, 3, padding=8, dilation=8, bias=False),
            AdaptiveNorm1d(out_channels),  # Use adaptive normalization
            nn.ReLU(inplace=True)
        )

        # Global average pooling branch
        self.global_pool = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(in_channels, out_channels, 1, bias=False),
            AdaptiveNorm1d(out_channels),  # Use adaptive normalization
            nn.ReLU(inplace=True)
        )

        # Final projection
        self.project = nn.Sequential(
            nn.Conv1d(out_channels * 5, out_channels, 1, bias=False),
            AdaptiveNorm1d(out_channels),  # Use adaptive normalization
            nn.ReLU(inplace=True),
            nn.Dropout1d(0.1)
        )

    def forward(self, x: Tensor) -> Tensor:
        size = x.size(-1)

        # Apply different dilated convolutions
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        x3 = self.conv3(x)
        x4 = self.conv4(x)

        # Global pooling branch
        x5 = self.global_pool(x)
        x5 = F.interpolate(x5, size=size, mode='linear', align_corners=False)

        # Concatenate all branches
        out = torch.cat([x1, x2, x3, x4, x5], dim=1)
        return self.project(out)

class AttentionGate1D(nn.Module):
    """Attention gate for skip connections"""
    def __init__(self, gate_channels: int, skip_channels: int, inter_channels: Optional[int] = None):
        super().__init__()

        if inter_channels is None:
            inter_channels = skip_channels // 2

        self.gate_conv = nn.Conv1d(gate_channels, inter_channels, 1, bias=False)
        self.skip_conv = nn.Conv1d(skip_channels, inter_channels, 1, bias=False)
        self.psi_conv = nn.Conv1d(inter_channels, 1, 1, bias=True)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, gate: Tensor, skip: Tensor) -> Tensor:
        # gate: gating signal from decoder, skip: skip connection from encoder
        g = self.gate_conv(gate)
        s = self.skip_conv(skip)

        # Upsample gate signal if needed
        if g.size(-1) != s.size(-1):
            g = F.interpolate(g, size=s.size(-1), mode='linear', align_corners=False)

        psi = self.relu(g + s)
        psi = torch.sigmoid(self.psi_conv(psi))

        return skip * psi

class DecoderBlock1D(nn.Module):
    """Decoder block with attention-gated skip connections"""
    def __init__(self, in_channels: int, skip_channels: int, out_channels: int,
                 use_attention: bool = True):
        super().__init__()

        self.upsample = nn.ConvTranspose1d(in_channels, out_channels,
                                         kernel_size=2, stride=2, bias=False)

        self.use_attention = use_attention
        if use_attention:
            self.attention = AttentionGate1D(out_channels, skip_channels)

        # Process concatenated features
        concat_channels = out_channels + skip_channels
        self.conv_block = nn.Sequential(
            ResidualBlock1D(concat_channels, out_channels),
            ResidualBlock1D(out_channels, out_channels)
        )

    def forward(self, x: Tensor, skip: Tensor) -> Tensor:
        # Upsample decoder features
        x = self.upsample(x)

        # Handle size mismatch
        if x.size(-1) != skip.size(-1):
            diff = skip.size(-1) - x.size(-1)
            x = F.pad(x, (diff // 2, diff - diff // 2))

        # Apply attention to skip connection
        if self.use_attention:
            skip = self.attention(x, skip)

        # Concatenate and process
        x = torch.cat([x, skip], dim=1)
        return self.conv_block(x)

class ResUNetPP1D(nn.Module):
    """1D ResUNet++ with reconstruction head for signal reconstruction"""

    def __init__(self,
                 in_channels: int = 20,
                 out_channels: int = 20,
                 base_filters: int = 64,
                 depth: int = 3,  # Reduced depth to prevent sequence length becoming 1
                 dropout: float = 0.1):
        super().__init__()

        self.depth = depth
        filters = [base_filters * (2 ** i) for i in range(depth + 1)]

        # Initial convolution
        self.input_conv = nn.Sequential(
            nn.Conv1d(in_channels, filters[0], kernel_size=7, padding=3, bias=False),
            AdaptiveNorm1d(filters[0]),  # Use adaptive normalization
            nn.ReLU(inplace=True)
        )

        # Encoder
        self.encoders = nn.ModuleList()
        self.pools = nn.ModuleList()

        for i in range(depth):
            # Encoder block
            if i == 0:
                encoder = nn.Sequential(
                    ResidualBlock1D(filters[i], filters[i], dropout=dropout),
                    ResidualBlock1D(filters[i], filters[i], dropout=dropout)
                )
            else:
                encoder = nn.Sequential(
                    ResidualBlock1D(filters[i-1], filters[i], dropout=dropout),
                    ResidualBlock1D(filters[i], filters[i], dropout=dropout)
                )
            self.encoders.append(encoder)

            # Pooling
            self.pools.append(nn.MaxPool1d(kernel_size=2, stride=2))

        # Bottleneck with ASPP
        bottleneck_filters = filters[depth]
        self.bottleneck = nn.Sequential(
            ResidualBlock1D(filters[depth-1], bottleneck_filters, dropout=dropout),
            ASPP1D(bottleneck_filters, bottleneck_filters // 2),
            ResidualBlock1D(bottleneck_filters // 2, bottleneck_filters, dropout=dropout)
        )

        # Decoder
        self.decoders = nn.ModuleList()
        for i in range(depth):
            decoder_in = filters[depth - i]
            skip_channels = filters[depth - i - 1]
            decoder_out = filters[depth - i - 1]

            self.decoders.append(
                DecoderBlock1D(decoder_in, skip_channels, decoder_out, use_attention=True)
            )

        # Reconstruction head
        self.reconstruction_head = nn.Sequential(
            nn.Conv1d(filters[0], filters[0] // 2, kernel_size=3, padding=1),
            AdaptiveNorm1d(filters[0] // 2),  # Use adaptive normalization
            nn.ReLU(inplace=True),
            nn.Dropout1d(dropout),
            nn.Conv1d(filters[0] // 2, out_channels, kernel_size=1)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize network weights"""
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.ConvTranspose1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x: Tensor) -> Tensor:
        # Store original length for final interpolation
        orig_length = x.size(-1)

        # Initial convolution
        x = self.input_conv(x)

        # Encoder path - store skip connections
        skips = []
        for i in range(self.depth):
            x = self.encoders[i](x)
            skips.append(x)
            x = self.pools[i](x)

        # Bottleneck
        x = self.bottleneck(x)

        # Decoder path
        for i in range(self.depth):
            skip = skips[self.depth - i - 1]
            x = self.decoders[i](x, skip)

        # Reconstruction head
        x = self.reconstruction_head(x)

        # Ensure output matches input length
        if x.size(-1) != orig_length:
            x = F.interpolate(x, size=orig_length, mode='linear', align_corners=False)

        return x

# Sliding window processing for long sequences
class SlidingWindowProcessor:
    """Process long sequences using sliding window with overlap-add"""

    def __init__(self, model: ResUNetPP1D, window_size: int = 1024, stride: int = 128):
        self.model = model
        self.window_size = window_size
        self.stride = stride
        self.overlap = window_size - stride

        # Create blending weights for overlap regions
        self.blend_weights = self._create_blend_weights()

    def _create_blend_weights(self) -> Tensor:
        """Create weights for blending overlapping windows"""
        weights = torch.ones(self.window_size)
        if self.overlap > 0:
            # Linear fade in/out for overlap regions
            fade_length = self.overlap
            weights[:fade_length] = torch.linspace(0, 1, fade_length)
            weights[-fade_length:] = torch.linspace(1, 0, fade_length)
        return weights.view(1, 1, -1)

    def process(self, x: Tensor) -> Tensor:
        """Process input using sliding window"""
        B, C, L = x.shape
        device = x.device

        if L <= self.window_size:
            return self.model(x)

        # Pad input if necessary
        pad_length = 0
        if (L - self.window_size) % self.stride != 0:
            pad_length = self.stride - ((L - self.window_size) % self.stride)
            x = F.pad(x, (0, pad_length), mode='reflect')
            L += pad_length

        # Initialize output
        output = torch.zeros_like(x)
        weights_sum = torch.zeros_like(x)
        blend_weights = self.blend_weights.to(device)

        # Process windows
        self.model.eval()
        with torch.no_grad():
            for start in range(0, L - self.window_size + 1, self.stride):
                end = start + self.window_size
                window = x[:, :, start:end]

                # Process window
                window_output = self.model(window)

                # Apply blending weights
                weighted_output = window_output * blend_weights

                # Accumulate
                output[:, :, start:end] += weighted_output
                weights_sum[:, :, start:end] += blend_weights

        # Normalize by accumulated weights
        output = output / (weights_sum + 1e-8)

        # Remove padding
        if pad_length > 0:
            output = output[:, :, :-pad_length]

        return output

# Training utilities
def reconstruction_loss(pred: Tensor, target: Tensor) -> Tensor:
    """MSE loss for reconstruction"""
    return F.mse_loss(pred, target)

def create_model_and_processor(in_channels: int = 20,
                             out_channels: int = 20,
                             window_size: int = 1024,
                             stride: int = 128) -> tuple:
    """Create model and sliding window processor"""
    model = ResUNetPP1D(in_channels=in_channels, out_channels=out_channels)
    processor = SlidingWindowProcessor(model, window_size=window_size, stride=stride)
    return model, processor

# Example usage and testing
if __name__ == "__main__":
    # Test the model
    print("Creating ResUNet++ 1D model...")
    model, processor = create_model_and_processor(in_channels=20, out_channels=20)

    # Print model info
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

    # Test with window size input
    print(f"\nTesting with window size input (1024)...")
    test_input = torch.randn(1, 20, 1024)
    with torch.no_grad():
        output = model(test_input)
    print(f"Input shape: {test_input.shape}")
    print(f"Output shape: {output.shape}")

    # Test with long sequence using sliding window
    print(f"\nTesting with long sequence (5000) using sliding window...")
    long_input = torch.randn(1, 20, 5000)
    with torch.no_grad():
        long_output = processor.process(long_input)
    print(f"Long input shape: {long_input.shape}")
    print(f"Long output shape: {long_output.shape}")

    # Test loss computation
    target = torch.randn_like(output)
    loss = reconstruction_loss(output, target)
    print(f"Reconstruction loss: {loss.item():.6f}")

    print("\nModel created and tested successfully!")
    print("Ready for training with MSE reconstruction loss.")

Creating ResUNet++ 1D model...
Total parameters: 7,394,583
Trainable parameters: 7,394,583

Testing with window size input (1024)...
Input shape: torch.Size([1, 20, 1024])
Output shape: torch.Size([1, 20, 1024])

Testing with long sequence (5000) using sliding window...
Long input shape: torch.Size([1, 20, 5000])
Long output shape: torch.Size([1, 20, 5000])
Reconstruction loss: 2.282301

Model created and tested successfully!
Ready for training with MSE reconstruction loss.


In [None]:
# Quiet Window Reconstruction Training for ResUNet++ 1D
# Specifically designed for your multi-class labeled space weather data

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from datetime import datetime, timedelta
import os
import json
from tqdm import tqdm
import glob
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

class QuietWindowDataset(Dataset):
    """Dataset specifically for quiet window reconstruction"""

    def __init__(self,
                 data_dir: str,
                 target_class: int = 0,  # 0 = quiet, 1 = sep, 2 = cme
                 add_noise: bool = True,
                 noise_std: float = 0.1,
                 normalize: bool = True,
                 validation_split: bool = False):
        """
        Args:
            data_dir: Directory containing .npz files
            target_class: Class to use for reconstruction (0=quiet, 1=sep, 2=cme)
            add_noise: Whether to add noise to input (for denoising task)
            noise_std: Standard deviation of noise to add
            normalize: Whether to normalize features
            validation_split: Whether this is validation data
        """
        self.data_dir = data_dir
        self.target_class = target_class
        self.add_noise = add_noise
        self.noise_std = noise_std
        self.normalize = normalize
        self.class_names = ["quiet", "sep", "cme"]

        print(f"🎯 Loading {self.class_names[target_class]} windows for reconstruction...")

        # Load all .npz files
        self.npz_files = glob.glob(os.path.join(data_dir, "*.npz"))
        if not self.npz_files:
            raise ValueError(f"No .npz files found in {data_dir}")

        print(f"📁 Found {len(self.npz_files)} .npz files")

        # Load and filter data
        self.windows, self.labels, self.feature_names = self._load_and_filter_data()

        # Normalize if requested
        if self.normalize:
            self._normalize_data()

        print(f"✅ Dataset ready:")
        print(f"  - {len(self.windows)} {self.class_names[target_class]} windows")
        print(f"  - Shape: {self.windows.shape}")
        print(f"  - Features: {len(self.feature_names)}")
        print(f"  - Add noise: {add_noise} (std={noise_std})")

    def _load_and_filter_data(self):
        """Load data and filter for target class"""
        all_windows = []
        all_labels = []
        feature_names = None

        for file_path in tqdm(self.npz_files, desc="Loading files"):
            try:
                data = np.load(file_path, allow_pickle=True)

                # Extract data
                windows = data["windows"]  # (N, 1024, 20)
                labels = data["labels"]    # (N,)

                if feature_names is None:
                    feature_names = data["feature_names"]

                # Filter for target class
                target_mask = labels == self.target_class
                target_windows = windows[target_mask]
                target_labels = labels[target_mask]

                if len(target_windows) > 0:
                    all_windows.append(target_windows)
                    all_labels.append(target_labels)

                print(f"  {os.path.basename(file_path)}: {len(target_windows)}/{len(windows)} {self.class_names[self.target_class]} windows")

            except Exception as e:
                print(f"⚠️ Error loading {file_path}: {e}")
                continue

        if not all_windows:
            raise ValueError(f"No {self.class_names[self.target_class]} windows found!")

        # Combine all data
        combined_windows = np.concatenate(all_windows, axis=0).astype(np.float32)
        combined_labels = np.concatenate(all_labels, axis=0)

        # Transpose to (N, channels, time) for PyTorch
        combined_windows = combined_windows.transpose(0, 2, 1)  # (N, 20, 1024)

        return combined_windows, combined_labels, feature_names

    def _normalize_data(self):
        """Normalize data across all samples and time points"""
        # Calculate statistics per channel
        self.data_mean = np.mean(self.windows, axis=(0, 2), keepdims=True)  # (1, 20, 1)
        self.data_std = np.std(self.windows, axis=(0, 2), keepdims=True) + 1e-8

        # Normalize
        self.windows = (self.windows - self.data_mean) / self.data_std

        print(f"📊 Data normalized:")
        print(f"  Mean range: [{self.data_mean.min():.4f}, {self.data_mean.max():.4f}]")
        print(f"  Std range: [{self.data_std.min():.4f}, {self.data_std.max():.4f}]")

    def __len__(self):
        return len(self.windows)

    def __getitem__(self, idx):
        # Get clean target
        target = torch.tensor(self.windows[idx])  # (20, 1024)

        # Create input (potentially noisy)
        if self.add_noise:
            noise = torch.randn_like(target) * self.noise_std
            input_data = target + noise
        else:
            input_data = target.clone()

        return input_data, target

class ReconstructionTrainer:
    """Trainer specifically for quiet window reconstruction"""

    def __init__(self,
                 model: nn.Module,
                 train_loader: DataLoader,
                 val_loader: DataLoader,
                 device: str = 'cuda',
                 learning_rate: float = 1e-3,
                 save_dir: str = '/content/drive/MyDrive/quiet_reconstruction'):

        self.model = model.to(device)
        self.device = device
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.save_dir = save_dir

        os.makedirs(save_dir, exist_ok=True)

        # Optimizer with different learning rates for different parts
        encoder_params = []
        decoder_params = []
        head_params = []

        for name, param in model.named_parameters():
            if 'reconstruction_head' in name:
                head_params.append(param)
            elif any(x in name for x in ['encoder', 'bottleneck']):
                encoder_params.append(param)
            else:
                decoder_params.append(param)

        # Higher learning rate for reconstruction head
        self.optimizer = optim.AdamW([
            {'params': encoder_params, 'lr': learning_rate * 0.1},  # Lower LR for encoder
            {'params': decoder_params, 'lr': learning_rate * 0.5},  # Medium LR for decoder
            {'params': head_params, 'lr': learning_rate}            # Full LR for head
        ], weight_decay=1e-5)

        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.7, patience=15, verbose=True
        )

        # Loss functions
        self.mse_loss = nn.MSELoss()
        self.l1_loss = nn.L1Loss()

        # Tracking
        self.train_losses = []
        self.val_losses = []
        self.best_val_loss = float('inf')

        self.writer = SummaryWriter(os.path.join(save_dir, 'logs'))

        print(f"🏋️ Trainer initialized:")
        print(f"  Device: {device}")
        print(f"  Model parameters: {sum(p.numel() for p in model.parameters()):,}")
        print(f"  Save directory: {save_dir}")

    def reconstruction_loss(self, pred, target):
        """Combined reconstruction loss"""
        mse = self.mse_loss(pred, target)
        l1 = self.l1_loss(pred, target)
        return 0.7 * mse + 0.3 * l1  # Weighted combination

    def train_epoch(self):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0.0

        pbar = tqdm(self.train_loader, desc='Training', leave=False)
        for batch_idx, (data, target) in enumerate(pbar):
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()

            # Forward pass
            output = self.model(data)
            loss = self.reconstruction_loss(output, target)

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()

            total_loss += loss.item()
            pbar.set_postfix({'Loss': f'{loss.item():.6f}'})

        return total_loss / len(self.train_loader)

    def validate(self):
        """Validate the model"""
        self.model.eval()
        total_loss = 0.0

        with torch.no_grad():
            for data, target in tqdm(self.val_loader, desc='Validation', leave=False):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.reconstruction_loss(output, target)
                total_loss += loss.item()

        return total_loss / len(self.val_loader)

    def visualize_reconstruction(self, num_samples=3):
        """Visualize reconstruction results"""
        self.model.eval()
        fig, axes = plt.subplots(num_samples, 1, figsize=(15, 4*num_samples))
        if num_samples == 1:
            axes = [axes]

        with torch.no_grad():
            data_iter = iter(self.val_loader)
            data, target = next(data_iter)
            data, target = data.to(self.device), target.to(self.device)

            output = self.model(data)

            for i in range(min(num_samples, data.size(0))):
                # Plot first few channels
                channels_to_plot = min(5, data.size(1))
                for ch in range(channels_to_plot):
                    axes[i].plot(target[i, ch].cpu().numpy(),
                               label=f'Target Ch{ch}', alpha=0.7, linewidth=1)
                    axes[i].plot(output[i, ch].cpu().numpy(),
                               label=f'Reconstructed Ch{ch}', alpha=0.7, linewidth=1, linestyle='--')

                axes[i].set_title(f'Sample {i+1} - Reconstruction vs Target')
                axes[i].legend()
                axes[i].grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, 'reconstruction_samples.png'), dpi=150)
        plt.show()

    def save_checkpoint(self, epoch, is_best=False):
        """Save checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'best_val_loss': self.best_val_loss
        }

        torch.save(checkpoint, os.path.join(self.save_dir, 'latest_checkpoint.pt'))

        if is_best:
            torch.save(checkpoint, os.path.join(self.save_dir, 'best_checkpoint.pt'))
            print(f"🏆 New best model! Val Loss: {self.best_val_loss:.6f}")

    def train(self, num_epochs):
        """Main training loop"""
        print(f"🚀 Training quiet window reconstruction for {num_epochs} epochs...")

        for epoch in range(num_epochs):
            print(f"\n📊 Epoch {epoch + 1}/{num_epochs}")

            # Train and validate
            train_loss = self.train_epoch()
            val_loss = self.validate()

            # Update tracking
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.scheduler.step(val_loss)

            # Logging
            self.writer.add_scalar('Loss/Train', train_loss, epoch)
            self.writer.add_scalar('Loss/Val', val_loss, epoch)

            print(f"Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f}")

            # Save checkpoint
            is_best = val_loss < self.best_val_loss
            if is_best:
                self.best_val_loss = val_loss
            self.save_checkpoint(epoch + 1, is_best)

            # Visualize every 20 epochs
            if (epoch + 1) % 20 == 0:
                self.visualize_reconstruction()
                self.plot_losses()

        print(f"🎉 Training completed! Best Val Loss: {self.best_val_loss:.6f}")
        self.writer.close()

    def plot_losses(self):
        """Plot training progress"""
        plt.figure(figsize=(12, 5))

        plt.subplot(1, 2, 1)
        plt.plot(self.train_losses, label='Training Loss', alpha=0.8)
        plt.plot(self.val_losses, label='Validation Loss', alpha=0.8)
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Reconstruction Training Progress')
        plt.legend()
        plt.grid(True, alpha=0.3)

        plt.subplot(1, 2, 2)
        recent_epochs = min(20, len(self.val_losses))
        plt.plot(self.val_losses[-recent_epochs:], 'o-', color='orange')
        plt.xlabel('Recent Epochs')
        plt.ylabel('Validation Loss')
        plt.title('Recent Validation Loss')
        plt.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(os.path.join(self.save_dir, 'training_progress.png'), dpi=150)
        plt.show()

def setup_quiet_reconstruction():
    """Setup and start training for quiet window reconstruction"""

    # Configuration
    config = {
        'data_dir': '/content/drive/MyDrive/labelled_windows_multiclass',
        'target_class': 0,        # 0=quiet, 1=sep, 2=cme
        'batch_size': 4,         # Adjust based on GPU memory
        'learning_rate': 1e-3,
        'num_epochs': 80,
        'val_split': 0.2,
        'add_noise': True,        # Add noise for denoising task
        'noise_std': 0.1,         # Noise standard deviation
        'device': 'cuda' if torch.cuda.is_available() else 'cpu',
        'save_dir': '/content/drive/MyDrive/quiet_reconstruction_model',
    }

    print("🎯 QUIET WINDOW RECONSTRUCTION TRAINING")
    print("=" * 50)
    for key, value in config.items():
        print(f"  {key}: {value}")
    print("=" * 50)

    # Create dataset
    print("\n📊 Creating dataset...")
    full_dataset = QuietWindowDataset(
        data_dir=config['data_dir'],
        target_class=config['target_class'],
        add_noise=config['add_noise'],
        noise_std=config['noise_std'],
        normalize=True
    )

    # Get sample to determine dimensions
    sample_input, sample_target = full_dataset[0]
    n_channels = sample_input.shape[0]  # Should be 20
    seq_length = sample_input.shape[1]   # Should be 1024

    print(f"📏 Data dimensions: {n_channels} channels, {seq_length} time points")

    # Split dataset
    val_size = int(len(full_dataset) * config['val_split'])
    train_size = len(full_dataset) - val_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'],
                            shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'],
                          shuffle=False, num_workers=2, pin_memory=True)

    print(f"✅ Data loaders created: {len(train_dataset)} train, {len(val_dataset)} val")

    # Create model
    print(f"\n🧠 Creating ResUNet++ model...")
    model = ResUNetPP1D(
        in_channels=n_channels,
        out_channels=n_channels,
        base_filters=64,
        depth=3,
        dropout=0.15
    )

    # Create trainer
    trainer = ReconstructionTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=config['device'],
        learning_rate=config['learning_rate'],
        save_dir=config['save_dir']
    )

    # Save config
    config['n_channels'] = n_channels
    config['seq_length'] = seq_length
    with open(os.path.join(config['save_dir'], 'config.json'), 'w') as f:
        json.dump(config, f, indent=2)

    # Start training
    trainer.train(config['num_epochs'])

    return trainer, model, full_dataset

# 🚀 READY TO RUN!
print("🎯 Quiet Window Reconstruction Training Setup Ready!")
print("\nThis will:")
print("  ✅ Load only QUIET windows from your data")
print("  ✅ Add noise for denoising reconstruction task")
print("  ✅ Train ResUNet++ to reconstruct clean quiet windows")
print("  ✅ Focus training on the reconstruction head")
print("  ✅ Save best model for quiet window reconstruction")
print("\n🚀 Run: trainer, model, dataset = setup_quiet_reconstruction()")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
🎯 Quiet Window Reconstruction Training Setup Ready!

This will:
  ✅ Load only QUIET windows from your data
  ✅ Add noise for denoising reconstruction task
  ✅ Train ResUNet++ to reconstruct clean quiet windows
  ✅ Focus training on the reconstruction head
  ✅ Save best model for quiet window reconstruction

🚀 Run: trainer, model, dataset = setup_quiet_reconstruction()


In [None]:
# Start training with your data
trainer, model = setup_quiet_reconstruction()

🎯 QUIET WINDOW RECONSTRUCTION TRAINING
  data_dir: /content/drive/MyDrive/labelled_windows_multiclass
  target_class: 0
  batch_size: 4
  learning_rate: 0.001
  num_epochs: 80
  val_split: 0.2
  add_noise: True
  noise_std: 0.1
  device: cuda
  save_dir: /content/drive/MyDrive/quiet_reconstruction_model

📊 Creating dataset...
🎯 Loading quiet windows for reconstruction...


ValueError: No .npz files found in /content/drive/MyDrive/labelled_windows_multiclass