In [None]:
%matplotlib widget
from astrotime.loaders.MIT import MITElementLoader
from astrotime.encoders.spectral import SpectralProjection, embedding_space
from astrotime.models.cnn.cnn_baseline import get_model_from_cfg
from astrotime.plot.analysis import RawDatasetPlot
from astrotime.plot.analysis import ClassificationEvalPlot
from astrotime.config.context import astrotime_initialize
from astrotime.plot.base import SignalPlotFigure
from astrotime.trainers.octave_classification import OctaveClassificationTrainer
from torch import nn
from astrotime.util.series import TSet
import torch
from hydra import initialize, compose

version = "MIT_period_cnn.classification.octaves"

In [None]:
initialize(version_base=None, config_path="../../../config")
cfg = compose(config_name=version)
device: torch.device = astrotime_initialize(cfg, version+".oplot")

## Plot Documentation
- In the following plot, the black lines represent the target period, the green lines represent the period produced by the model, and the yellow line represents the product of the peakfinder algorithm.
- The vertical lines in the upper plot illustrate the sizes of the periods for two cases.
- Use the file/element sliders to chose different data elements.
- If you shift-right-click on a extrema of the upper plot, it will align the period markers to that location.

In [None]:
data_loader = MITElementLoader( cfg.data )
data_loader.init_epoch(TSet.Train)
dplot = RawDatasetPlot(f"{version}: Lightcurves", data_loader )

embedding_space_array, embedding_space_tensor = embedding_space(cfg.transform, device)
embedding = SpectralProjection( cfg.transform, embedding_space_tensor, device )
model: nn.Module = get_model_from_cfg( cfg,  embedding ).to(device)

evaluator = OctaveClassificationTrainer( cfg, device, data_loader, model, embedding )
evaluator.init_eval(version)
wplot = ClassificationEvalPlot(f"Lightcurve Period Analysis, model=spectral_cnn", evaluator )

fig = SignalPlotFigure([dplot,wplot])
fig.show()