Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support operations on single dimension #1

Merged
merged 3 commits into from
Aug 26, 2020
Merged

Support operations on single dimension #1

merged 3 commits into from
Aug 26, 2020

Conversation

davidnvq
Copy link
Owner

@davidnvq davidnvq commented Aug 26, 2020

1. Redimension Functionality

The function redim is added. This performs redimension elements on 1 specific dimension, e.g., chunking or reordering.
The operation is performed by the pattern within the bracket.

Note that: It only supports a single pattern so far, and there must be a single group/element on that dimension

def redim(tensor, pattern, **axes_lengths: int)
    """Args:
        tensor: (torch.Tensor or np.Array)
        pattern: pattern to redimension, e.g., "B H W [(r g b) -> (b g r)]"
        **axes_lengths: optional lengths of axes, e.g., B=10, H=300, W=600, b=1, g=1,r=1
    Returns:
        a re-dimensioned tensor
    """

Example: Reordering

>>> image = np.random.randn(30, 40, 3) # RGB
# change it to RGB -> BGR
# It is not necessary to specify the length of other axes, only for `assert` purpose
# When element length = 1, or can be infered from the context, we also don't need to specify

>>> image = redim(image, "height width [(r g b) -> (b g r)]", height=30, width=40, r=1, g=1, b=1)

Example: chunking

 # Split dataset into train and validation set
>>> train_set = redim(dataset, "[(train valid) -> train] H W", train=800, valid=200)
>>> valid_set = redim(dataset, "[(train valid) -> valid] H W", train=800, valid=200)

# Remove alpha channel
>>> image = np.random.randn(30, 40, 4) # RGBA
>>> image = redim(image, "H W [(rgb a) -> rgb]", rgb=3) # or the below
>>> image = redim(image, "H W [(r g b a) -> (r g b)]")

# Crop the image
>>> image = redim(image, "[(top down) -> top] W", top=20)
>>> image = redim(image, "H [left right]", left=10)

2. Concatenation Functionality

The function concat is added. This performs the concatenation of tensors along 1 axis.

Note that:

  1. Except for the concatenated axis, the lengths of the other axes must be the same.
  2. It is not necessary for all tensors to have the same length, they can be different.
  3. We DONT need to specify all the lengths of dimensions.
def concat(tensor_list, pattern, **axes_lengths: int):
    Args:
        tensor_list:(List[torch.Tensor/np.Array]) list of tensors have same length on all dimensions (except concat dim)
        pattern: (str) pattern to redimension, e.g., "batch seq [dx dy dz -> (dx dy dz)]"
        **axes_lengths: optional lengths of axes, B=10, H=300, W=600
    Returns:
        a concatenated tensor

Example: concatenate

>>> x = torch.randn(2, 10, 512)
>>> y = torch.randn(2, 10, 128)
>>> z = torch.randn(2, 10, 256)
>>> h = concat([x, y, z], "batch seq [dx dy dz -> (dx dy dz)]", batch=2, seq=10, dx=512, dy=128, dz=256)

Other Note:
We can use ellipsis when we don't want to list all dimension names along this axis

###  Example: concatenate
>>> h = concat([x, y, z], "batch seq [... -> ...]")
>>> h = concat([x, y, z], "batch seq [... -> d]")

This pull request based on arogozhnikov#56, arogozhnikov#50, and arogozhnikov#20.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
1 participant