# 📦 Imports
This section includes necessary Python libraries for data processing, machine learning (PyTorch), visualization, metrics evaluation, and file handling.

In [1]:
import os

import numpy as np
import pandas as pd
import duckdb

import torch
import torch.fft as fft

from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose

import matplotlib.pyplot as plt

from datetime import datetime, timedelta

import torch.optim as optim
from tqdm import tqdm

from sklearn.metrics import accuracy_score, mean_squared_error, mean_absolute_error

import math
import argparse
import json

import h5py
from torch.utils.data import Dataset
from matplotlib import pyplot as plt

import pickle
import dill

# 📚 STORMAIDataset Definition
Defines a custom PyTorch Dataset class to load and preprocess STORMAI data from HDF5 files.

In [2]:
class STORMAIDataset(Dataset):
    def __init__(self):
        self.x = None
        self.y = None

        self.col_ranges = {
        'altitude': (0,1000),
        'ap_index_nT': (0,400),
        'f10.7_index': (63.4,250),
        'Lyman_alpha': (0.00588,0.010944),
        'Dst_index_nT': (-422,71),
        'BX_nT_GSE_GSM': (-40.8,34.8),
        'BY_nT_GSE': (-33.2,63.4),
        'BZ_nT_GSE': (-53.7,37.5),
        'SW_Proton_Density_N_cm3': (0.1,137.2),
        'SW_Plasma_Speed_km_s': (233,1189),
        'Magnetosonic_Mach_number': (0.6,14.3),
        'log_Lyman_alpha2': (4, 8),
        'f10.7_index2': (3969, 62500),
        'Lyman_alpha_f10.7': (0, 2.75),
        'ap_index_nT2': (0, 1.6e5),
        'ap_index_nT_f10.7': (0, 1e5),
        'log_xrsb_flux': (2.5, 5),
        'log_xrsb_flux2': (5,10),
        'log_xrsb_flux_Lyman_alpha': (5, 9)
        }

        self.sw_varlist = ['ap_index_nT', 'f10.7_index', 'Lyman_alpha', 'Dst_index_nT',
                'BX_nT_GSE_GSM', 'BY_nT_GSE', 'BZ_nT_GSE', 'SW_Proton_Density_N_cm3',
                'SW_Plasma_Speed_km_s', 'Magnetosonic_Mach_number', 'Lyman_alpha2',
                'f10.7_index2', 'Lyman_alpha_f10.7', 'ap_index_nT2', 'ap_index_nT_f10.7',
                'xrsb_flux', 'xrsb_flux2', 'xrsb_flux_Lyman_alpha']
        self.features = ['altitude', 'ap_index_nT', 'f10.7_index', 'Lyman_alpha', 'Dst_index_nT',
                'BX_nT_GSE_GSM', 'BY_nT_GSE', 'BZ_nT_GSE', 'SW_Proton_Density_N_cm3',
                'SW_Plasma_Speed_km_s', 'Magnetosonic_Mach_number', 'log_Lyman_alpha2',
                'f10.7_index2', 'Lyman_alpha_f10.7', 'ap_index_nT2', 'ap_index_nT_f10.7',
                'log_xrsb_flux', 'log_xrsb_flux2', 'log_xrsb_flux_Lyman_alpha']

    def load_hdf5(self, filelist):
        x = []
        y = []
        alt_array = []

        for path in filelist:
            with h5py.File(path, 'r') as f:
                tmpy = np.array(f['density'])[:,:,3]
                tmpx = np.array(f['space_weather'])
                tmp_alt_array = np.array(f['density'])[:,0,2]
                x.append(tmpx)
                y.append(tmpy)
                alt_array.append(tmp_alt_array)

        x = np.concatenate(x)
        y = np.concatenate(y)
        alt_array = np.concatenate(alt_array)

        # convert space-weather variables to log scale
        log_varlist = ['Lyman_alpha2', 'xrsb_flux', 'xrsb_flux2', 'xrsb_flux_Lyman_alpha']
        for v in log_varlist:
            idx = self.sw_varlist.index(v)
            mask = (x[:,:,idx]==0)
            x[:,:,idx] = x[:,:,idx]+(mask)*(10**(-self.col_ranges['log_'+v][0]))
            x[:,:,idx] = -np.log10(x[:,:,idx])

        # Add altitudes form density data to SW data
        altitudes = (np.ones((x.shape[1], x.shape[0]))*alt_array).T.reshape(x.shape[0], x.shape[1], 1)
        x = np.concatenate((altitudes, x), axis=2)

        # scale the features
        for idx, v in enumerate(self.features):
            x[:,:,idx] = (x[:,:,idx] - self.col_ranges[v][0]) / (self.col_ranges[v][1] - self.col_ranges[v][0])

        # convert density values to log scale
        y = (y<=0)*1e-15 + (y>0)*y

        # permute the last two dimensions of feature tensor
        self.x = torch.from_numpy(x).float().permute(0, 2, 1)
        self.y = torch.from_numpy(y).float()

