diff --git a/src/PointProcessDecoder.Core/Configuration/PointProcessModelConfiguration.cs b/src/PointProcessDecoder.Core/Configuration/PointProcessModelConfiguration.cs index 2325383..5454099 100644 --- a/src/PointProcessDecoder.Core/Configuration/PointProcessModelConfiguration.cs +++ b/src/PointProcessDecoder.Core/Configuration/PointProcessModelConfiguration.cs @@ -97,6 +97,11 @@ public class PointProcessModelConfiguration /// public double? SigmaRandomWalk { get; set; } + /// + /// The kernel limit of the model. + /// + public int? KernelLimit { get; set; } + /// /// The scalar type of the model. /// diff --git a/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs b/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs index 1149482..e228a59 100644 --- a/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs +++ b/src/PointProcessDecoder.Core/Encoder/ClusterlessMarkEncoder.cs @@ -34,7 +34,7 @@ public class ClusterlessMarkEncoder : ModelComponent, IEncoder private Tensor _markIntensities = empty(0); private Tensor _channelIntensities = empty(0); private Tensor _observationDensity = empty(0); - private Tensor[] _channelEstimates = []; + private readonly Tensor[] _channelEstimates = []; private Tensor _spikeCounts = empty(0); private Tensor _samples = empty(0); @@ -98,6 +98,7 @@ public ClusterlessMarkEncoder( _observationEstimation = new KernelDensity( bandwidth: observationBandwidth, dimensions: _stateSpace.Dimensions, + kernelLimit: kernelLimit, device: device, scalarType: scalarType ); @@ -106,7 +107,8 @@ public ClusterlessMarkEncoder( { _markEstimation[i] = new KernelDensity( bandwidth: bandwidth, - dimensions: jointDimensions, + dimensions: jointDimensions, + kernelLimit: kernelLimit, device: device, scalarType: scalarType ); diff --git a/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs b/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs index a205c00..d00e4f5 100644 --- a/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs +++ b/src/PointProcessDecoder.Core/Encoder/SortedSpikeEncoder.cs @@ -112,6 +112,7 @@ private static IEstimation GetEstimationMethod( EstimationMethod.KernelDensity => new KernelDensity( bandwidth: bandwidth, dimensions: dimensions, + kernelLimit: kernelLimit, device: device, scalarType: scalarType ), diff --git a/src/PointProcessDecoder.Core/PointProcessModel.cs b/src/PointProcessDecoder.Core/PointProcessModel.cs index fb92527..32e1ff8 100644 --- a/src/PointProcessDecoder.Core/PointProcessModel.cs +++ b/src/PointProcessDecoder.Core/PointProcessModel.cs @@ -181,6 +181,7 @@ public PointProcessModel( DistanceThreshold = distanceThreshold, IgnoreNoSpikes = ignoreNoSpikes, SigmaRandomWalk = sigmaRandomWalk, + KernelLimit = kernelLimit, ScalarType = scalarType }; } @@ -272,6 +273,7 @@ public override void Save(string basePath) distanceThreshold: configuration.DistanceThreshold, ignoreNoSpikes: configuration.IgnoreNoSpikes, sigmaRandomWalk: configuration.SigmaRandomWalk, + kernelLimit: configuration.KernelLimit, device: device, scalarType: configuration.ScalarType );