Skip to content

Commit

Permalink
add support to different test resolution; comment on WTA/spatial conv
Browse files Browse the repository at this point in the history
  • Loading branch information
gengshan-y committed Jul 30, 2020
1 parent eb816cb commit 00c4bef
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
4 changes: 2 additions & 2 deletions models/conv4d.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,10 @@ def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, ksize=3,
#@profile
def forward(self,x):
b,c,u,v,h,w = x.shape
x = self.conv2(x.view(b,c,u,v,-1))
x = self.conv2(x.view(b,c,u,v,-1)) # WTA convolution over (u,v)
b,c,u,v,_ = x.shape
x = self.relu(x)
x = self.conv1(x.view(b,c,-1,h,w))
x = self.conv1(x.view(b,c,-1,h,w)) # spatial convolution over (x,y)
b,c,_,h,w = x.shape

if self.isproj:
Expand Down
5 changes: 5 additions & 0 deletions run_self.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
## point $datapath to the folder of your images
datapath=./dataset/IIW/
modelname=things
i=239999
CUDA_VISIBLE_DEVICES=1 python submission.py --dataset kitticlip --datapath $datapath/ --outdir ./weights/$modelname/ --loadmodel ./weights/$modelname/finetune_$i.tar --maxdisp 256 --fac 1
18 changes: 18 additions & 0 deletions submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,24 @@ def main():
imgL = np.transpose(imgL, [2,0,1])[np.newaxis]
imgR = np.transpose(imgR, [2,0,1])[np.newaxis]

# support for any resolution inputs
from models.VCN import WarpModule, flow_reg
if hasattr(model.module, 'flow_reg64'):
model.module.flow_reg64 = flow_reg([1,max_w//64,max_h//64], ent=model.module.flow_reg64.ent, maxdisp=model.module.flow_reg64.md, fac=model.module.flow_reg64.fac).cuda()
if hasattr(model.module, 'flow_reg32'):
model.module.flow_reg32 = flow_reg([1,max_w//64*2,max_h//64*2], ent=model.module.flow_reg32.ent, maxdisp=model.module.flow_reg32.md, fac=model.module.flow_reg32.fac).cuda()
if hasattr(model.module, 'flow_reg16'):
model.module.flow_reg16 = flow_reg([1,max_w//64*4,max_h//64*4], ent=model.module.flow_reg16.ent, maxdisp=model.module.flow_reg16.md, fac=model.module.flow_reg16.fac).cuda()
if hasattr(model.module, 'flow_reg8'):
model.module.flow_reg8 = flow_reg([1,max_w//64*8, max_h//64*8], ent=model.module.flow_reg8.ent, maxdisp=model.module.flow_reg8.md , fac = model.module.flow_reg8.fac).cuda()
if hasattr(model.module, 'flow_reg4'):
model.module.flow_reg4 = flow_reg([1,max_w//64*16, max_h//64*16 ], ent=model.module.flow_reg4.ent, maxdisp=model.module.flow_reg4.md , fac = model.module.flow_reg4.fac).cuda()
model.module.warp5 = WarpModule([1,max_w//32,max_h//32]).cuda()
model.module.warp4 = WarpModule([1,max_w//16,max_h//16]).cuda()
model.module.warp3 = WarpModule([1,max_w//8, max_h//8]).cuda()
model.module.warp2 = WarpModule([1,max_w//4, max_h//4]).cuda()
model.module.warpx = WarpModule([1,max_w, max_h]).cuda()

# forward
imgL = Variable(torch.FloatTensor(imgL).cuda())
imgR = Variable(torch.FloatTensor(imgR).cuda())
Expand Down

0 comments on commit 00c4bef

Please sign in to comment.