In [3]:
class STORMAILoader(DataLoader):
    def __init__(self, x, y, **kwargs): # Add **kwargs here
        super(STORMAILoader, self).__init__(dataset=[(x_i, y_i) for x_i, y_i in zip(x, y)], **kwargs) # Pass **kwargs to super().__init__
        self.x = x
        self.y = y
    def __len__(self):
        return len(self.x)
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

In [4]:
class PositionalEncoding(nn.Module):
    """
    Positional Encoding for time series patches
    """
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, embedding_dim]
        """
        x = x + self.pe[:x.size(1)]
        return x

class PatchEmbedding(nn.Module):
    """
    Split time series into patches and project to embedding dimension
    """
    def __init__(self, patch_size, in_channels, embed_dim):
        super().__init__()
        self.patch_size = patch_size
        self.proj = nn.Linear(patch_size * in_channels, embed_dim)

    def forward(self, x):
        # x: [batch_size, channels, seq_len]
        batch_size, channels, seq_len = x.shape

        # Pad sequence if needed
        if seq_len % self.patch_size != 0:
            pad_len = self.patch_size - (seq_len % self.patch_size)
            x = torch.cat([x, torch.zeros(batch_size, channels, pad_len, device=x.device)], dim=-1)
            seq_len += pad_len

        # Split into patches
        num_patches = seq_len // self.patch_size
        x = x.view(batch_size, channels, num_patches, self.patch_size)
        x = x.permute(0, 2, 1, 3)  # [batch_size, num_patches, channels, patch_size]
        x = x.reshape(batch_size, num_patches, -1)  # Flatten channels and patch_size

        # Project to embedding dimension
        x = self.proj(x)
        return x

class AdaptiveSpectralBlock(nn.Module):
    """
    Adaptive Spectral Block (ASB) that performs frequency domain processing
    with adaptive noise filtering
    """
    def __init__(self, embed_dim, threshold_quantile=0.9):
        super().__init__()
        self.embed_dim = embed_dim
        self.threshold_quantile = threshold_quantile

        # Learnable global and local filters
        self.global_filter = nn.Parameter(torch.randn(embed_dim, dtype=torch.cfloat))
        self.local_filter = nn.Parameter(torch.randn(embed_dim, dtype=torch.cfloat))

        # Learnable threshold parameter
        self.threshold = nn.Parameter(torch.tensor(0.5))

        # Layer normalization
        self.norm = nn.LayerNorm(embed_dim)

    def adaptive_high_freq_mask(self, x_fft):
        """
        Create adaptive mask for high frequency components
        """
        # Calculate power spectrum
        power = torch.abs(x_fft).pow(2)

        # Compute adaptive threshold
        threshold = torch.quantile(power, self.threshold_quantile)

        # Create mask (1 for frequencies to keep, 0 for those to filter)
        mask = (power > threshold * self.threshold).float()
        return mask

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, embed_dim]
        """
        batch_size, seq_len, embed_dim = x.shape

        # Apply FFT along the sequence dimension
        x_fft = fft.fft(x, dim=1)

        # Adaptive filtering
        mask = self.adaptive_high_freq_mask(x_fft)
        x_filtered = x_fft * mask

        # Apply global and local filters
        x_global = x_fft * self.global_filter
        x_local = x_filtered * self.local_filter

        # Combine and inverse FFT
        x_combined = x_global + x_local
        x_out = fft.ifft(x_combined, dim=1).real

        # Layer normalization
        x_out = self.norm(x_out)
        return x_out

