In [31]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
torch.__version__

'1.4.0+cpu'

### SharedConv

In [52]:
class SharedConv(nn.Module):
    def __init__(self):
        super(SharedConv, self).__init__()
        self.bbnet = torchvision.models.resnet50(pretrained=True)
        self.conv1 = nn.Conv2d(3072, 128, 1)
        self.bn1 = nn.BatchNorm2d(128)
        self.conv2 = nn.Conv2d(640, 64, 1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(320, 32, 1)
        self.bn3 = nn.BatchNorm2d(32)
        
    def __bbnetforward(self,x):
        for name,layer in self.bbnet.named_children():
            x = layer(x)
            
            if name=='layer1':
                down1 = x
            elif name=='layer2':
                down2 = x
            elif name=='layer3':
                down3 = x
            elif name =='layer4':
                down4 = x
                break
                
        # torch.Size([1, 256, 64, 64])
        # torch.Size([1, 512, 32, 32])
        # torch.Size([1, 1024, 16, 16])
        # torch.Size([1, 2048, 8, 8])   
        return down1, down2, down3, down4
        
    def forward(self, x):
        down1, down2, down3, down4 = self.__bbnetforward(x)
        
        o = F.interpolate(down4,scale_factor=2, mode='bilinear', align_corners=True)
        o = torch.cat([o,down3], dim=1) # torch.Size([1, 3072, 16, 16])
        o = self.conv1(o)
        o = self.bn1(o)
        o = F.relu(o) # torch.Size([1, 128, 16, 16])
        print(o.size())
        
        o = F.interpolate(o,scale_factor=2, mode='bilinear', align_corners=True)
        o = torch.cat([o,down2], dim=1) # torch.Size([1, 3584, 32, 32])
        o = self.conv2(o)
        o = self.bn2(o)
        o = F.relu(o) # torch.Size([1, 64, 32, 32])
        print(o.size())
        
        
        o = F.interpolate(o,scale_factor=2, mode='bilinear', align_corners=True)
        o = torch.cat([o,down1], dim=1) # torch.Size([1, 3840, 64, 64])
        o = self.conv3(o)
        o = self.bn3(o)
        o = F.relu(o) # torch.Size([1, 32, 64, 64])
        print(o.size())
        
        return o

In [54]:
net = SharedConv()
x = torch.randn(1,3,256,256)
y = net.forward(x)
y.shape

torch.Size([1, 128, 16, 16])
torch.Size([1, 64, 32, 32])
torch.Size([1, 32, 64, 64])


torch.Size([1, 32, 64, 64])

### Detector

In [None]:
class Detector(nn.Module):
    def __init__(self):
        super(Detector, self).__init__()
        self.score = nn.Conv2d(32,1,1)
        self.geo = nn.Conv2d(32,4,1)
        self.angle = nn.Conv2d(32,1,1)
        
    def forward(self, x):
        score = self.score(x)
        score = torch.sigmoid(score)
        
        geo = self.geo(x)
        geo = torch.sigmoid(geo)
        
        angle = self.angel(x)
        angel = torch.sigmoid(angle)
        
        return x