In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

In [2]:
def conv_relu_bn(in_channels, out_channels, kernel=3, double=True):
    if double:
        return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel, padding=1),
                nn.ReLU(inplace=True),
                nn.BatchNorm2d(out_channels),
                nn.Conv2d(out_channels, out_channels, kernel, padding=1),
                nn.ReLU(inplace=True),
                nn.BatchNorm2d(out_channels)
                )
    else:
        return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel, padding=1),
        nn.ReLU(inplace=True),
        )


def upconv(in_channels, out_channels, mode='transpose', kernel_size=2):
    if mode == 'transpose':
        return nn.Sequential(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=2), 
                             nn.ReLU(inplace=True),
                             nn.BatchNorm2d(out_channels),
                             nn.Conv2d(out_channels, out_channels, kernel_size, padding=1),
                             nn.ReLU(inplace=True),
                             nn.BatchNorm2d(out_channels))
                             
    else:
         return nn.Sequential(nn.Upsample(mode='bilinear', scale_factor=2), 
                             nn.ReLU(inplace=True),
                             nn.BatchNorm2d(out_channels),
                             nn.Conv2d(out_channels, out_channels, kernel_size, padding=1),
                             nn.ReLU(inplace=True),
                             nn.BatchNorm2d(out_channels))

In [3]:
class unet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        
        self.conv1 = conv_relu_bn(self.n_channels,32)
        self.conv2 = conv_relu_bn(32,64)
        self.conv3 = conv_relu_bn(64,128)
        self.conv4 = conv_relu_bn(128,256)
        self.conv5 = conv_relu_bn(256,512)
        
        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        self.up4 = upconv(512, 256, )
        self.up3 = upconv(512+256, 128)
        self.up2 = upconv(256+128, 64)
        self.up1 = upconv(128+64, 32)
        self.singleconv = conv_relu_bn(32,32, double=False)
        
        self.output = nn.Sequential(nn.Conv2d(32, n_classes, 1),nn.Sigmoid())
        
    def forward(self, x):
        conv1 = self.conv1(x)
        x = self.maxpool(conv1)
        
        conv2 = self.conv2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.conv3(x)
        x = self.maxpool(conv3)
        
        conv4 = self.conv4(x)
        x = self.maxpool(conv4)
       
        x = self.conv5(x)

        x = self.up4(x)
        print(x.shape)
        x = torch.cat([conv4, x], dim=1)

        x = self.up3(x)
        x = torch.cat([conv3, x], dim=1)

        x = self.up2(x)
        x = torch.cat([conv2, x], dim=1)

        x = self.up1(x)
        x = torch.cat([conv1, x], dim=1)
        x = self.singleconv(x)
        out = self.output(x)
        
        return out

In [8]:
model = unet(3,1).to('cuda')
summary(model, input_size=(3,1024,1024))

RuntimeError: CUDA error: out of memory