Skip to content

Commit

Permalink
for right code to fix
Browse files Browse the repository at this point in the history
  • Loading branch information
meteorshowers committed Mar 30, 2019
1 parent 976b785 commit e2cd703
Showing 1 changed file with 1 addition and 117 deletions.
118 changes: 1 addition & 117 deletions models/StereoNet8Xmulti.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,38 +84,7 @@ def forward(self, rgb_img):
output = block(output)
return self.conv_alone(output)

class EdgeAwareRefinement(nn.Module):
def __init__(self, in_channel):
super().__init__()
self.conv2d_feature = nn.Sequential(
convbn(in_channel, 32, kernel_size=3, stride=1, pad=1, dilation=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True))
self.residual_astrous_blocks = nn.ModuleList()
astrous_list = [1, 2, 4, 8 , 1 , 1]
for di in astrous_list:
self.residual_astrous_blocks.append(
BasicBlock(
32, 32, stride=1, downsample=None, pad=1, dilation=di))

self.conv2d_out = nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1)

def forward(self, low_disparity, corresponding_rgb):
output = torch.unsqueeze(low_disparity, dim=1)
twice_disparity = F.interpolate(
output,
size = corresponding_rgb.size()[-2:],
mode='bilinear',
align_corners=False)
if corresponding_rgb.size()[-1]/ low_disparity.size()[-1] >= 1.5:
twice_disparity *= 2 # ??????
# print(corresponding_rgb.size()[-1]// low_disparity.size()[-1])
output = self.conv2d_feature(
torch.cat([twice_disparity, corresponding_rgb], dim=1))
for astrous_block in self.residual_astrous_blocks:
output = astrous_block(output)

return nn.ReLU(inplace=True)(torch.squeeze(
twice_disparity + self.conv2d_out(output), dim=1))


class disparityregression(nn.Module):
def __init__(self, maxdisp):
Expand All @@ -128,91 +97,6 @@ def forward(self, x):
out = torch.sum(x * disp, 1)
return out


class StereoNet(nn.Module):
def __init__(self, k, r, maxdisp=192):
super().__init__()
self.maxdisp = maxdisp
self.k = k
self.r = r
self.feature_extraction = FeatureExtraction(k)
self.filter = nn.ModuleList()
for _ in range(4):
self.filter.append(
nn.Sequential(
convbn_3d(32, 32, kernel_size=3, stride=1, pad=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True)))
self.conv3d_alone = nn.Conv3d(
32, 1, kernel_size=3, stride=1, padding=1)

self.edge_aware_refinements = nn.ModuleList()
for _ in range(r):
self.edge_aware_refinements.append(EdgeAwareRefinement(4))

def forward(self, left, right):
disp = (self.maxdisp + 1) // pow(2, self.k)
refimg_feature = self.feature_extraction(left)
targetimg_feature = self.feature_extraction(right)

# matching
cost = torch.FloatTensor(refimg_feature.size()[0],
refimg_feature.size()[1],
disp,
refimg_feature.size()[2],
refimg_feature.size()[3]).zero_().cuda()
for i in range(disp):
if i > 0:
cost[:, :, i, :, i:] = refimg_feature[ :, :, :, i:] - targetimg_feature[:, :, :, :-i]
else:
cost[:, :, i, :, :] = refimg_feature - targetimg_feature
cost = cost.contiguous()

for f in self.filter:
cost = f(cost)
cost = self.conv3d_alone(cost)
cost = torch.squeeze(cost, 1)
pred = F.softmax(cost, dim=1)
pred = disparityregression(disp)(pred)



img_pyramid_list = []

for i in range(self.r):
img_pyramid_list.append(F.interpolate(
left,
scale_factor=1 / pow(2, i),
mode='bilinear',
align_corners=False))

img_pyramid_list.reverse()


pred_pyramid_list= [pred]

for i in range(self.r):
# start = datetime.datetime.now()
pred_pyramid_list.append(self.edge_aware_refinements[i](
pred_pyramid_list[i], img_pyramid_list[i]))

length_all = len(pred_pyramid_list)


for i in range(length_all):
pred_pyramid_list[i] = pred_pyramid_list[i]* (
left.size()[-1] / pred_pyramid_list[i].size()[-1])
pred_pyramid_list[i] = torch.squeeze(
F.interpolate(
torch.unsqueeze(pred_pyramid_list[i], dim=1),
size=left.size()[-2:],
mode='bilinear',
align_corners=False),
dim=1)

return pred_pyramid_list



if __name__ == '__main__':
model = StereoNet(k=3, r=3).cuda()
# model.eval()
Expand Down

0 comments on commit e2cd703

Please sign in to comment.