# PyTorch dataset template

by Andr√©s Mu√±oz-Jaramillo

This notebook is meant to act as a template to create a custom dataset based on a downstream application (DS) index.

It requires an DS index file to be combined with a HelioFM index.  It also shows how to create a child database class based on HelioFM's database class so that all the code related to the input data is handled transparently, while the new code focuses exclusively in adding the DS information

This template uses a flare forecasting dataset as an example, casting the problem as an X-ray flux regression problem

In [None]:
import numpy as np
import sys
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import sunpy.visualization.colormaps as cm
import matplotlib.gridspec as gridspec

import yaml

# Append base path.  May need to be modified if the folder structure changes.
# It gives the notebook access to the wokshop_infrastructure folder.
sys.path.append("../../")
 
# Append Surya path. May need to be modified if the folder structure changes.
# It gives the notebook access to surya's release code.
sys.path.append("../../Surya")

from surya.utils.data import build_scalers  # Data scaling utilities for Surya stacks

## Download scalers
Surya input data needs to be scaled properly for the model to work and this cell downloads the scaling information.

- If the cell below fails, try running the provided shell script directly in the terminal.
- Sometimes the download may fail due to network or server issues‚Äîif that happens, simply re-run the script a few times until it completes successfully.

In [None]:
!sh download_scalers.sh

## Load configuration

Surya was designed to read a configuration file that defines many aspects of the model
including the data it uses we use this config file to set default values that do not
need to be modified, but also to define values specific to our downstream application

In [None]:
# Configuration paths - modify these if your files are in different locations
config_path = "./configs/config.yaml"

# Load configuration
print("üìã Loading configuration...")
try:
    config = yaml.safe_load(open(config_path, "r"))
    config["data"]["scalers"] = yaml.safe_load(open(config["data"]["scalers_path"], "r"))
    print("‚úÖ Configuration loaded successfully!")
except FileNotFoundError as e:
    print(f"‚ùå Error: {e}")
    print("Make sure config.yaml exists in your current directory")
    raise

scalers = build_scalers(info=config["data"]["scalers"])

## Define DS dataset

This child class takes as input all expected HelioFM parameters, plus additonal parameters relevant to the downstream application.  Here we focus in particular to the DS index and parameters necessary to combine it with the HelioFM index.

Another important component of creating a dataset class for your DS is normalization.  Here we use a log normalization on xray flux that will act as the output target.  Making log10(xray_flux) strictly positive and having 66% of its values between 0 and 1

Since we are going to use this dataset moving forward, it is better to develop it as script and not as a notebook.

In [None]:
from downstream_apps.template.datasets.template_dataset import FlareDSDataset

## Initialize class without Surya stacks

All the parameters that define a HelioFM dataset are contained within the test config file.  Scalers used to normalize HelioFm's input data are also necessary

**_Important: This first initalization is set not to return the full Surya stack so that we can verify that the target is returning what we expect. This is so that you can check things quickly without having to pull full stacks until you need them._**

_We do this by setting return_surya_stack=False_

Make sure to set return_surya_stack=True if you need the full surya stack at this stage, otherwise we'll reinitialize the dataset shortly.

**_Important:  In this notebook we sets max_number_of_samples=6 to potentially avoid going through the whole dataset as we explore it.  Keep in mind this for the future in case the database seems smaller than you expect_**


In [None]:
train_dataset = FlareDSDataset(
    #### All these lines are required by the parent HelioNetCDFDataset class
    index_path=config["data"]["train_data_path"],
    time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
    time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
    n_input_timestamps=config["model"]["time_embedding"]["time_dim"],
    rollout_steps=config["rollout_steps"],
    channels=config["data"]["channels"],
    drop_hmi_probability=config["drop_hmi_probablity"],
    use_latitude_in_learned_flow=config["use_latitude_in_learned_flow"],
    scalers=scalers,
    phase="train",
    s3_use_simplecache = True,
    s3_cache_dir= "/tmp/helio_s3_cache",    
    #### Put your donwnstream (DS) specific parameters below this line
    return_surya_stack=False,
    max_number_of_samples=6,
    ds_flare_index_path="./data/hek_flare_catalog.csv",
    ds_time_column="start_time",
    ds_time_tolerance = "4d",
    ds_match_direction = "forward"    
)

## Test length and structure

Now we can test that the database is properly initialized and returns what is expected. In th case of the template, there are 6294 flares that take place during the training split

In [None]:
train_dataset.__len__()

In [None]:
item = train_dataset.__getitem__(0)
item.keys()

Note that the dataset returns a single item and it's always the same item according to the index used:

In [None]:
train_dataset.__getitem__(0)

