In [41]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class CostNet(nn.Module):

    def __init__(self, k=32):
        super().__init__()

        self.cnn = CNN(k=k)
        self.spp = SPP()
        self.fusion = nn.Sequential(
                Conv2dBn(in_channels=320, out_channels=128, kernel_size=3, stride=1, padding=1, use_relu=True),
                nn.Conv2d(in_channels=128, out_channels=32, kernel_size=1, stride=1, padding=0, bias=False)
            )

    def forward(self, inputs):
        conv2_out, conv4_out = self.cnn(inputs)           # [B, 64, 1/4H, 1/4W], [B, 128, 1/4H, 1/4W]

        spp_out = self.spp(conv4_out)                    # [B, 128, 1/4H, 1/4W]
        print("costnet2", conv2_out.shape, conv4_out.shape, spp_out.shape)
        
        out = torch.cat([conv2_out, conv4_out, spp_out], dim=1)  # [B, 320, 1/4H, 1/4W]
        print("costnet3",out.shape)
        out = self.fusion(out)                            # [B, 32, 1/4H, 1/4W]
        print("fusion out", out.shape)
        return out


class SPP(nn.Module):

    def __init__(self):
        super().__init__()

        self.branch1 = self.__make_branch(kernel_size=64, stride=64)
        self.branch2 = self.__make_branch(kernel_size=32, stride=32)
        self.branch3 = self.__make_branch(kernel_size=16, stride=16)
        self.branch4 = self.__make_branch(kernel_size=8, stride=8)

    def forward(self, inputs):

        out_size = inputs.size(2), inputs.size(3)
        branch1_out = F.upsample(self.branch1(inputs), size=out_size, mode='bilinear')  # [B, 32, 1/4H, 1/4W]
        # print('branch1_out')
        # print(branch1_out[0, 0, :3, :3])
        branch2_out = F.upsample(self.branch2(inputs), size=out_size, mode='bilinear')  # [B, 32, 1/4H, 1/4W]
        branch3_out = F.upsample(self.branch3(inputs), size=out_size, mode='bilinear')  # [B, 32, 1/4H, 1/4W]
        branch4_out = F.upsample(self.branch4(inputs), size=out_size, mode='bilinear')  # [B, 32, 1/4H, 1/4W]
        out = torch.cat([branch4_out, branch3_out, branch2_out, branch1_out], dim=1)  # [B, 128, 1/4H, 1/4W]

        return out

    @staticmethod
    def __make_branch(kernel_size, stride):
        branch = nn.Sequential(
                nn.AvgPool2d(kernel_size, stride),
                Conv2dBn(in_channels=128, out_channels=32, kernel_size=3, stride=1, padding=1, use_relu=True)  # kernel size maybe 1
            )
        return branch

class Squeeze_excitation_layer(nn.Module):
    def __init__(self, filters, se_ratio=4):
        super(Squeeze_excitation_layer, self).__init__()
        reduction = filters // se_ratio
        self.se = nn.Sequential(nn.Conv2d(filters, reduction, kernel_size=1, bias=True),
                                nn.SiLU(),
                                nn.Conv2d(reduction, filters, kernel_size=1, bias=True),
                                nn.Sigmoid())
    def forward(self, inputs):
        x = self.se(inputs)
        return torch.multiply(inputs, x)

class depthwise_separable_conv(nn.Module):
    def __init__(self, nin, nout, kernel_size = 3, padding = 1, bias=False):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size=kernel_size, padding=padding, groups=nin, bias=bias)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1, bias=bias)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

