-
Notifications
You must be signed in to change notification settings - Fork 213
Closed
Labels
Description
The LogCumSumExp operation available in PyTorch (torch.logcumsumexp) is currently missing in TorchSharp. This function provides a numerically stable way to compute the cumulative log-sum-exp along a specified dimension, which is essential for log-domain computations in probabilistic models and sequence processing. I found a reference implementation by agadetsk.
Python Code:
def logcumsumexp(x, dim):
# slow implementation, but ok for now
if (dim != -1) or (dim != x.ndimension() - 1):
x = x.transpose(dim, -1)
out = []
for i in range(1, x.size(-1) + 1):
out.append(torch.logsumexp(x[..., :i], dim=-1, keepdim=True))
out = torch.cat(out, dim=-1)
if (dim != -1) or (dim != x.ndimension() - 1):
out = out.transpose(-1, dim)
return out
I've reimplemented this functionality in C#, achieving equivalent behavior.
C# Code:
public static Tensor LogCumSumExp(Tensor x, long dim)
{
int ndim = (int)x.ndim;
int lastDim = ndim - 1;
bool needTranspose = (dim != -1) && (dim != lastDim);
if (needTranspose)
{
x = x.transpose((int)dim, lastDim);
}
int size = (int)x.size(lastDim);
List<Tensor> outputs = new List<Tensor>();
for (int i = 1; i <= size; i++)
{
Tensor slice = x.slice(lastDim, 0, i, 1);
Tensor lse = torch.logsumexp(slice, dim: lastDim, keepdim: true);
outputs.Add(lse);
}
Tensor[] outputArray = outputs.ToArray();
Tensor result = torch.cat(outputArray, dim: lastDim);
if (needTranspose)
{
result = result.transpose(lastDim, (int)dim);
}
return result;}
I would like this implementation to be considered for inclusion in TorchSharp.