In [9]:
import torch
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import lightning.pytorch as pl
from lightning.pytorch.callbacks import RichProgressBar, Timer

# Add the prime_torch file to the system path so we can import it
import sys
sys.path.append("/glade/u/home/cobrien/prime/prime_lib/primesw")
from data import SWDataset, SWDataModule
from prime_torch import crps, SWRegressor

DATAPATH = '~/data/'

In [None]:
# test_dataframe = pd.DataFrame([])
# test_dataframe['time'] = pd.date_range(
#     pd.to_datetime('20150902 00:00:00+0000'),
#     pd.to_datetime('20250101 00:00:00+0000'),
#     freq = '100s'
# )
# test_dataframe['a'] = np.arange(len(test_dataframe)) # Fake input
# test_dataframe['b'] = test_dataframe['a'] * 2 # Fake target
# test_dataframe['x_pos'] = np.arange(len(test_dataframe)) * 0.1
# test_dataframe['y_pos'] = np.arange(len(test_dataframe)) * 0.2
# test_dataframe['z_pos'] = np.arange(len(test_dataframe)) * 0.3
# test_dataframe.to_hdf("~/data/prime/test.h5", key = 'lineartest')

data = pd.read_hdf(DATAPATH + 'combined_data.h5', key = '1min_mms_wind')
# list(data.columns)

In [None]:
target_features = [
    'mms1_dis_bulkv_gse_fast_0', # V GSE X
    'mms1_dis_bulkv_gse_fast_1', # V GSE Y
    'mms1_dis_bulkv_gse_fast_2', # V GSE Z
    # 'mms1_dis_numberdensity_fast', # Ni
    # 'mms1_dis_temppara_fast', # Ti parallel to B
    # 'mms1_dis_tempperp_fast', # Ti perpendicular to B
    'mms1_des_numberdensity_fast', # Ne
    'mms1_fgm_b_gse_srvy_l2_0', # B GSE X
    'mms1_fgm_b_gse_srvy_l2_1', # B GSE Y
    'mms1_fgm_b_gse_srvy_l2_2', # B GSE Z
]
input_features = [
    'Np', # Ni
    'V_GSE_0', # V GSE X
    'V_GSE_1', # V GSE Y
    'V_GSE_2', # V GSE Z
    'THERMAL_SPD', # Vth
    'BGSE_0', # B GSE X
    'BGSE_1', # B GSE Y
    'BGSE_2', # B GSE Z
    'PGSE_0', # Wind Position GSE X
    'PGSE_1', # Wind Position GSE Y
    'PGSE_2', # Wind Position GSE Z
]
position_features = [
    'mms1_mec_r_gse_0', # MMS Position GSE X
    'mms1_mec_r_gse_1', # MMS Position GSE Y
    'mms1_mec_r_gse_2', # MMS Position GSE Z
]

trn_bounds = ['20151001 00:00:00+0000', '20151002 00:00:00+0000']
tst_bounds = ['20151002 00:00:00+0000', '20151003 00:00:00+0000']
val_bounds = ['20151003 00:00:00+0000', '20151004 00:00:00+0000']
datamodule = SWDataModule(
    target_features = target_features,
    input_features = input_features,
    position_features = position_features,
    region = 'solar wind'
    cadence = '100s',
    window = 100,
    stride = 10,
    interp_frac = 0.1,
    trn_bounds = trn_bounds,
    val_bounds = val_bounds,
    tst_bounds = tst_bounds,
    datastore = "~/data/combined_data.h5",
    key = "1min_mms_wind",
)
datamodule.setup()

In [8]:
datamodule.trn_ds.raw_data