In [None]:
train_dataset.__getitem__(3)

In [None]:
train_dataset.__getitem__(3)

## Define dataloader

With a working dataset we can define a dataloader.  A dataloader is simply a wrapper around a dataset that includes a sampling strategy to turn your dataset into batches.   Once we request a batch, the dataloader will return a dictionary like the dataset, but data inside will have a batch dimension

In [None]:
data_loader = DataLoader(
                dataset=train_dataset,
                batch_size=5,
                num_workers=8
            )

In [None]:
batch = next(iter(data_loader))
batch.keys()

Now the batch will have more than one item

In [None]:
batch

Typically we set dataloaders to shufle the data so that the model sees data in different order during training.  This means that in general we don't want and don't expect the batch to return the same sequence of events.

In [None]:
data_loader = DataLoader(
                dataset=train_dataset,
                batch_size=5,
                shuffle=True,
            )

next(iter(data_loader))

## Initialize class with Surya stacks

Now we initalize the database to return full surya stacks to visualize them by setting _return_surya_stack=True_

In [None]:
train_dataset = FlareDSDataset(
    #### All these lines are required by the parent HelioNetCDFDataset class
    index_path=config["data"]["train_data_path"],
    time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
    time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
    n_input_timestamps=config["model"]["time_embedding"]["time_dim"],
    rollout_steps=config["rollout_steps"],
    channels=config["data"]["channels"],
    drop_hmi_probability=config["drop_hmi_probablity"],
    use_latitude_in_learned_flow=config["use_latitude_in_learned_flow"],
    scalers=scalers,
    phase="train",
    s3_use_simplecache = True,
    s3_cache_dir= "/tmp/helio_s3_cache",    
    #### Put your donwnstream (DS) specific parameters below this line
    return_surya_stack=True,
    max_number_of_samples=6,
    ds_flare_index_path="./data/hek_flare_catalog.csv",
    ds_time_column="start_time",
    ds_time_tolerance = "4d",
    ds_match_direction = "forward"    
)

In surya's convention an item contains the following elements:
- 'ts': input tensor.
- 'time_delta_input': minutes with respect to the present in the time dimension of the input tensor.
- 'forecast': target SDO stack.
- 'lead_time_delta': how many minutes into the future is the target stack with respect to the present.


In [None]:
item = train_dataset.__getitem__(0)
item.keys()

The shape of the input tensors has dimensions [C, T, H, W], where
- C: instrument channels.
- T: timestamps.
- H: Height.
- W: Width.

In [None]:
item['ts'].shape

## Plotting input stack

Before plotting, it is necessary to undo the z-score and logarithmic normalization.

In [None]:
unnormalized_ts = train_dataset.inverse_transform_data(item['ts'][:,0,...])
channel_order = config["data"]["channels"]
fig = plt.figure(figsize=np.array([4,4]), dpi=300)
gs = gridspec.GridSpec(4, 4, figure=fig, wspace=0, hspace=0)

limits = {}

for i in range(4):
    for j in range(4):
        n = i*4 + j
        if n < len(channel_order):

            ax = fig.add_subplot(gs[i,j])
            channel = channel_order[n]
            if 'hmi' not in channel:
                lim = np.percentile(unnormalized_ts[n,...][unnormalized_ts[n,...]!=0], 99)
                ax.imshow(unnormalized_ts[n,...], cmap=f'sdo{channel}', vmin=0, vmax=lim)
                font_color = 'w'

            else:
                font_color = 'k'
                if "_v" not in channel:
                    lim = 1000
                    ax.imshow(unnormalized_ts[n,...], cmap=f'hmimag', vmin = -lim, vmax=lim)
                else:
                    lim = 1000
                    ax.imshow(unnormalized_ts[n,...], cmap=f'coolwarm', vmin = -lim, vmax=lim)                    

            ax.text(0.01, 0.99, channel, transform=ax.transAxes, horizontalalignment='left', verticalalignment='top', color=font_color, fontsize=5)  
            ax.set_xticks([])
            ax.set_yticks([])             

## Define dataloader

Now the dataloader will return also a surya stack alongside our flaring data

In [None]:
data_loader = DataLoader(
                dataset=train_dataset,
                batch_size=2
            )

In [None]:
batch = next(iter(data_loader))
batch.keys()

And now it will also have a batch dimension of 2 (batch dimensions are typically the leftmost dimension in a tensor's shape)

In [None]:
batch['ts'].shape

## Conclusions

Once this notebook runs successfully you have a dataset and dataloaders done and ready to train a DS application. The next step continues in 1_baseline_template.ipynb which will involve putting together a simple baseline including metrics and a training loop that can be used to compare with Surya.