Skip to content

Commit

Permalink
fix wrong parameter name in plotting function and change value tensor…
Browse files Browse the repository at this point in the history
… in pytorch sparse tensor to float type
  • Loading branch information
biphasic committed Jul 21, 2021
1 parent 59a731c commit 85b6968
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 16 deletions.
41 changes: 28 additions & 13 deletions tonic/functional/to_sparse_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ def get_indices_values(events, sensor_size, ordering, merge_polarities):
Parameters:
merge_polarities (bool): flag that decides whether to combine polarities into a single channel
or split them into separate channels. If True, the number of channels for
indices is 1, otherwise it's the number of different polarities. Regardless
indices is 1, otherwise it's the number of different polarities. Regardless
of this flag, all values assigned to indices will be 1, which signify a spike.
Returns:
sparse tensor in TxCxWxH format, where T is timesteps, C is the number of channels for each polarity,
and W and H are always the size of the sensor.
"""
"""
assert "x" and "t" and "p" in ordering
x_index = ordering.find("x")
t_index = ordering.find("t")
Expand All @@ -30,8 +30,9 @@ def get_indices_values(events, sensor_size, ordering, merge_polarities):
values = np.ones(events.shape[0])

# prevents polarities used as indices that are not 0
if len(np.unique(events[:, p_index])) == 1: merge_polarities = True

if len(np.unique(events[:, p_index])) == 1:
merge_polarities = True

if merge_polarities: # the indices need to start at 0
events[:, p_index] = 0
n_channels = 1
Expand All @@ -40,35 +41,49 @@ def get_indices_values(events, sensor_size, ordering, merge_polarities):
n_channels = len(np.unique(events[:, p_index]))

max_time = int(max(events[:, t_index]) + 1)
if "y" in ordering:

if "y" in ordering:
y_index = ordering.find("y")
indices = events[:, [t_index, p_index, x_index, y_index]]
else:
indices = events[:, [t_index, p_index, x_index]]

return indices, values, max_time, n_channels


def to_sparse_tensor_pytorch(events, sensor_size, ordering, merge_polarities):
try:
import torch
except ImportError:
raise ImportError('The sparse tensor transform needs PyTorch installed. Please install a stable version ' +
'of PyTorch or alternatively install Tonic with optional PyTorch dependencies.')
indices, values, max_time, n_channels = get_indices_values(events, sensor_size, ordering, merge_polarities)
raise ImportError(
"The sparse tensor transform needs PyTorch installed. Please install a"
" stable version "
+ "of PyTorch or alternatively install Tonic with optional PyTorch"
" dependencies."
)
indices, values, max_time, n_channels = get_indices_values(
events, sensor_size, ordering, merge_polarities
)
indices = torch.LongTensor(indices).T
values = torch.LongTensor(values)
values = torch.FloatTensor(values)
return torch.sparse.FloatTensor(
indices, values, torch.Size([max_time, n_channels, *sensor_size])
)


def to_sparse_tensor_tensorflow(events, sensor_size, ordering, merge_polarities):
indices, values, max_time, n_channels = get_indices_values(events, sensor_size, ordering, merge_polarities)
indices, values, max_time, n_channels = get_indices_values(
events, sensor_size, ordering, merge_polarities
)
try:
import tensorflow
except ImportError:
raise ImportError('The sparse tensor transform needs PyTorch installed. Please install a stable version ' +
'of PyTorch or alternatively install Tonic with optional PyTorch dependencies.')
raise ImportError(
"The sparse tensor transform needs PyTorch installed. Please install a"
" stable version "
+ "of PyTorch or alternatively install Tonic with optional PyTorch"
" dependencies."
)
return tensorflow.sparse.SparseTensor(
indices, values, torch.Size([max_time, n_channels, *sensor_size])
)
8 changes: 5 additions & 3 deletions tonic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ def plot_event_grid(events, ordering, axis_array=(1, 3), plot_frame_number=False
import matplotlib.pyplot as plt
except ImportError:
raise ImportError(
"Please install the matplotlib package to plot events. This is an optional dependency."
"Please install the matplotlib package to plot events. This is an optional"
" dependency."
)

events = events.squeeze()
events = np.array(events)
transform = transforms.Compose(
[transforms.ToVoxelGrid(num_time_bins=np.product(axis_array))]
[transforms.ToVoxelGrid(n_time_bins=np.product(axis_array))]
)
x_index = ordering.find("x")
y_index = ordering.find("y")
Expand Down Expand Up @@ -71,7 +72,8 @@ def pad_tensors(batch):

if not isinstance(batch[0][0], torch.Tensor):
print(
"tonic.utils.pad_tensors expects a PyTorch Tensor of events. Please use ToSparseTensor or similar transform to convert the events."
"tonic.utils.pad_tensors expects a PyTorch Tensor of events. Please use"
" ToSparseTensor or similar transform to convert the events."
)
return None, None
max_length = max([sample.size()[0] for sample, target in batch])
Expand Down

0 comments on commit 85b6968

Please sign in to comment.