In [3]:
import numpy as np
import torch 
import torch.nn as nn

In [8]:
class UNET(nn.Module):
    def __init__(self, in_channels, out_channels, should_pad=True):
        super().__init__()
        self.name = 'UNET'
        if should_pad:
            conv1_pad = 3
            gen_pad = 1
        else:
            conv1_pad = 0
            gen_pad = 0
        self.conv1 = self.contract_block(in_channels, 32, 7, conv1_pad)
        self.conv2 = self.contract_block(32, 64, 3, gen_pad)
        self.conv3 = self.contract_block(64, 128, 3, gen_pad)

        self.upconv3 = self.expand_block(128, 64, 3, gen_pad)
        self.upconv2 = self.expand_block(64*2, 32, 3, gen_pad)
        self.upconv1 = self.expand_block(32*2, out_channels, 3, gen_pad)
        
        self.softmax = torch.nn.LogSoftmax(dim=1)

    def __call__(self, x):

        # downsampling part
        conv1 = self.conv1(x)
        print('convv1', conv1.shape)
        conv2 = self.conv2(conv1)
        print('convv2', conv2.shape)
        conv3 = self.conv3(conv2)
        print('convv3', conv3.shape)
        upconv3 = self.upconv3(conv3)
        print('upconv3', upconv3.shape)

        cat1_trim = 6
        cat2_trim = 18
        upconv2 = self.upconv2(torch.cat([upconv3, conv2[:, :, cat1_trim:-cat1_trim, cat1_trim:-cat1_trim]], 1))
        upconv1 = self.upconv1(torch.cat([upconv2, conv1[:, :, cat2_trim:-cat2_trim, cat2_trim:-cat2_trim]], 1))
        xout = self.softmax(upconv1)

        return xout

    def contract_block(self, in_channels, out_channels, kernel_size, padding):

        contract = nn.Sequential(
            torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(out_channels),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
                                 )

        return contract

    def expand_block(self, in_channels, out_channels, kernel_size, padding):

        expand = nn.Sequential(torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=padding),
                            torch.nn.BatchNorm2d(out_channels),
                            torch.nn.ReLU(),
                            torch.nn.ConvTranspose2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1) 
                            )
        return expand

In [9]:
model = UNET(1, 2, should_pad=False)

In [10]:
input_im = np.ones((1, 1, 1516, 2700))
input_tensor = torch.tensor(input_im, dtype=torch.float)

In [11]:
x = model(input_tensor)

convv1 torch.Size([1, 32, 752, 1344])
convv2 torch.Size([1, 64, 374, 670])
convv3 torch.Size([1, 128, 185, 333])
upconv3 torch.Size([1, 64, 362, 658])


In [22]:
x.conv2

AttributeError: 'Tensor' object has no attribute 'conv2'