Neural network models (architectures)

In [4]:
%load_ext autoreload
%autoreload 2

# Export cells
!python notebook2script.py models.ipynb

Converted models/unet.py to torchtrainer/models/unet.py
Converted models/edunet.py to torchtrainer/models/edunet.py
Converted models/resunet.py to torchtrainer/models/resunet.py


In [1]:
# U-Net simple
#export models/unet.py
'''U-Net architecture'''

import torch
import torch.nn.functional as F
from torch import nn

class DoubleConvolution(nn.Module):
    def __init__(self, in_channels, middle_channel, out_channels, kernel_size=3, p=1):
        super(DoubleConvolution, self).__init__()
        layers = [
            nn.Conv2d(in_channels, middle_channel, kernel_size=kernel_size, padding=p),
            nn.BatchNorm2d(middle_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(middle_channel, out_channels, kernel_size=kernel_size, padding=p),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        ]
        self.dconv = nn.Sequential(*layers)

    def forward(self, x):
        return self.dconv(x)

class UNet(nn.Module):
    def __init__(self, num_channels, num_classes):
        super(UNet, self).__init__()

        reduce_by = 1

        self.l1_ = DoubleConvolution(num_channels, 64//reduce_by, 64//reduce_by)
        self.a1_dwn = nn.MaxPool2d(kernel_size=2, stride=2)
        self.l2_ = DoubleConvolution(64//reduce_by, 128//reduce_by, 128//reduce_by)
        self.a2_dwn = nn.MaxPool2d(kernel_size=2, stride=2)
        self.l3_ = DoubleConvolution(128//reduce_by, 256//reduce_by, 256//reduce_by)
        self.a3_dwn = nn.MaxPool2d(kernel_size=2, stride=2)
        self.l4_ = DoubleConvolution(256//reduce_by, 512//reduce_by, 512//reduce_by)
        self.a4_dwn = nn.MaxPool2d(kernel_size=2, stride=2)

        self.l_mid = DoubleConvolution(512//reduce_by, 1024//reduce_by, 1024//reduce_by)   
        
        self.a_mid_up = nn.ConvTranspose2d(1024//reduce_by, 512//reduce_by, kernel_size=2, stride=2)
        self._l4 = DoubleConvolution(1024//reduce_by, 512//reduce_by, 512//reduce_by)

        self.a4_up = nn.ConvTranspose2d(512//reduce_by, 256//reduce_by, kernel_size=2, stride=2)
        self._l3 = DoubleConvolution(512//reduce_by, 256//reduce_by, 256//reduce_by)

        self.a3_up = nn.ConvTranspose2d(256//reduce_by, 128//reduce_by, kernel_size=2, stride=2)
        self._l2 = DoubleConvolution(256//reduce_by, 128//reduce_by, 128//reduce_by)

        self.a2_up = nn.ConvTranspose2d(128//reduce_by, 64//reduce_by, kernel_size=2, stride=2)
        self._l1 = DoubleConvolution(128//reduce_by, 64//reduce_by, 64//reduce_by)

        self.final = nn.Conv2d(64//reduce_by, num_classes, kernel_size=1)
        self.reset_parameters()

    def forward(self, x): 
        
        a1_ = self.l1_(x)
        a1_dwn = self.a1_dwn(a1_)

        a2_ = self.l2_(a1_dwn)
        a2_dwn = self.a2_dwn(a2_)

        a3_ = self.l3_(a2_dwn)
        a3_dwn = self.a3_dwn(a3_)

        a4_ = self.l4_(a3_dwn)
        # a4_ = F.dropout(a4_, p=0.2)
        a4_dwn = self.a4_dwn(a4_)

        a_mid = self.l_mid(a4_dwn)                      
        
        a_mid_up = self.a_mid_up(a_mid)                              
        _a4 = self._l4(UNet.match_and_concat(a4_, a_mid_up))      
        # _a4 = F.dropout(_a4, p=0.2)

        a4_up = self.a4_up(_a4)                                
        _a3 = self._l3(UNet.match_and_concat(a3_, a4_up))      

        a3_up = self.a3_up(_a3)                                
        _a2 = self._l2(UNet.match_and_concat(a2_, a3_up))      
        # _a2 = F.dropout(_a2, p=0.2)

        a2_up = self.a2_up(_a2)                                
        _a1 = self._l1(UNet.match_and_concat(a1_, a2_up))     

        final = self.final(_a1)
        return F.log_softmax(final, 1)

    @staticmethod
    def match_and_concat(bypass, upsampled, crop=True):
        
        if crop:
            c_h = (bypass.shape[2] - upsampled.shape[2])
            c_w = (bypass.shape[3] - upsampled.shape[3])
            if c_h%2==0:
                c_hu = c_hd = c_h//2
            else:
                c_hu = c_h//2
                c_hd = c_h//2+1
            if c_w%2==0:
                c_wl = c_wr = c_w//2
            else:
                c_wl = c_w//2
                c_wr = c_w//2+1
                
            bypass = F.pad(bypass, (-c_wl, -c_wr, -c_hu, -c_hd))
        return torch.cat((upsampled, bypass), 1)

    def reset_parameters(self):

        for module in self.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()
                    
    def get_shapes(self, img_shape):

        input_img = torch.zeros(img_shape)[None, None]
        input_img = input_img.to(next(model.parameters()).device)
        output = self(input_img)
        return output[0, 0].shape

In [2]:
# U-Net encoder decoder
#export models/edunet.py
'''U-Net architecture'''

import torch
import torch.nn.functional as F
from torch import nn
from torch import tensor

# For importing in both the notebook and in the .py file
try:
    import ActivationSampler
except ImportError:
    from torchtrainer.module_util import ActivationSampler

class DoubleConvolution(nn.Module):
    def __init__(self, in_channels, middle_channel, out_channels, kernel_size=3, p=1):
        super(DoubleConvolution, self).__init__()
        layers = [
            nn.Conv2d(in_channels, middle_channel, kernel_size=kernel_size, padding=p),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(middle_channel),
            nn.Conv2d(middle_channel, out_channels, kernel_size=kernel_size, padding=p),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels)
        ]
        self.dconv = nn.Sequential(*layers)

    def forward(self, x):
        return self.dconv(x)
    
class ResDoubleConvolution(nn.Module):
    def __init__(self, in_channels, middle_channel, out_channels, kernel_size=3, p=1):
        super(ResDoubleConvolution, self).__init__()
        
        self.relu = nn.ReLU(inplace=True)
        layers = [
            nn.Conv2d(in_channels, middle_channel, kernel_size=kernel_size, padding=p, bias=False),
            nn.BatchNorm2d(middle_channel),
            self.relu,
            nn.Conv2d(middle_channel, out_channels, kernel_size=kernel_size, padding=p, bias=False),
            nn.BatchNorm2d(out_channels)
        ]
        self.dconv = nn.Sequential(*layers)

    def forward(self, x):
        
        return self.relu(self.dconv(x) + x)
    
class Concat(nn.Module):
    '''Module for concatenating two activations'''

    def __init__(self, concat_dim=1):
        super(Concat, self).__init__()    
        self.concat_dim = concat_dim
    
    def forward(self, x1, x2):
        # Inputs will be padded if not the same size
        
        x1, x2 = self.pad_inputs(x1, x2)
        return torch.cat((x1, x2), self.concat_dim)
    
    def pad_inputs(self, x1, x2):
        
        cd = self.concat_dim
        shape_diff = tensor(x2.shape[cd+1:]) - tensor(x1.shape[cd+1:])
        pad1 = []
        pad2 = []
        for sd in shape_diff.flip(0):
            sd_abs = abs(sd.item())
            if sd%2==0:
                pb = pe = sd_abs//2
            else:
                pb = sd_abs//2
                pe = pb + 1
                
            if sd>=0:
                pad1 += [pb, pe]
                pad2 += [0, 0]
            else:
                pad1 += [0, 0]
                pad2 += [pb, pe]
                
        x1 = F.pad(x1, pad1)
        x2 = F.pad(x2, pad2)
        
        return x1, x2
    
    def extra_repr(self):
        s = 'concat_dim={concat_dim}'
        return s.format(**self.__dict__)
        
    
class Encoder(nn.Module):
    '''Encoder part of U-Net'''
    
    def __init__(self, num_channels, ConvBlock, reduce_by=1):
        super(Encoder, self).__init__()

        self.l1_ = ConvBlock(num_channels, 64//reduce_by, 64//reduce_by)
        self.a1_dwn = nn.MaxPool2d(kernel_size=2, stride=2)
        self.l2_ = ConvBlock(64//reduce_by, 128//reduce_by, 128//reduce_by)
        self.a2_dwn = nn.MaxPool2d(kernel_size=2, stride=2)
        self.l3_ = ConvBlock(128//reduce_by, 256//reduce_by, 256//reduce_by)
        self.a3_dwn = nn.MaxPool2d(kernel_size=2, stride=2)
        self.l4_ = ConvBlock(256//reduce_by, 512//reduce_by, 512//reduce_by)
        self.a4_dwn = nn.MaxPool2d(kernel_size=2, stride=2)

        self.l_mid = ConvBlock(512//reduce_by, 1024//reduce_by, 1024//reduce_by)

    def forward(self, x):
        for layer in self.children(): x = layer(x)
        return x


class EDUNet(nn.Module):
    def __init__(self, num_channels, num_classes):
        super(EDUNet, self).__init__()

        reduce_by = 1            

        ConvBlock = DoubleConvolution
        self.encoder = Encoder(num_channels, ConvBlock)
                  
        self.a_mid_up = nn.ConvTranspose2d(1024//reduce_by, 512//reduce_by, kernel_size=2, stride=2)
        self.sample_a4_ = ActivationSampler(self.encoder.l4_)
        self.concat_a4 = Concat(1)
        self._l4 = ConvBlock(1024//reduce_by, 512//reduce_by, 512//reduce_by)

        self.a4_up = nn.ConvTranspose2d(512//reduce_by, 256//reduce_by, kernel_size=2, stride=2)
        self.sample_a3_ = ActivationSampler(self.encoder.l3_)
        self.concat_a3 = Concat(1)
        self._l3 = ConvBlock(512//reduce_by, 256//reduce_by, 256//reduce_by)

        self.a3_up = nn.ConvTranspose2d(256//reduce_by, 128//reduce_by, kernel_size=2, stride=2)
        self.sample_a2_ = ActivationSampler(self.encoder.l2_)
        self.concat_a2 = Concat(1)
        self._l2 = ConvBlock(256//reduce_by, 128//reduce_by, 128//reduce_by)

        self.a2_up = nn.ConvTranspose2d(128//reduce_by, 64//reduce_by, kernel_size=2, stride=2)
        self.sample_a1_ = ActivationSampler(self.encoder.l1_)
        self.concat_a1 = Concat(1)
        self._l1 = ConvBlock(128//reduce_by, 64//reduce_by, 64//reduce_by)

        self.final = nn.Conv2d(64//reduce_by, num_classes, kernel_size=1)
        self.reset_parameters()

    def forward(self, x):    
        a_mid = self.encoder(x)                        
        
        a_mid_up = self.a_mid_up(a_mid)  
        a4_ = self.sample_a4_()
        _a4 = self._l4(self.concat_a4(a4_, a_mid_up))      
        # _a4 = F.dropout(_a4, p=0.2)

        a4_up = self.a4_up(_a4)           
        a3_ = self.sample_a3_()
        _a3 = self._l3(self.concat_a3(a3_, a4_up))      

        a3_up = self.a3_up(_a3)      
        a2_ = self.sample_a2_()
        _a2 = self._l2(self.concat_a2(a2_, a3_up))      
        # _a2 = F.dropout(_a2, p=0.2)

        a2_up = self.a2_up(_a2)  
        a1_ = self.sample_a1_()
        _a1 = self._l1(self.concat_a1(a1_, a2_up))     

        final = self.final(_a1)
        return F.log_softmax(final, 1)

    def reset_parameters(self):

        for module in self.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()
                    
    def get_shapes(self, img_shape):

        input_img = torch.zeros(img_shape)[None, None]
        input_img = input_img.to(next(model.parameters()).device)
        output = self(input_img)
        return output[0, 0].shape
    
class ActivationSampler(nn.Module):
    '''Generates a hook for sampling a layer activation'''
    
    def __init__(self, model):
        super(ActivationSampler, self).__init__()
        self.model_name = model.__class__.__name__
        self.activation = None
        model.register_forward_hook(self.get_hook())
        
    def forward(self, x=None):
        return self.activation
    
    def get_hook(self):
        def hook(model, input, output):
            self.activation = output
        return hook
    
    def extra_repr(self):
        return f'{self.model_name}'
    

In [3]:
# U-Net encoder decoder ResNet
#export models/resunet.py
'''U-Net architecture with residual blocks'''

import torch
import torch.nn.functional as F
from torch import nn
from torch import tensor

# For importing in both the notebook and in the .py file
try:
    import ActivationSampler
except ImportError:
    from torchtrainer.module_util import ActivationSampler

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

    
class ResBlock(nn.Module):

    def __init__(self, inplanes, planes, stride=1, norm_layer=None):
        super(ResBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self.stride = stride
          
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        
        if (inplanes!=planes) or (stride>1):
            # If in and out planes are different, we also need to change the planes of the input
            # If stride is not 1, we need to change the size of the input
            reshape_input = nn.Sequential(
                                    conv1x1(inplanes, planes, stride),
                                    norm_layer(planes),
                            )
            self.reshape_input = reshape_input
        else:
            self.reshape_input = None

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.reshape_input is not None:
            identity = self.reshape_input(x)

        out += identity
        out = self.relu(out)

        return out
    
class Concat(nn.Module):
    '''Module for concatenating two activations'''

    def __init__(self, concat_dim=1):
        super(Concat, self).__init__()    
        self.concat_dim = concat_dim
    
    def forward(self, x1, x2):
        # Inputs will be padded if not the same size
        
        x1, x2 = self.pad_inputs(x1, x2)
        return torch.cat((x1, x2), self.concat_dim)
    
    def pad_inputs(self, x1, x2):
        
        cd = self.concat_dim
        shape_diff = tensor(x2.shape[cd+1:]) - tensor(x1.shape[cd+1:])
        pad1 = []
        pad2 = []
        for sd in shape_diff.flip(0):
            sd_abs = abs(sd.item())
            if sd%2==0:
                pb = pe = sd_abs//2
            else:
                pb = sd_abs//2
                pe = pb + 1
                
            if sd>=0:
                pad1 += [pb, pe]
                pad2 += [0, 0]
            else:
                pad1 += [0, 0]
                pad2 += [pb, pe]
                
        x1 = F.pad(x1, pad1)
        x2 = F.pad(x2, pad2)
        
        return x1, x2
    
    def extra_repr(self):
        s = 'concat_dim={concat_dim}'
        return s.format(**self.__dict__)
        
    
class Encoder(nn.Module):
    '''Encoder part of U-Net'''
    
    def __init__(self, num_channels, reduce_by=1):
        super(Encoder, self).__init__()
        
        #num_planes = 64
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=7, stride=1, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)

        self.resblock1 = ResBlock(64, 64, stride=1)
        self.resblock2 = ResBlock(64, 128, stride=2)
        self.resblock3 = ResBlock(128, 256, stride=2)
        self.resblock4 = ResBlock(256, 512, stride=2)
        self.resblock_mid = ResBlock(512, 1024, stride=2)

    def forward(self, x):
        for layer in self.children(): x = layer(x)
        return x
    
class ResUNet(nn.Module):
    # TODO: fix output size being different than input
    
    def __init__(self, num_channels, num_classes):
        super(ResUNet, self).__init__()
         
        self.encoder = Encoder(num_channels)
                  
        self.a_mid_up = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.blur_mid_up = Blur()
        self.sample_a4_ = ActivationSampler(self.encoder.resblock4)
        self.concat_a4 = Concat(1)
        self._l4 = ResBlock(1024, 512, stride=1)

        self.a4_up = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.blur_a4_up = Blur()
        self.sample_a3_ = ActivationSampler(self.encoder.resblock3)
        self.concat_a3 = Concat(1)
        self._l3 = ResBlock(512, 256, stride=1)

        self.a3_up = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.blur_a3_up = Blur()
        self.sample_a2_ = ActivationSampler(self.encoder.resblock2)
        self.concat_a2 = Concat(1)
        self._l2 = ResBlock(256, 128, stride=1)

        self.a2_up = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.blur_a2_up = Blur()
        self.sample_a1_ = ActivationSampler(self.encoder.resblock1)
        self.concat_a1 = Concat(1)
        self._l1 = ResBlock(128, 64, stride=1)

        self.final = nn.Conv2d(64, num_classes, kernel_size=1)
        self.reset_parameters()

    def forward(self, x):    
        a_mid = self.encoder(x)                        
        
        a_mid_up = self.a_mid_up(a_mid)  
        a_mid_up = self.blur_mid_up(a_mid_up)
        a4_ = self.sample_a4_()
        _a4 = self._l4(self.concat_a4(a4_, a_mid_up))      
        # _a4 = F.dropout(_a4, p=0.2)

        a4_up = self.a4_up(_a4)   
        a4_up = self.blur_a4_up(a4_up)
        a3_ = self.sample_a3_()
        _a3 = self._l3(self.concat_a3(a3_, a4_up))      

        a3_up = self.a3_up(_a3)      
        a3_up = self.blur_a3_up(a3_up)
        a2_ = self.sample_a2_()
        _a2 = self._l2(self.concat_a2(a2_, a3_up))      
        # _a2 = F.dropout(_a2, p=0.2)

        a2_up = self.a2_up(_a2)  
        a2_up = self.blur_a2_up(a2_up)
        a1_ = self.sample_a1_()
        _a1 = self._l1(self.concat_a1(a1_, a2_up))     

        final = self.final(_a1)
        return F.log_softmax(final, 1)

    def reset_parameters(self):

        for module in self.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()
                    
    def get_shapes(self, img_shape):

        input_img = torch.zeros(img_shape)[None, None]
        input_img = input_img.to(next(model.parameters()).device)
        output = self(input_img)
        return output[0, 0].shape

class Blur(nn.Module):
    
    def __init__(self):
        super(Blur, self).__init__()
        
        self.pad = nn.ReplicationPad2d((0,1,0,1))
        self.blur = nn.AvgPool2d(2, stride=1)
    
    def forward(self, x):
        
        return self.blur(self.pad(x))

In [3]:
# Dynamic UNet from Fastai (not working)
import fastai
import fastai.vision as fai_vision
import torch
import torchvision

def DUNet(n_classes, img_size):
    
    resnet = fastai.vision.models.resnet34(True)
    del resnet.avgpool
    del resnet.fc
    model = torch.nn.Sequential(*list(resnet.children()))
    
    return fai_vision.models.unet.DynamicUnet(model, n_classes=n_classes, img_size=img_size)

#resnet34 = torchvision.models.resnet34(False)
#unetai = DUNet(2, (1000,1000))
edunet = EDUNet(3, 2)

## Test model

In [3]:
xb = torch.randn((2, 3, 572, 572)).to('cuda')
model = ResUNet(3, 2)
model.to('cuda')

a1_ = ActivationSampler(model.encoder.resblock1)
a2_ = ActivationSampler(model.encoder.resblock2)
a3_ = ActivationSampler(model.encoder.resblock3)
a4_ = ActivationSampler(model.encoder.resblock4)
a_mid = ActivationSampler(model.encoder.resblock_mid)
a_mid_up = ActivationSampler(model.a_mid_up)
_a4 = ActivationSampler(model._l4)
_a3 = ActivationSampler(model._l3)
_a2 = ActivationSampler(model._l2)
_a1 = ActivationSampler(model._l1)

pred = model(xb)

In [7]:
print(a1_().shape)
print(a2_().shape)
print(a3_().shape)
print(a4_().shape)
print(a_mid().shape)
print(a_mid_up().shape)
print(_a4().shape)
print(_a3().shape)
print(_a2().shape)
print(_a1().shape)

torch.Size([2, 64, 572, 572])
torch.Size([2, 128, 286, 286])
torch.Size([2, 256, 143, 143])
torch.Size([2, 512, 72, 72])
torch.Size([2, 1024, 36, 36])
torch.Size([2, 512, 72, 72])
torch.Size([2, 512, 72, 72])
torch.Size([2, 256, 144, 144])
torch.Size([2, 128, 288, 288])
torch.Size([2, 64, 576, 576])


In [7]:
resnet34

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [17]:
unetai

DynamicUnet(
  (layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05

In [4]:
from pytorchlb.torchsummary import summary
edunet.to('cuda')
summary(edunet, input_size=(3, 500, 500))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 500, 500]           1,792
              ReLU-2         [-1, 64, 500, 500]               0
       BatchNorm2d-3         [-1, 64, 500, 500]             128
            Conv2d-4         [-1, 64, 500, 500]          36,928
              ReLU-5         [-1, 64, 500, 500]               0
       BatchNorm2d-6         [-1, 64, 500, 500]             128
 DoubleConvolution-7         [-1, 64, 500, 500]               0
         MaxPool2d-8         [-1, 64, 250, 250]               0
            Conv2d-9        [-1, 128, 250, 250]          73,856
             ReLU-10        [-1, 128, 250, 250]               0
      BatchNorm2d-11        [-1, 128, 250, 250]             256
           Conv2d-12        [-1, 128, 250, 250]         147,584
             ReLU-13        [-1, 128, 250, 250]               0
      BatchNorm2d-14        [-1, 128, 2

(tensor(31043586), tensor(31043586))