# Setup

In [63]:
import torch.nn as nn
import torch
import numpy as np
from helpers import Item, input_to_file,output_to_file,random_ndarray,random_tensor,build_sam_test


In [64]:
import uuid
import os 
import json
from functools import reduce

def set_nested_key(dct, keys, value):
    reduce(lambda d, k: d.setdefault(k, {}), keys[:-1], dct)[keys[-1]] = value

ignored_items =["act","num_heads","scale","use_rel_pos"]
linear_names = ["lin1", "lin2","qkv","proj","q_proj","k_proj","v_proj","out_proj"]
def input_to_file(file_name: str, model: nn.Module):
    json_data = {
        "metadata": {
            "float": "f32",
            "int": "i32",
            "format": "burn_core::record::file::FilePrettyJsonRecorder",
            "version": "0.6.0",
            "settings": "DebugRecordSettings"
        },
        "item": {
            key: None for key in ignored_items
        }
    }
    for name, param in model.named_parameters():
        keys = name.split('.')
        param_id = str(uuid.uuid4())
        param_shape = list(param.size())
        print(name)
        if name.replace(".weight","") in linear_names:
            print("transposing"+name)
            param = param.transpose(0,1)
        param_value = param.flatten().detach().cpu().numpy().tolist()
        param_data = {
            "id": param_id,
            "param": {
                "value": param_value,
                "shape": param_shape
            }
            
        }
        set_nested_key(json_data["item"], keys, param_data)
    
    path = "~/Documents/test-inputs/" + file_name + '.json'
    path = os.path.expanduser(path)
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'w') as json_file:
        json.dump(json_data, json_file, indent=2)


## Common
#### LayerNorm2d

In [65]:
from segment_anything.modeling.common import LayerNorm2d

layer_norm = LayerNorm2d(256,0.1)


