In [1]:
%load_ext autoreload
%autoreload 2

In [208]:
import torch
import torch.nn as nn
import torch.nn.functional as F 

In [224]:
inp = torch.zeros(1, 3, 223, 223)
class BinaryFCN32(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU()
        )
        self.pool3 = nn.MaxPool2d(2, 2)
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding=1),
            nn.ReLU()
        )
        self.pool4 = nn.MaxPool2d(2, 2)
        self.conv5 = nn.Sequential(
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU()
        )
        self.pool5 = nn.MaxPool2d(2, 2)
        self.conv6 = nn.Sequential(
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU()
        )
        
        self.conv7 = nn.Sequential(
            nn.Conv2d(512, 21, 1)
        )
        
        self.pool4prediction = nn.Conv2d(512, 21, 1)
        self.up2 = nn.ConvTranspose2d(21, 21, 4, stride=2, padding=1)
        self.up16 = nn.ConvTranspose2d(21, 21, 32, stride=16, padding=8)
        
        
    def forward(self, x):
        shape = x.shape
        x = self.conv1(x)
        x = self.pool1(x) # size / 2
        x = self.conv2(x)
        x = self.pool2(x) # size / 4
        x = self.conv3(x)
        x = self.pool3(x) # size / 8
        x = self.conv4(x)
        x = self.pool4(x) # size / 16
        
        pool4 = x
        x = self.conv5(x)
        x = self.pool5(x) # size / 32
        x = self.conv6(x)
        x = self.conv7(x)
        conv7 = x
        
        # print(pool4.shape, conv7.shape)
        
        pool4prediction = self.pool4prediction(pool4)
        conv7prediction2x = self.up2(conv7)
        
        # resize to pool4 prediction size
        conv7prediction2x = F.interpolate(conv7prediction2x, pool4prediction.shape[2:], mode="bilinear")
        
        # print(pool4prediction.shape, conv7prediction2x.shape)
        
        combined = pool4prediction + conv7prediction2x
        
        # upsample by 16
        final = self.up16(combined)    
        final = F.interpolate(final, shape[2:], mode="bilinear")
        
        # print(final.shape)    
        
        return final
    
net = BinaryFCN32()
out = net(inp)
out.shape

torch.Size([1, 21, 223, 223])

In [228]:
shapes = [(1, 3, 112, 112), (1, 3, 945, 673), (1, 3, 448, 448)]

for shape in shapes:
    inp = torch.rand(shape)
    out = net(inp)
    
    print(inp.shape, out.shape)
    
    assert out.shape[2:] == inp.shape[2:]

torch.Size([1, 3, 112, 112]) torch.Size([1, 21, 112, 112])
torch.Size([1, 3, 945, 673]) torch.Size([1, 21, 945, 673])
torch.Size([1, 3, 448, 448]) torch.Size([1, 21, 448, 448])


In [175]:
deconv = nn.ConvTranspose2d(3, 1, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
input_tensor = torch.randn(1, 3, 4, 3)  # Example input tensor (batch_size, channels, height, width)
output_tensor = deconv(input_tensor)
print(input_tensor.shape)   # Output: torch.Size([1, 3, 32, 32])
print(output_tensor.shape)  # Output: torch.Size([1, 64, 64, 64])

torch.Size([1, 3, 4, 3])
torch.Size([1, 1, 8, 6])


In [174]:
deconv = nn.ConvTranspose2d(3, 1, kernel_size=32, stride=16)
input_tensor = torch.randn(1, 3, 2, 2)  # Example input tensor (batch_size, channels, height, width)
output_tensor = deconv(input_tensor)
print(input_tensor.shape)   # Output: torch.Size([1, 3, 32, 32])
print(output_tensor.shape)

torch.Size([1, 3, 2, 2])
torch.Size([1, 1, 48, 48])