class InteractiveConvolutionBlock(nn.Module):
    """
    Interactive Convolution Block (ICB) with parallel convolutions
    """
    def __init__(self, embed_dim, kernel_sizes=[3, 5]):
        super().__init__()
        self.conv1 = nn.Conv1d(embed_dim, embed_dim, kernel_sizes[0], padding='same')
        self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_sizes[1], padding='same')
        self.conv3 = nn.Conv1d(embed_dim, embed_dim, 1)  # Final mixing convolution
        self.activation = nn.GELU()
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        """
        Args:
            x: [batch_size, seq_len, embed_dim]
        """
        # Permute for convolution
        x_perm = x.permute(0, 2, 1)  # [batch_size, embed_dim, seq_len]

        # First convolution path
        conv1_out = self.conv1(x_perm)
        conv2_out = self.conv2(x_perm)

        # Interactive multiplication
        a1 = self.activation(conv1_out) * conv2_out
        a2 = self.activation(conv2_out) * conv1_out

        # Combine and final convolution
        combined = a1 + a2
        output = self.conv3(combined)

        # Permute back and normalize
        output = output.permute(0, 2, 1)
        output = self.norm(output + x)  # Residual connection
        return output

class TSLANetLayer(nn.Module):
    """
    A single TSLANet layer composed of ASB and ICB
    """
    def __init__(self, embed_dim, kernel_sizes=[3, 5]):
        super().__init__()
        self.asb = AdaptiveSpectralBlock(embed_dim)
        self.icb = InteractiveConvolutionBlock(embed_dim, kernel_sizes)

    def forward(self, x):
        x = self.asb(x)
        x = self.icb(x)
        return x

class TSLANet(nn.Module):
    """
    Complete TSLANet model
    """
    def __init__(self, in_channels, patch_size, embed_dim, num_layers,
                 num_classes=None, forecast_horizon=None, kernel_sizes=[3, 5]):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim

        # Patch embedding
        self.patch_embed = PatchEmbedding(patch_size, in_channels, embed_dim)
        self.pos_encoder = PositionalEncoding(embed_dim)

        # TSLANet layers
        self.layers = nn.ModuleList([
            TSLANetLayer(embed_dim, kernel_sizes) for _ in range(num_layers)
        ])

        # Output heads
        self.num_classes = num_classes
        self.forecast_horizon = forecast_horizon

        if num_classes is not None:
            self.classifier = nn.Linear(embed_dim, num_classes)
        if forecast_horizon is not None:
            self.forecaster = nn.Linear(embed_dim, forecast_horizon)

    def forward(self, x):
        # x: [batch_size, channels, seq_len]

        # Patch embedding
        x = self.patch_embed(x)

        # Positional encoding
        x = self.pos_encoder(x)

        # TSLANet layers
        for layer in self.layers:
            x = layer(x)

        # Global average pooling
        x = x.mean(dim=1)  # [batch_size, embed_dim]

        # Output heads
        outputs = {}
        if self.num_classes is not None:
            outputs['classification'] = self.classifier(x)
        if self.forecast_horizon is not None:
            outputs['forecasting'] = self.forecaster(x)

        return outputs

In [5]:
rate = -np.log(1e-5)/431
propagation_weights = np.array([np.exp(-rate*i) for i in range(432)])

def MAE(true, pred, device='cpu'):
    weights_tensor = torch.ones((true.shape[0], 432))#*torch.tensor(propagation_weights)
    weights_tensor = weights_tensor.to(device)
    loss = torch.sum(torch.abs(true-pred)*weights_tensor, axis=1)
    return torch.mean(loss)

In [6]:
def DerivativeLoss(true, pred, device='cpu'):
    weights_tensor = torch.ones((true.shape[0], 432))#*torch.tensor(propagation_weights)
    weights_tensor = weights_tensor.to(device)
    weights_tensor = weights_tensor.reshape(true.shape[0], 8, 54).mean(dim=2).squeeze(-1)
    pred = pred.reshape(true.shape[0], 9, 48).mean(dim=2).squeeze(-1)
    true = true.reshape(true.shape[0], 9, 48).mean(dim=2).squeeze(-1)
    loss = torch.abs(torch.diff(true, dim=1)-torch.diff(pred, dim=1))*weights_tensor
    loss = torch.sum(loss, axis=1)
    return torch.mean(loss)

