Ridge probe: balanced class weighting + configurable dtype#32
Merged
Conversation
Adds a `class_weight` parameter (default `"balanced"`) to the streaming ridge probe. When enabled, an extra label-only pass over the train loader computes sklearn-style per-class weights `w[c] = N / (n_classes * count[c])`, which are then applied to every sufficient-statistic accumulator (A, B, s_h, s_h2, s_y, N) so that the weighted-least-squares fit, weighted centering, and weighted standardization are all internally consistent. Regression silently ignores the parameter. Verified end-to-end on REVE × chbmit (97/3 imbalance): unweighted: test_balanced_accuracy = 0.5000 (chance) balanced: test_balanced_accuracy = 0.8594
Adds a `dtype: Literal["float32", "float64"]` parameter (default `"float64"`) threaded through `RidgeProbingTraining`, `StreamingRidgeProbeLearner`, and `_fit_streaming_ridge`. All previously hardcoded float64 accumulators, eigendecomposition tensors, and predict paths now honor this dtype. `"float64"` remains the recommended precision; `"float32"` exists for devices that don't support double, notably Apple MPS. To make MPS actually work, the eigendecomposition + Ws/biases construction (operating on `(D, D)` matrices ≤ `max_features`) is now run on CPU unconditionally — `torch.linalg.eigh` is not implemented on MPS, and these matrices are small enough that the CPU detour is free. The streaming backbone forward and statistics accumulation stay on the configured device. Verified REVE × chbmit on MPS with `class_weight="balanced"`, `dtype="float32"`: test_balanced_accuracy=0.8594 (matches CPU/float64), fit_time 651s (vs 1235s on CPU).
Contributor
Author
|
CC @tomMoral |
tomMoral
reviewed
Apr 28, 2026
tomMoral
reviewed
Apr 28, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
class_weighttoRidgeProbingTraining("balanced"orNone); default changed to"balanced"so imbalanced classification datasets are no longer collapsed to majority-class predictions.dtypetoRidgeProbingTraining("float32"or"float64", default"float64")."float32"enables running on devices without double support, notably Apple MPS.max_features × max_featuresso the detour is free) —torch.linalg.eighis not implemented on MPS. Streaming forward + accumulation stay on the configured device.Validation: REVE × chbmit (97/3 imbalance)
None(previous default)"balanced"(new default)"balanced"Test plan
pytest tests/test_ridge_probe.py(14/14 pass, including newtest_balanced_class_weight_recovers_minorityandtest_balanced_class_weight_noop_when_classes_balanced)pytest tests/test_default_configs.py(44/44 pass)class_weightsettingsclass_weight="balanced"