# PyTorch Lightning baseline template

This notebook is meant to act as a template to train and use a simple regression model to define a baseline that can be compared with a DS application.

It focuses on the concept of defining a PyTorch model, a PyTorch lightning training loop and the deffinition of metrics of performance.

This notebook assumes familiarity with the concepts of datasets and dataloaders contained in the **_0_dataset_dataloader_template.ipynb_**

## Set your cuda visible device

**IMPORTANT:** Since we are sharing resources, please make sure that the cuda visible device you put here is the one assigned to your team and your machine.   

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

Here we initalize variables related to Weights and Biases, our online logging system to ensure they are user specific

In [None]:
# Make sure wandb logs are stored in a user-specific directory
# Set writable directories
os.environ["WANDB_DIR"] = "./wandb/wandb_logs"
os.environ["WANDB_CACHE_DIR"] = "./wandb/wandb_cache"
os.environ["WANDB_CONFIG_DIR"] = "./wandb/wandb_config"
# Optional:
os.environ["TMPDIR"] = "./wandb/wandb_tmp"

# Ensure directories exist (optional, wandb usually creates them)
os.makedirs(os.environ["WANDB_DIR"], exist_ok=True)
os.makedirs(os.environ["WANDB_CACHE_DIR"], exist_ok=True)
os.makedirs(os.environ["WANDB_CONFIG_DIR"], exist_ok=True)
os.makedirs(os.environ["TMPDIR"], exist_ok=True)

In [None]:

import sys
from torch.utils.data import DataLoader

import torch
import yaml

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, WandbLogger

# Append base path.  May need to be modified if the folder structure changes.
# It gives the notebook access to the wokshop_infrastructure folder.
sys.path.append("../../")
 
# Append Surya path. May need to be modified if the folder structure changes.
# It gives the notebook access to surya's release code.
sys.path.append("../../Surya")

from surya.utils.data import build_scalers  # Data scaling utilities for Surya stacks

torch.set_float32_matmul_precision('medium')



## Download scalers
Surya input data needs to be scaled properly for the model to work and this cell downloads the scaling information.

- If the cell below fails, try running the provided shell script directly in the terminal.
- Sometimes the download may fail due to network or server issues‚Äîif that happens, simply re-run the script a few times until it completes successfully.

In [None]:
!sh download_scalers.sh

## Load configuration

Surya was designed to read a configuration file that defines many aspects of the model
including the data it uses we use this config file to set default values that do not
need to be modified, but also to define values specific to our downstream application

In [None]:
# Configuration paths - modify these if your files are in different locations
config_path = "./configs/config.yaml"

# Load configuration
print("üìã Loading configuration...")
try:
    config = yaml.safe_load(open(config_path, "r"))
    config["data"]["scalers"] = yaml.safe_load(open(config["data"]["scalers_path"], "r"))
    print("‚úÖ Configuration loaded successfully!")
except FileNotFoundError as e:
    print(f"‚ùå Error: {e}")
    print("Make sure config.yaml exists in your current directory")
    raise

scalers = build_scalers(info=config["data"]["scalers"])

## Define Downstream (DS) datasets

This child class takes as input all expected HelioFM parameters, plus additonal parameters relevant to the downstream application.  Here we focus in particular to the DS index and parameters necessary to combine it with the HelioFM index.

Another important component of creating a dataset class for your DS is normalization.  Here we use a log normalization on xray flux that will act as the output target.  Making log10(xray_flux) strictly positive and having 66% of its values between 0 and 1

In this case we will define both a training and a validation dataset using the indices pointed at in the config

**_Important:  In this notebook we sets max_number_of_samples=6 to potentially avoid going through the whole dataset as we explore it.  Keep in mind this for the future in case the database seems smaller than you expect_**


In [None]:
from downstream_apps.template.datasets.template_dataset import FlareDSDataset