In [7]:
def PropagationLoss(true, pred, device='cpu'):
    loss = MAE(true, pred, device)+DerivativeLoss(true, pred, device)
    #loss = DrivativeLoss(true, pred, device)
    return loss

In [46]:
def train_tslanet(model, train_loader, val_loader, task_type='forecasting',
                 num_epochs=100, learning_rate=1e-3, weight_decay=1e-4,
                 patience=10, device='cuda', pretrain_epochs=0,
                 pretrain_loader=None, pretrain_mask_ratio=0.15):
    """
    Train the TSLANet model with optional self-supervised pretraining

    Args:
        model: TSLANet model instance
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        task_type: 'classification', 'forecasting', or 'both'
        num_epochs: Number of training epochs
        learning_rate: Initial learning rate
        weight_decay: Weight decay for optimizer
        patience: Early stopping patience
        device: Device to train on ('cuda' or 'cpu')
        pretrain_epochs: Number of self-supervised pretraining epochs
        pretrain_loader: DataLoader for pretraining (if None, uses train_loader)
        pretrain_mask_ratio: Ratio of patches to mask during pretraining

    Returns:
        Tuple of (trained_model, training_history, best_val_metric)
    """
    # Move model to device
    model = model.to(device)

    # Initialize optimizer and scheduler
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if task_type != 'classification' else 'max', patience=patience//2)

    # Loss functions
    if task_type == 'classification' or task_type == 'both':
        cls_criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    if task_type == 'forecasting' or task_type == 'both':
        #forecast_criterion = nn.MSELoss()
        #forecast_criterion = MAE
        forecast_criterion = PropagationLoss

    pretrain_criterion = nn.MSELoss()  # For self-supervised pretraining

    # Training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_metric': [],
        'val_metric': [],
        'pretrain_loss': []
    }

    best_val_metric = -np.inf if task_type == 'classification' else np.inf
    best_model_state = None
    epochs_without_improvement = 0

    # Self-supervised pretraining phase
    if pretrain_epochs > 0:
        print(f"Starting self-supervised pretraining for {pretrain_epochs} epochs...")
        pretrain_loader = pretrain_loader if pretrain_loader is not None else train_loader

        for epoch in range(pretrain_epochs):
            model.train()
            epoch_pretrain_loss = 0.0

            for batch in tqdm(pretrain_loader, desc=f"Pretrain Epoch {epoch+1}/{pretrain_epochs}"):
                x = batch[0].to(device)  # Input time series

                # Create masked version for pretraining
                batch_size, channels, seq_len = x.shape
                num_patches = seq_len // model.patch_size

                # Generate random mask (1 = keep, 0 = mask)
                mask = torch.ones(batch_size, num_patches, device=device)
                num_masked = int(pretrain_mask_ratio * num_patches)

                # Mask the input
                masked_x = x.clone()
                for i in range(batch_size):
                    # Reshape to patches
                    patches = x[i].unfold(-1, model.patch_size, model.patch_size)  # shape: (in_channels, num_patches, patch_size)

                    # Apply mask - Reshape mask to match patches dimension
                    masking = mask[i, :patches.shape[1]].view(-1, 1, 1).repeat(1, patches.shape[0], patches.shape[2])
                    patches = patches * masking

                    # Reshape back
                    masked_x[i] = patches.reshape(channels, -1)

                # Forward pass
                optimizer.zero_grad()
                outputs = model(masked_x)

                # Get reconstruction target (only masked patches)
                target = x.clone()
                for i in range(batch_size):
                    patches = target[i].unfold(-1, model.patch_size, model.patch_size)
                    patches = patches * (1 - mask[i].view(-1, 1, 1))  # Only masked patches
                    target[i] = patches.reshape(channels, -1)

                # Calculate loss and backprop
                loss = pretrain_criterion(outputs['forecasting'], target, device=device)
                loss.backward()
                optimizer.step()

                epoch_pretrain_loss += loss.item()

            avg_pretrain_loss = epoch_pretrain_loss / len(pretrain_loader)
            history['pretrain_loss'].append(avg_pretrain_loss)
            print(f"Pretrain Epoch {epoch+1} Loss: {avg_pretrain_loss:.4f}")

    # Main training phase
    print(f"Starting main training for {num_epochs} epochs...")

    for epoch in range(num_epochs):
        model.train()
        epoch_train_loss = 0.0
        epoch_train_metric = 0.0

        for batch in tqdm(train_loader, desc=f"Train Epoch {epoch+1}/{num_epochs}"):
            x = batch[0].to(device)
            optimizer.zero_grad()

            # Forward pass
            outputs = model(x)

            # Calculate loss based on task type
            if task_type == 'classification':
                targets = batch[1].to(device)
                loss = cls_criterion(outputs['classification'], targets)
            elif task_type == 'forecasting':
                targets = batch[1].to(device)
                loss = forecast_criterion(outputs['forecasting'], targets, device=device)
            elif task_type == 'both':
                cls_targets = batch[1].to(device)
                forecast_targets = batch[2].to(device)
                loss = cls_criterion(outputs['classification'], cls_targets) + \
                       forecast_criterion(outputs['forecasting'], forecast_targets, device=device)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            epoch_train_loss += loss.item()

            # Calculate training metric
            with torch.no_grad():
                if task_type == 'classification':
                    preds = torch.argmax(outputs['classification'], dim=1)
                    epoch_train_metric += accuracy_score(batch[1].cpu().numpy(), preds.cpu().numpy())
                elif task_type == 'forecasting':
                    epoch_train_metric += PropagationLoss(batch[1].cpu(), outputs['forecasting'].cpu()).item()

                elif task_type == 'both':
                    # For 'both' mode, we track classification accuracy as the primary metric
                    preds = torch.argmax(outputs['classification'], dim=1)
                    epoch_train_metric += accuracy_score(batch[1].cpu().numpy(), preds.cpu().numpy())

        avg_train_loss = epoch_train_loss / len(train_loader)
        avg_train_metric = epoch_train_metric / len(train_loader)
        history['train_loss'].append(avg_train_loss)
        history['train_metric'].append(avg_train_metric)

        # Validation phase
        model.eval()
        epoch_val_loss = 0.0
        epoch_val_metric = 0.0

        with torch.no_grad():
            for batch in val_loader:
                x = batch[0].to(device)
                outputs = model(x)

                if task_type == 'classification':
                    targets = batch[1].to(device)
                    loss = cls_criterion(outputs['classification'], targets)
                    preds = torch.argmax(outputs['classification'], dim=1)
                    epoch_val_metric += accuracy_score(batch[1].cpu().numpy(), preds.cpu().numpy())
                elif task_type == 'forecasting':
                    targets = batch[1].to(device)
                    loss = forecast_criterion(outputs['forecasting'], targets, device=device)
                    epoch_val_metric += PropagationLoss(batch[1].cpu(), outputs['forecasting'].cpu()).item()
                elif task_type == 'both':
                    cls_targets = batch[1].to(device)
                    forecast_targets = batch[2].to(device)
                    loss = cls_criterion(outputs['classification'], cls_targets) + \
                           forecast_criterion(outputs['forecasting'], forecast_targets)
                    preds = torch.argmax(outputs['classification'], dim=1)
                    epoch_val_metric += accuracy_score(batch[1].cpu().numpy(), preds.cpu().numpy())

                epoch_val_loss += loss.item()

        avg_val_loss = epoch_val_loss / len(val_loader)
        avg_val_metric = epoch_val_metric / len(val_loader)
        history['val_loss'].append(avg_val_loss)
        history['val_metric'].append(avg_val_metric)

        # Update learning rate
        scheduler.step(avg_val_metric if task_type == 'classification' else avg_val_loss)

        # Print epoch summary
        if task_type == 'classification':
            print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
            print(f"Train Acc: {avg_train_metric:.4f}, Val Acc: {avg_val_metric:.4f}")
        elif task_type == 'forecasting':
            print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
            print(f"Train PMSE: {avg_train_metric:.4f}, Val PMSE: {avg_val_metric:.4f}")
        else:  # both
            print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
            print(f"Train Acc: {avg_train_metric:.4f}, Val Acc: {avg_val_metric:.4f}")

        # Check for early stopping and model saving
        if ((task_type == 'classification' and avg_val_metric > best_val_metric) or
            (task_type != 'classification' and avg_val_metric < best_val_metric)):
            best_val_metric = avg_val_metric
            best_model_state = model.state_dict()
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1
            if epochs_without_improvement >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break

    # Load best model state
    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    return model, history, best_val_metric