class MBConv2d_block(nn.Module):
    def __init__(self, in_channels, out_channels, k=1):
        super(MBConv2d_block, self).__init__()

        self.net = nn.Sequential(nn.Conv2d(in_channels, out_channels * k, kernel_size=(1, 1), stride=(1, 1), padding="valid", bias=False),
               nn.BatchNorm2d(out_channels * k),
               nn.SiLU(),
               depthwise_separable_conv(out_channels * k, out_channels * k, kernel_size = 3, padding ="same", bias=False),
               nn.BatchNorm2d(out_channels * k),
               nn.SiLU(),
               Squeeze_excitation_layer(filters=out_channels * k, se_ratio=4),
               nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), padding="valid", bias=False),
               nn.BatchNorm2d(out_channels * k),
               nn.Dropout(p=0.2))

    def forward(self, inputs):
        x = self.net(inputs)
        return torch.add(inputs, x)
  

class CNN(nn.Module):

    def __init__(self, k=32):
        super().__init__()

        self.conv0 = nn.Sequential(
                Conv2dBn(in_channels=3, out_channels=k, kernel_size=3, stride=2, padding=1, use_relu=True),  # downsample
                Conv2dBn(in_channels=k, out_channels=k, kernel_size=3, stride=1, padding=1, use_relu=True),
                Conv2dBn(in_channels=k, out_channels=k, kernel_size=3, stride=1, padding=1, use_relu=True)
            )
        self.mbconv0 = MBConv2d_block(in_channels=k, out_channels=k, k=1)
        self.conv1 = StackedBlocks(n_blocks=3, in_channels=k, out_channels=k, kernel_size=3, stride=1, padding=1, dilation=1)
        self.mbconv1 = MBConv2d_block(in_channels=k, out_channels=k, k=1)
        self.conv2 = StackedBlocks(n_blocks=3, in_channels=k, out_channels=k*2, kernel_size=3, stride=2, padding=1, dilation=1)  # downsample
        self.mbconv2 = MBConv2d_block(in_channels=k*2, out_channels=k*2, k=1)
        self.conv3 = StackedBlocks(n_blocks=3, in_channels=k*2, out_channels=k*4, kernel_size=3, stride=1, padding=2, dilation=2)  # dilated
        self.mbconv3 = MBConv2d_block(in_channels=k*4, out_channels=k*4, k=1)
        self.conv4 = StackedBlocks(n_blocks=3, in_channels=k*4, out_channels=k*4, kernel_size=3, stride=1, padding=4, dilation=4)  # dilated

    def forward(self, inputs):
        conv0_out = self.mbconv0(self.conv0(inputs))
        conv1_out = self.mbconv1(self.conv1(conv0_out))  # [B, 32, 1/2H, 1/2W]
        conv2_out = self.mbconv2(self.conv2(conv1_out))  # [B, 64, 1/4H, 1/4W]
        conv3_out = self.mbconv3(self.conv3(conv2_out))  # [B, 128, 1/4H, 1/4W]
        conv4_out = self.conv4(conv3_out)  # [B, 128, 1/4H, 1/4W]
        print("inp, c0 c1", inputs.shape, conv0_out.shape, conv1_out.shape)
        print("c2 c3 c4", conv2_out.shape, conv3_out.shape, conv4_out.shape)
        return conv2_out, conv4_out