# Forward
input = random_tensor([2,256,16,16],1)
output = layer_norm(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
output_to_file("layer_norm_2d",items)
del layer_norm, input, output,items

#### MLPBlock

In [66]:
from segment_anything.modeling.common import MLPBlock
mlp_block = MLPBlock(256,256,nn.GELU)
input_to_file("mlp_block",mlp_block)

# Forward
input = random_tensor([256,256],5)
output = mlp_block(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
output_to_file("mlp_block",items)
del mlp_block, input, output,items

lin1.weight
transposinglin1.weight
lin1.bias
lin2.weight
transposinglin2.weight
lin2.bias
Linear forward torch.Size([256, 256]) torch.Size([256, 256]) torch.Size([256])
Linear forward torch.Size([256, 256]) torch.Size([256, 256]) torch.Size([256])


# Image encoder

#### PatchEmbeded

In [67]:
from segment_anything.modeling.image_encoder import PatchEmbed

patch_embed = PatchEmbed((16,16),(16,16),(0,0),3,320)
input_to_file("patch_embed",patch_embed)

# Forward
input = random_tensor([1,3,512,512],3)
output = patch_embed(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
output_to_file("patch_embed",items)
del patch_embed, input, output,items

proj.weight
transposingproj.weight
proj.bias


#### Attention

In [68]:
from segment_anything.modeling.image_encoder import get_rel_pos,add_decomposed_rel_pos

# Get rel pos
q_size = 32
k_size = 32
input = random_tensor([127,40],1)
output = get_rel_pos( q_size, k_size, input)
items = [Item("input",input,"TensorFloat"),Item("output", output, "TensorFloat")]
output_to_file("get_rel_pos",items)
del input, output

In [69]:
# Add decomposed rel pos
attn = random_tensor([200,49,49],2)
q = random_tensor([200,49,20],3)
relo_pos_h = random_tensor([20,20],4)
relo_pos_w = random_tensor([20,20],5)
q_size = (7,7)
k_size = (7,7)
output = add_decomposed_rel_pos(attn,q,relo_pos_h,relo_pos_w,q_size,k_size)
items = [Item("attn", attn, "TensorFloat"), Item("q", q, "TensorFloat"), Item("q_size", q_size, "Size"), Item("k_size", k_size, "Size"), Item("output", output, "TensorFloat")]
output_to_file("add_decomposed_rel_pos",items)
del attn,q,relo_pos_h,relo_pos_w,q_size,k_size,output,items

In [70]:
from segment_anything.modeling.image_encoder import Attention

# Attention
attention = Attention(320, 16 ,True ,True ,True, (14, 14))
input_to_file("attention",attention)

# Forward
input = random_tensor([25,14,14,320],1)
output = attention(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
output_to_file("attention",items)
del input
del output
del attention

rel_pos_h
rel_pos_w
qkv.weight
transposingqkv.weight
qkv.bias
proj.weight
transposingproj.weight
proj.bias
Linear forward torch.Size([25, 14, 14, 320]) torch.Size([960, 320]) torch.Size([960])
Linear forward torch.Size([25, 14, 14, 320]) torch.Size([320, 320]) torch.Size([320])


In [71]:
from segment_anything.modeling.image_encoder import  window_partition,window_unpartition

# Window partition
input = random_tensor([2,256,16,16],1)
output,size = window_partition(input,16)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat"), Item("size", size, "Size")]
output_to_file("window_partition",items)

# Window unpartition
input = random_tensor([2,256,16,16],2)
output = window_unpartition(input,16,(16,16),(14,14))
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
output_to_file("window_unpartition",items)
del input, output, items

#### Block

In [72]:
from segment_anything.modeling.image_encoder import Block

#Block
block = Block(320,16,4.0,True,nn.LayerNorm,nn.GELU,True,True,14,(64,64))
input_to_file("block",block)
#Forward
input = random_tensor([1,64,64,320],1)
output = block(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
output_to_file("block",items)
del block, input, output

norm1.weight
norm1.bias
attn.rel_pos_h
attn.rel_pos_w
attn.qkv.weight
attn.qkv.bias
attn.proj.weight
attn.proj.bias
norm2.weight
norm2.bias
mlp.lin1.weight
mlp.lin1.bias
mlp.lin2.weight
mlp.lin2.bias
Linear forward torch.Size([25, 14, 14, 320]) torch.Size([960, 320]) torch.Size([960])
Linear forward torch.Size([25, 14, 14, 320]) torch.Size([320, 320]) torch.Size([320])
Linear forward torch.Size([1, 64, 64, 320]) torch.Size([1280, 320]) torch.Size([1280])
Linear forward torch.Size([1, 64, 64, 1280]) torch.Size([320, 1280]) torch.Size([320])


#### Image encoderViT

In [73]:
from segment_anything.modeling.image_encoder import ImageEncoderViT

image_encoder = ImageEncoderViT(32,4,3,80,32,16,4.0,256,True,nn.LayerNorm,nn.GELU,True,True,True,14,[7,15,23,31])
input_to_file("image_encoder",image_encoder)

input = random_tensor([1,3,32,32],1)
output = image_encoder(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
output_to_file("image_encoder",items)

del image_encoder
del input
del output

pos_embed
patch_embed.proj.weight
patch_embed.proj.bias
blocks.0.norm1.weight
blocks.0.norm1.bias
blocks.0.attn.rel_pos_h
blocks.0.attn.rel_pos_w
blocks.0.attn.qkv.weight
blocks.0.attn.qkv.bias
blocks.0.attn.proj.weight
blocks.0.attn.proj.bias
blocks.0.norm2.weight
blocks.0.norm2.bias
blocks.0.mlp.lin1.weight
blocks.0.mlp.lin1.bias
blocks.0.mlp.lin2.weight
blocks.0.mlp.lin2.bias
blocks.1.norm1.weight
blocks.1.norm1.bias
blocks.1.attn.rel_pos_h
blocks.1.attn.rel_pos_w
blocks.1.attn.qkv.weight
blocks.1.attn.qkv.bias
blocks.1.attn.proj.weight
blocks.1.attn.proj.bias
blocks.1.norm2.weight
blocks.1.norm2.bias
blocks.1.mlp.lin1.weight
blocks.1.mlp.lin1.bias
blocks.1.mlp.lin2.weight
blocks.1.mlp.lin2.bias
blocks.2.norm1.weight
blocks.2.norm1.bias
blocks.2.attn.rel_pos_h
blocks.2.attn.rel_pos_w
blocks.2.attn.qkv.weight
blocks.2.attn.qkv.bias
blocks.2.attn.proj.weight
blocks.2.attn.proj.bias
blocks.2.norm2.weight
blocks.2.norm2.bias
blocks.2.mlp.lin1.weight
blocks.2.mlp.lin1.bias
blocks.2.mlp.l

## Transformer
#### Attention

In [74]:
from segment_anything.modeling.transformer import Attention

attention = Attention(32,8,1)
input_to_file("transformer_attention",attention)

#Forward
q = random_tensor([1,32,32],1)
k = random_tensor([1,32,32],2)
v = random_tensor([1,32,32],3)
output = attention.forward(q,k,v)
items = [Item("q", q, "TensorFloat"), Item("k", k, "TensorFloat"), Item("v", v, "TensorFloat"), Item("output", output, "TensorFloat")]
output_to_file("transformer_attention",items)


q_proj.weight
transposingq_proj.weight
q_proj.bias
k_proj.weight
transposingk_proj.weight
k_proj.bias
v_proj.weight
transposingv_proj.weight
v_proj.bias
out_proj.weight
transposingout_proj.weight
out_proj.bias
Linear forward torch.Size([1, 32, 32]) torch.Size([32, 32]) torch.Size([32])
Linear forward torch.Size([1, 32, 32]) torch.Size([32, 32]) torch.Size([32])
Linear forward torch.Size([1, 32, 32]) torch.Size([32, 32]) torch.Size([32])
Linear forward torch.Size([1, 32, 32]) torch.Size([32, 32]) torch.Size([32])


#### TwoWayAttention

In [75]:
from segment_anything.modeling.transformer import TwoWayAttentionBlock

block = TwoWayAttentionBlock(256,8,2048,nn.ReLU,2,False)
input_to_file("transformer_two_way_attention_block",block)

#Forward
queries = random_tensor([1,256,256],1)
keys = random_tensor([1,256,256],2)
query_pe = random_tensor([1,256,256],3)
key_pe = random_tensor([1,256,256],4)
out_queries,out_keys = block(queries,keys,query_pe,key_pe)
items = [Item("queries", queries, "TensorFloat"), Item("keys", keys, "TensorFloat"), Item("query_pe", query_pe, "TensorFloat"), Item("key_pe", key_pe, "TensorFloat"), Item("out_queries", out_queries, "TensorFloat"), Item("out_keys", out_keys, "TensorFloat")]
output_to_file("transformer_two_way_attention_block",items)


self_attn.q_proj.weight
self_attn.q_proj.bias
self_attn.k_proj.weight
self_attn.k_proj.bias
self_attn.v_proj.weight
self_attn.v_proj.bias
self_attn.out_proj.weight
self_attn.out_proj.bias
norm1.weight
norm1.bias
cross_attn_token_to_image.q_proj.weight
cross_attn_token_to_image.q_proj.bias
cross_attn_token_to_image.k_proj.weight
cross_attn_token_to_image.k_proj.bias
cross_attn_token_to_image.v_proj.weight
cross_attn_token_to_image.v_proj.bias
cross_attn_token_to_image.out_proj.weight
cross_attn_token_to_image.out_proj.bias
norm2.weight
norm2.bias
mlp.lin1.weight
mlp.lin1.bias
mlp.lin2.weight
mlp.lin2.bias
norm3.weight
norm3.bias
norm4.weight
norm4.bias
cross_attn_image_to_token.q_proj.weight
cross_attn_image_to_token.q_proj.bias
cross_attn_image_to_token.k_proj.weight
cross_attn_image_to_token.k_proj.bias
cross_attn_image_to_token.v_proj.weight
cross_attn_image_to_token.v_proj.bias
cross_attn_image_to_token.out_proj.weight
cross_attn_image_to_token.out_proj.bias


#### TwoWayTransformer

In [None]:
from segment_anything.modeling.transformer import TwoWayTransformer

transformer = TwoWayTransformer(2, 64, 4, 256, nn.ReLU, 2)
input_to_file("transformer_two_way_transformer",transformer)

# Forward
image_embedding = random_tensor([1,64,16,16],1)
image_pe = random_tensor([1,64,16,16],2)
point_embedding = random_tensor([16, 256, 64],3)
queries,keys = transformer(image_embedding,image_pe,point_embedding)
items = [Item("image_embedding", image_embedding, "TensorFloat"), Item("image_pe", image_pe, "TensorFloat"), Item("point_embedding", point_embedding, "TensorFloat"), Item("queries", queries, "TensorFloat"), Item("keys", keys, "TensorFloat")]
output_to_file("transformer_two_way_transformer",items)

Linear forward torch.Size([16, 256, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([16, 256, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([16, 256, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([16, 256, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([16, 256, 64]) torch.Size([32, 64]) torch.Size([32])
Linear forward torch.Size([1, 256, 64]) torch.Size([32, 64]) torch.Size([32])
Linear forward torch.Size([1, 256, 64]) torch.Size([32, 64]) torch.Size([32])
Linear forward torch.Size([16, 256, 32]) torch.Size([64, 32]) torch.Size([64])
Linear forward torch.Size([16, 256, 64]) torch.Size([256, 64]) torch.Size([256])
Linear forward torch.Size([16, 256, 256]) torch.Size([64, 256]) torch.Size([64])
Linear forward torch.Size([1, 256, 64]) torch.Size([32, 64]) torch.Size([32])
Linear forward torch.Size([16, 256, 64]) torch.Size([32, 64]) torch.Size([32])
Linear forward torch.Size([16, 256, 64]) torch.Size

## Mask decoder
#### MLP block


In [None]:
from segment_anything.modeling.mask_decoder import MLP

mlp = MLP(256,256,256,4,False)
input_to_file("mlp",mlp)

# Forward
input = random_tensor([1,256],1)
output = mlp(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
output_to_file("mlp",items)

Linear forward torch.Size([1, 256]) torch.Size([256, 256]) torch.Size([256])
Linear forward torch.Size([1, 256]) torch.Size([256, 256]) torch.Size([256])
Linear forward torch.Size([1, 256]) torch.Size([256, 256]) torch.Size([256])
Linear forward torch.Size([1, 256]) torch.Size([256, 256]) torch.Size([256])


#### Mask decoder

In [None]:
from segment_anything.modeling.mask_decoder import MaskDecoder

transformer = TwoWayTransformer(2, 64, 2, 512, nn.ReLU, 2)
mask_decoder = MaskDecoder(transformer_dim=64,transformer=transformer,num_multimask_outputs=3, activation=nn.GELU,iou_head_depth=3,iou_head_hidden_dim=64)
input_to_file("mask_decoder",mask_decoder)

# Forward
image_embedding = random_tensor([1,64,16,16],1)
image_pe = random_tensor([1,64,16,16],2)
sparse_prompt_embeddings = random_tensor([16, 2, 64],3)
dense_prompt_embeddings = random_tensor([16, 64, 16, 16],4)
masks, iou_pred = mask_decoder(image_embedding,image_pe,sparse_prompt_embeddings,dense_prompt_embeddings,True)
items = [Item("image_embedding", image_embedding, "TensorFloat"), Item("image_pe", image_pe, "TensorFloat"), Item("sparse_prompt_embeddings", sparse_prompt_embeddings, "TensorFloat"), Item("dense_prompt_embeddings", dense_prompt_embeddings, "TensorFloat"), Item("masks", masks, "TensorFloat"), Item("iou_pred", iou_pred, "TensorFloat")]
output_to_file("mask_decoder",items)

Linear forward torch.Size([16, 7, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([16, 7, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([16, 7, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([16, 7, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([16, 7, 64]) torch.Size([32, 64]) torch.Size([32])
Linear forward torch.Size([16, 256, 64]) torch.Size([32, 64]) torch.Size([32])
Linear forward torch.Size([16, 256, 64]) torch.Size([32, 64]) torch.Size([32])
Linear forward torch.Size([16, 7, 32]) torch.Size([64, 32]) torch.Size([64])
Linear forward torch.Size([16, 7, 64]) torch.Size([512, 64]) torch.Size([512])
Linear forward torch.Size([16, 7, 512]) torch.Size([64, 512]) torch.Size([64])
Linear forward torch.Size([16, 256, 64]) torch.Size([32, 64]) torch.Size([32])
Linear forward torch.Size([16, 7, 64]) torch.Size([32, 64]) torch.Size([32])
Linear forward torch.Size([16, 7, 64]) torch.Size([32, 64]) torch.

In [None]:
# Predict masks
transformer = TwoWayTransformer(2, 64, 2, 512, nn.ReLU, 2)
mask_decoder = MaskDecoder(transformer_dim=64,transformer=transformer,num_multimask_outputs=3, activation=nn.GELU,iou_head_depth=3,iou_head_hidden_dim=64)
input_to_file("mask_decoder_predict",mask_decoder)

# Predict masks
image_embedding = random_tensor([1,64,16,16],1)
image_pe = random_tensor([1,64,16,16],2)
sparse_prompt_embeddings = random_tensor([16, 2, 64],3)
dense_prompt_embeddings = random_tensor([16, 64, 16, 16],4)
masks, iou_pred = mask_decoder.predict_masks(image_embedding,image_pe,sparse_prompt_embeddings,dense_prompt_embeddings)
print("asas",masks.shape,iou_pred.shape)
items = [Item("image_embedding", image_embedding, "TensorFloat"), Item("image_pe", image_pe, "TensorFloat"), Item("sparse_prompt_embeddings", sparse_prompt_embeddings, "TensorFloat"), Item("dense_prompt_embeddings", dense_prompt_embeddings, "TensorFloat"), Item("masks", masks, "TensorFloat"), Item("iou_pred", iou_pred, "TensorFloat")]
output_to_file("mask_decoder_predict",items)


Linear forward torch.Size([16, 7, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([16, 7, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([16, 7, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([16, 7, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([16, 7, 64]) torch.Size([32, 64]) torch.Size([32])
Linear forward torch.Size([16, 256, 64]) torch.Size([32, 64]) torch.Size([32])
Linear forward torch.Size([16, 256, 64]) torch.Size([32, 64]) torch.Size([32])
Linear forward torch.Size([16, 7, 32]) torch.Size([64, 32]) torch.Size([64])
Linear forward torch.Size([16, 7, 64]) torch.Size([512, 64]) torch.Size([512])
Linear forward torch.Size([16, 7, 512]) torch.Size([64, 512]) torch.Size([64])
Linear forward torch.Size([16, 256, 64]) torch.Size([32, 64]) torch.Size([32])
Linear forward torch.Size([16, 7, 64]) torch.Size([32, 64]) torch.Size([32])
Linear forward torch.Size([16, 7, 64]) torch.Size([32, 64]) torch.

## Prompt Encoder
#### Positional Embedding

In [None]:
from segment_anything.modeling.prompt_encoder import PositionEmbeddingRandom

# _pe_encoding 
position_embedding = PositionEmbeddingRandom(128, None)
input_to_file("position_embedding_random_pe_encoding",position_embedding)

input = random_tensor([64,2,2],1)
output = position_embedding._pe_encoding(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
output_to_file("position_embedding_random_pe_encoding",items)

In [None]:
# Forward
position_embedding = PositionEmbeddingRandom(128, None)
input_to_file("position_embedding_random_forward",position_embedding)

input= (64,64)
output = position_embedding.forward(input)
items = [Item("input", input, "Size"), Item("output", output, "TensorFloat")]
output_to_file("position_embedding_random_forward",items)

In [None]:
# Forward with coords
position_embedding = PositionEmbeddingRandom(128, None)
input_to_file("position_embedding_random_forward_with_coords",position_embedding)

input = random_tensor([64,2,2],1)
image_size  = (1024,1024)
output = position_embedding.forward_with_coords(input,image_size)
items = [Item("input", input, "TensorFloat"), Item("image_size", image_size, "Size"), Item("output", output, "TensorFloat")]
output_to_file("position_embedding_random_forward_with_coords",items)

#### Prompt Encoder

In [None]:
from segment_anything.modeling.prompt_encoder import PromptEncoder

mask_in_chans =8
embed_dim =128
def init_prompt_encoder():
    prompt_encoder = PromptEncoder(embed_dim,(32,32),(512,512),mask_in_chans,nn.GELU)
    input_to_file("prompt_encoder",prompt_encoder)
    return prompt_encoder

In [None]:
# Embed points
prompt_encoder = init_prompt_encoder()

points = random_tensor([32,1,2],1)
labels = random_tensor([32,1],2)
output = prompt_encoder._embed_points(points,labels,True)
items = [Item("points", points, "TensorFloat"), Item("labels", labels, "TensorFloat"), Item("output", output, "TensorFloat")]
output_to_file("prompt_encoder_embed_points",items)

In [None]:
# Embed boxes
prompt_encoder = init_prompt_encoder()
input_to_file("prompt_encoder_embed_boxes",prompt_encoder)

boxes = random_tensor([32,1,2],1)
output = prompt_encoder._embed_boxes(boxes)
items = [Item("boxes", boxes, "TensorFloat"), Item("output", output, "TensorFloat")]
output_to_file("prompt_encoder_embed_boxes",items)

In [None]:
# # Embed masks
prompt_encoder = init_prompt_encoder()
input_to_file("prompt_encoder_embed_masks",prompt_encoder)

masks = random_tensor([8,1,4,4],1)
output = prompt_encoder._embed_masks(masks)
items = [Item("masks", masks, "TensorFloat"), Item("output", output, "TensorFloat")]
output_to_file("prompt_encoder_embed_masks",items)

In [None]:
# forward 
prompt_encoder = init_prompt_encoder()
input_to_file("prompt_encoder_forward",prompt_encoder)

points = random_tensor([8,1,2],1),random_tensor([8,1],2)
boxes = None
masks = None
sparse,dense = prompt_encoder.forward(points,boxes,masks)
items = [Item("points", points[0], "TensorFloat"),Item("labels", points[1], "TensorFloat"), Item("sparse", sparse, "TensorFloat"), Item("dense", dense, "TensorFloat")]
output_to_file("prompt_encoder_forward",items)

## Utils
#### ResizeLongestSide

In [None]:
from segment_anything.utils.transforms import ResizeLongestSide

# Get Preprocess shape
resize = ResizeLongestSide(64)
output = resize.get_preprocess_shape(32,32,64)
items = [Item("output", output, "Size")]
output_to_file("resize_get_preprocess_shape",items)

In [None]:
# Apply image
resize = ResizeLongestSide(64)
input = random_tensor([120,180,3],1).mul(255).type(torch.uint8).numpy()
output = resize.apply_image(input)
items = [Item("input", input, "TensorUint8"), Item("output", output, "TensorUint8")]
output_to_file("resize_apply_image",items)

In [None]:
# Apply coords
resize = ResizeLongestSide(64)

input = random_tensor([1, 2, 2],1).numpy()
original_size = (1200,1800)
output = resize.apply_coords(input,original_size)
items = [Item("original_size", original_size, "Size"),Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
output_to_file("resize_apply_coords",items)

In [None]:
# Apply boxes
resize = ResizeLongestSide(64)

boxes=random_tensor([1, 4],1).numpy()
original_size = (1200,1800)
output = resize.apply_boxes(boxes,original_size)
items = [Item("original_size", original_size, "Size"),Item("boxes", boxes, "TensorFloat"), Item("output", output, "TensorFloat")]
output_to_file("resize_apply_boxes",items)

In [None]:
# Apply image torch
resize  = ResizeLongestSide(64)
input = random_tensor([1, 3, 32, 32],1)
output = resize.apply_image_torch(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
output_to_file("resize_apply_image_torch",items)

In [None]:
# Apply coords torch
resize  = ResizeLongestSide(64)
coords = random_tensor([32,32],1)
original_size = (32,32)
output = resize.apply_coords_torch(coords,original_size)
items = [Item("coords", coords, "TensorFloat"), Item("original_size", original_size, "Size"), Item("output", output, "TensorFloat")]
output_to_file("resize_apply_coords_torch",items)

In [None]:
# Apply boxes torch
resize  = ResizeLongestSide(64)
boxes = random_tensor([32,32],1)
original_size = (32,32)
output = resize.apply_boxes_torch(boxes,original_size)
items = [Item("boxes", boxes, "TensorFloat"), Item("original_size", original_size, "Size"), Item("output", output, "TensorFloat")]
output_to_file("resize_apply_boxes_torch",items)

## Build Sam

In [None]:
from segment_anything.build_sam import build_sam_vit_h,Sam,build_sam_vit_b,build_sam_vit_l

def get_items(sam:Sam):
    items = [
        Item("mask_threshold", sam.mask_threshold, "Float"),
        Item("image_format",sam.image_format, "String"),
        Item("pixel_mean", sam.pixel_mean, "TensorFloat"),
        Item("pixel_std", sam.pixel_std, "TensorFloat"),
        Item("mask_decoder.num_mask_tokens", sam.mask_decoder.num_mask_tokens, "Int"),
        Item("prompt_encoder.embed_dim", sam.prompt_encoder.embed_dim, "Int"),
        Item("prompt_encoder.input_image_size", sam.prompt_encoder.input_image_size, "Size"),
    ]
    return items

sam_vit_h = build_sam_vit_h()
output_to_file("sam_vit_h",get_items(sam_vit_h))

sam_vit_b = build_sam_vit_b()
output_to_file("sam_vit_b",get_items(sam_vit_b))

sam_vit_l = build_sam_vit_l()
output_to_file("sam_vit_l",get_items(sam_vit_l))

del sam_vit_h, sam_vit_b, sam_vit_l

## Sam

In [None]:
# Init
sam = build_sam_test()
input_to_file("sam_test",sam)

In [None]:
# Forward
sam = build_sam_test()

batched_input = [
     {
         'image': random_tensor([3,8,8],1),
         'boxes': random_tensor([4,4],1),
         'original_size': (100,200)
     },
     {
         'image': random_tensor([3,8,8],2),
         'boxes': random_tensor([4,4],2),
         'original_size': (50,80)
     }
]
output = sam.forward(batched_input,False)
items=[]
i=0
for x in output:
    masks = x['masks']
    items.append(Item("masks"+str(i), masks, "TensorBool"))

    iou_predictions = x['iou_predictions']
    items.append(Item("iou_predictions"+str(i), iou_predictions, "TensorFloat"))

    if 'low_res_logits' in x:
        low_res_masks = x['low_res_logits']
        items.append(Item("low_res_logits"+str(i), low_res_masks, "TensorFloat"))
    i+=1

    
output_to_file("sam_forward",items)
del sam

Linear forward torch.Size([50, 14, 14, 64]) torch.Size([192, 64]) torch.Size([192])
Linear forward torch.Size([50, 14, 14, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([2, 64, 64, 64]) torch.Size([256, 64]) torch.Size([256])
Linear forward torch.Size([2, 64, 64, 256]) torch.Size([64, 256]) torch.Size([64])
Linear forward torch.Size([50, 14, 14, 64]) torch.Size([192, 64]) torch.Size([192])
Linear forward torch.Size([50, 14, 14, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([2, 64, 64, 64]) torch.Size([256, 64]) torch.Size([256])
Linear forward torch.Size([2, 64, 64, 256]) torch.Size([64, 256]) torch.Size([64])
Linear forward torch.Size([2, 64, 64, 64]) torch.Size([192, 64]) torch.Size([192])
Linear forward torch.Size([2, 64, 64, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([2, 64, 64, 64]) torch.Size([256, 64]) torch.Size([256])
Linear forward torch.Size([2, 64, 64, 256]) torch.Size([64, 256]) torch.Size([64])
Linear

In [None]:
# Postprocess masks
sam = build_sam_test()

masks = random_tensor([4,1,256,256],1)
input_size = (684,1024)
original_size = (534,800)
output = sam.postprocess_masks(masks,input_size,original_size)
items = [Item("masks", masks, "TensorFloat"), Item("input_size", input_size, "Size"), Item("original_size", original_size, "Size"), Item("output", output, "TensorFloat")]
output_to_file("sam_postprocess_masks",items)
del sam

In [None]:
# Preprocess
sam = build_sam_test()

input = random_tensor([1,3,171,128],1)
output = sam.preprocess(input)
items = [Item("input", input, "TensorFloat"), Item("output", output, "TensorFloat")]
output_to_file("sam_preprocess",items)
del sam

## Sam Predictor

In [None]:
# Init
from segment_anything.predictor import SamPredictor

def get_predictor(with_set_image:bool=False):
    sam = build_sam_test()
    predictor = SamPredictor(sam)
    if with_set_image:
        image = random_tensor([120,180,3],1).mul(255).type(torch.uint8).numpy()
        predictor.set_image(image,"RGB")
    return predictor


In [None]:
# Set image 
predictor = get_predictor(True)

items=[
    Item("original_size",predictor.original_size,"Size"),
    Item("input_size",predictor.input_size,"Size"),
    Item("features",predictor.features,"TensorFloat"),
    Item("is_image_set",predictor.is_image_set,"Bool"),
]
output_to_file("predictor_set_image",items)
del predictor,

Linear forward torch.Size([25, 14, 14, 64]) torch.Size([192, 64]) torch.Size([192])
Linear forward torch.Size([25, 14, 14, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([256, 64]) torch.Size([256])
Linear forward torch.Size([1, 64, 64, 256]) torch.Size([64, 256]) torch.Size([64])
Linear forward torch.Size([25, 14, 14, 64]) torch.Size([192, 64]) torch.Size([192])
Linear forward torch.Size([25, 14, 14, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([256, 64]) torch.Size([256])
Linear forward torch.Size([1, 64, 64, 256]) torch.Size([64, 256]) torch.Size([64])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([192, 64]) torch.Size([192])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([256, 64]) torch.Size([256])
Linear forward torch.Size([1, 64, 64, 256]) torch.Size([64, 256]) torch.Size([64])
Linear

In [None]:
# Set torch image
predictor = get_predictor()

image = random_tensor([1, 3, 683, 1024],1)
original_size = (120, 180)
predictor.set_torch_image(image,original_size)
items=[
    Item("original_size",predictor.original_size,"Size"),
    Item("input_size",predictor.input_size,"Size"),
    Item("features",predictor.features,"TensorFloat"),
    Item("is_image_set",predictor.is_image_set,"Bool"),
]
output_to_file("predictor_set_torch_image",items)

Linear forward torch.Size([25, 14, 14, 64]) torch.Size([192, 64]) torch.Size([192])
Linear forward torch.Size([25, 14, 14, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([256, 64]) torch.Size([256])
Linear forward torch.Size([1, 64, 64, 256]) torch.Size([64, 256]) torch.Size([64])
Linear forward torch.Size([25, 14, 14, 64]) torch.Size([192, 64]) torch.Size([192])
Linear forward torch.Size([25, 14, 14, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([256, 64]) torch.Size([256])
Linear forward torch.Size([1, 64, 64, 256]) torch.Size([64, 256]) torch.Size([64])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([192, 64]) torch.Size([192])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([256, 64]) torch.Size([256])
Linear forward torch.Size([1, 64, 64, 256]) torch.Size([64, 256]) torch.Size([64])
Linear

In [None]:
# Predict
predictor = get_predictor(True)

point_coords = random_ndarray([1,2],1)
point_labels = random_tensor([1],1).mul(255).type(torch.int).numpy()

masks, iou_predictions, low_res_masks =predictor.predict(point_coords,point_labels,None,None,True,False)
items =[
    Item("masks", masks, "TensorBool"),
    Item("iou_predictions", iou_predictions, "TensorFloat"),
    Item("low_res_masks", low_res_masks, "TensorFloat"),
]
output_to_file("predictor_predict",items)
del point_coords, point_labels, masks, iou_predictions, low_res_masks,predictor

Linear forward torch.Size([25, 14, 14, 64]) torch.Size([192, 64]) torch.Size([192])
Linear forward torch.Size([25, 14, 14, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([256, 64]) torch.Size([256])
Linear forward torch.Size([1, 64, 64, 256]) torch.Size([64, 256]) torch.Size([64])
Linear forward torch.Size([25, 14, 14, 64]) torch.Size([192, 64]) torch.Size([192])
Linear forward torch.Size([25, 14, 14, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([256, 64]) torch.Size([256])
Linear forward torch.Size([1, 64, 64, 256]) torch.Size([64, 256]) torch.Size([64])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([192, 64]) torch.Size([192])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([256, 64]) torch.Size([256])
Linear forward torch.Size([1, 64, 64, 256]) torch.Size([64, 256]) torch.Size([64])
Linear

In [None]:
# Predict torch
predictor = get_predictor(True)

point_coords = random_tensor([1,1,2],1)
point_labels = random_tensor([1,1],1)
masks, iou_predictions, low_res_masks = predictor.predict_torch(point_coords,point_labels,None,None,True,False)
items =[
    Item("masks", masks, "TensorBool"),
    Item("iou_predictions", iou_predictions, "TensorFloat"),
    Item("low_res_masks", low_res_masks, "TensorFloat")
]
output_to_file("predictor_predict_torch",items)
del predictor,point_coords,point_labels,masks,iou_predictions,low_res_masks

Linear forward torch.Size([25, 14, 14, 64]) torch.Size([192, 64]) torch.Size([192])
Linear forward torch.Size([25, 14, 14, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([256, 64]) torch.Size([256])
Linear forward torch.Size([1, 64, 64, 256]) torch.Size([64, 256]) torch.Size([64])
Linear forward torch.Size([25, 14, 14, 64]) torch.Size([192, 64]) torch.Size([192])
Linear forward torch.Size([25, 14, 14, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([256, 64]) torch.Size([256])
Linear forward torch.Size([1, 64, 64, 256]) torch.Size([64, 256]) torch.Size([64])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([192, 64]) torch.Size([192])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([64, 64]) torch.Size([64])
Linear forward torch.Size([1, 64, 64, 64]) torch.Size([256, 64]) torch.Size([256])
Linear forward torch.Size([1, 64, 64, 256]) torch.Size([64, 256]) torch.Size([64])
Linear