In [1]:
#i "nuget: C:/Users/ncgui/Desktop/PointProcessDecoder/artifacts/package/release"

In [2]:
#r "nuget: TorchSharp"
#r "nuget: PointProcessDecoder.Core"
#r "nuget: PointProcessDecoder.Cpu"
#r "nuget: PointProcessDecoder.Plot"
#r "nuget: PointProcessDecoder.Simulation"
#r "nuget: OxyPlot.Core, 2.1.2"

Loading extensions from `C:\Users\ncgui\.nuget\packages\skiasharp\2.88.8\interactive-extensions\dotnet\SkiaSharp.DotNet.Interactive.dll`

In [3]:
using System;
using System.IO;

using TorchSharp;
using static TorchSharp.torch;

using PointProcessDecoder.Core;
using PointProcessDecoder.Plot;
using PointProcessDecoder.Simulation;
using PointProcessDecoder.Core.Estimation;
using PointProcessDecoder.Core.Transitions;
using PointProcessDecoder.Core.Encoder;
using PointProcessDecoder.Core.Decoder;

In [4]:
string EncodeToHtml(byte[] data)
{
    return $"<img src=\"data:image/png;base64,{Convert.ToBase64String(data)}\">";
}

static Tensor ReadBinaryFile(
    string binary_file,
    Device device = null,
    ScalarType scalarType = ScalarType.Float32
)
{
    device ??= CPU;
    byte[] fileBytes = File.ReadAllBytes(binary_file);
    int elementCount = fileBytes.Length / sizeof(double);
    double[] doubleArray = new double[elementCount];
    Buffer.BlockCopy(fileBytes, 0, doubleArray, 0, fileBytes.Length);
    Tensor t = tensor(doubleArray, device: device, dtype: scalarType);
    return t;
}

static (Tensor, Tensor) InitializeRealData(
    string positionFile,
    string spikesFile,
    Device device,
    ScalarType scalarType = ScalarType.Float32
)
{
    var position = ReadBinaryFile(positionFile, device, scalarType);
    var spikes = ReadBinaryFile(spikesFile, device, scalarType);
    return (position, spikes);
}

In [5]:
Device device = CPU;
ScalarType scalarType = ScalarType.Float32;
double[] bandwidth = [2, 2];
int numDimensions = 2;
long[] evaluationSteps = [50, 50];
double xMin = 0.0;
double xMax = 120.0;
double yMin = 0.0;
double yMax = 120.0;
double[] sigma = [1, 1];
double fractionTraining = 0.8;
int trainingBatchSize = 10000;
int evaluationBatchSize = 10;
double distanceThreshold = 1.5;
var pointProcessModelDirectory = "PointProcessModelRandomWalkCompressionRealData2D";

In [6]:
string positionFile = "../data/positions_2D.bin";
string spikesFile = "../data/spike_times.bin";

var (position, spikingData) = InitializeRealData(
    positionFile: positionFile,
    spikesFile: spikesFile,
    device: device,
    scalarType: scalarType
);

var position2D = position.reshape(-1, 2);
spikingData = spikingData.reshape(position2D.shape[0], -1)
    .to_type(ScalarType.Bool);
var numNeurons = (int)spikingData.shape[1];

In [7]:
var pointProcessModel = new PointProcessModel(
    EstimationMethod.KernelCompression,
    TransitionsType.RandomWalk,
    EncoderType.SortedSpikeEncoder,
    DecoderType.SortedSpikeDecoder,
    [xMin, yMin],
    [xMax, yMax],
    evaluationSteps,
    bandwidth,
    latentDimensions: numDimensions,
    nUnits: numNeurons,
    device: device,
    distanceThreshold: distanceThreshold,
    sigmaRandomWalk: sigma
);

# Encode

In [8]:
int nTraining = (int)(position2D.shape[0] * fractionTraining);