In [None]:
train_dataset = FlareDSDataset(
    #### All these lines are required by the parent HelioNetCDFDataset class
    index_path=config["data"]["train_data_path"],
    time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
    time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
    n_input_timestamps=config["model"]["time_embedding"]["time_dim"],
    rollout_steps=config["rollout_steps"],
    channels=config["data"]["channels"],
    drop_hmi_probability=config["drop_hmi_probablity"],
    use_latitude_in_learned_flow=config["use_latitude_in_learned_flow"],
    scalers=scalers,
    phase="train",
    s3_use_simplecache = True,
    s3_cache_dir= "/tmp/helio_s3_cache",    
    #### Put your donwnstream (DS) specific parameters below this line
    return_surya_stack=True,
    max_number_of_samples=6,
    ds_flare_index_path="./data/hek_flare_catalog.csv",
    ds_time_column="start_time",
    ds_time_tolerance = "4d",
    ds_match_direction = "forward"    
)

# The Validation dataset changes the index we read
val_dataset = FlareDSDataset(
    #### All these lines are required by the parent HelioNetCDFDataset class
    index_path=config["data"]["valid_data_path"],  #<---------------- different index path
    time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
    time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
    n_input_timestamps=config["model"]["time_embedding"]["time_dim"],
    rollout_steps=config["rollout_steps"],
    channels=config["data"]["channels"],
    drop_hmi_probability=config["drop_hmi_probablity"],
    use_latitude_in_learned_flow=config["use_latitude_in_learned_flow"],
    scalers=scalers,
    s3_use_simplecache = True,
    s3_cache_dir= "/tmp/helio_s3_cache",    
    #### Put your donwnstream (DS) specific parameters below this line
    return_surya_stack=True,
    max_number_of_samples=6,
    ds_flare_index_path="./data/hek_flare_catalog.csv",
    ds_time_column="start_time",
    ds_time_tolerance = "4d",
    ds_match_direction = "forward"    
)

We also intialize separate training and validation dataloaders.   Since we are working in a shared environment.  Using multiprocessing_context="spawn" helps avoid lockups.

In [None]:
batch_size = 2

train_data_loader = DataLoader(
                dataset=train_dataset,
                batch_size=batch_size,
                shuffle=True,
                num_workers=4,
                multiprocessing_context="spawn",
                persistent_workers=True,
                pin_memory=True,
            )

val_data_loader = DataLoader(
                dataset=val_dataset,
                batch_size=batch_size,
                num_workers=4,
                multiprocessing_context="spawn",
                persistent_workers=True,
                pin_memory=True,
            )

## Define simple baseline model

Defining a simple baseline is important to understand what value is bringing the AI model to the problem.  

It is always very good to have a very simple baseline model.  Ideally one that cannot overfit the data.  This is a very good way of really measuring the value added of complex models.   Classical machine learning excels here:

- Regressions and logistic regressions.
- Climatological averages.
- Persistance.
- Simple transformations.

Simple models avoid excesively optimistic assessments of the capatiblities of a complex models and for many problems are actually remarkably hard to beat.

In this example we define a simple regression acting on the intensity of each channel.  Note that we invert the normalization to deal with strictly positive quantities.  As with the dataset we will be importing the model from a module so that we can use it within training scripts later on.

In [None]:
from downstream_apps.template.models.simple_baseline import RegressionFlareModel

We can now test that this model manipulates a batch as expected and returns an estimate of flare intensity

Note that the simple regression model definition requires knowing the number of channels and timesteps so here we pull that information from the configuration intializing the model.

In [None]:
n_input_timestamps = config["model"]["time_embedding"]["time_dim"]
n_channels = len(config["data"]["channels"])

model = RegressionFlareModel(n_input_timestamps*n_channels, config["data"]["channels"], scalers)

Now we can pass the input stack 'ts' to the model to transform it into our regression output.   Note that since this model has not been trained and was initalized randomly.  The output here has no real meaning.  It only acts as a test that our model forward doesn't have dimension problems.

Dimension problemns are the dominant source of error in this kind of work.

Note that our output has now the size of our batch.

In [None]:
batch = next(iter(train_data_loader))
output = model.forward(batch['ts'])[:,0]  # Get rid of singleton dimension
output

## Define your metrics

Metrics are a very important part of training AI models.   They provide your models with the quantitification of error, which in turn shifts the weights towards better pefrorming models.  They also provide a way for you to monitor performance, identify overfitting, and quantify value added. 

We now initialize the metrics class which allows you to control what metrics do you want to use as "loss" (i.e. the metrics that backpropagate through your model) and which ones for monitoring performance.  As with other components, this takes the form of a loaded module that can be later use in a training script

In [None]:
from downstream_apps.template.metrics.template_metrics import FlareMetrics

