Skip to content

v1.4.0

Choose a tag to compare

@github-actions github-actions released this 13 Mar 16:52

Highlights:

  • new kd.cli tool

  • py:: flag syntax for konfig

  • expanded kd.ktyping with @typechecked context managers / dataclass /
    generator support

  • meta-configs with @dataclasses.dataclass

  • new Nnx wrapper API,

  • many quality-of-life improvements across data, evals, metrics, and checkpointing.

  • kd.cli:

    • [New] kd.cli: New kauldron CLI tool — a :trainer_cli binary
      automatically created by kauldron_binary, using noun-verb command style
      (e.g. kd data element_spec), with each command mirrored as a Python
      function.
  • kd.konfig:

    • [New] py:: flag value parsing — specify Python objects directly from CLI
      flags (e.g. --cfg.xxx="py::my_module.MyObject(x=1)"), with Lark grammar
      and alias resolution.
    • [New] Meta-configs: @dataclasses.dataclass-based config declaration with
      __args__ CLI overrides and lazy config building.
    • [New] konfig.export(): serialize Python objects (dataclasses, arrays,
      dicts) to a dict representation.
    • [New] unfreeze() function for unfreezing ImmutableDicts.
    • [Extended] Support (named) tuples as dictionary keys in serialization.
    • [Extended] DEFINE_config_file accepts a required argument.
    • [Extended] konfig.resolve highlights where the original ConfigDict was
      created in tracebacks.
    • [Extended] Better error messages for config resolution failures and
      FieldReference with path tracking.
    • [Extended] Allow konfig.restricted() without specifying type.
    • [Changed] Always use literal evals when parsing flags — --cfg.xxx=None is
      now None rather than 'None'.
    • [Changed] Two-stage resolution is now the default for DEFINE_config_file.
    • [Changed] Deprecated konfig property now raises an error.
    • [Fix] Fix JSON args parsing, list arguments from CLI, unfreeze bug,
      dynamic resolve trigger, temporary_imports() thread-safety, BaseConfig
      hash, and resolve errors.
  • kd.ktyping:

    • [New] ArraySpec, ElementSpec, PRNGKeyLike types.
    • [New] kt.isinstance function for bool-returning type checks.
    • [New] Basic PyTree[T] annotation with runtime checking and path-aware
      errors.
    • [New] PyTree structure specs.
    • [New] Per-module config system.
    • [New] Warnings when mixing ktyping and jaxtyping.
    • [Extended] @typechecked now supports: context managers (with typechecked():), nested context managers, dataclasses, generator functions,
      methods / class methods / static methods.
    • [Extended] Improved shape inference for binary operations (e.g.
      Array["a+1"]).
    • [Extended] TensorFlow and XArray type support.
    • [Breaking] Rename get_shape()shape() and kt.dimskt.dim.
    • [Fix] Fix shape checking with TF Tensors, PRNGKey dtype for new-style JAX
      keys, Scalar type checking, array type union checking, broadcastable dims,
      typeguard 4.5.0 compatibility.
  • kd.nn:

    • [Changed] New Nnx wrapper API — natively compatible with kontext keys,
      supports catching intermediates.
    • [New] Nnx wrapper documentation.
  • kd.data:

    • [New] LazyBagDataSource for lazy loading of bag data.
    • [New] SelectFromDatasets for dataset mixtures with user-defined selection.
    • [New] shard_by_process to control dataset sharding behavior.
    • [New] AddBias transform.
    • [New] Random transforms for PyGrain pipelines.
    • [New] Padding batches with batch_drop_remainder='pad'.
    • [Extended] CenterCrop supports nD arrays.
    • [Extended] Resize supports min/max size targets.
    • [Extended] RepeatFrames works with both TF and NumPy/JAX arrays.
    • [Changed] Default Resize method for float inputs is now bilinear for
      JAX/NumPy (remains area for TF).
    • [Fix] Fix ElementWiseRandomTransform, grain.shuffle seed range,
      unknown-length datasets, Tfds.decoders with ImmutableDict, Resize
      device transfer, element_spec global vs device-local, filter transform
      type checking, walrus operator breaking TF autograph.
  • kd.kontext:

    • [Extended] set_by_path returns the list of concrete modified paths when
      using glob patterns (**, *).
    • [Fix] Fix kontext.imports() errors in docs and CONFIG_IMPORT
      placeholder for Colab.
  • kd.train:

    • [New] NoOpTrainStep for use cases skipping training.
    • [New] checkify support on TrainStep.init and Evaluator.evaluate.
    • [New] Expose KDMetricWriter, Orchestrator, DirectoryBuilder as public
      APIs for subclassing.
    • [New] konfig_freeze option to skip immutabledict conversion.
    • [Extended] MultiTrainStep subupdates for better logging.
    • [Extended] Device-to-host transfer for checkify error checking.
    • [Breaking] Rename ShardingStrategy.dsShardingStrategy.batch.
    • [Removed] Deprecate CollectingState.
    • [Fix] Fix sweeps bug with default config_args, partial_updates with
      integer keys, transfer_guard with jax_debug_nans, MultiTrainStep
      hashability, ml_python+adhoc error, FSDPSharding type annotation.
  • kd.evals:

    • [New] CheckpointedEvaluator for resumable evaluations.
    • [New] Skip initial step 0 option.
    • [New] NoopExporter.
    • [New] eval_step added to Evaluator.
    • [Extended] Allow skipping checkpointing in TrainEvaluator.
    • [Changed] NoOpCheckpointer is default for SamplingEvaluator.
    • [Changed] _ConcatContainer speeds up concat_field aggregation.
    • [Fix] Fix non unresponsive with custom dataset in eval, duplicated
      job_group in eval_only.
  • kd.metrics:

    • [New] finalize method for metric states.
    • [Extended] Support predicted labels (not just logits) in Accuracy.
    • [Extended] min_field, max_field for AutoState.
    • [Extended] Pytree support for auto_state.sum_field.
    • [Extended] Better error reporting for merging / finalizing / computing.
    • [Fix] Fix one-hot class count in segmentation metrics, finalize() bugs,
      CollectingState.merge performance.
  • kd.summaries / kd.vizual:

    • [New] Confusion matrix summary.
    • [Extended] ShowSegmentations: palette, edge, and hard options.
    • [Extended] ShowImages: cmap option.
    • [Extended] ImageGrid convenience method.
    • [Fix] Fix ShowDifferenceImages type-check and JAX/numpy mismatch,
      ShowImages RGB output with NaN values, bfloat16 support, integer arrays
      in ShowSegmentations.
  • kd.optim:

    • [New] ema_weights wrapper for EMA weight tracking.
    • [Fix] Fix debias logic in ema_params.
  • kd.ckpts:

    • [New] Custom Orbax preservation policy support.
    • [Removed] Remove deprecated AbstractPartialLoader alias.
    • [Fix] Fix EMA params loading for frozen params, checkpoint loading, snapshot
      directory race conditions, named tuple compatibility in parameter paths.
  • kd.contrib:

    • [New] NpzWriter: metric writer saving array summaries to .npz files.
    • [New] TreeUnflattenForKey PyGrain transform.
    • [New] GifVideoWriter and ShowVideosAsGif for GIF video summaries.
    • [New] NNX-to-Linen wrapper linen_from_nnx().
    • [New] Model exporter for JAX export.
    • [New] Online Mean+Covariance estimation state, merge_field in auto-state.
    • [Extended] concat_field works with pytrees.
  • kd.contrib.millstone:

    • [New] New doc.
    • [Extended] Custom Borg runtime, eval dataset support, troubleshooting guide.
    • [Removed] Delete deprecated Millstone API.
    • [Fix] Fix Pathways server termination during eval.
  • kd.xm:

    • [New] jax_log_compiles configurable via xp.debug.jax_log_compiles.
    • [Extended] Launch configargs support.
    • [Fix] Fix cuda_compress flag for non-GPU builds, duplicated job_group.
  • kd.random:

    • [Changed] Move truncation of as_seed() to uint32 inside as_seed().