In [1]:
import torch
import torch.nn as nn

In [7]:
# Define the DownSample2x2BlockFix class
class DownSample2x2BlockFix(nn.Module):
    def forward(self, x):
        print(f"forward {x.size()=}")
        vit_embeds = x
        h = w = int(vit_embeds.shape[1] ** 0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = flat_square_2x2(vit_embeds)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
        return vit_embeds

# Define the flat_square_2x2 function
def flat_square_2x2(x):
    print(f"flat_square_2x2 {x.size()=}")
    n, w, h, c = x.size()
    if w % 2 == 1:
        x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
        n, w, h, c = x.size()
    x = x.contiguous()
    if h % 2 == 1:
        x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
        n, w, h, c = x.size()
    x = x.view(n, w, int(h / 2), int(c * 2))
    x = x.permute(0, 2, 1, 3).contiguous()
    x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
    x = x.permute(0, 2, 1, 3).contiguous()
    return x

batch_size = 4
num_patches = 16  # This should be a perfect square for simplicity
embedding_dim = 64

# Simulate a batch of ViT embeddings
vit_embeds = torch.randn(batch_size, num_patches, embedding_dim)

# Initialize the downsampling block
downsample_block = DownSample2x2BlockFix()

# Forward pass through the downsampling block
downsampled_embeds = downsample_block(vit_embeds)

# Print the shape of the output
print("Original shape:", vit_embeds.shape)
print("Downsampled shape:", downsampled_embeds.shape)

forward x.size()=torch.Size([4, 16, 64])
flat_square_2x2 x.size()=torch.Size([4, 4, 4, 64])
Original shape: torch.Size([4, 16, 64])
Downsampled shape: torch.Size([4, 4, 256])


In [13]:
16 * int(1/0.5)**2

64

In [12]:
class ImageProcessor:
    def __init__(self, image_size, patch_size, config):
        self.image_size = image_size
        self.patch_size = patch_size
        self.config = config
        self.num_image_token = int((image_size // patch_size) ** 2 * (config.downsample_ratio ** 2))

# Example configuration class
class Config:
    def __init__(self, downsample_ratio):
        self.downsample_ratio = downsample_ratio


# Define image size, patch size, and downsample ratio
image_size = 256  # e.g., 256x256 image
patch_size = 16   # e.g., 16x16 patches
downsample_ratio = 0.5  # e.g., downsample by a factor of 0.5

print(f"Num patches:", (image_size/patch_size)**2)

# Create a config object
config = Config(downsample_ratio)

# Initialize the ImageProcessor
processor = ImageProcessor(image_size, patch_size, config)

# Print the number of image tokens
print("Number of image tokens:", processor.num_image_token)

Num patches: 256.0
Number of image tokens: 64


# InternVL - pixel_shuffle

In [None]:
def flat_square(x, kernel_size):
    """
    Reshapes and rearranges the input tensor to prepare it for downsampling.

    Args:
        x (Tensor): Input tensor.
        kernel_size (int): Size of the downsampling kernel.

    Returns:
        Tensor: Reshaped and rearranged tensor.
    """
    n, w, h, c = x.size()

    # Reshape and rearrange
    x = x.view(n, w, int(h / kernel_size), int(c * kernel_size))
    x = x.permute(0, 2, 1, 3).contiguous()
    x = x.view(n, int(h / kernel_size), int(w / kernel_size), int(c * kernel_size ** 2))
    x = x.permute(0, 2, 1, 3).contiguous()

    return x

def pixel_shuffle(x, scale_factor=0.5):
    n, w, h, c = x.size()
    # N, W, H, C --> N, W, H * scale, C // scale
    x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
    # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
    x = x.permute(0, 2, 1, 3).contiguous()
    # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
    x = x.view(n, int(h * scale_factor), int(w * scale_factor),
                int(c / (scale_factor * scale_factor)))

    x = x.permute(0, 2, 1, 3).contiguous()
    
    return x

def extract_feature(pixel_values):
    if self.select_layer == -1:
        vit_embeds = self.vision_model(
            pixel_values=pixel_values,
            output_hidden_states=False,
            return_dict=True).last_hidden_state
    else:
        vit_embeds = self.vision_model(
            pixel_values=pixel_values,
            output_hidden_states=True,
            return_dict=True).hidden_states[self.select_layer]
    vit_embeds = vit_embeds[:, 1:, :]

    h = w = int(vit_embeds.shape[1] ** 0.5)
    vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
    vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
    vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
    vit_embeds = self.mlp1(vit_embeds)
    return vit_embeds

In [None]:
# Create an instance of the class that contains the pixel_shuffle and extract_feature methods
model = YourModel()

# Example usage of pixel_shuffle
input_tensor = torch.randn(1, 16, 16, 3)  # batch_size, height, width, channels
output_tensor = model.pixel_shuffle(input_tensor, scale_factor=0.5)
print(output_tensor.shape)

# Example usage of extract_feature
input_tensor = torch.randn(1, 3, 224, 224)  # batch_size, channels, height, width
output_tensor = model.extract_feature(input_tensor)
print(output_tensor.shape)

# NVILA - flat square

In [36]:
class DownSample2x2BlockFix(nn.Module):
    def forward(self, x):
        vit_embeds = x
        h = w = int(vit_embeds.shape[1] ** 0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = flat_square_2x2(vit_embeds)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
        return vit_embeds


def flat_square_2x2(x):
    n, w, h, c = x.size()
    if w % 2 == 1:
        x = torch.concat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
        n, w, h, c = x.size()
    x = x.contiguous()
    if h % 2 == 1:
        x = torch.concat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
        n, w, h, c = x.size()
    x = x.view(n, w, int(h / 2), int(c * 2))
    x = x.permute(0, 2, 1, 3).contiguous()
    x = x.view(n, int(h / 2), int(w / 2), int(c * 4))
    x = x.permute(0, 2, 1, 3).contiguous()
    return x

def flat_square_downsample(x: torch.Tensor, down_sample_ratio: int):
    """
    Reshapes and rearranges the input tensor to prepare it for downsampling.
    Args:
        x (Tensor): Input tensor.
    Returns:
        Tensor: Reshaped and rearranged tensor.
    """
    h = w = int(x.shape[1] ** 0.5)
    x = x.reshape(x.shape[0], h, w, -1)
    
    n, w, h, c = x.size()
    # Pad width and height if necessary
    if w % down_sample_ratio != 0:
        padding_w = down_sample_ratio - (w % down_sample_ratio)
        x = torch.concat([x, torch.zeros((n, padding_w, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
        n, w, h, c = x.size()
    if h % down_sample_ratio != 0:
        padding_h = down_sample_ratio - (h % down_sample_ratio)
        x = torch.concat([x, torch.zeros((n, w, padding_h, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
        n, w, h, c = x.size()
    # Reshape and rearrange
    x = x.view(n, w, int(h / down_sample_ratio), int(c * down_sample_ratio))
    x = x.permute(0, 2, 1, 3).contiguous()
    x = x.view(n, int(h / down_sample_ratio), int(w / down_sample_ratio), int(c * down_sample_ratio ** 2))
    x = x.permute(0, 2, 1, 3).contiguous()
    
    return x.reshape(x.shape[0], -1, x.shape[-1])

class DownSample3x3BlockFix(nn.Module):
    def forward(self, x):
        vit_embeds = x
        h = w = int(vit_embeds.shape[1] ** 0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        vit_embeds = flat_square_3x3(vit_embeds)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
        return vit_embeds


def flat_square_3x3(x):
    n, w, h, c = x.size()
    if w % 3 != 0:
        x = torch.concat([x, torch.zeros((n, 3 - (w % 3), h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
        n, w, h, c = x.size()
    x = x.contiguous()
    if h % 3 != 0:
        x = torch.concat([x, torch.zeros((n, w, 3 - (h % 3), c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
        n, w, h, c = x.size()
    x = x.view(n, w, int(h / 3), int(c * 3))
    x = x.permute(0, 2, 1, 3).contiguous()
    x = x.view(n, int(h / 3), int(w / 3), int(c * 9))
    x = x.permute(0, 2, 1, 3).contiguous()
    return x

def flat_square(x, kernel_size):
    """
    Reshapes and rearranges the input tensor to prepare it for downsampling.

    Args:
        x (Tensor): Input tensor.
        kernel_size (int): Size of the downsampling kernel.

    Returns:
        Tensor: Reshaped and rearranged tensor.
    """
    n, w, h, c = x.size()

    # Pad width and height if necessary
    if w % kernel_size != 0:
        padding_w = kernel_size - (w % kernel_size)
        x = torch.concat([x, torch.zeros((n, padding_w, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
        n, w, h, c = x.size()

    if h % kernel_size != 0:
        padding_h = kernel_size - (h % kernel_size)
        x = torch.concat([x, torch.zeros((n, w, padding_h, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
        n, w, h, c = x.size()

    # Reshape and rearrange
    x = x.view(n, w, int(h / kernel_size), int(c * kernel_size))
    x = x.permute(0, 2, 1, 3).contiguous()
    x = x.view(n, int(h / kernel_size), int(w / kernel_size), int(c * kernel_size ** 2))
    x = x.permute(0, 2, 1, 3).contiguous()

    return x

class DownSampleBlock(nn.Module):
    def __init__(self, down_sample_ratio: int):
        self.kernel_size = down_sample_ratio
        
    def _flat_square(self, x):
        """
        Reshapes and rearranges the input tensor to prepare it for downsampling.

        Args:
            x (Tensor): Input tensor.
        Returns:
            Tensor: Reshaped and rearranged tensor.
        """
        n, w, h, c = x.size()

        # Pad width and height if necessary
        if w % self.kernel_size != 0:
            padding_w = self.kernel_size - (w % self.kernel_size)
            x = torch.concat([x, torch.zeros((n, padding_w, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
            n, w, h, c = x.size()

        if h % self.kernel_size != 0:
            padding_h = self.kernel_size - (h % self.kernel_size)
            x = torch.concat([x, torch.zeros((n, w, padding_h, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
            n, w, h, c = x.size()

        # Reshape and rearrange
        x = x.view(n, w, int(h / self.kernel_size), int(c * self.kernel_size))
        x = x.permute(0, 2, 1, 3).contiguous()
        x = x.view(n, int(h / self.kernel_size), int(w / self.kernel_size), int(c * self.kernel_size ** 2))
        x = x.permute(0, 2, 1, 3).contiguous()
        
        return x
    
    def forward(self, x):
        vit_embeds = x
        h = w = int(vit_embeds.shape[1] ** 0.5)
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
        
        vit_embeds = self._flat_square(vit_embeds)
        
        vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
        return vit_embeds
    

def flat_square_downsample(x: torch.Tensor, down_sample_ratio: int):
    """
    Reshapes and rearranges the input tensor to prepare it for downsampling.
    Args:
        x (Tensor): Input tensor.
    Returns:
        Tensor: Reshaped and rearranged tensor.
    """
    h = w = int(x.shape[1] ** 0.5)
    x = x.reshape(x.shape[0], h, w, -1)
    n, w, h, c = x.size()
    # Pad width and height if necessary
    if w % down_sample_ratio != 0:
        padding_w = down_sample_ratio - (w % down_sample_ratio)
        x = torch.concat([x, torch.zeros((n, padding_w, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
        n, w, h, c = x.size()
    if h % down_sample_ratio != 0:
        padding_h = down_sample_ratio - (h % down_sample_ratio)
        x = torch.concat([x, torch.zeros((n, w, padding_h, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
        n, w, h, c = x.size()
    # Reshape and rearrange
    x = x.view(n, w, int(h / down_sample_ratio), int(c * down_sample_ratio))
    x = x.permute(0, 2, 1, 3).contiguous()
    x = x.view(n, int(h / down_sample_ratio), int(w / down_sample_ratio), int(c * down_sample_ratio ** 2))
    x = x.permute(0, 2, 1, 3).contiguous()
    return x.reshape(x.shape[0], -1, x.shape[-1])


In [40]:
# Create a random input tensor
x = torch.randn(8, 256, 1024)
# Create an instance of DownSample2x2BlockFix
downsample_block = DownSample2x2BlockFix()
# Run the input through the block
output_block = downsample_block(x)
# Run the input through the flat_square_downsample function
output_function = flat_square_downsample(x, 2)
# Check if the outputs are equal
print(f"{output_block.shape=}")
print(f"{output_function.shape=}")
print(torch.allclose(output_block, output_function))

output_block.shape=torch.Size([8, 64, 4096])
output_function.shape=torch.Size([8, 64, 4096])
True


In [29]:
def test_flat_square_2x2():
    x = torch.randn(1, 10, 10, 3)
    expected_output = flat_square_2x2(x)
    
    downsampler = DownSampleBlock(down_sample_ratio=2)
    actual_output = downsampler._flat_square(x)
    # actual_output = downsampler.forward(x)
    # actual_output = flat_square(x, 2)
    print(torch.allclose(expected_output, actual_output))
    # self.assertTrue()
    
def test_flat_square_3x3():
    x = torch.randn(1, 12, 12, 3)
    expected_output = flat_square_3x3(x)
    actual_output = flat_square(x, 3)
    print(torch.allclose(expected_output, actual_output))
    # self.assertTrue(torch.allclose(expected_output, actual_output))
    
test_flat_square_2x2()
test_flat_square_3x3()

True
True


# MLP Projector

In [37]:
class MLPResampler(nn.Module):
    # enforce_uniform_emb_variance = True 
    
    def __init__(
        self, 
        feature_dim: int, 
        embedding_size: int,
        scale_factor: int = 2
    ):
        super().__init__()
        self.embedding_size = embedding_size
        
        self.scale_factor = scale_factor
        
        # To-do: add pos_emb
        self.pos_emb = None
        
        # To-do: add pos_emb resize
        
        # To-do: add eoi_emb
        
        # To-do: add eoi_emb resize
        
        # To-do: add out-norm
        
        
        # self.register_buffer('average_tok_rms', torch.tensor(1.0))
        # if self.enforce_uniform_emb_variance:
        #     self.average_tok_rms = None
                
         # with scale_factor=0.5 --> feature_dim x 4
        
        # h = int(feature_dim ** 0.5)
        # assert h % self.scale_factor == 0
        
        inner_dim = feature_dim * int(self.scale_factor ** 2)
        
        
        self.mlp = nn.Sequential(
            nn.LayerNorm(inner_dim),
            nn.Linear(inner_dim, self.embedding_size),
            nn.GELU(),
            nn.Linear(self.embedding_size, self.embedding_size)
        )
        
        # EOI
        init_scale = 1 / torch.sqrt(
            torch.tensor(embedding_size, dtype=torch.float32, requires_grad=False)
        )
        self.eoi = nn.Parameter(torch.randn(1, 1, embedding_size) * init_scale)
        
    def _channel_reshuffle(self, x):
        """
        Applies spatial-to-channel reshuffling to the input tensor.
        This function rearranges the dimensions of the input tensor to transform
        spatial dimensions into channel dimensions. The transformation is as follows:
        bsz, w, h, dim --> bsz, w * scale, h * scale, dim // (scale ** 2)
        Args:
            x (torch.Tensor): Input tensor with shape (N, W, H, C)
        Returns:
            torch.Tensor: Reshuffled tensor with shape (N, W * scale, H * scale, C // (scale ** 2))
        """
    
        # n, w, h, c = x.size()
        
        # # N, W, H, C --> N, W, H * scale, C // scale
        # x = x.view(n, w, int(h * self.scale_factor), int(c / self.scale_factor))
        
        # # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
        # x = x.permute(0, 2, 1, 3).contiguous()
        
        # # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
        # x = x.view(n, int(h * self.scale_factor), int(w * self.scale_factor),
        #             int(c / (self.scale_factor * self.scale_factor)))

        # # N, W*scale, H*scale, C//(scale ** 2)
        # x = x.permute(0, 2, 1, 3).contiguous()
        
        # return x
        
        bsz, w, h, dim = x.size()
        
        # bsz, w, h, dim --> bsz, w, h * scale, dim // scale
        x = x.view(bsz, w, int(h // self.scale_factor), dim * self.scale_factor)
        
        # bsz, w, h * scale, dim // scale --> bsz, h * scale, w, dim // scale
        x = x.permute(0, 2, 1, 3).contiguous()
        
        # bsz, h * scale, w, dim // scale --> bsz, h * scale, w * scale, dim // (scale ** 2)
        # x = x.view(bsz, int(h * self.scale_factor), int(w * self.scale_factor),
        #             int(dim / (self.scale_factor * self.scale_factor)))
        x = x.view(bsz, int(h // self.scale_factor), int(w // self.scale_factor),
                dim * self.scale_factor * self.scale_factor)

        # bsz, w * scale, h * scale, dim // (scale ** 2)
        x = x.permute(0, 2, 1, 3).contiguous()
        
        return x
        
    def forward(self, modality_embs: torch.Tensor):
        bsz, seqlen, dim = modality_embs.shape
        
        print(f"{modality_embs.shape = }")
        
        h = w = int(seqlen ** 0.5)
        
        modality_embs = modality_embs.reshape(modality_embs.shape[0], h, w, -1) # (bz, h, w, -1)
        print(f"1. {modality_embs.shape = }")
        
        modality_embs = self._channel_reshuffle(modality_embs) # bz, w*scale, h*scale, dim /(scale **2 )
        print(f"2. {modality_embs.shape = }")
        
        modality_embs = modality_embs.reshape(modality_embs.shape[0], -1, modality_embs.shape[-1]) # bz, w*h*scale^2, dim / scale ^ 2
        print(f"3. {modality_embs.shape = }")
        
        modality_embs = self.mlp(modality_embs) # bz, w*h*scale^2, emd
        print(f"4. {modality_embs.shape = }")
        
        # To-do: add EOI token 
        modality_embs = torch.cat([modality_embs, self.eoi.repeat(bsz, 1, 1)], dim=1)
        
        # # To-do: add out norm
        # if self.out_norm is not None:
        #     modality_embs = self.out_norm(modality_embs)
            
        # # To-do: add average_tok_rms
        # if self.enforce_uniform_emb_variance:
        #     modality_embs = modality_embs * self.average_tok_rms
        
        return modality_embs

In [47]:
# Create a test input tensor
batch_size = 2
patch_sequence_length = 16
embedding_dim = 128
x = torch.randn(batch_size, patch_sequence_length, embedding_dim)
# Create an instance of the MLPResampler module
print(f"{x.shape = }")
# resampler = MLPResampler(feature_dim=embedding_dim, embedding_size=4096, scale_factor=0.5)
resampler = MLPResampler(feature_dim=embedding_dim, embedding_size=4096, scale_factor=2)
# Run the forward pass
output = resampler(x)
# Print the output shape
print(f"{output.shape = }")


x.shape = torch.Size([2, 16, 128])
modality_embs.shape = torch.Size([2, 16, 128])
1. modality_embs.shape = torch.Size([2, 4, 4, 128])
2. modality_embs.shape = torch.Size([2, 2, 2, 512])
3. modality_embs.shape = torch.Size([2, 4, 512])
4. modality_embs.shape = torch.Size([2, 4, 4096])
output.shape = torch.Size([2, 5, 4096])


# VE pos_emb

In [78]:
from typing import Optional 

def interpolate_position_embedding(
    pos_embedding: torch.Tensor,
    target_image_size: int,
    patch_size: int
):
    hidden_dim = pos_embedding.shape[-1]
    original_n_patches = int((pos_embedding.shape[0]) ** (1 / 2))
    target_n_patches = target_image_size // patch_size
    print(
        f"Interpolating pos embedding from n patches: {original_n_patches} to {target_n_patches}. hidden dim: {hidden_dim}"
    )
    # reshape to (original_n_patches, original_n_patches, hidden_dim)
    pos_embedding = pos_embedding.reshape(
        original_n_patches, original_n_patches, -1
    )
    print(f"1. {pos_embedding.shape=}")
    
    pos_embedding = pos_embedding.unsqueeze(dim=0).permute(0, 3, 1, 2)
    print(f"2. {pos_embedding.shape=}")
    
    # interpolate
    pos_embedding = torch.nn.functional.interpolate(
        pos_embedding, size=(target_n_patches, target_n_patches), mode="bilinear"
    )
    print(f"3. {pos_embedding.shape=}")
    
    pos_embedding = (
        pos_embedding[0]
        .permute(1, 2, 0)
        .reshape(target_n_patches * target_n_patches, hidden_dim)
    )
    print(f"4. {pos_embedding.shape=}")
    
    print(f"5. {pos_embedding.shape=}")
    
    return pos_embedding

pos_embedding = torch.randn(256, 1024)  # original_n_patches x hidden_dim
print(f"original {pos_embedding.shape=}")
# Create a sample target image size and patch size
target_image_size = 336
patch_size = 14

# Call the interpolate_position_embedding method
interpolated_pos_embedding = interpolate_position_embedding(
    pos_embedding, target_image_size, patch_size
)
# Check that the output shape is correct
expected_shape = (target_image_size // patch_size) ** 2, pos_embedding.shape[-1]
print(f"{interpolated_pos_embedding.shape=},  {expected_shape=}")

original pos_embedding.shape=torch.Size([256, 1024])
Interpolating pos embedding from n patches: 16 to 24. hidden dim: 1024
1. pos_embedding.shape=torch.Size([16, 16, 1024])
2. pos_embedding.shape=torch.Size([1, 1024, 16, 16])
3. pos_embedding.shape=torch.Size([1, 1024, 24, 24])
4. pos_embedding.shape=torch.Size([576, 1024])
5. pos_embedding.shape=torch.Size([576, 1024])
interpolated_pos_embedding.shape=torch.Size([576, 1024]),  expected_shape=(576, 1024)


Interpolating pos embedding from n patches: 4 to 8. hidden dim: 128
interpolated_pos_embedding.shape=torch.Size([64, 128]),  expected_shape=(64, 128)


In [80]:
interpolated_pos_embedding.shape

torch.Size([576, 1024])

In [81]:
x = interpolated_pos_embedding.unsqueeze(0)
x.shape

torch.Size([1, 576, 1024])

In [66]:
s

'{"a": {"b": 1}}'

In [67]:
d['a']

{'b': 1}

In [70]:
torch.randn(16, 10).shape, torch.randn(1, 16, 10).shape

(torch.Size([16, 10]), torch.Size([1, 16, 10]))