Skip to content

Commit

Permalink
Replace matching loss with feature rendering loss; Fix bugs in LBS; S…
Browse files Browse the repository at this point in the history
…tablize optimization.
  • Loading branch information
Gengshan Yang committed Apr 11, 2022
1 parent 58d8932 commit 0d507fe
Show file tree
Hide file tree
Showing 17 changed files with 184 additions and 277 deletions.
25 changes: 11 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#### [[Webpage]](https://banmo-www.github.io/) [[Latest preprint (02/14/2022)]](https://banmo-www.github.io/banmo-2-14.pdf) [[Arxiv]](https://arxiv.org/abs/2112.12761)

### Changelog
- **04/11**: Replace matching loss with feature rendering loss; Fix bugs in LBS; Stablize optimization.
- **03/20**: Add mesh color option (canonical mappihg vs radiance) during surface extraction. See `--ce_color` flag.
- **02/23**: Improve NVS with fourier light code, improve uncertainty MLP, add long schedule, minor speed up.
- **02/17**: Add adaptation to a new video, optimization with known root poses, and pose code visualization.
Expand Down Expand Up @@ -90,31 +91,29 @@ For more examples, see [here](./scripts/README.md).

<details><summary>Hardware/time for running the demo</summary>

By default, it takes 8 hours on 2 V100 GPUs and 15 hours on 1 V100 GPU.
The [short schedule](./scripts/template-short.sh) takes 4 hours on 2 V100 GPUs (+SSD storage).
To reach higher quality, the [full schedule](./scripts/template.sh) takes 12 hours.
We provide a [script](./scripts/template-accu.sh) that use gradient accumulation
to support experiments on fewer GPUs / GPU with lower memory.
</details>

<details><summary>Setting good hyper-parameter for your own videos</summary>
<details><summary>Setting good hyper-parameter for videos with various length</summary>

When optimizing your own videos, a rule of thumb is to set
"num gpus" x "batch size" x "accu steps" ~= num frames (default number 512 suits for cat-pikachiu and human-hap)
When optimizing videos with different lengths, we found it useful to scale batchsize with the number of frames.
A rule of thumb is to set "num gpus" x "batch size" x "accu steps" ~= num frames.
This means more video frames needs more GPU memory but the same optimization time.
</details>

<details><summary>Try pre-optimized models</summary>

We provide pre-optimized models and scripts to run mesh extraction and novel view synthesis.