"""
class CNN(nn.Module):

    def __init__(self, k=64):
        super().__init__()
        self.conv0 = nn.Sequential(
                Conv2dBn(in_channels=3, out_channels=k, kernel_size=3, stride=2, padding=1, use_relu=True),  # downsample
                Conv2dBn(in_channels=k, out_channels=k, kernel_size=3, stride=1, padding=1, use_relu=True),
                Conv2dBn(in_channels=k, out_channels=k, kernel_size=3, stride=1, padding=1, use_relu=True),
                Squeeze_excitation_layer(filters=k, se_ratio=4)
                )
        self.conv1 = nn.Sequential(
                StackedBlocks(n_blocks=3, in_channels=k, out_channels=k*2, kernel_size=3, stride=1, padding=1, dilation=1),
                MBConv2d_block(in_channels=k*2, out_channels=k*2, k=1)
                )
        self.conv2 = nn.Sequential(
                StackedBlocks(n_blocks=3, in_channels=k*2, out_channels=k*4, kernel_size=3, stride=2, padding=1, dilation=1),
                MBConv2d_block(in_channels=k*4, out_channels=k*4, k=1)
                )
        #self.conv2 = StackedBlocks(n_blocks=16, in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1, dilation=1)  # downsample
        self.conv3 = nn.Sequential(
                StackedBlocks(n_blocks=3, in_channels=k*4, out_channels=k*8, kernel_size=3, stride=1, padding=2, dilation=2),  # dilated
                MBConv2d_block(in_channels=k*8, out_channels=k*8, k=1),
                MBConv2d_block(in_channels=k*8, out_channels=k*8, k=1)
                )
    def forward(self, inputs): # inputs is [4, 3, 384, 1280]
        conv0_out = self.conv0(inputs) # [4, 32, 192, 640]
        conv1_out = self.conv1(conv0_out)  # [B, 32, 1/2H, 1/2W] # [4, 32, 192, 640]
        conv2_out = self.conv2(conv1_out)  # [B, 64, 1/4H, 1/4W] # [4, 64, 96, 320]
        conv3_out = self.conv3(conv2_out)  # [B, 128, 1/4H, 1/4W] # [4, 128, 96, 320]
        print("inp, c0 c1", inputs.shape, conv0_out.shape, conv1_out.shape)
        print("c2 c3 c4", conv2_out.shape, conv3_out.shape)
        return conv2_out, conv3_out
"""

class StackedBlocks(nn.Module):

    def __init__(self, n_blocks, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1):
        super().__init__()

        if stride == 1 and in_channels == out_channels:
            downsample = False
        else:
            downsample = True
        net = [ResidualBlock(in_channels, out_channels, kernel_size, stride, padding, dilation, downsample)]

        for i in range(n_blocks - 1):
            net.append(ResidualBlock(out_channels, out_channels, kernel_size, 1, padding, dilation, downsample=False))
        self.net = nn.Sequential(*net)

    def forward(self, inputs):
        out = self.net(inputs)
        return out


class ResidualBlock(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, downsample=False):
        super().__init__()

        self.net = nn.Sequential(
                Conv2dBn(in_channels, out_channels, kernel_size, stride, padding, dilation, use_relu=True),
                Conv2dBn(out_channels, out_channels, kernel_size, 1, padding, dilation, use_relu=False)
            )

        self.downsample = None
        if downsample:
            self.downsample = Conv2dBn(in_channels, out_channels, 1, stride, use_relu=False)

    def forward(self, inputs):
        out = self.net(inputs)
        if self.downsample:
            inputs = self.downsample(inputs)
        out = out + inputs
        return out
    

class Conv3dBn(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, use_relu=True):
        super().__init__()

        net = [nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=False),
               nn.BatchNorm3d(out_channels)]
        if use_relu:
            net.append(nn.SiLU(inplace=True))

        self.net = nn.Sequential(*net)

    def forward(self, inputs):
        out = self.net(inputs)
        return out
    


 
class Conv2dBn(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, use_relu=True):
        super().__init__()

        net = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=False),
               nn.BatchNorm2d(out_channels)]
        if use_relu:
            net.append(nn.SiLU(inplace=True))
        self.net = nn.Sequential(*net)

    def forward(self, inputs):
        out = self.net(inputs)
        return out
  
