In [None]:
#i "nuget: /home/nicholas/PointProcessDecoder/artifacts/package/release"

In [None]:
#r "nuget: PointProcessDecoder.Core, 0.4.0-preview"
#r "nuget: PointProcessDecoder.Cpu, 0.4.0-preview"
#r "nuget: PointProcessDecoder.Plot, 0.4.0-preview"
#r "nuget: PointProcessDecoder.Plot.Linux, 0.4.0-preview"
#r "nuget: PointProcessDecoder.Simulation, 0.4.0-preview"

In [None]:
using System;
using System.IO;

using TorchSharp;
using static TorchSharp.torch;

using PointProcessDecoder.Core;
using PointProcessDecoder.Plot;
using PointProcessDecoder.Simulation;
using PointProcessDecoder.Core.Estimation;
using PointProcessDecoder.Core.Transitions;
using PointProcessDecoder.Core.Encoder;
using PointProcessDecoder.Core.Decoder;
using PointProcessDecoder.Core.StateSpace;
using PointProcessDecoder.Core.Likelihood;

In [None]:
static Tensor ReadBinaryFile(
    string binary_file,
    Device device = null,
    ScalarType scalarType = ScalarType.Float32
)
{
    device ??= CPU;
    byte[] fileBytes = File.ReadAllBytes(binary_file);
    int elementCount = fileBytes.Length / sizeof(double);
    double[] doubleArray = new double[elementCount];
    Buffer.BlockCopy(fileBytes, 0, doubleArray, 0, fileBytes.Length);
    Tensor t = tensor(doubleArray, device: device, dtype: scalarType);
    return t;
}

static (Tensor, Tensor) InitializeRealData(
    string positionFile,
    string spikesFile,
    Device device,
    ScalarType scalarType = ScalarType.Float32
)
{
    var position = ReadBinaryFile(positionFile, device, scalarType);
    var spikes = ReadBinaryFile(spikesFile, device, scalarType);
    return (position, spikes);
}

Heatmap PlotPosteriorPrediction(
    Tensor posteriorPrediction,
    Tensor position2D,
    string title,
    double xMin,
    double xMax,
    double yMin,
    double yMax,
    string outputDirectory
)
{
    Heatmap plotPrediction = new(
        xMin,
        xMax,
        yMin,
        yMax,
        title: title
    );

    plotPrediction.OutputDirectory = Path.Combine(plotPrediction.OutputDirectory, outputDirectory);
    plotPrediction.Show(
        posteriorPrediction,
        position2D
    );
    plotPrediction.Save(png: true);
    
    return plotPrediction;
}

static ScatterPlot PlotStatePrediction(
    Tensor stateProbabilities,
    string title,
    string outputDirectory
)
{
    var lengthOfData = stateProbabilities.size(0);
    ScatterPlot plotStatePrediction = new(
        0, 
        lengthOfData, 
        -0.1, 
        1.1, 
        title: title
    );

    plotStatePrediction.OutputDirectory = Path.Combine(plotStatePrediction.OutputDirectory, outputDirectory);
    var time = arange(0, lengthOfData, 1);

    OxyPlot.OxyColor[] colors = [
        OxyPlot.OxyColors.Red,
        OxyPlot.OxyColors.Green,
        OxyPlot.OxyColors.Blue
    ];

    string[] labels = [
        "Stationary",
        "Continuous",
        "Fragmented"
    ];

    for (var i = 0; i < stateProbabilities.size(1); i++)
    {
        var statePrediction = stack([time, stateProbabilities[TensorIndex.Colon, i]], 1);
        plotStatePrediction.Show(
            statePrediction,
            color: colors[i],
            addLine: true,
            seriesLabel: labels[i]
        );
    }

    plotStatePrediction.Save(png: true);

    return plotStatePrediction;
}

In [None]:
string positionFile = "../data/position.bin";
string spikesFile = "../data/spike_counts.bin";

Device device = CPU;
ScalarType scalarType = ScalarType.Float32;

var (position, spikingData) = InitializeRealData(
    positionFile: positionFile,
    spikesFile: spikesFile,
    device: device,
    scalarType: scalarType
);

position = position.reshape(-1, 2);
spikingData = spikingData.reshape(position.shape[0], -1)
    .to_type(ScalarType.Int32);
var numNeurons = (int)spikingData.shape[1];

In [None]:
position = position[TensorIndex.Slice(0, 100000)];
spikingData = spikingData[TensorIndex.Slice(0, 100000)];

In [None]:
var pointProcessModel = new PointProcessModel(
    estimationMethod: EstimationMethod.KernelCompression,
    transitionsType: TransitionsType.RandomWalk,
    encoderType: EncoderType.SortedSpikes,
    decoderType: DecoderType.HybridStateSpaceClassifier,
    stateSpaceType: StateSpaceType.DiscreteUniform,
    likelihoodType: LikelihoodType.Poisson,
    minStateSpace: [0, 0],
    maxStateSpace: [120, 120],
    stepsStateSpace: [50, 50],
    observationBandwidth: [2, 2],
    stateSpaceDimensions: 2,
    nUnits: numNeurons,
    ignoreNoSpikes: false,
    distanceThreshold: 1.5,
    sigmaRandomWalk: 1,
    device: device,
    scalarType: scalarType
);

# Encode

In [None]:
double fractionTraining = 0.8;
int trainingBatchSize = 10000;
int testingBatchSize = 100;

In [None]:
int nTraining = (int)(position.shape[0] * fractionTraining);

for (int i = 0; i < nTraining + 1; i += trainingBatchSize)
{
    Console.WriteLine($"Training batch {i / trainingBatchSize + 1} of {nTraining / trainingBatchSize + 1}");
    var end = Math.Min(i + trainingBatchSize, nTraining);
    pointProcessModel.Encode(
        position[TensorIndex.Slice(i, end)],
        spikingData[TensorIndex.Slice(i, end)]
    );
}

# Decode

In [None]:
int nTesting = (int)position.shape[0] - nTraining;
for (int i = nTraining; i < nTraining + nTesting + 1; i += testingBatchSize)
{
    Console.WriteLine($"Testing batch {(i - nTraining) / testingBatchSize + 1} of {nTesting / testingBatchSize + 1}"); 
    var end = Math.Min(i + testingBatchSize, nTraining + nTesting);
    var posteriorDecoded = pointProcessModel.Decode(spikingData[TensorIndex.Slice(i, end)]);
    var prediction = new ClassifierData(pointProcessModel.StateSpace, posteriorDecoded);
    var posterior = prediction.DecoderData.Posterior.mean([0]);
    var positionSampled = position[TensorIndex.Slice(i, end)];
    var heatmap = PlotPosteriorPrediction(
        posterior, 
        positionSampled, 
        $"Prediction2D_{i}-{end}",
        0,
        120,
        0,
        120,
        "20250408_SortedUnits_43765200"
    );
    var scatter = PlotStatePrediction(
        prediction.StateProbabilities,
        $"StatePrediction_{i}-{end}",
        "20250408_SortedUnits_43765200"
    );

}

# Encode and Decode