|seqname | download link |
|---|---|
|cat-pikachiu|[.npy](https://www.dropbox.com/s/nc2aawnwrmil8jr/cat-pikachiu.npy), [.pth](https://www.dropbox.com/s/i8sjlgbom5eoy0j/cat-pikachiu.pth)|
|cat-coco|[.npy](https://www.dropbox.com/s/fwf8il8bt9c812f/cat-coco.npy), [.pth](https://www.dropbox.com/s/4g0w6z4xec4f88g/cat-coco.pth)|
We provide [pre-optimized models](https://www.dropbox.com/sh/5ue6tpsqmt6gstw/AAB9FD6on0UZDnThr6GEde46a?dl=0)
and scripts to run mesh extraction and novel view synthesis.

```
# download pre-optimized models
mkdir -p tmp && cd "$_"
wget https://www.dropbox.com/s/nc2aawnwrmil8jr/cat-pikachiu.npy
wget https://www.dropbox.com/s/i8sjlgbom5eoy0j/cat-pikachiu.pth
wget https://www.dropbox.com/s/qzwuqxp0mzdot6c/cat-pikachiu.npy
wget https://www.dropbox.com/s/dnob0r8zzjbn28a/cat-pikachiu.pth
cd ../
seqname=cat-pikachiu
Expand Down Expand Up @@ -184,8 +183,6 @@ https://user-images.githubusercontent.com/13134872/154554210-3bb0a439-fe46-4ea3-

</details>

Use more iterations for for better color rendering and novel view synthesis results, see `scripts/template-long.sh`.

#### 2. Visualization tools
<details><summary>[Tensorboard]</summary>

Expand Down
79 changes: 26 additions & 53 deletions nnutils/banmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
flags.DEFINE_bool('symm_shape', False, 'whether to set geometry to x-symmetry')
flags.DEFINE_bool('env_code', True, 'whether to use environment code for each video')
flags.DEFINE_bool('env_fourier', True, 'whether to use fourier basis for env')
flags.DEFINE_bool('use_unc',True, 'whether to use uncertainty sampling')
flags.DEFINE_bool('use_unc',False, 'whether to use uncertainty sampling')
flags.DEFINE_bool('nerf_vis', True, 'use visibility volume')
flags.DEFINE_bool('anneal_freq', False, 'whether to use frequency annealing')
flags.DEFINE_integer('alpha', 10, 'maximum frequency for fourier features')
Expand Down Expand Up @@ -97,7 +97,7 @@
flags.DEFINE_float('dskin_steps', 0.8, 'steps to add delta skinning weights')
flags.DEFINE_float('init_beta', 0.1, 'initial value for transparency beta')
flags.DEFINE_bool('reset_beta', False, 'reset volsdf beta to 0.1')
flags.DEFINE_float('fine_steps', 0.8, 'by default, not using fine samples')
flags.DEFINE_float('fine_steps', 1.1, 'by default, not using fine samples')
flags.DEFINE_float('nf_reset', 0.5, 'by default, start reseting near-far plane at 50%')
flags.DEFINE_float('bound_reset', 0.5, 'by default, start reseting bound from 50%')
flags.DEFINE_float('bound_factor', 2, 'by default, use a loose bound')
Expand All @@ -112,6 +112,7 @@
# optimization: fine-tuning
flags.DEFINE_bool('keep_pose_basis', True, 'keep pose basis when loading models at train time')
flags.DEFINE_bool('freeze_coarse', False, 'whether to freeze coarse posec of MLP')
flags.DEFINE_bool('freeze_root', False, 'whether to freeze root body pose')
flags.DEFINE_bool('root_stab', True, 'whether to stablize root at ft')
flags.DEFINE_bool('freeze_cvf', False, 'whether to freeze canonical features')
flags.DEFINE_bool('freeze_shape',False, 'whether to freeze canonical shape')
Expand Down Expand Up @@ -148,7 +149,8 @@
flags.DEFINE_float('total_wt', 1, 'by default, multiple total loss by 1')
flags.DEFINE_float('sil_wt', 0.1, 'weight for silhouette loss')
flags.DEFINE_float('img_wt', 0.1, 'weight for silhouette loss')
flags.DEFINE_float('feat_wt', 0.2, 'by default, multiple feat loss by 1')
flags.DEFINE_float('feat_wt', 0., 'by default, multiple feat loss by 1')
flags.DEFINE_float('frnd_wt', 1., 'by default, multiple feat loss by 1')
flags.DEFINE_float('proj_wt', 0.02, 'by default, multiple proj loss by 1')
flags.DEFINE_float('flow_wt', 1, 'by default, multiple flow loss by 1')
flags.DEFINE_float('cyc_wt', 1, 'by default, multiple cyc loss by 1')
Expand Down Expand Up @@ -420,7 +422,8 @@ def __init__(self, opts, data_info):
# input, (x,y,t)+code, output, (1)
vid_code_dim=32 # add video-specific code
self.vid_code = embed_net(self.num_vid, vid_code_dim)
self.nerf_unc = NeRFUnc(in_channels_xyz=in_channels_xyz, D=5, W=128,
#self.nerf_unc = NeRFUnc(in_channels_xyz=in_channels_xyz, D=5, W=128,
self.nerf_unc = NeRFUnc(in_channels_xyz=in_channels_xyz, D=8, W=256,
out_channels=1,in_channels_dir=vid_code_dim, raw_feat=True, init_beta=1.)
self.nerf_models['nerf_unc'] = self.nerf_unc

Expand Down Expand Up @@ -498,6 +501,8 @@ def forward_default(self, batch):
print('%d removed from sil'%(invalid_idx.sum()))

img_loss_samp = opts.img_wt*rendered['img_loss_samp']
if opts.loss_flt:
img_loss_samp[invalid_idx] *= 0
img_loss = img_loss_samp
if opts.rm_novp:
img_loss = img_loss * rendered['sil_coarse'].detach()
Expand All @@ -508,27 +513,22 @@ def forward_default(self, batch):
aux_out['img_loss'] = img_loss
total_loss = img_loss
total_loss = total_loss + sil_loss


# feat rnd loss
frnd_loss_samp = opts.frnd_wt*rendered['frnd_loss_samp']
if opts.loss_flt:
frnd_loss_samp[invalid_idx] *= 0
if opts.rm_novp:
frnd_loss_samp = frnd_loss_samp * rendered['sil_coarse'].detach()
feat_rnd_loss = frnd_loss_samp[sil_at_samp[...,0]>0].mean() # eval on valid pts
aux_out['feat_rnd_loss'] = feat_rnd_loss
total_loss = total_loss + feat_rnd_loss

# flow loss
if opts.use_corresp:
if opts.loss_flt:
# find flow window
dframe = (self.frameid.view(2,-1).flip(0).reshape(-1) - \
self.frameid).abs()
didxs = dframe.log2().long()
for didx in range(6):
subidx = didxs==didx
flo_err, invalid_idx = loss_filter(self.latest_vars['flo_err'][:,didx],
rendered['flo_loss_samp'][subidx],
sil_at_samp_flo[subidx], scale_factor=20)
self.latest_vars['flo_err'][self.errid.long()[subidx],didx] = flo_err
if self.progress > (opts.warmup_steps):
#print('%d removed from flow'%(invalid_idx.sum()))
flo_loss_samp_sub = rendered['flo_loss_samp'][subidx]
flo_loss_samp_sub[invalid_idx] *= 0.
rendered['flo_loss_samp'][subidx] = flo_loss_samp_sub

flo_loss_samp = rendered['flo_loss_samp']
if opts.loss_flt:
flo_loss_samp[invalid_idx] *= 0
if opts.rm_novp:
flo_loss_samp = flo_loss_samp * rendered['sil_coarse'].detach()

Expand All @@ -548,21 +548,7 @@ def forward_default(self, batch):
if opts.use_embed:
feat_err_samp = rendered['feat_err']* opts.feat_wt
if opts.loss_flt:
if opts.lineload:
invalid_idx = loss_filter_line(self.latest_vars['fp_err'][:,0],
self.errid.long(),self.frameid.long(),
feat_err_samp * sil_at_samp,
opts.img_size, scale_factor=10)
else:
# loss filter
feat_err, invalid_idx = loss_filter(self.latest_vars['fp_err'][:,0],
feat_err_samp,
sil_at_samp>0)
self.latest_vars['fp_err'][self.errid.long(),0] = feat_err
if self.progress > (opts.warmup_steps):
feat_err_samp[invalid_idx] *= 0.
if invalid_idx.sum()>0:
print('%d removed from feat'%(invalid_idx.sum()))
feat_err_samp[invalid_idx] *= 0

feat_loss = feat_err_samp
if opts.rm_novp:
Expand All @@ -576,20 +562,7 @@ def forward_default(self, batch):
if opts.use_proj:
proj_err_samp = rendered['proj_err']* opts.proj_wt
if opts.loss_flt:
if opts.lineload:
invalid_idx = loss_filter_line(self.latest_vars['fp_err'][:,1],
self.errid.long(),self.frameid.long(),
proj_err_samp * sil_at_samp,
opts.img_size, scale_factor=10)
else:
proj_err, invalid_idx = loss_filter(self.latest_vars['fp_err'][:,1],
proj_err_samp,
sil_at_samp>0)
self.latest_vars['fp_err'][self.errid.long(),1] = proj_err
if self.progress > (opts.warmup_steps):
proj_err_samp[invalid_idx] *= 0.
if invalid_idx.sum()>0:
print('%d removed from proj'%(invalid_idx.sum()))
proj_err_samp[invalid_idx] *= 0

proj_loss = proj_err_samp[sil_at_samp>0].mean()
aux_out['proj_loss'] = proj_loss
Expand Down Expand Up @@ -677,8 +650,8 @@ def forward_default(self, batch):
unc_sil = sil_loss_samp[...,0]
#unc_accumulated = unc_feat + unc_proj
#unc_accumulated = unc_feat + unc_proj + unc_rgb*0.1
unc_accumulated = unc_feat + unc_proj + unc_rgb
# unc_accumulated = unc_rgb
# unc_accumulated = unc_feat + unc_proj + unc_rgb
unc_accumulated = unc_rgb
# unc_accumulated = unc_rgb + unc_sil

unc_loss = (unc_accumulated.detach() - unc_pred[...,0]).pow(2)
Expand Down
63 changes: 33 additions & 30 deletions nnutils/geom_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def bone_transform(bones_in, rts, is_vec=False):
"""
bones_in: 1,B,10 - B gaussian ellipsoids of bone coordinates
rts: ...,B,3,4 - B ririd transforms
rts are transformation applied to bone coordinates (right multiply)
rts are applied to bone coordinate transforms (left multiply)
is_vec: whether rts are stored as r1...9,t1...3 vector form
"""
B = bones_in.shape[-2]
Expand All @@ -72,9 +72,10 @@ def bone_transform(bones_in, rts, is_vec=False):
Rmat = rts[:,:,:3,:3]
Tmat = rts[:,:,:3,3:4]

center = transforms.quaternion_to_matrix(orient).matmul(Tmat)[...,0]+center
# move bone coordinates (left multiply)
center = Rmat.matmul(center[...,None])[...,0]+Tmat[...,0]
Rquat = transforms.matrix_to_quaternion(Rmat)
orient = transforms.quaternion_multiply(orient, Rquat)
orient = transforms.quaternion_multiply(Rquat, orient)

scale = scale.repeat(bs,1,1)
bones = torch.cat([center,orient,scale],-1)
Expand Down Expand Up @@ -177,7 +178,8 @@ def gauss_mlp_skinning(xyz, embedding_xyz, bones,
pose_code: ...,1, nchannel
"""
N_rays = xyz.shape[0]
if pose_code.dim() == 2:
#TODO hacky way to make code compaitible with noqueryfw
if pose_code.dim() == 2 and pose_code.shape[0]!=N_rays:
pose_code = pose_code[None].repeat(N_rays, 1,1)

xyz_embedded = embedding_xyz(xyz)
Expand Down Expand Up @@ -275,46 +277,46 @@ def blend_skinning_chunk(bones, rts, skin, pts):
#def blend_skinning(bones, rts, skin, pts):
"""
bone: bs,B,10 - B gaussian ellipsoids
rts: bs,B,3,4 - B ririd transforms, applied to bone coordinates
rts: bs,B,3,4 - B ririd transforms, applied to bone coordinates (points attached to bones in world coords)
pts: bs,N,3 - N 3d points
skin: bs,N,B - skinning matrix
apply rts to bone coordinates, while computing blending globally
"""
B = rts.shape[-3]
N = pts.shape[-2]
bones = bones.view(-1,B,10)
pts = pts.view(-1,N,3)
rts = rts.view(-1,B,3,4)
Rmat = rts[:,:,:3,:3] # bs, B, 3,3
Tmat = rts[:,:,:3,3]
device = Tmat.device

# convert from bone to root transforms
bs = Rmat.shape[0]
center = bones[:,:,:3]
orient = bones[:,:,3:7] # real first
orient = F.normalize(orient, 2,-1)
orient = transforms.quaternion_to_matrix(orient) # real first
gmat = torch.eye(4)[None,None].repeat(bs, B, 1, 1).to(device)

# root to bone
gmat_r2b = gmat.clone()
gmat_r2b[:,:,:3,:3] = orient.permute(0,1,3,2)
gmat_r2b[:,:,:3,3] = -orient.permute(0,1,3,2).matmul(center[...,None])[...,0]
## convert from bone to root transforms
#bones = bones.view(-1,B,10)
#bs = Rmat.shape[0]
#center = bones[:,:,:3]
#orient = bones[:,:,3:7] # real first
#orient = F.normalize(orient, 2,-1)
#orient = transforms.quaternion_to_matrix(orient) # real first
#gmat = torch.eye(4)[None,None].repeat(bs, B, 1, 1).to(device)
#
## root to bone
#gmat_r2b = gmat.clone()
#gmat_r2b[:,:,:3,:3] = orient.permute(0,1,3,2)
#gmat_r2b[:,:,:3,3] = -orient.permute(0,1,3,2).matmul(center[...,None])[...,0]

# bone to root
gmat_b2r = gmat.clone()
gmat_b2r[:,:,:3,:3] = orient
gmat_b2r[:,:,:3,3] = center

# bone to bone
gmat_b = gmat.clone()
gmat_b[:,:,:3,:3] = Rmat
gmat_b[:,:,:3,3] = Tmat
## bone to root
#gmat_b2r = gmat.clone()
#gmat_b2r[:,:,:3,:3] = orient
#gmat_b2r[:,:,:3,3] = center

## bone to bone
#gmat_b = gmat.clone()
#gmat_b[:,:,:3,:3] = Rmat
#gmat_b[:,:,:3,3] = Tmat

gmat = gmat_b2r.matmul(gmat_b.matmul(gmat_r2b))
Rmat = gmat[:,:,:3,:3]
Tmat = gmat[:,:,:3,3]
#gmat = gmat_b2r.matmul(gmat_b.matmul(gmat_r2b))
#Rmat = gmat[:,:,:3,:3]
#Tmat = gmat[:,:,:3,3]

# Gi=sum(wbGb), V=RV+T
Rmat_w = (skin[...,None,None] * Rmat[:,None]).sum(2) # bs,N,B,3
Expand Down Expand Up @@ -1512,3 +1514,4 @@ def fid_reindex(fid, num_vids, vid_offset):
#tid[assign] = 2*(tid[assign] / doffset)-1
#tid[assign] = (tid[assign] - doffset/2)/1000.
return vid, tid

4 changes: 2 additions & 2 deletions nnutils/nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,8 @@ class RTExpMLP(nn.Module):
"""
def __init__(self, max_t, num_freqs, t_embed_dim, data_offset, delta=False):
super(RTExpMLP, self).__init__()
self.root_code = nn.Embedding(max_t, t_embed_dim)
#self.root_code = FrameCode(num_freqs, t_embed_dim, data_offset)
#self.root_code = nn.Embedding(max_t, t_embed_dim)
self.root_code = FrameCode(num_freqs, t_embed_dim, data_offset, scale=0.1)

self.base_rt = RTExplicit(max_t, delta=delta,rand=False)
#self.base_rt = RTHead(use_quat=True,
Expand Down
Loading

0 comments on commit 0d507fe

Please sign in to comment.