# Setup

In [1]:
import json
import os
import torch
import torch
from torch import nn

class Item:
    def __init__(self, key, value, type:str):
        self.key = key
        self.type = type
        if type.startswith("Tensor"):
            self.value = {"size":value.size(),"values":value.flatten().tolist()}
        else:
            self.value = value

    def to_dict(self):
        return {self.key: {self.type: self.value}}


def to_file(name:str,items:list):
    path = "test-files/"+name+".json"
    values = {}
    for item in items:
        values.update(item.to_dict())
    output = {"values": values}
    
    data = json.dumps(output, indent=4)
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w") as f:
        f.write(data)

def random_tensor(shape:list,seed:int=0):
    n = 1 
    for dim in shape:
        n*=dim

    a = 3
    c = 23
    m = 2**32
    
    result = []
    x = seed
    for _ in range(n):
        x = (a * x + c) % m
        result.append(x / m)  # Normalize the result to [0, 1]

    return torch.tensor(result).view(shape)

def mock_linear(linear:nn.Linear )->nn.Linear:
    linear.weight.data = random_tensor(linear.weight.size(),1)
    linear.bias.data = random_tensor(linear.bias.size(),2)

def mock_layer_norm(layer_norm: nn.LayerNorm)->nn.LayerNorm:
    layer_norm.weight.data = random_tensor(layer_norm.weight.size(),1)
    layer_norm.bias.data = random_tensor(layer_norm.bias.size(),2)

def mock_conv2d(conv2d:nn.Conv2d)->nn.Conv2d:
    conv2d.weight.data = random_tensor(conv2d.weight.size(),1)
    conv2d.bias.data = random_tensor(conv2d.bias.size(),2)

def mock_embedding(embedding:nn.Embedding)->nn.Embedding:
    embedding.weight.data = random_tensor(embedding.weight.size(),1)

def mock_conv_transpose2d(conv: nn.ConvTranspose2d)->nn.ConvTranspose2d:
    conv.weight.data = random_tensor(conv.weight.size(),1)
    conv.bias.data = random_tensor(conv.bias.size(),2)

def mock_tensor(tensor:torch.Tensor)->torch.Tensor:
    tensor.data = random_tensor(tensor.size(),1)

## Common
#### LayerNorm2d

In [2]:
from segment_anything.modeling.common import LayerNorm2d

layer_norm = LayerNorm2d(256,0.1)
items = [Item("weight", layer_norm.weight, "TensorFloat"), Item("bias", layer_norm.bias, "TensorFloat"), Item("eps", layer_norm.eps, "Float")]
to_file("layer_norm_2d",items)

