Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric Cox committed Jan 15, 2024
1 parent 3e7a40f commit 5afcdf0
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 1 deletion.
7 changes: 7 additions & 0 deletions nobuco/node_converters/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@
from nobuco.converters.tensor import dim_pytorch2keras


@converter(torch.nn.functional.logsigmoid, channel_ordering_strategy=ChannelOrderingStrategy.MINIMUM_TRANSPOSITIONS)
def logsigmoid(input: torch.Tensor):
def func(input):
return tf.math.log_sigmoid(input)
return func


def hard_sigmoid_pytorch_compatible(x):
x = tf.clip_by_value(x/6 + 1/2, clip_value_min=0, clip_value_max=1)
return x
Expand Down
5 changes: 5 additions & 0 deletions nobuco/node_converters/boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def func(input, other):
return tf.logical_or(input, other)
return func

@converter(torch.Tensor.__ior__, channel_ordering_strategy=ChannelOrderingStrategy.MINIMUM_TRANSPOSITIONS)
def converter_t_ior(self, other):
def func(self, other):
return tf.logical_or(self, other)
return func

@converter(torch.Tensor.__invert__, channel_ordering_strategy=ChannelOrderingStrategy.MINIMUM_TRANSPOSITIONS)
def converter_invert(input: Tensor):
Expand Down
13 changes: 13 additions & 0 deletions nobuco/node_converters/boolean_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,16 @@ def func(condition, input=None, other=None):
else:
return tf.where(condition)[..., 0]
return func

# @converter(torch.where, channel_ordering_strategy=ChannelOrderingStrategy.MINIMUM_TRANSPOSITIONS)
# def converter_where(condition, x=None, y=None):
# def func(condition, x=None, y=None):
# if x is not None and y is not None:
# # Element-wise selection: torch.where(condition, x, y)
# return tf.where(condition, x, y)
# else:
# # Condition-based indexing: torch.where(condition)
# # Adjusting the shape to match PyTorch's output (tuple of 1D tensors)
# indices = tf.where(condition)
# return [indices[:, i] for i in range(indices.shape[1])]
# return func
6 changes: 6 additions & 0 deletions nobuco/node_converters/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def func(input, other, *args, **kwargs):
return input + other
return func

@converter(torch.neg, torch.Tensor.neg, channel_ordering_strategy=ChannelOrderingStrategy.FORCE_PYTORCH_ORDER)
def converter_t_neg(self):
def func(self):
return tf.negative(self)

return func

