In [1]:
import torch
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import lightning.pytorch as pl
from lightning.pytorch.callbacks import RichProgressBar, Timer

# Add the prime_torch file to the system path so we can import it
import sys
sys.path.append("/glade/u/home/cobrien/prime/prime_lib/primesw")
from data import SWDataset, SWDataModule
from prime_torch import crps, SWRegressor

In [2]:
test_dataframe = pd.DataFrame([])
test_dataframe['time'] = pd.date_range(
    pd.to_datetime('20150902 00:00:00+0000'),
    pd.to_datetime('20250101 00:00:00+0000'),
    freq = '100s'
)
test_dataframe['a'] = np.arange(len(test_dataframe)) # Fake input
test_dataframe['b'] = test_dataframe['a'] * 2 # Fake target
test_dataframe['x_pos'] = np.arange(len(test_dataframe)) * 0.1
test_dataframe['y_pos'] = np.arange(len(test_dataframe)) * 0.2
test_dataframe['z_pos'] = np.arange(len(test_dataframe)) * 0.3
test_dataframe.to_hdf("~/data/prime/test.h5", key = 'lineartest')

In [3]:
bounds = ['20151003 00:00:00+0000', '20151004 00:00:00+0000']
test_dataframe.loc[
    (test_dataframe['time'] <= pd.to_datetime(bounds[1]))&
    (test_dataframe['time'] >= pd.to_datetime(bounds[0])), :
]

Unnamed: 0,time,a,b,x_pos,y_pos,z_pos
26784,2015-10-03 00:00:00+00:00,26784,53568,2678.4,5356.8,8035.2
26785,2015-10-03 00:01:40+00:00,26785,53570,2678.5,5357.0,8035.5
26786,2015-10-03 00:03:20+00:00,26786,53572,2678.6,5357.2,8035.8
26787,2015-10-03 00:05:00+00:00,26787,53574,2678.7,5357.4,8036.1
26788,2015-10-03 00:06:40+00:00,26788,53576,2678.8,5357.6,8036.4
...,...,...,...,...,...,...
27644,2015-10-03 23:53:20+00:00,27644,55288,2764.4,5528.8,8293.2
27645,2015-10-03 23:55:00+00:00,27645,55290,2764.5,5529.0,8293.5
27646,2015-10-03 23:56:40+00:00,27646,55292,2764.6,5529.2,8293.8
27647,2015-10-03 23:58:20+00:00,27647,55294,2764.7,5529.4,8294.1


In [4]:
trn_bounds = ['20151001 00:00:00+0000', '20151002 00:00:00+0000']
tst_bounds = ['20151002 00:00:00+0000', '20151003 00:00:00+0000']
val_bounds = ['20151003 00:00:00+0000', '20151004 00:00:00+0000']
datamodule = SWDataModule(
    target_features = ['b'],
    input_features = ['a'],
    position_features = ['x_pos', 'y_pos', 'z_pos'],
    cadence = '100s',
    window = 100,
    stride = 10,
    interp_frac = 0.1,
    trn_bounds = trn_bounds,
    val_bounds = val_bounds,
    tst_bounds = tst_bounds,
    datastore = "~/data/prime/test.h5",
    key = "lineartest",
)
datamodule.setup()

[32m2025-09-12 09:33:49.218[0m | [1mINFO    [0m | [36mdata[0m:[36msetup[0m:[36m229[0m - [1mTrain dataloader is ready. Dataset size: 756[0m
[32m2025-09-12 09:33:49.307[0m | [1mINFO    [0m | [36mdata[0m:[36msetup[0m:[36m247[0m - [1mValidation dataloader is ready. Dataset size: 756[0m
[32m2025-09-12 09:33:49.396[0m | [1mINFO    [0m | [36mdata[0m:[36msetup[0m:[36m265[0m - [1mTest dataloader is ready. Dataset size: 756[0m


In [5]:
model = SWRegressor(
    in_dim = 1,
    tar_dim = 1,
    pos_dim = 3,
    decoder_type = 'linear',
    encoder_type = 'rnn',
    lr_scheduler = 'cosine',
    decoder_hidden_layers = [4],
    encoder_hidden_dim = 4,
    pos_encoding_size=4,
    encoder_num_layers=1,
    loss='mae'
)



In [6]:
trainer = pl.Trainer(
    accelerator='cpu',
    max_epochs=1,
    callbacks = [Timer(), RichProgressBar()],
    # precision='16-true', #Lower the precision to not blow up memory
)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/glade/work/cobrien/conda-envs/pt212gpu_conda/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [7]:
trainer.fit(model=model, datamodule=datamodule)

[32m2025-09-12 09:33:50.362[0m | [1mINFO    [0m | [36mdata[0m:[36msetup[0m:[36m229[0m - [1mTrain dataloader is ready. Dataset size: 756[0m
[32m2025-09-12 09:33:50.449[0m | [1mINFO    [0m | [36mdata[0m:[36msetup[0m:[36m247[0m - [1mValidation dataloader is ready. Dataset size: 756[0m
[32m2025-09-12 09:33:50.536[0m | [1mINFO    [0m | [36mdata[0m:[36msetup[0m:[36m265[0m - [1mTest dataloader is ready. Dataset size: 756[0m


Output()

`Trainer.fit` stopped: `max_epochs=1` reached.


In [None]:
in_test = torch.rand((50,100,14))
tar_test = torch.rand((50,1))
pos_test = torch.rand((50,3))
out_test = model.forward(in_test, pos_test)
model.loss_fn(out_test, tar_test)

torch.Size([50, 100, 128])