Unnamed: 0,Epoch,probe,ratio_max_width,ratio_high_low,norm_Btot,small_energy_mean,large_energy_mean,temp_total,r_gse_x,r_gse_y,...,BGSE_0,BGSE_1,BGSE_2,PGSM_0,PGSM_1,PGSM_2,PGSE_0,PGSE_1,PGSE_2,stable
22819,2015-10-01 05:13:00+00:00,mms1,0.0,0.0,0.0,0.0,0.0,3637.4941,43061.871405,21763.442406,...,-0.341594,0.418592,-3.803144,260.803131,5.141964,15.564111,260.803131,-1.799515,16.292426,1.0
22820,2015-10-01 05:14:00+00:00,mms1,0.0,0.0,0.0,0.0,0.0,3566.7576,43135.473788,21903.821297,...,-0.228631,0.707085,-3.854215,260.803146,5.151292,15.561420,260.803146,-1.802009,16.292527,1.0
22821,2015-10-01 05:15:00+00:00,mms1,0.0,0.0,0.0,0.0,0.0,3435.2864,43208.623544,22043.967322,...,-0.404843,0.621456,-3.872930,260.803162,5.160608,15.558728,260.803162,-1.804501,16.292628,1.0
22822,2015-10-01 05:16:00+00:00,mms1,0.0,0.0,0.0,0.0,0.0,3482.3137,43281.143842,22183.788352,...,-0.551774,0.643673,-3.860786,260.803162,5.169910,15.556034,260.803162,-1.806993,16.292729,1.0
22823,2015-10-01 05:17:00+00:00,mms1,0.0,0.0,0.0,0.0,0.0,3556.3467,43353.217158,22323.377268,...,-0.553716,0.758813,-3.845825,260.803177,5.179204,15.553338,260.803177,-1.809485,16.292830,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
23530,2015-10-01 19:23:00+00:00,mms1,0.0,0.0,0.0,0.0,0.0,4451.7060,18301.807012,60707.123644,...,1.132437,11.481940,-0.905233,260.805695,1.028696,16.805570,260.805695,-3.912936,16.376032,1.0
23531,2015-10-01 19:24:00+00:00,mms1,0.0,0.0,0.0,0.0,0.0,6319.1465,18212.489010,60652.335301,...,0.787784,11.743395,-1.658357,260.805695,1.016980,16.806953,260.805695,-3.915417,16.376127,1.0
23532,2015-10-01 19:25:00+00:00,mms1,0.0,0.0,0.0,0.0,0.0,6764.8440,18123.084969,60597.265617,...,0.844871,11.864328,-1.840551,260.805695,1.005292,16.808329,260.805695,-3.917898,16.376223,1.0
23533,2015-10-01 19:26:00+00:00,mms1,0.0,0.0,0.0,0.0,0.0,6663.0464,18033.560499,60541.800894,...,0.874529,11.824957,-1.690531,260.805679,0.993642,16.809692,260.805679,-3.920378,16.376319,1.0


In [None]:
model = SWRegressor(
    in_dim = len(input_features),
    tar_dim = len(target_features),
    pos_dim = len(position_features),
    decoder_type = 'linear',
    encoder_type = 'rnn',
    lr_scheduler = 'cosine',
    decoder_hidden_layers = [4],
    encoder_hidden_dim = 4,
    pos_encoding_size=4,
    encoder_num_layers=1,
    loss='mae'
)

In [5]:
trainer = pl.Trainer(
    accelerator='cpu',
    max_epochs=1,
    callbacks = [Timer(), RichProgressBar()],
    # precision='16-true', #Lower the precision to not blow up memory
)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/glade/work/cobrien/conda-envs/pt212gpu_conda/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
/glade/work/cobrien/conda-envs/pt212gpu_conda/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip inst

In [6]:
trainer.fit(model=model, datamodule=datamodule)

[32m2025-09-15 06:13:04.561[0m | [1mINFO    [0m | [36mdata[0m:[36msetup[0m:[36m229[0m - [1mTrain dataloader is ready. Dataset size: 607[0m


[32m2025-09-15 06:13:04.678[0m | [1mINFO    [0m | [36mdata[0m:[36msetup[0m:[36m247[0m - [1mValidation dataloader is ready. Dataset size: 738[0m
[32m2025-09-15 06:13:04.795[0m | [1mINFO    [0m | [36mdata[0m:[36msetup[0m:[36m265[0m - [1mTest dataloader is ready. Dataset size: 738[0m


RuntimeError: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero.

In [None]:
in_test = torch.rand((50,100,14))
tar_test = torch.rand((50,1))
pos_test = torch.rand((50,3))
out_test = model.forward(in_test, pos_test)
model.loss_fn(out_test, tar_test)

torch.Size([50, 100, 128])