@converter(torch.sub, torch.subtract, torch.Tensor.sub, torch.Tensor.sub_, torch.Tensor.__sub__, torch.Tensor.__isub__, channel_ordering_strategy=ChannelOrderingStrategy.MINIMUM_TRANSPOSITIONS_OR_PYTORCH, autocast=True)
def converter_sub(input: Union[Tensor, Number], other: Union[Tensor, Number], *, alpha: Optional[Number]=1, out: Optional[Tensor]=None):
Expand Down
9 changes: 9 additions & 0 deletions nobuco/node_converters/tensor_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,12 @@ def converter_linspace(start, end, steps: _int, *, out: Optional[Tensor] = None,
def func(start, end, steps, *, out = None, dtype = None, layout = None, device = None, pin_memory = False, requires_grad = False):
return tf.linspace(start, end, steps)
return func

# @converter(torch.Tensor.new_tensor, channel_ordering_strategy=ChannelOrderingStrategy.FORCE_PYTORCH_ORDER)
# def converter_new_tensor(tensor, data, dtype=None, device=None, requires_grad=False):
# def func(tensor, data, dtype=None, device=None, requires_grad=False):
# # Create a TensorFlow constant with the same dtype as the input tensor
# # TensorFlow can handle Python scalars, so no need to check explicitly
# return tf.constant(data, dtype=tensor.dtype if dtype is None else dtype)

# return func
88 changes: 87 additions & 1 deletion nobuco/node_converters/tensor_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,25 @@ def func(input: Tensor, dim, sizes):
return tf.reshape(input, (*start_shape, *sizes, *end_shape))
return func

@converter(torch.Tensor.unflatten, torch.unflatten, channel_ordering_strategy=ChannelOrderingStrategy.FORCE_PYTORCH_ORDER)
def converter_t_unflatten(self, dim, sizes):
def func(self, dim, sizes):
n_dims = len(self.shape)
dim = _dim_make_positive(dim, n_dims)

# Before and after the dimension to unflatten
shape_before = self.shape[:dim]
shape_after = self.shape[dim + 1:]

# Expand the specified dimension into multiple dimensions
unflatten_shape = sizes if isinstance(sizes, tuple) else (sizes,)

# New shape combines the unchanged dimensions with the expanded dimensions
new_shape = (*shape_before, *unflatten_shape, *shape_after)

return tf.reshape(self, new_shape)

return func

@converter(torch.Tensor.narrow, channel_ordering_strategy=ChannelOrderingStrategy.MANUAL)
def converter_narrow(self, dimension, start, length):
Expand Down Expand Up @@ -328,7 +347,7 @@ def torch_gather(x, indices, gather_axis):
return reshaped


@converter(torch.gather, channel_ordering_strategy=ChannelOrderingStrategy.MINIMUM_TRANSPOSITIONS)
@converter(torch.gather, torch.Tensor.gather, channel_ordering_strategy=ChannelOrderingStrategy.MINIMUM_TRANSPOSITIONS)
def converter_gather(input: Tensor, dim, index: Tensor, *, sparse_grad: _bool=False, out: Optional[Tensor]=None):
n_dims = input.dim()

Expand All @@ -337,3 +356,70 @@ def func(input, dim, index, *, sparse_grad=False, out=None):
dim = dim_pytorch2keras(dim, n_dims)
return torch_gather(input, index, dim)
return func


@converter(torch.Tensor.repeat_interleave, channel_ordering_strategy=ChannelOrderingStrategy.FORCE_PYTORCH_ORDER)
def converter_t_repeat_interleave(self, repeats, dim=None):
def func(self, repeats, dim=None):
# Handling the dimension argument
if dim is None:
flat_input = tf.reshape(self, [-1])
return tf.repeat(flat_input, repeats)

n_dims = len(self.shape)
dim = _dim_make_positive(dim, n_dims)
shape_before = self.shape[:dim]
shape_after = self.shape[dim + 1:]

# Expanding the tensor before repeating along the specified dimension
expanded_tensor = tf.reshape(self, (*shape_before, -1, *shape_after))
repeated_tensor = tf.repeat(expanded_tensor, repeats, axis=dim)

return repeated_tensor

return func


# @converter(torch.Tensor.gather, channel_ordering_strategy=ChannelOrderingStrategy.MINIMUM_TRANSPOSITIONS)
# def converter_t_gather(self, dim, index):
# def func(self, dim, index):
# # Convert PyTorch dimension to TensorFlow dimension if necessary
# if get_channel_order(self) == ChannelOrder.TENSORFLOW:
# dim = dim_pytorch2keras(dim, tf.rank(self))

# return tf.gather(self, index, axis=dim)
# return func

# @converter(torch.Tensor.index_select, channel_ordering_strategy=ChannelOrderingStrategy.FORCE_PYTORCH_ORDER)
# def converter_t_index_select(input, dim, index):
# def func(input, dim, index):
# # TensorFlow's advanced indexing can be used directly
# # First, handle the case where `dim` might be negative
# dim = dim if dim >= 0 else len(input.shape) + dim

# # TensorFlow's advanced indexing works differently from PyTorch's index_select
# # We need to create a meshgrid and use it for gathering
# indices_shape = [-1 if i == dim else 1 for i in range(len(input.shape))]
# index = tf.reshape(index, indices_shape)
# mesh = tf.meshgrid(*[tf.range(d) for d in input.shape], indexing='ij')
# mesh[dim] = index
# return tf.gather_nd(input, tf.stack(mesh, -1))

# return func

@converter(torch.Tensor.index_select, channel_ordering_strategy=ChannelOrderingStrategy.FORCE_PYTORCH_ORDER)
def converter_t_index_select(input, dim, index):
def func(input, dim, index):
# TensorFlow's advanced indexing can be used directly
# First, handle the case where `dim` might be negative
dim = dim if dim >= 0 else len(input.shape) + dim

# TensorFlow's advanced indexing works differently from PyTorch's index_select
# We need to create a meshgrid and use it for gathering
indices_shape = [-1 if i == dim else 1 for i in range(len(input.shape))]
index = tf.reshape(index, indices_shape)
mesh = tf.meshgrid(*[tf.range(d) for d in input.shape], indexing='ij')
mesh[dim] = index
return tf.gather_nd(input, tf.stack(mesh, -1))

return func

0 comments on commit 5afcdf0

Please sign in to comment.