for (int i = 0; i < nTraining; i += trainingBatchSize)
{
    Console.WriteLine($"Training batch {i / trainingBatchSize + 1} of {nTraining / trainingBatchSize + 1}");
    var end = Math.Min(i + trainingBatchSize, nTraining);
    pointProcessModel.Encode(
        position2D[TensorIndex.Slice(i, end)],
        spikingData[TensorIndex.Slice(i, end)]
    );
}

Training batch 1 of 37
Training batch 2 of 37
Training batch 3 of 37
Training batch 4 of 37
Training batch 5 of 37
Training batch 6 of 37
Training batch 7 of 37
Training batch 8 of 37
Training batch 9 of 37
Training batch 10 of 37
Training batch 11 of 37
Training batch 12 of 37
Training batch 13 of 37
Training batch 14 of 37
Training batch 15 of 37
Training batch 16 of 37
Training batch 17 of 37
Training batch 18 of 37
Training batch 19 of 37
Training batch 20 of 37
Training batch 21 of 37
Training batch 22 of 37
Training batch 23 of 37
Training batch 24 of 37
Training batch 25 of 37
Training batch 26 of 37
Training batch 27 of 37
Training batch 28 of 37
Training batch 29 of 37
Training batch 30 of 37
Training batch 31 of 37
Training batch 32 of 37
Training batch 33 of 37
Training batch 34 of 37
Training batch 35 of 37
Training batch 36 of 37
Training batch 37 of 37


# Decode

In [11]:
Heatmap PlotPrediction(
    Tensor prediction,
    Tensor position2D,
    string title
)
{
    Heatmap plotPrediction = new(
        xMin,
        xMax,
        yMin,
        yMax,
        title: title
    );

    plotPrediction.OutputDirectory = Path.Combine(plotPrediction.OutputDirectory, pointProcessModelDirectory);
    plotPrediction.Show<float>(
        prediction,
        position2D
    );
    plotPrediction.Save(png: true);
    
    return plotPrediction;
}

In [12]:
int nTesting = (int)position2D.shape[0] - nTraining;
for (int i = nTraining; i < nTraining + nTesting; i += evaluationBatchSize)
{
    Console.WriteLine($"Testing batch {(i - nTraining) / evaluationBatchSize + 1} of {nTesting / evaluationBatchSize + 1}"); 
    var end = Math.Min(i + evaluationBatchSize, nTraining + nTesting);
    var prediction = pointProcessModel.Decode(spikingData[TensorIndex.Slice(i, end)]);
    prediction = (prediction.sum(dim: 0) / prediction.sum()).reshape(evaluationSteps);
    var title = $"Prediction2D_{i}-{end}";
    var positionSampled = position2D[TensorIndex.Slice(i, end)];
    var heatmap = PlotPrediction(prediction, positionSampled, title);
}

Testing batch 1 of 9047
Testing batch 2 of 9047
Testing batch 3 of 9047
Testing batch 4 of 9047
Testing batch 5 of 9047
Testing batch 6 of 9047
Testing batch 7 of 9047
Testing batch 8 of 9047
Testing batch 9 of 9047
Testing batch 10 of 9047
Testing batch 11 of 9047
Testing batch 12 of 9047
Testing batch 13 of 9047
Testing batch 14 of 9047
Testing batch 15 of 9047
Testing batch 16 of 9047
Testing batch 17 of 9047
Testing batch 18 of 9047
Testing batch 19 of 9047
Testing batch 20 of 9047
Testing batch 21 of 9047
Testing batch 22 of 9047
Testing batch 23 of 9047
Testing batch 24 of 9047
Testing batch 25 of 9047
Testing batch 26 of 9047
Testing batch 27 of 9047
Testing batch 28 of 9047
Testing batch 29 of 9047
Testing batch 30 of 9047
Testing batch 31 of 9047
Testing batch 32 of 9047
Testing batch 33 of 9047
Testing batch 34 of 9047
Testing batch 35 of 9047
Testing batch 36 of 9047
Testing batch 37 of 9047
Testing batch 38 of 9047
Testing batch 39 of 9047
Testing batch 40 of 9047
Testing b