In [None]:
train_loss_metrics = FlareMetrics("train_loss")
train_evaluation_metrics = FlareMetrics("train_metrics")
validation_evaluation_metrics = FlareMetrics("val_metrics")

Now they can be evaluated in our model's output and our ground truth.   First the loss that actually will backpropagate, in this case Mean Squared Errror

In [None]:
train_loss_metrics(output, batch["forecast"])

Then a training evaluation that will not backpropagate and inform our model, but that we can keep an eye on. Note that reporting lots of metrics during training will slow the training process.  I'm including it her as an example, but oftentimes is better to put the diagnostics only in the validation evaluation metrics.

Here we are caclulating the Root Relative Squared Error https://lightning.ai/docs/torchmetrics/stable/regression/rse.html 

A value below one means the prediction is better than predicting the average.  It is unlikely that this metric will be lower than one with a randomly initialized model

In [None]:
train_evaluation_metrics(output, batch["forecast"])

In the validation evaluation metrics we report both MSE and RRSE

In [None]:
validation_evaluation_metrics(output, batch["forecast"])

## Define your PyTorch ligthning module

In this workshop we will use PyTorch lightning to train our models.  PyTorch lighting reduces the amount of code required to implement a training loop in comparison to PyTorch (at the expense of control and versatility).  

Opening the FlareLightningModule shows a simple Lightning model implementation.  It consists of:

- An initialization of the class (metrics, model, and learning rate).
- The forward code that runs evaluation of the model.
- Training and validation steps.
- Configuration of optimizers.

In [None]:
from downstream_apps.template.lightning_modules.pl_simple_baseline import FlareLightningModule

## Set your global seeds

Since training AI models generally uses stochastic gradient descent, it is a good idea to fix your random seeds so that your training exercise is reproducible.    

In [None]:
pl.seed_everything(42, workers=True)

## Intialize Lightning module

Now we properly initalize the Lightning module to enable training, including passing the dictionary of metrics

In [None]:
metrics = {'train_loss': train_loss_metrics,
           'train_metrics': train_evaluation_metrics,
           'val_metrics': validation_evaluation_metrics}

learning_rate = 1e-3
lit_model = FlareLightningModule(model, metrics, lr=learning_rate, batch_size=batch_size)

## Logging

In order to properly compare experiments against each other, it is very useful to log evaluation metrics in a place where they can be compared against other training runs.  In this workshop we will use Weights and Biases (WandB). 

The first time you run WandB in a machine it will ask you to login to WandB.  You should have received an invitation to our project.  In order to login you must:

- Select option 2 (existing account).   In VScode the dialog opens a box at the top of your screen.
- Click on get API Key (this will open a browser).
- Generate API Key.
- Paste it in the dialog box at the top of your VSCode

In [None]:
project_name = "template_flare_regression"
run_name = "baseline_experiment_1"

wandb_logger = WandbLogger(
    entity="surya_handson",
    project=project_name,
    name=run_name,
    log_model=False,
    save_dir="./wandb/wandb_tmp",
)

csv_logger = CSVLogger("runs", name="simple_flare")

## Initialize trainer

With the loggers done, now the trainer needs to be defined.  The trainer defines several properties of your training run. Here we define:

- The max number of epochs (one epoch represents your model seeing your entire training dataset).
- Define where the training run will take place (auto uses the GPU if possible, if not, CPU).
- The loggers.
- The callbacks (here we save the model with the lowest validation loss).
- Logging frequency (because we are working with a small dataset it needs to be small).

In [None]:
max_epochs = 2

# -------------------------------------------------------------------------
# Trainer
# -------------------------------------------------------------------------
trainer = pl.Trainer(
    max_epochs=max_epochs,
    accelerator="auto",
    devices="auto",
    logger=[wandb_logger, csv_logger],
    callbacks=[
        ModelCheckpoint(
            monitor="val_loss",
            mode="min",
            save_top_k=1,
        )
    ],
    log_every_n_steps=2,
)

## Fit the model

Finally we fit the model.  We pass the Lighting module, and our dataloaders.

In [None]:
trainer.fit(lit_model, train_data_loader, val_data_loader)

## Conclusion

With this we have now integrated our dataset, dataloaders, metrics, and baseline into an end-2-end training loop.  The next step is to substitute the simple model with Surya.