class StackedHourglass(nn.Module):
    '''
    inputs --- [B, 64, 1/4D, 1/4H, 1/4W]
    '''

    def __init__(self, max_disp):
        super().__init__()

        self.conv0 = nn.Sequential(
            Conv3dBn(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, use_relu=True),
            Conv3dBn(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, use_relu=True)
        )
        self.conv1 = nn.Sequential(
            Conv3dBn(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, use_relu=True),
            Conv3dBn(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, use_relu=False)
        )
        self.hourglass1 = Hourglass()
        self.hourglass2 = Hourglass()
        self.hourglass3 = Hourglass()

        self.out1 = nn.Sequential(
            Conv3dBn(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, use_relu=True),
            nn.Conv3d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1, dilation=1, bias=False)
        )
        self.out2 = nn.Sequential(
            Conv3dBn(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, use_relu=True),
            nn.Conv3d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1, dilation=1, bias=False)
        )
        self.out3 = nn.Sequential(
            Conv3dBn(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, use_relu=True),
            nn.Conv3d(in_channels=32, out_channels=1, kernel_size=3, stride=1, padding=1, dilation=1, bias=False)
        )

        self.regression = DisparityRegression(max_disp)

    def forward(self, inputs, out_size):

        conv0_out = self.conv0(inputs)     # [B, 32, 1/4D, 1/4H, 1/4W]
        conv1_out = self.conv1(conv0_out)
        conv1_out = conv0_out + conv1_out  # [B, 32, 1/4D, 1/4H, 1/4W]

        hourglass1_out1, hourglass1_out3, hourglass1_out4 = self.hourglass1(conv1_out, scale1=None, scale2=None, scale3=conv1_out)
        hourglass2_out1, hourglass2_out3, hourglass2_out4 = self.hourglass2(hourglass1_out4, scale1=hourglass1_out3, scale2=hourglass1_out1, scale3=conv1_out)
        hourglass3_out1, hourglass3_out3, hourglass3_out4 = self.hourglass3(hourglass2_out4, scale1=hourglass2_out3, scale2=hourglass1_out1, scale3=conv1_out)

        out1 = self.out1(hourglass1_out4)  # [B, 1, 1/4D, 1/4H, 1/4W]
        out2 = self.out2(hourglass2_out4) + out1
        out3 = self.out3(hourglass3_out4) + out2

        cost1 = F.upsample(out1, size=out_size, mode='trilinear').squeeze(dim=1)  # [B, D, H, W]
        cost2 = F.upsample(out2, size=out_size, mode='trilinear').squeeze(dim=1)  # [B, D, H, W]
        cost3 = F.upsample(out3, size=out_size, mode='trilinear').squeeze(dim=1)  # [B, D, H, W]

        prob1 = F.softmax(-cost1, dim=1)  # [B, D, H, W]
        prob2 = F.softmax(-cost2, dim=1)
        prob3 = F.softmax(-cost3, dim=1)

        disp1 = self.regression(prob1)
        disp2 = self.regression(prob2)
        disp3 = self.regression(prob3)

        return disp1, disp2, disp3


class DisparityRegression(nn.Module):

    def __init__(self, max_disp):
        super().__init__()

        self.disp_score = torch.range(0, max_disp - 1)  # [D]
        self.disp_score = self.disp_score.unsqueeze(0).unsqueeze(2).unsqueeze(3)  # [1, D, 1, 1]

    def forward(self, prob):
        disp_score = self.disp_score.expand_as(prob).type_as(prob)  # [B, D, H, W]
        out = torch.sum(disp_score * prob, dim=1)  # [B, H, W]
        return out


