In [None]:
%matplotlib widget
from torch import nn
from astrotime.encoders.spectral import SpectralProjection, embedding_space
from astrotime.models.cnn.cnn_baseline import get_model_from_cfg
from astrotime.plot.analysis import RawDatasetPlot
from astrotime.loaders.sinusoid import SinusoidElementLoader
from astrotime.plot.analysis import EvaluatorPlot
from astrotime.config.context import astrotime_initialize
from astrotime.plot.base import SignalPlotFigure
from astrotime.trainers.iterative_trainer import IterativeTrainer
import torch
from astrotime.util.series import TSet
from hydra import initialize, compose

version = "sinusoid_period"
mtype = "cnn"

In [None]:
overrides = [ 'platform.gpu=-1', 'data.batch_size=1' ]
initialize(version_base=None, config_path="../../../config")
cfg = compose( config_name=version, overrides=overrides )
device: torch.device = astrotime_initialize(cfg,version+".plot")

Documentation


In [None]:
data_loader = SinusoidElementLoader( cfg.data, TSet.Train, use_batches=False )
dplot = RawDatasetPlot( f"{version}: Lightcurves", data_loader )

embedding_space_array, embedding_space_tensor = embedding_space(cfg.transform, device)
embedding = SpectralProjection( cfg.transform, embedding_space_tensor, device )
model: nn.Module = get_model_from_cfg( cfg,  embedding ).to(device)

evaluator = IterativeTrainer( cfg, device, data_loader, model, embedding )
evaluator.init_evel(version)
wplot = EvaluatorPlot( f"Lightcurve Period Analysis, model=spectral_{mtype}", evaluator )

fig = SignalPlotFigure( [dplot, wplot] )
fig.show()