# Forward
input = random_tensor([2,256,16,16])
output = layer_norm(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
to_file("layer_norm_2d_forward",items)
del layer_norm, input, output,items

#### MLPBlock

In [3]:
from segment_anything.modeling.common import MLPBlock

mlp_block = MLPBlock(256,256,nn.GELU)
items=[Item("lin1_size", mlp_block.lin1.weight.size(), "List"), Item("lin2_size", mlp_block.lin2.weight.size(), "List")]
to_file("mlp_block",items)

#Mocking 
def mock_mlp_block(mlp_block:MLPBlock)->MLPBlock:
    mock_linear(mlp_block.lin1)
    mock_linear(mlp_block.lin2)
mock_mlp_block(mlp_block)

# Forward
input = random_tensor([256,256],5)
output = mlp_block(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
to_file("mlp_block_forward",items)
del mlp_block, input, output,items

#### Activation

In [4]:
# Gelu
gelu = nn.GELU()
input = random_tensor([256,256])
output = gelu(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
to_file("activation_gelu",items)

# ReLU
relu = nn.ReLU()
input = random_tensor([256,256])
output = relu(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
to_file("activation_relu",items)
del input,output

# Image encoder

#### PatchEmbeded

In [5]:
from segment_anything.modeling.image_encoder import PatchEmbed

patch_embed = PatchEmbed((16,16),(16,16),(0,0),3,320)
items=[Item("proj_size", patch_embed.proj.weight.size(), "List")]
to_file("patch_embed",items)

# Mocking 
def mock_patch_embed(patch_embed:PatchEmbed)->PatchEmbed:
    mock_conv2d(patch_embed.proj)
mock_patch_embed(patch_embed)

# Forward
input = random_tensor([1,3,512,512],3)
output = patch_embed(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
to_file("patch_embed_forward",items)
del patch_embed, input, output,items

#### Attention

In [6]:
from segment_anything.modeling.image_encoder import get_rel_pos,add_decomposed_rel_pos

# Get rel pos
q_size = 32
k_size = 32
input = random_tensor([127,40],1)
output = get_rel_pos( q_size, k_size, input)
items = [Item("input",input,"TensorFloat"),Item("output", output, "TensorFloat")]
to_file("get_rel_pos",items)
del input, output


# Add decomposed rel pos
attn = random_tensor([200,49,49],2)
q = random_tensor([200,49,20],3)
relo_pos_h = random_tensor([20,20],4)
relo_pos_w = random_tensor([20,20],5)
q_size = (7,7)
k_size = (7,7)
output = add_decomposed_rel_pos(attn,q,relo_pos_h,relo_pos_w,q_size,k_size)
items = [Item("attn", attn, "TensorFloat"), Item("q", q, "TensorFloat"), Item("q_size", q_size, "Size"), Item("k_size", k_size, "Size"), Item("output", output, "TensorFloat")]
to_file("add_decomposed_rel_pos",items)
del attn,q,relo_pos_h,relo_pos_w,q_size,k_size,output,items

In [7]:
from segment_anything.modeling.image_encoder import Attention

# Attention
attention = Attention(320, 16 ,True ,True ,True, (14, 14))
items =[Item("num_heads", attention.num_heads, "Int"), Item("scale", attention.scale, "Float"),  Item("use_rel_pos", attention.use_rel_pos, "Bool")]
to_file("attention",items)

#Mocking 
def mock_attention(attention:Attention)->Attention:
    mock_linear(attention.qkv)
    mock_linear(attention.proj)
mock_attention(attention)

# Forward
input = random_tensor([25,14,14,320],1)
output = attention(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
to_file("attention_forward",items)
del input
del output
del attention

In [8]:
from segment_anything.modeling.image_encoder import  window_partition,window_unpartition

# Window partition
input = random_tensor([2,256,16,16],1)
output,size = window_partition(input,16)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat"), Item("size", size, "Size")]
to_file("window_partition",items)

# Window unpartition
input = random_tensor([2,256,16,16],2)
output = window_unpartition(input,16,(16,16),(14,14))
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
to_file("window_unpartition",items)
del input, output, items

#### Block

In [9]:
from segment_anything.modeling.image_encoder import Block

#Block
block = Block(320,16,4.0,True,nn.LayerNorm,nn.GELU,True,True,14,(64,64))
items=[Item("window_size", block.window_size, "Int")]
to_file("block",items)

#Mocking 
def mock_block(block:Block)->Block:
    mock_layer_norm(block.norm1)
    mock_layer_norm(block.norm2)
    mock_attention(block.attn)
    mock_mlp_block(block.mlp )
mock_block(block)
#Forward
input = random_tensor([1,64,64,320],1)
output = block(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
to_file("block_forward",items)
del block, input, output

#### Image encoderViT

In [10]:
from segment_anything.modeling.image_encoder import ImageEncoderViT

image_encoder = ImageEncoderViT(128,4,3,320,32,16,4.0,256,True,nn.LayerNorm,nn.GELU,True,True,True,14,[7,15,23,31])
items =[Item("img_size", image_encoder.img_size,"Int")]
to_file("image_encoder",items)

input = random_tensor([1,3,128,128],1)
output = image_encoder(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
to_file("image_encoder_forward",items)

del image_encoder
del input
del output

## Transformer
#### Attention

In [11]:
from segment_anything.modeling.transformer import Attention

attention = Attention(256,8,1)
items = [Item("embedding_dim", attention.embedding_dim, "Int"),
          Item("internal_dim", attention.internal_dim, "Int"),
          Item("num_heads", attention.num_heads, "Int"),
          Item("q_proj_size", attention.q_proj.weight.size(), "List"),
          Item("k_proj_size", attention.k_proj.weight.size(), "List"),
          Item("v_proj_size", attention.v_proj.weight.size(), "List"),
          Item("out_proj_size", attention.out_proj.weight.size(), "List"),]
to_file("transformer_attention",items)

#Mocking
def mock_transformer_attention(attention:Attention)->Attention:
    mock_linear(attention.q_proj)
    mock_linear(attention.k_proj)
    mock_linear(attention.v_proj)
    mock_linear(attention.out_proj)
mock_transformer_attention(attention)

#Forward
q = random_tensor([1,256,256],1)
k = random_tensor([1,256,256],2)
v = random_tensor([1,256,256],3)
output = attention.forward(q,k,v)
items = [Item("q", q, "TensorFloat"), Item("k", k, "TensorFloat"), Item("v", v, "TensorFloat"), Item("output", output, "TensorFloat")]
to_file("transformer_attention_forward",items)


#### TwoWayAttention

In [12]:
from segment_anything.modeling.transformer import TwoWayAttentionBlock

block = TwoWayAttentionBlock(256,8,2048,nn.ReLU,2,False)
items = [
    Item("norm1_size", block.norm1.weight.size(), "List"),
    Item("norm2_size", block.norm2.weight.size(), "List"),
    Item("norm3_size", block.norm3.weight.size(), "List"),
    Item("norm4_size", block.norm4.weight.size(), "List"),
    Item("skip_first_layer_pe", block.skip_first_layer_pe, "Bool"),
]
to_file("transformer_two_way_attention_block",items)

#Mocking
def mock_transformer_two_way_attention_block(block:TwoWayAttentionBlock)->TwoWayAttentionBlock:
    mock_layer_norm(block.norm1)
    mock_layer_norm(block.norm2)
    mock_layer_norm(block.norm3)
    mock_layer_norm(block.norm4)
    mock_transformer_attention(block.cross_attn_image_to_token)
    mock_transformer_attention(block.cross_attn_token_to_image)
    mock_transformer_attention(block.self_attn)
    mock_mlp_block(block.mlp)
mock_transformer_two_way_attention_block(block)

#Forward
queries = random_tensor([1,256,256],1)
keys = random_tensor([1,256,256],2)
query_pe = random_tensor([1,256,256],3)
key_pe = random_tensor([1,256,256],4)
out_queries,out_keys = block(queries,keys,query_pe,key_pe)
items = [Item("queries", queries, "TensorFloat"), Item("keys", keys, "TensorFloat"), Item("query_pe", query_pe, "TensorFloat"), Item("key_pe", key_pe, "TensorFloat"), Item("out_queries", out_queries, "TensorFloat"), Item("out_keys", out_keys, "TensorFloat")]
to_file("transformer_two_way_attention_block_forward",items)


#### TwoWayTransformer

In [13]:
from segment_anything.modeling.transformer import TwoWayTransformer

transformer = TwoWayTransformer(2, 64, 4, 256, nn.ReLU, 2)
items =[
    Item("depth", transformer.depth, "Int"),
    Item("embedding_dim", transformer.embedding_dim, "Int"),
    Item("num_heads", transformer.num_heads, "Int"),
    Item("mlp_dim", transformer.mlp_dim, "Int"),
    Item("layers_len", len(transformer.layers), "Int"),
]
to_file("transformer_two_way_transformer",items)

# Mocking
def mock_transformer_two_way_transformer(transformer:TwoWayTransformer)->TwoWayTransformer:
    for i in range(len(transformer.layers)):
        mock_transformer_two_way_attention_block(transformer.layers[i])
    mock_transformer_attention(transformer.final_attn_token_to_image)
    mock_layer_norm(transformer.norm_final_attn)
mock_transformer_two_way_transformer(transformer)

# Forward

image_embedding = random_tensor([1,64,16,16],1)
image_pe = random_tensor([1,64,16,16],2)
point_embedding = random_tensor([16, 256, 64],3)
queries,keys = transformer(image_embedding,image_pe,point_embedding)
items = [Item("image_embedding", image_embedding, "TensorFloat"), Item("image_pe", image_pe, "TensorFloat"), Item("point_embedding", point_embedding, "TensorFloat"), Item("queries", queries, "TensorFloat"), Item("keys", keys, "TensorFloat")]
to_file("transformer_two_way_transformer_forward",items)

## Mask decoder
#### MLP block


In [14]:
from segment_anything.modeling.mask_decoder import MLP

mlp = MLP(256,256,256,4,False)
items = [Item("num_layers", mlp.num_layers, "Int"), Item("sigmoid_output", mlp.sigmoid_output, "Bool"), Item("layers_len",len(mlp.layers),"Int")]
for i in range(len(mlp.layers)):
    items.append(Item("layer"+str(i), mlp.layers[i].weight.size(), "List"))
to_file("mlp",items)

# Mocking
def mock_mlp(mlp:MLP)->MLP:
    for i in range(len(mlp.layers)):
        mock_linear(mlp.layers[i])
mock_mlp(mlp)
# Forward
input = random_tensor([1,256],1)
output = mlp(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
to_file("mlp_forward",items)

#### Mask decoder

In [15]:
from segment_anything.modeling.mask_decoder import MaskDecoder

transformer = TwoWayTransformer(2, 64, 2, 512, nn.ReLU, 2)
mask_decoder = MaskDecoder(transformer_dim=64,transformer=transformer,num_multimask_outputs=3, activation=nn.GELU,iou_head_depth=3,iou_head_hidden_dim=64)
items = [
    Item("transformer_dim", mask_decoder.transformer_dim, "Int"),
    Item("num_multimask_outputs", mask_decoder.num_multimask_outputs, "Int"),
    Item("num_mask_tokens", mask_decoder.num_mask_tokens, "Int"),
]
to_file("mask_decoder",items)

# Mocking
def mock_mask_decoder(mask_decoder:MaskDecoder)->MaskDecoder:
    mock_transformer_two_way_transformer(mask_decoder.transformer)
    mock_embedding(mask_decoder.iou_token)
    mock_embedding(mask_decoder.mask_tokens)
    for i in range(len(mask_decoder.output_hypernetworks_mlps)):
        mock_mlp(mask_decoder.output_hypernetworks_mlps[i])
    mock_mlp(mask_decoder.iou_prediction_head)
    conv = nn.ConvTranspose2d(mask_decoder.transformer_dim, mask_decoder.transformer_dim // 4, kernel_size=2, stride=2)
    conv2 = nn.ConvTranspose2d(mask_decoder.transformer_dim // 4, mask_decoder.transformer_dim // 8, kernel_size=2, stride=2)
    mock_conv_transpose2d(conv)
    mock_conv_transpose2d(conv2)
    mask_decoder.output_upscaling = nn.Sequential(
            conv,
            LayerNorm2d(mask_decoder.transformer_dim // 4),
            nn.GELU(),
            conv2,
            nn.GELU(),
        )
mock_mask_decoder(mask_decoder)

# Forward
image_embedding = random_tensor([1,64,16,16],1)
image_pe = random_tensor([1,64,16,16],2)
sparse_prompt_embeddings = random_tensor([16, 2, 64],3)
dense_prompt_embeddings = random_tensor([16, 64, 16, 16],4)
masks, iou_pred = mask_decoder(image_embedding,image_pe,sparse_prompt_embeddings,dense_prompt_embeddings,True)
items = [Item("image_embedding", image_embedding, "TensorFloat"), Item("image_pe", image_pe, "TensorFloat"), Item("sparse_prompt_embeddings", sparse_prompt_embeddings, "TensorFloat"), Item("dense_prompt_embeddings", dense_prompt_embeddings, "TensorFloat"), Item("masks", masks, "TensorFloat"), Item("iou_pred", iou_pred, "TensorFloat")]
to_file("mask_decoder_forward",items)

In [16]:
# Predict masks
transformer = TwoWayTransformer(2, 64, 2, 512, nn.ReLU, 2)
mask_decoder = MaskDecoder(transformer_dim=64,transformer=transformer,num_multimask_outputs=3, activation=nn.GELU,iou_head_depth=3,iou_head_hidden_dim=64)

# Mocking
mock_mask_decoder(mask_decoder)

# Predict masks
image_embedding = random_tensor([1,64,16,16],1)
image_pe = random_tensor([1,64,16,16],2)
sparse_prompt_embeddings = random_tensor([16, 2, 64],3)
dense_prompt_embeddings = random_tensor([16, 64, 16, 16],4)
masks, iou_pred = mask_decoder.predict_masks(image_embedding,image_pe,sparse_prompt_embeddings,dense_prompt_embeddings)
items = [Item("image_embedding", image_embedding, "TensorFloat"), Item("image_pe", image_pe, "TensorFloat"), Item("sparse_prompt_embeddings", sparse_prompt_embeddings, "TensorFloat"), Item("dense_prompt_embeddings", dense_prompt_embeddings, "TensorFloat"), Item("masks", masks, "TensorFloat"), Item("iou_pred", iou_pred, "TensorFloat")]
to_file("mask_decoder_predict",items)


## Prompt Encoder
#### Positional Embedding

In [17]:
from segment_anything.modeling.prompt_encoder import PositionEmbeddingRandom

position_embedding = PositionEmbeddingRandom(128, None)
items = [Item("gaussian_matrix", position_embedding.positional_encoding_gaussian_matrix.size(), "List")]
to_file("position_embedding_random",items)

def mock_position_embedding_random(position_embedding:PositionEmbeddingRandom)->PositionEmbeddingRandom:
    mock_tensor(position_embedding.positional_encoding_gaussian_matrix)
mock_position_embedding_random(position_embedding)


In [18]:
# _pe_encoding 
position_embedding = PositionEmbeddingRandom(128, None)
mock_position_embedding_random(position_embedding)

input = random_tensor([64,2,2],1)
output = position_embedding._pe_encoding(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
to_file("position_embedding_random_pe_encoding",items)

In [19]:
# Forward
position_embedding = PositionEmbeddingRandom(128, None)
mock_position_embedding_random(position_embedding)

input= (64,64)
output = position_embedding.forward(input)
items = [Item("input", input, "Size"), Item("output", output, "TensorFloat")]
to_file("position_embedding_random_forward",items)

In [20]:
# Forward with coords
position_embedding = PositionEmbeddingRandom(128, None)
mock_position_embedding_random(position_embedding)

input = random_tensor([64,2,2],1)
image_size  = (1024,1024)
output = position_embedding.forward_with_coords(input,image_size)
items = [Item("input", input, "TensorFloat"), Item("image_size", image_size, "Size"), Item("output", output, "TensorFloat")]
to_file("position_embedding_random_forward_with_coords",items)

#### Prompt Encoder

In [21]:
from segment_anything.modeling.prompt_encoder import PromptEncoder

mask_in_chans =16
embed_dim =256
prompt_encoder = PromptEncoder(256,(64,64),(1024,1024),mask_in_chans,nn.GELU)

items = [
    Item("embed_dim", prompt_encoder.embed_dim, "Int"), 
    Item("input_image_size", prompt_encoder.input_image_size, "Size"), 
    Item("image_embedding_size", prompt_encoder.image_embedding_size, "Size"),
    Item("num_point_embeddings", prompt_encoder.num_point_embeddings, "Int"),
    Item("mask_input_size", prompt_encoder.mask_input_size, "Size"),
    ]
to_file("prompt_encoder",items)

def mock_prompt_encoder(encoder:PromptEncoder):
    mock_position_embedding_random(encoder.pe_layer)
    for i in range(len(encoder.point_embeddings)):
        mock_embedding(encoder.point_embeddings[i])
    mock_embedding(encoder.no_mask_embed)
    mock_embedding(encoder.not_a_point_embed)
    conv1 = nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2)
    conv2 = nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2)
    conv3 = nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1)
    mock_conv2d(conv1)
    mock_conv2d(conv2)
    mock_conv2d(conv3)
    encoder.mask_downscaling = nn.Sequential(
            conv1,
            LayerNorm2d(mask_in_chans // 4),
            nn.GELU(),
            conv2,
            LayerNorm2d(mask_in_chans),
            nn.GELU(),
            conv3,
        )

In [22]:
# Embed points
prompt_encoder = PromptEncoder(256,(64,64),(1024,1024),mask_in_chans,nn.GELU)
mock_prompt_encoder(prompt_encoder)

points = random_tensor([64,1,2],1)
labels = random_tensor([64,1],2)
output = prompt_encoder._embed_points(points,labels,True)
items = [Item("points", points, "TensorFloat"), Item("labels", labels, "TensorFloat"), Item("output", output, "TensorFloat")]
to_file("prompt_encoder_embed_points",items)

In [23]:
# Embed boxes
prompt_encoder = PromptEncoder(256,(64,64),(1024,1024),mask_in_chans,nn.GELU)
mock_prompt_encoder(prompt_encoder)

boxes = random_tensor([64,1,2],1)
output = prompt_encoder._embed_boxes(boxes)
items = [Item("boxes", boxes, "TensorFloat"), Item("output", output, "TensorFloat")]
to_file("prompt_encoder_embed_boxes",items)

torch.Size([32, 2, 256]) torch.Size([1, 256]) torch.Size([1, 256])
torch.Size([32, 2, 256])


In [24]:
# # Embed masks
# prompt_encoder = PromptEncoder(256,(64,64),(1024,1024),mask_in_chans,nn.GELU)
# mock_prompt_encoder(prompt_encoder)

# masks = random_tensor([2,2,2,1],1)
# output = prompt_encoder._embed_masks(masks)
# items = [Item("masks", masks, "TensorFloat"), Item("output", output, "TensorFloat")]
# to_file("prompt_encoder_embed_masks",items)

In [25]:
# forward 
prompt_encoder = PromptEncoder(256,(64,64),(1024,1024),mask_in_chans,nn.GELU)
mock_prompt_encoder(prompt_encoder)

points = random_tensor([16,1,2],1),random_tensor([16,1],2)
boxes = None
masks = None
sparse,dense = prompt_encoder.forward(points,boxes,masks)
items = [Item("points", points[0], "TensorFloat"),Item("labels", points[1], "TensorFloat"), Item("sparse", sparse, "TensorFloat"), Item("dense", dense, "TensorFloat")]
to_file("prompt_encoder_forward",items)

In [37]:
# Test
labels = torch.rand([16,2])
point_embeddings = nn.Embedding(1, 256), nn.Embedding(1, 256)
not_a_point_embed = nn.Embedding(1, 256)

val1 = random_tensor([16,2,256],1)
val1[labels == -1] = 0.0
val1[labels == -1] += not_a_point_embed.weight
val1[labels == 0] += point_embeddings[0].weight
val1[labels == 1] += point_embeddings[1].weight
print(val1.shape) # torch.Size([16, 2, 256])

torch.Size([16, 2, 256])


In [38]:
val2 = random_tensor([16,2,256],1)
# Create masks for each condition
mask_minus_one = labels.eq(-1)
mask_zero = labels.eq(0)
mask_one = labels.eq(1)

# Update point_embedding based on the masks
val2 = torch.where(mask_minus_one.unsqueeze(-1), torch.zeros_like(val2), val2)
val2 = torch.where(mask_minus_one.unsqueeze(-1), val2 + not_a_point_embed.weight, val2)
val2 = torch.where(mask_zero.unsqueeze(-1), val2 + point_embeddings[0].weight, val2)
val2 = torch.where(mask_one.unsqueeze(-1), val2 + point_embeddings[1].weight, val2)
print(val2.shape) # torch.Size([16, 2, 256])
torch.equal(val1,val2)

torch.Size([16, 2, 256])


True

In [62]:
corner_embedding = random_tensor([32, 2, 256],1)
point_embeddings = nn.Embedding(1, 256), nn.Embedding(1, 256), nn.Embedding(1, 256), nn.Embedding(1, 256)

corner_embedding[:, 0, :] += point_embeddings[2].weight
corner_embedding[:, 1, :] += point_embeddings[3].weight
print(corner_embedding.shape) # torch.Size([32, 2, 256])

torch.Size([32, 2, 256])


In [66]:
corner_embedding2 = random_tensor([32, 2, 256],1)

corner_embedding2[:, 0, :] = corner_embedding2[:, 0, :]+ point_embeddings[2].weight
corner_embedding2[:, 1, :] =corner_embedding2[:, 1, :]+ point_embeddings[3].weight
print(corner_embedding.shape) # torch.Size([32, 2, 256])
torch.equal(corner_embedding,corner_embedding2)

torch.Size([32, 2, 256])


True

In [64]:
corner_embedding3 = random_tensor([32, 2, 256],1)

# Select and update specific slices along dimension 1
corner_embedding_0 = torch.narrow(corner_embedding3, 1, 0, 1)
corner_embedding_1 = torch.narrow(corner_embedding3, 1, 1, 1)

# Use regular addition and assign the result to new tensors
updated_corner_embedding_0 = corner_embedding_0 + point_embeddings[2].weight.squeeze(0).expand_as(corner_embedding_0)
updated_corner_embedding_1 = corner_embedding_1 + point_embeddings[3].weight.squeeze(0).expand_as(corner_embedding_1)

# Combine the updated slices back into the original tensor
corner_embedding3 = torch.cat((updated_corner_embedding_0, updated_corner_embedding_1), dim=1)

print(corner_embedding3.shape) # torch.Size([32, 2, 256]

torch.eq(corner_embedding,corner_embedding3)

torch.Size([32, 2, 256])


tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        ...,

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]])