class Hourglass(nn.Module):

    def __init__(self):
        super().__init__()

        self.net1 = nn.Sequential(
            Conv3dBn(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1, dilation=1, use_relu=True),
            Conv3dBn(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, dilation=1, use_relu=False)
        )
        self.net2 = nn.Sequential(
            Conv3dBn(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1, dilation=1, use_relu=True),
            Conv3dBn(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, dilation=1, use_relu=True)
        )
        self.net3 = nn.Sequential(
            nn.ConvTranspose3d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
            nn.BatchNorm3d(num_features=64)
            # nn.ReLU(inplace=True)
        )
        self.net4 = nn.Sequential(
            nn.ConvTranspose3d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
            nn.BatchNorm3d(num_features=32)
        )

    def forward(self, inputs, scale1=None, scale2=None, scale3=None):
        net1_out = self.net1(inputs)  # [B, 64, 1/8D, 1/8H, 1/8W]

        if scale1 is not None:
            net1_out = F.relu(net1_out + scale1, inplace=True)
        else:
            net1_out = F.relu(net1_out, inplace=True)

        net2_out = self.net2(net1_out)  # [B, 64, 1/16D, 1/16H, 1/16W]
        net3_out = self.net3(net2_out)  # [B, 64, 1/8D, 1/8H, 1/8W]

        if scale2 is not None:
            net3_out = F.relu(net3_out + scale2, inplace=True)
        else:
            net3_out = F.relu(net3_out + net1_out, inplace=True)

        net4_out = self.net4(net3_out)

        if scale3 is not None:
            net4_out = net4_out + scale3

        return net1_out, net3_out, net4_out