In [9]:
dataset = STORMAIDataset()
data_list = ['../data/transformer_data/dev1.h5',
             '../data/transformer_data/dev3.h5',
             '../data/transformer_data/dev5.h5',
             '../data/transformer_data/dev6.h5',
             '../data/transformer_data/dev7.h5',
             '../data/transformer_data/dev8.h5',
            '../data/transformer_data/dev9.h5',
             '../data/transformer_data/data_large_v2.h5']
dataset.load_hdf5(data_list)
sample_size = len(dataset.y)
train_size = int(sample_size*0.9)
rnidx = np.random.permutation(sample_size)
train_idx = rnidx[:train_size]
val_idx = rnidx[train_size:]

In [10]:
dataset.x = torch.nan_to_num(dataset.x, 0)
dataset.y = torch.nan_to_num(dataset.y, 0)

In [None]:
X = dataset.x
Y = torch.log(dataset.y)
Y = (dataset.y-dataset.y[:,0].reshape(dataset.y.shape[0], 1))/dataset.y[:,0].reshape(dataset.y.shape[0], 1)
Y = torch.nan_to_num(Y, 1)

In [None]:
dataset_test = STORMAIDataset()
dataset_list = ['../data/transformer_data/train.h5',
                '../data/transformer_data/phase1p1.h5',
                '../data/transformer_data/phase1p2.h5']
