In [1]:
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.sampler import RandomSampler

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

import torchvision
import torchvision.models as models
import torchvision.transforms as transforms

from torchsummary import summary

In [24]:
class Decoder(nn.Module):
    def __init__(self, in_channels, channels, out_channels):
        super(Decoder, self).__init__()

        self.decoder = nn.Sequential(
            nn.Conv2d(in_channels, channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),

            nn.Conv2d(channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True), 
        )

    def forward(self, x):
        x = self.decoder(torch.cat(x, 1))
        
        return x

In [37]:
class SimpleRes34Unet(nn.Module):
    def __init__(self, n_classes=4):
        super(SimpleRes34Unet, self).__init__()
        self.resnet = torchvision.models.resnet34(True)
        self.encode1 = nn.Sequential(self.resnet.conv1,
                                     self.resnet.bn1,
                                     self.resnet.relu)
        self.encode2 = nn.Sequential(self.resnet.layer1)
        self.encode3 = nn.Sequential(self.resnet.layer2)
        self.encode4 = nn.Sequential(self.resnet.layer3)
        self.encode5 = nn.Sequential(self.resnet.layer4)

        self.resnet = None # delete resnet
        
        self.decode5 = Decoder(512, 256, 64)
        self.decode4 = Decoder( 64, 256, 64)
        self.decode3 = Decoder( 64, 128, 64)
        self.decode2 = Decoder( 64,  64, 64)
        self.decode1 = Decoder( 64,  32, 64)
        self.logit = nn.Conv2d( 64, n_classes, kernel_size=1, bias=False)
        
    def forward(self, x):
        x = self.encode1(x)
        e2 = self.encode2(x)
        e3 = self.encode3(e2)
        e4 = self.encode4(e3)
        e5 = self.encode5(e4)
        
        d5 = self.decode5([e5, ])
        d4 = self.decode4([d5, e4])
        d3 = self.decode3([d4, e3])
        d2 = self.decode2([d3, e2])
        d1 = self.decode1([d2, ])
        
        f = self.logit(x)
        
        return f

In [38]:
ru = SimpleRes34Unet()

In [39]:
ru

SimpleRes34Unet(
  (resnet): None
  (encode1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (encode2): Sequential(
    (0): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace)
        (

In [40]:
summary(ru, input_size=(3, 256, 1600))

AttributeError: 'list' object has no attribute 'size'