Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Does TokenLearner only square inputs supported? #279

Closed
leijue222 opened this issue Apr 19, 2022 · 7 comments
Closed

Does TokenLearner only square inputs supported? #279

leijue222 opened this issue Apr 19, 2022 · 7 comments

Comments

@leijue222
Copy link

TokenLearner has versions of v1.0 and v1.1.

if h * h != hw:
raise ValueError('Only square inputs supported.')

The v1.1 said only supported square inputs.
Does the V1.0 version also support square input? Why?

@leijue222
Copy link
Author

leijue222 commented Apr 21, 2022

I want to convert your code to PyTorch, but I don't know much about Jax.
Could you give me a sample of how to use TokenLearner&TokenFuser in Jax like this:

if __name__ == '__main__':
    input_shape = [10, 96, 64, 48]
    original_shape = [10, 96, 64, 48]

    x = torch.rand(10, 96, 64, 48)          # [10, 64, 48, 96]
    tklr = TokenLearner(input_shape=input_shape, DIM_MODEL=96, n_token=8)        
    selected = tklr(x)                      # [10, 8, 96] B, N, C

    tkfr = TokenFuser(selected.shape, original_shape) 
    tokenFuser = tkfr(selected, x)          # [10, 64, 48, 3]

Or give class TokenLearner the same tensor shape annotation in the class TokenFuser?

@mryoo
Copy link

mryoo commented Apr 21, 2022

Hi, I updated model.py so that V1.1 supports inputs with none square shapes. V1 only supports square inputs at this point.

In JAX, we follow the channels-last format. The input to the TokenLearner V1.1 module is [B, H, W, C] or simply [B, HW, C]. For 224x224 images with ViT, this typically would be [B, 14, 14, C] or [B, 196, C].

@mryoo
Copy link

mryoo commented Apr 21, 2022

To further clarify:

V1.1 module supports non-square 4D tensor inputs with [B, H, W, C] and any 3D tensor inputs with [B, HW, C].

V1 module also supports non-square 4D tensor inputs with [B, H, W, C]. However, for 3D tensor inputs with [B, HW, C], it expects HW to be squared for now.

@leijue222
Copy link
Author

leijue222 commented Apr 22, 2022

Hi, I updated model.py so that V1.1 supports inputs with none square shapes. V1 only supports square inputs at this point.

In JAX, we follow the channels-last format. The input to the TokenLearner V1.1 module is [B, H, W, C] or simply [B, HW, C]. For 224x224 images with ViT, this typically would be [B, 14, 14, C] or [B, 196, C].

Thanks, my purpose is to apply TokenLearnner in the field of Human Pose Estimation, so the width and height of each patch are not equal.

For the code in the TokenLearner part, I can read it even if I don't know Jax, because each step has detailed shape annotations, like this:

selected = jnp.transpose(selected, [0, 2, 1]) # Shape: [bs, n_token, h*w].

Could you add shape annotations on TokenFuser class either?
It seems that this repository lacks the __init__.py file, so I can't successfully run it to debug to see the shape of each step.

class TokenFuser(nn.Module):

Thanks a lot!

@leijue222
Copy link
Author

Since ImageNet, inetics-400, Kinetics-600, Charades, and AViD are all used for classified tasks.
Moreover, I want to know if you have used TokenLearner in regression tasks?
(such as Image segmentation, heatMap estimation, and so on.)
I would be interested to see how reducing the number of tokens on a more fine-grained regression task would affect performance.

@mryoo
Copy link

mryoo commented Apr 22, 2022

We will update the TokenFuser documentation soon.

We have not tried this on segmentation but we tried it on other types of regression tasks for robotics, and it worked fine in our case.

@leijue222
Copy link
Author

leijue222 commented May 11, 2022

Pytorch Version

import torch
import torch.nn as nn
import torch.nn.functional as F


class MLPBlock(nn.Module):
    """Transformer MLP / feed-forward block.
    https://github.com/google-research/scenic/blob/5b5a78da05855dc8111aaaa68bd6e71c783e1422/scenic/model_lib/layers/attention_layers.py#L393
    """
    def __init__(self, mlp_dim, out_dim, dropout_rate=0.1):
        super().__init__()
        actual_out_dim = mlp_dim if out_dim is None else out_dim
        self.layer_in = nn.Linear(mlp_dim, mlp_dim)
        self.activation_fn = F.gelu
        self.drop1 = nn.Dropout(dropout_rate)
        
        self.layer_out = nn.Linear(mlp_dim, actual_out_dim)
        self.drop2 = nn.Dropout(dropout_rate)
        
    def forward(self, x):
        x = self.layer_in(x)
        x = self.activation_fn(x)
        x = self.drop1(x)
        x = self.layer_out(x)
        x = self.drop2(x)
        
        return x