dataset_test.load_hdf5(data_list)
sample_size = len(dataset_test.y)
dataset_test.x = torch.nan_to_num(dataset_test.x, 0)
dataset_test.y = torch.nan_to_num(dataset_test.y, 0)

In [47]:
# Configuration
in_channels = 19  # Univariate time series
patch_size = 48
embed_dim = 64
num_layers = 5
num_classes = None  # For classification
forecast_horizon = 432  # For forecasting

# Create model
model = TSLANet(
    in_channels=in_channels,
    patch_size=patch_size,
    embed_dim=embed_dim,
    num_layers=num_layers,
    num_classes=num_classes,
    forecast_horizon=forecast_horizon
)

model.load_state_dict(torch.load('tslanet_v5.pkl', weights_only=True))

<All keys matched successfully>

In [None]:
train_loader = STORMAILoader(X[train_idx], Y[train_idx], batch_size=512)
val_loader = STORMAILoader(X[val_idx], Y[val_idx], batch_size=512)
test_loader = STORMAILoader(dataset_test.x, dataset_test.y, batch_size=512)

In [None]:
# Train the model (classification + forecasting)
trained_model, history, best_metric = train_tslanet(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    task_type='forecasting',
    num_epochs=100,
    learning_rate=1e-4,
    weight_decay=1e-4,
    patience=10,
    device='cuda:0',
    pretrain_epochs=0  # Optional pretraining
)

print(f"Training complete. Best validation metric: {best_metric:.6f}")

In [None]:
device = 'cuda:0'

n = 0
rmse = 0
with torch.no_grad():
    for ibatch, batch in enumerate(val_loader):
        x = batch[0].to(device)
        y = batch[1].to(device)
        outputs = model(x)
        truth = y.cpu().numpy()
        pred = outputs['forecasting'].cpu().numpy()
        truth = np.exp(truth)
        pred = np.exp(pred)
        rmse += np.mean(((truth-pred)**2)*propagation_weights)

rmse = np.sqrt(rmse)
print(rmse)

3.582897862584793e-13


In [None]:
tarray = np.arange(432)
with torch.no_grad():
    for ibatch, batch in enumerate(test_loader):
        x = batch[0].to(device)
        y = batch[1].to(device)
        outputs = model(x)
        truth = y.cpu()
        pred = outputs['forecasting'].cpu()
        truth = np.exp(truth)
        pred = np.exp(pred)
        for j in range(pred.shape[0]):
            plt.clf()
            plt.figure(figsize=(9,3))
            plt.plot(tarray, pred[j], label='Preds')
            plt.plot(tarray, truth[j], label='True')
            plt.xlabel('Timestamp', fontsize=10)
            plt.ylabel('Orbit Mean Density (kg/m$^3$)')
            plt.legend()
            plt.show()
            plt.close()

            if j>2: break
        if ibatch>10: break

In [None]:
torch.save(model.state_dict(), 'tslanet.pkl')