class GCNetPlus(nn.Module):

    def __init__(self, max_disp, k=32):
        super().__init__()

        self.cost_net = CostNet(k=k)
        self.stackedhourglass = StackedHourglass(max_disp)
        self.D = max_disp


    def forward(self, left_img, right_img):
        original_size = [self.D, left_img.size(2), left_img.size(3)]

        left_cost = self.cost_net(left_img)  # [B, 32, 1/4H, 1/4W]
        right_cost = self.cost_net(right_img)  # [B, 32, 1/4H, 1/4W]
        #cost = torch.cat([left_cost, right_cost], dim=1)  # [B, 64, 1/4H, 1/4W]
        # B, C, H, W = cost.size()

        print("lcost, rcost", left_cost.shape, right_cost.shape) # torch.Size([4, 32, 96, 320]) torch.Size([4, 32, 96, 320])
        # print(left_cost[0, 0, :3, :3])

        B, C, H, W = left_cost.size()
        print("B, C, H, W", B, C, H, W)
        cost_volume = torch.zeros(B, C * 2, self.D // 4, H, W).type_as(left_cost)  # [B, 64, D, 1/4H, 1/4W]
        print("cost_volume1", cost_volume.shape)
        # for i in range(self.D // 4):
        #     cost_volume[:, :, i, :, i:] = cost[:, :, :, i:]
        #"""
        for i in range(self.D // 4):
            if i > 0:
                cost_volume[:, :C, i, :, i:] = left_cost[:, :, :, i:]
                cost_volume[:, C:, i, :, i:] = right_cost[:, :, :, :-i]
            else:
                cost_volume[:, :C, i, :, :] = left_cost
                cost_volume[:, C:, i, :, :] = right_cost
       # """
        print("cost_volume2", cost_volume.shape)
        disp1, disp2, disp3 = self.stackedhourglass(cost_volume, out_size=original_size)
        print("disp", disp1.shape, disp2.shape, disp3.shape)
        return cost_volume #disp1, disp2, disp3

    

In [42]:
H= 192*2
W=640*2
rinp = torch.randn(4, 3, 192*2, 640*2)
linp = torch.randn(4, 3, 192*2, 640*2)
print(rinp.shape, linp.shape)

torch.Size([4, 3, 384, 1280]) torch.Size([4, 3, 384, 1280])


In [43]:
inp, c0 c1 torch.Size([2, 3, 384, 1280]) torch.Size([2, 32, 192, 640]) torch.Size([2, 32, 192, 640])
c2 c3 c4 torch.Size([2, 64, 96, 320]) torch.Size([2, 128, 96, 320]) torch.Size([2, 128, 96, 320])
costnet2 torch.Size([2, 64, 96, 320]) torch.Size([2, 128, 96, 320]) torch.Size([2, 128, 96, 320])
costnet3 torch.Size([2, 320, 96, 320])
fusion out torch.Size([2, 32, 96, 320])
inp, c0 c1 torch.Size([2, 3, 384, 1280]) torch.Size([2, 32, 192, 640]) torch.Size([2, 32, 192, 640])
c2 c3 c4 torch.Size([2, 64, 96, 320]) torch.Size([2, 128, 96, 320]) torch.Size([2, 128, 96, 320])
costnet2 torch.Size([2, 64, 96, 320]) torch.Size([2, 128, 96, 320]) torch.Size([2, 128, 96, 320])
costnet3 torch.Size([2, 320, 96, 320])
fusion out torch.Size([2, 32, 96, 320])
lcost, rcost torch.Size([2, 32, 96, 320]) torch.Size([2, 32, 96, 320])
B, C, H, W 2 32 96 320
cost_volume1 torch.Size([2, 64, 48, 96, 320])

SyntaxError: invalid syntax (2032453164.py, line 1)

In [44]:
from torchsummary import summary
device = torch.device('cpu')
model = GCNetPlus(max_disp=192, k=32).to(device)
#out = model(rinp, linp)
#print(out.shape)
summary(model, [(3, 192*2, 640*2), (3, 192*2, 640*2)])

  self.disp_score = torch.range(0, max_disp - 1)  # [D]


inp, c0 c1 torch.Size([2, 3, 384, 1280]) torch.Size([2, 32, 192, 640]) torch.Size([2, 32, 192, 640])
c2 c3 c4 torch.Size([2, 64, 96, 320]) torch.Size([2, 128, 96, 320]) torch.Size([2, 128, 96, 320])
costnet2 torch.Size([2, 64, 96, 320]) torch.Size([2, 128, 96, 320]) torch.Size([2, 128, 96, 320])
costnet3 torch.Size([2, 320, 96, 320])
fusion out torch.Size([2, 32, 96, 320])
inp, c0 c1 torch.Size([2, 3, 384, 1280]) torch.Size([2, 32, 192, 640]) torch.Size([2, 32, 192, 640])
c2 c3 c4 torch.Size([2, 64, 96, 320]) torch.Size([2, 128, 96, 320]) torch.Size([2, 128, 96, 320])
costnet2 torch.Size([2, 64, 96, 320]) torch.Size([2, 128, 96, 320]) torch.Size([2, 128, 96, 320])
costnet3 torch.Size([2, 320, 96, 320])
fusion out torch.Size([2, 32, 96, 320])
lcost, rcost torch.Size([2, 32, 96, 320]) torch.Size([2, 32, 96, 320])
B, C, H, W 2 32 96 320
cost_volume1 torch.Size([2, 64, 48, 96, 320])
cost_volume2 torch.Size([2, 64, 48, 96, 320])
disp torch.Size([2, 384, 1280]) torch.Size([2, 384, 1280]) tor

In [16]:

def soft_argmax(voxels):
    """
    Arguments: voxel patch in shape (batch_size, channel, H, W, depth)
    Return: 3D coordinates in shape (batch_size, channel, 3)
    """
    assert voxels.dim()==5
    # alpha is here to make the largest element really big, so it
    # would become very close to 1 after softmax
    alpha = 1000.0 
    N,C,H,W,D = voxels.shape
    soft_max = nn.functional.softmax(voxels.view(N,C,-1)*alpha,dim=2)
    soft_max = soft_max.view(voxels.shape)
    indices_kernel = torch.arange(start=0,end=H*W*D).unsqueeze(0)
    indices_kernel = indices_kernel.view((H,W,D))
    conv = soft_max*indices_kernel
    indices = conv.sum(2).sum(2).sum(2)
    z = indices%D
    y = (indices/D).floor()%W
    x = (((indices/D).floor())/W).floor()%H
    coords = torch.stack([x,y,z],dim=2)
    return coords

voxel = torch.randn(1,2,3,3,3) # (batch_size, channel, H, W, depth)
print(voxel.shape)
coords = soft_argmax(voxel)
coords.shape

torch.Size([1, 2, 3, 3, 3])


torch.Size([1, 2, 3])