class TokenLearner(nn.Module):
    def __init__(self, d_model, n_token=8):
        super().__init__()
        self.dropout_rate = 0.1
        self.num_tokens = n_token
        self.MlpBlock  = MLPBlock(d_model, self.num_tokens, self.dropout_rate)
        self.layerNorm = nn.LayerNorm(d_model)
        self.softmax = nn.Softmax(-1)
        

    def forward(self, inputs):
        """Applies learnable tokenization to the 2D inputs.
        Args:
        inputs: Inputs of shape `[bs, h, w, c]`.

        Returns:
        Output of shape `[bs, n_token, c]`.
        """
        bs, h, w, c = inputs.shape
        inputs = inputs.reshape(bs, h*w, c)
        
        feature_shape = inputs.shape

        selected = inputs
        
        selected = self.layerNorm(selected)
        selected = self.MlpBlock(selected)
        selected = selected.reshape(feature_shape[0], -1, self.num_tokens) # Shape: [bs, h*w, n_token].
        selected = selected.permute(0, 2, 1)  # Shape: [bs, n_token, h*w].
        selected = self.softmax(selected)
        
        feat = inputs
        feat = torch.einsum('...si,...id->...sd', selected, feat)
        return feat


class TokenFuser(nn.Module):
    def __init__(self, d_model, num_tokens, use_normalization=True):
        super().__init__()
        self.num_tokens = num_tokens
        self.dropout_rate = 0.
        self.use_normalization = use_normalization
        self.fuser_mix_norm1 = nn.LayerNorm(d_model)
        self.layer_inputs = nn.Linear(num_tokens, num_tokens)
        self.fuser_mix_norm2 = nn.LayerNorm(d_model)
        self.original_norm = nn.LayerNorm(d_model)
        
        self.MlpBlock  = MLPBlock(d_model, self.num_tokens, self.dropout_rate)
        self.sigmoid = nn.Sigmoid()
        self.drop_inputs = nn.Dropout(self.dropout_rate)
        
    def forward(self, inputs, original):
        """Applies token fusion to the generate 2D ouputs.
        Args:
        inputs: Inputs of shape `[bs, n_token, c]`.
        original: Inputs of shape `[bs, hw, c]` or `[bs, h, w, c]`.
        
        Returns:
        Output tensor with the shape identical to `original'.
        """
        if original.ndim == 4:
            n, h, w, c = original.shape
            original = original.reshape(n, h*w, c)
        
        if self.use_normalization:
            inputs = self.fuser_mix_norm1(inputs)
        
        inputs = inputs.permute(0, 2, 1)  # Shape: [bs, c, n_token].
        inputs = self.layer_inputs(inputs)
        inputs = inputs.permute(0, 2, 1)  # Shape: [bs, n_token, c].
        
        if self.use_normalization:
            inputs = self.fuser_mix_norm2(inputs)

        original = self.original_norm(original)
        mix = self.MlpBlock(original)     # Shape: [bs, h*w, n_token].
        mix = self.sigmoid(mix)
        
        inputs = torch.einsum('...sc,...hs->...hc', inputs, mix)    # Shape: [bs, h*w, c].
        inputs = self.drop_inputs(inputs)
        
        inputs = inputs.reshape(n, h, w, -1)
            
        return inputs
    
if __name__ == '__main__':    
    x = torch.rand(10, 64, 48, 96)          # [bs, H, W, C]
    tklr = TokenLearner(d_model=96, n_token=8)        
    tklr_res = tklr(x)                  # torch.Size([10, 8, 96]) B, N, C
    print('tklr_res shape: ', tklr_res.shape)
    
    tkfr = TokenFuser(96, 8) 
    tkfr_res = tkfr(tklr_res, x)      # torch.Size([10, 64, 48, 96])
    print('tkfr_res shape: ', tkfr_res.shape)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants