In [1]:
import torchvision.models as models
from torch import nn
from torchscan import summary
import torch

In [2]:
densenet = models.densenet169( pretrained=False )

In [6]:
class UpSample(nn.Module):
    def __init__(self, in_c, out_c):
        super(UpSample, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode = 'bilinear')
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.relu1 = nn.LeakyReLU(0.2)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.relu2 = nn.LeakyReLU(0.2)
    def forward(self, x, concat):
        x = self.upsample(x)
        x = torch.cat([x, concat], dim=1)
        x = self.relu1(self.conv1(x))
        x = self.relu2(self.conv2(x))
        return x
    
class DepthNet(nn.Module):
    def __init__(self, pretrained = True):
        super(DepthNet, self).__init__()
        densenet = models.densenet169(pretrained)
        # Densenet backbone
        self.conv1 = densenet.features[0]
        self.block1 = nn.Sequential(*(densenet.features[1:4]))
        self.block2 = nn.Sequential(*(densenet.features[4:6]))
        self.block3 = nn.Sequential(*(densenet.features[6:8]))
        self.block4 = nn.Sequential(*(densenet.features[8:-1]))
        self.conv2 = nn.Conv2d(1664, 1664, 1)
        self.upsample1 = UpSample(1920, 832)
        self.upsample2 = UpSample(960, 416)
        self.upsample3 = UpSample(480, 208)
        self.upsample4 = UpSample(272, 104)
        self.conv3 = nn.Conv2d(104, 104, 3, padding=1)
    def forward(self, x):
        c1 = self.conv1(x)
        c2 = self.block1(c1)
        c3 = self.block2(c2)
        c4 = self.block3(c3)
        x = self.block4(c4)
        x = self.conv2(x)
        x = self.upsample1(x, c4)
        x = self.upsample2(x, c3)
        x = self.upsample3(x, c2)
        x = self.upsample4(x, c1)
        return self.conv3(x)

In [7]:
densenet.features[4:6]

Sequential(
  (denseblock1): _DenseBlock(
    (denselayer1): _DenseLayer(
      (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (denselayer2): _DenseLayer(
      (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (denselayer3): _

In [12]:
depthnet = DepthNet(False).float()

In [13]:
summary(depthnet, (3, 480, 640))



__________________________________________________________________________________________
Layer                        Type                  Output Shape              Param #        
depthnet                     DepthNet              (-1, 104, 240, 320)       0              
├─conv1                      Conv2d                (-1, 64, 240, 320)        9,408          
├─block1                     Sequential            (-1, 64, 120, 160)        0              
|    └─0                     BatchNorm2d           (-1, 64, 240, 320)        257            
|    └─1                     ReLU                  (-1, 64, 240, 320)        0              
|    └─2                     MaxPool2d             (-1, 64, 120, 160)        0              
├─block2                     Sequential            (-1, 128, 60, 80)         0              
|    └─0                     _DenseBlock           (-1, 256, 120, 160)       0              
|    |    └─denselayer1      _DenseLayer           (-1, 32, 120, 160)   



In [73]:
summary(densenet, (3, 480, 640))

__________________________________________________________________________________________
Layer                        Type                  Output Shape              Param #        
densenet                     DenseNet              (-1, 1000)                0              
├─features                   Sequential            (-1, 1664, 15, 20)        0              
|    └─conv0                 Conv2d                (-1, 64, 240, 320)        9,408          
|    └─norm0                 BatchNorm2d           (-1, 64, 240, 320)        257            
|    └─relu0                 ReLU                  (-1, 64, 240, 320)        0              
|    └─pool0                 MaxPool2d             (-1, 64, 120, 160)        0              
|    └─denseblock1           _DenseBlock           (-1, 256, 120, 160)       0              
|    |    └─denselayer1      _DenseLayer           (-1, 32, 120, 160)        0              
|    |    |    └─norm1       BatchNorm2d           (-1, 64, 120, 160)   

In [97]:
class UpSample(nn.Sequential):
    def __init__(self, skip_input, output_features):
        super(UpSample, self).__init__()        
        self.convA = nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1)
        self.leakyreluA = nn.LeakyReLU(0.2)
        self.convB = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1)
        self.leakyreluB = nn.LeakyReLU(0.2)

    def forward(self, x, concat_with):
        up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True)
        return self.leakyreluB( self.convB( self.leakyreluA(self.convA( torch.cat([up_x, concat_with], dim=1) ) ) )  )

class Decoder(nn.Module):
    def __init__(self, num_features=2208, decoder_width = 0.5):
        super(Decoder, self).__init__()
        features = int(num_features * decoder_width)

        self.conv2 = nn.Conv2d(num_features, features, kernel_size=1, stride=1, padding=1)

        self.up1 = UpSample(skip_input=features//1 + 384, output_features=features//2)
        self.up2 = UpSample(skip_input=features//2 + 192, output_features=features//4)
        self.up3 = UpSample(skip_input=features//4 +  96, output_features=features//8)
        self.up4 = UpSample(skip_input=features//8 +  96, output_features=features//16)

        self.conv3 = nn.Conv2d(features//16, 1, kernel_size=3, stride=1, padding=1)

    def forward(self, features):
        x_block0, x_block1, x_block2, x_block3, x_block4 = features[3], features[4], features[6], features[8], features[11]
        x_d0 = self.conv2(x_block4)
        x_d1 = self.up1(x_d0, x_block3)
        x_d2 = self.up2(x_d1, x_block2)
        x_d3 = self.up3(x_d2, x_block1)
        x_d4 = self.up4(x_d3, x_block0)
        return self.conv3(x_d4)

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()       
        import torchvision.models as models
        self.original_model = models.densenet161( pretrained=False )

    def forward(self, x):
        features = [x]
        for k, v in self.original_model.features._modules.items(): features.append( v(features[-1]) )
        return features

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        return self.decoder( self.encoder(x) )

In [98]:
t = Model()