In [1]:
import torch
from hmr.model.smpl_head.transformer import SMPLTransformerDecoderHead
from torchinfo import summary

In [2]:
context_dim = 1280

smpl_head = SMPLTransformerDecoderHead(
    smpl_mean_params_path="./data/smpl_mean_params.npz",
    depth=6,
    heads=6,
    mlp_dim=1024,
    dim_head=64,
    dropout=0.0,
    emb_dropout=0.0,
    norm="layer",
    context_dim=context_dim,
    smpl_num_body_joints=23,
)
_ = smpl_head.eval()

In [3]:
H, W = (16, 12)
inp = torch.randn(1, context_dim, 16, 12)

In [7]:
pred_smpl_params, pred_cam, _ = smpl_head(inp)
global_orient = pred_smpl_params["global_orient"]
body_pose = pred_smpl_params["body_pose"]
betas = pred_smpl_params["betas"]

print(f"{global_orient.shape=}")
print(f"{body_pose.shape=}")
print(f"{betas.shape=}")
print(f"{pred_cam.shape}")

global_orient.shape=torch.Size([1, 1, 3, 3])
body_pose.shape=torch.Size([1, 23, 3, 3])
betas.shape=torch.Size([1, 10])
torch.Size([1, 3])


In [8]:
summary(smpl_head, input_data=inp, col_names=("input_size", "output_size", "num_params"), depth=100)

Layer (type:depth-idx)                                  Input Shape               Output Shape              Param #
SMPLTransformerDecoderHead                              [1, 1280, 16, 12]         [1, 1, 3, 3]              --
├─TransformerDecoder: 1-1                               [1, 1, 1]                 [1, 1, 1024]              1,024
│    └─Linear: 2-1                                      [1, 1, 1]                 [1, 1, 1024]              2,048
│    └─DropTokenDropout: 2-2                            [1, 1, 1024]              [1, 1, 1024]              --
│    └─TransformerCrossAttn: 2-3                        [1, 1, 1024]              [1, 1, 1024]              --
│    │    └─ModuleList: 3-1                             --                        --                        --
│    │    │    └─ModuleList: 4-1                        --                        --                        --
│    │    │    │    └─PreNorm: 5-1                      [1, 1, 1024]              [1, 1, 1024]       