In [None]:
#i "nuget: C:/Users/ncgui/Desktop/PointProcessDecoder/artifacts/package/release"

In [None]:
#r "nuget: PointProcessDecoder.Core, 0.1.4"
#r "nuget: PointProcessDecoder.Cpu, 0.1.4"
#r "nuget: PointProcessDecoder.Plot, 0.1.4"
#r "nuget: PointProcessDecoder.Simulation, 0.1.4"

In [3]:
using System;
using System.IO;

using TorchSharp;
using static TorchSharp.torch;

using PointProcessDecoder.Core;
using PointProcessDecoder.Plot;
using PointProcessDecoder.Simulation;
using PointProcessDecoder.Core.Estimation;
using PointProcessDecoder.Core.Transitions;
using PointProcessDecoder.Core.Encoder;
using PointProcessDecoder.Core.Decoder;
using PointProcessDecoder.Core.StateSpace;
using PointProcessDecoder.Core.Likelihood;

In [4]:
string EncodeToHtml(byte[] data)
{
    return $"<img src=\"data:image/png;base64,{Convert.ToBase64String(data)}\">";
}

static Tensor ReadBinaryFile(
    string binary_file,
    Device device = null,
    ScalarType scalarType = ScalarType.Float32
)
{
    device ??= CPU;
    byte[] fileBytes = File.ReadAllBytes(binary_file);
    int elementCount = fileBytes.Length / sizeof(double);
    double[] doubleArray = new double[elementCount];
    Buffer.BlockCopy(fileBytes, 0, doubleArray, 0, fileBytes.Length);
    Tensor t = tensor(doubleArray, device: device, dtype: scalarType);
    return t;
}

static (Tensor, Tensor) InitializeRealData(
    string positionFile,
    string spikesFile,
    Device device,
    ScalarType scalarType = ScalarType.Float32
)
{
    var position = ReadBinaryFile(positionFile, device, scalarType);
    var spikes = ReadBinaryFile(spikesFile, device, scalarType);
    return (position, spikes);
}

Heatmap PlotPrediction(
    Tensor prediction,
    Tensor position2D,
    string title,
    double xMin,
    double xMax,
    double yMin,
    double yMax,
    string outputDirectory
)
{
    Heatmap plotPrediction = new(
        xMin,
        xMax,
        yMin,
        yMax,
        title: title
    );

    plotPrediction.OutputDirectory = Path.Combine(plotPrediction.OutputDirectory, outputDirectory);
    plotPrediction.Show<float>(
        prediction,
        position2D
    );
    plotPrediction.Save(png: true);
    
    return plotPrediction;
}

In [5]:
string positionFile = "../data/positions_2D.bin";
string spikesFile = "../data/spike_counts.bin";

Device device = CPU;
ScalarType scalarType = ScalarType.Float32;

var (position, spikingData) = InitializeRealData(
    positionFile: positionFile,
    spikesFile: spikesFile,
    device: device,
    scalarType: scalarType
);

position = position.reshape(-1, 2);
spikingData = spikingData.reshape(position.shape[0], -1)
    .to_type(ScalarType.Bool);
var numNeurons = (int)spikingData.shape[1];

In [6]:
position = position[TensorIndex.Slice(0, 100000)];
spikingData = spikingData[TensorIndex.Slice(0, 100000)];

In [7]:
var pointProcessModel = new PointProcessModel(
    EstimationMethod.KernelCompression,
    TransitionsType.RandomWalk,
    EncoderType.SortedSpikeEncoder,
    DecoderType.StateSpaceDecoder,
    StateSpaceType.DiscreteUniformStateSpace,
    LikelihoodType.Poisson,
    minStateSpace: [0, 0],
    maxStateSpace: [120, 120],
    stepsStateSpace: [50, 50],
    observationBandwidth: [5, 5],
    stateSpaceDimensions: 2,
    nUnits: numNeurons,
    distanceThreshold: 1.5,
    sigmaRandomWalk: 1,
    device: device,
    scalarType: scalarType
);

# Encode

In [8]:
double fractionTraining = 0.8;
int trainingBatchSize = 1000;
int testingBatchSize = 60;

In [None]:
int nTraining = (int)(position.shape[0] * fractionTraining);

for (int i = 0; i < nTraining; i += trainingBatchSize)
{
    Console.WriteLine($"Training batch {i / trainingBatchSize + 1} of {nTraining / trainingBatchSize + 1}");
    var end = Math.Min(i + trainingBatchSize, nTraining);
    pointProcessModel.Encode(
        position[TensorIndex.Slice(i, end)],
        spikingData[TensorIndex.Slice(i, end)]
    );
}

# Decode

In [None]:
int nTesting = (int)position.shape[0] - nTraining;
for (int i = nTraining; i < nTraining + nTesting; i += testingBatchSize)
{
    Console.WriteLine($"Testing batch {(i - nTraining) / testingBatchSize + 1} of {nTesting / testingBatchSize + 1}"); 
    var end = Math.Min(i + testingBatchSize, nTraining + nTesting);
    var prediction = pointProcessModel.Decode(spikingData[TensorIndex.Slice(i, end)]);
    prediction = (prediction.sum(dim: 0) / prediction.sum()).reshape([50, 50]);
    var title = $"Prediction2D_{i}-{end}";
    var positionSampled = position[TensorIndex.Slice(i, end)];
    var heatmap = PlotPrediction(
        prediction, 
        positionSampled, 
        title,
        0,
        120,
        0,
        120,
        "20250125_SortedUnits_latest"
    );
}