Skip to content

Sample-adaptive regression collection + multi-output adaptive linear regressor#11

Merged
cboulay merged 16 commits into
devfrom
sample_adapt_regress
May 14, 2026
Merged

Sample-adaptive regression collection + multi-output adaptive linear regressor#11
cboulay merged 16 commits into
devfrom
sample_adapt_regress

Conversation

@cboulay
Copy link
Copy Markdown
Member

@cboulay cboulay commented May 14, 2026

Summary

Adds a new SampleAdaptRegressor collection that wires resampling, trigger-driven sequence-to-sequence sampling, optional time-lag windowing, and an adaptive linear regressor into a single reusable graph. Along the way the underlying adaptive regressor gains multi-output river support, a new SeqSeqSampler unit and a learn-side Flatten wrapper are introduced, and SelfSupervisedRegression/LRR get an automatic block-based channel clustering option plus an identity pass-through before fitting.

Changes

New

  • ezmsg.learn.collection.SampleAdaptRegressor (src/ezmsg/learn/collection/sample_adapt_regressor.py): collection composing ResampleUnit, SeqSeqSamplerUnit, optional Window + Flatten, and AdaptiveLinearRegressorUnit. Routes labels through resampling against the signal time base, drives training from triggers, and switches between direct inference and windowed/flattened inference based on decode_window_dur.
  • ezmsg.learn.process.seqseqsampler (src/ezmsg/learn/process/seqseqsampler.py): buffers continuous signal and value streams keyed on a shared time base, then emits per-trigger AxisArray samples whose attrs["trigger"].value carries the matching label window. Handles off-by-one mismatches by picking the best-aligned subwindow (bounded by MAX_ONE_TO_ONE_SAMPLE_MISMATCH) and drops/defers triggers that fall outside the buffered span. Includes max_buffer_dur cap on per-stream memory.
  • ezmsg.learn.process.flatten (src/ezmsg/learn/process/flatten.py): thin wrapper around ezmsg.sigproc.flatten.FlattenTransformer that detects the (win, time, ch) windowed-feature case and injects a structured lag CoordinateAxis (with integer lag and "t-i" label fields) into the inner sample dim before delegating, so the merged-axis struct in the output carries real lag metadata ("t-2/c0" style labels). Outside the lag case it delegates unchanged.

AdaptiveLinearRegressor

  • Multi-output river path: when the trigger target has multiple channels the transformer builds a per-label dict of river GLMs and routes learn_many/predict_many per output, preserving label order from the target axis.
  • Loads pickled multi-output models (dicts of river GLMs), applies l2/optimizer overrides across all sub-models, and deep-copies optimizers so each sub-model has independent state.
  • New _normalize_axis_label helper handles structured-dtype channel axes (extracting "label" sub-fields and unwrapping numpy scalars) so dataframe construction and prediction-template labels work with composed axes coming out of the new Flatten.
  • Inference now runs as a pass-through (returns empty AxisArray) until a model exists rather than requiring template to be set first, and builds a prediction template from the signal when no fit-time template is available.

SelfSupervisedRegression / LRR (src/ezmsg/learn/process/ssr.py)

  • New block_size setting on SelfSupervisedRegressionSettings: when channel_clusters is None, auto-generates contiguous block clusters of this size. Threaded through _validate_clusters, the SSR compute path, and LRRTransformer.
  • LRRTransformer._process no longer raises before fit — it lazily initialises an identity AffineTransformTransformer (respecting block_size/min_cluster_size) so signal flows unchanged until weights are learned.

Dependency bump and API updates

  • pyproject.toml: ezmsg>=3.9.0, ezmsg-baseproc>=1.7.0, ezmsg-sigproc>=2.23.0.
  • Drop zero_copy=True from RNNUnit, TorchSimpleUnit, TorchModelUnit, and TransformerUnit signal subscribers to match the new base interface.
  • tests/integration/conftest.py: NoiseSrc.OUTPUT_SIGNAL switched from ez.OutputStream to ez.OutputTopic.

Tests

  • New tests/unit/test_sample_adapt_regressor.py, tests/unit/test_seqseqsampler.py, tests/unit/test_flatten.py.
  • Expanded tests/unit/test_adaptive_linear_regressor.py covering multi-output river training/prediction, structured-axis label handling, and pre-fit pass-through.
  • Expanded tests/unit/test_ssr.py covering block_size clustering and LRR identity pass-through before fit.

@cboulay cboulay merged commit d99db11 into dev May 14, 2026
8 checks passed
@cboulay cboulay deleted the sample_adapt_regress branch May 14, 2026 23:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants