In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import model_zoo
from torchvision.models.densenet import densenet121, densenet161
from torchvision.models.squeezenet import squeezenet1_1

In [2]:
def load_weight_sequential(target, source_state):
    new_dict = {}
    for (k1, v1), (k2, v2) in zip(target.state_dict().items(), source_state.items()):
        new_dict[k1] = v2
    target.load_state_dict(new_dict)

In [4]:
model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}

In [None]:
def conv3x3(in_planes, out_planes, stride=1, dilation=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                    padding=dilation, dilation=dilation, bias=False)

In [5]:
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride
        
    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        if self.downsample is not None:
            residual = self.downsample(x)
        
        out += residual
        out = self.relu(out)
        
        return out

In [6]:
class BottleNeck(nn.Module):
    expansion = 4
    
    def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
        super(BottleNeck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                                padding=dilation, dilation=dilation, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        
    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        
        out = self.conv3(out)
        out = self.bn3(out)
        
        if self.downsample is not None:
            residual = self.downsample(x)
        
        out += residual
        out = self.relu(out)
        
        return out

In [7]:
class ResNet(nn.Module):
    def __init__(self, block, layers=(3,4,23,3)):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, diliation=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, diliation=4)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
    
    def _make_layer(self, block, planes, blocks, stride=1, diliation=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion)
            )
            
        layers = [block(self.inplanes, planes, stride, downsample, diliation)]
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, diliation=diliation))
            
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x_3 = self.layer3(x)
        x = self.layer4(x_3)
        
        return x, x_3

In [8]:
class _DenseLayer(nn.Sequential):
    def __init__(self, num_input_feature, growth_rate, bn_size, drop_rate):
        super(_DenseLayer, self).__init__()
        self.add_module('norm.1', nn.BatchNorm2d(num_input_feature)),
        self.add_module('relu.1', nn.ReLU(inplace=True)),
        self.add_module('conv.1', nn.Conv2d(num_input_feature, bn_size * growth_rate, kernel_size=1, stride=1, bias=False)),
        self.add_module('norm.2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu.2', nn.ReLU(inplace=True)),
        self.add_module('conv.2', nn.Conv2d(bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False)),
        self.drop_rate = drop_rate
        
    def forward(self, x):
        new_features = super(_DenseLayer, self).forward(x)
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return torch.cat([x, new_features], 1)

In [9]:
class _DenseBlock(nn.Sequential):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
            self.add_module('denselayer%d' % (i + 1), layer)

In [10]:
class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features, downsample=False):
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, kernel_size=1, stride=1, bias=False))
        
        if downsample:
            self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
        else:
            self.add_module('pool', nn.AvgPool2d(kernel_size=1, stride=1))

In [11]:
class DenseNet(nn.Module):
    def __init__(self, growth_rate=32, batch_config=(6, 12, 24, 16),
                num_init_features=64, bn_size=4, drop_rate=0, pretrained=True):
        super(DenseNet, self).__init__()
        
        self.start_features = nn.Sequential(
            {
                [
                    ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
                    ('norm0', nn.BatchNorm2d(num_init_features)),
                    ('relu0', nn.ReLU(inplace=True)),
                    ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
                ]
            }
        )
        
        num_features = num_init_features
        
        init_weights = list(densenet121(pretrained=pretrained).features.children())
        start = 0
        for i, c in enumerate(init_weights):
            if pretrained:
                c.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/densenet121-a639ec97.pth'))
            start += 1
        self.blocks = nn.ModuleList()
        for i, num_layers in enumerate(batch_config):
            block = _DenseBlock(num_layers, num_features, bn_size, growth_rate, drop_rate)
            if pretrained:
                block.load_state_dict(model_zoo.load_url('https://download.pytorch.org/models/densenet121-a639ec97.pth'))
            start += 1
        self.blocks.append(block)
        setattr(self, 'denseblock%d' % (i + 1), block)
        
        num_features = num_features + num_layers * growth_rate
        if i != len(batch_config) - 1:
            downsample = True
            trans = _Transition(num_input_features=num_feature, num_output_features=num_features // 2,
                                downsample=downsample)
            if pretrained:
                trans.load_state_dict(init_weights[start].state_dict())
            start += 1
            self.blocks.append(trans)
            setattr(self, 'transition%d' % (i+1), trans)
            num_features = num_features // 2
            
    def forward(self, x):
        out = self.start_features(x)
        deep_features = None
        for i, block in enumerate(self.blocks):
            out = block(out)
            if i == 5:
                deep_features = out
                
        return out, deep_features
        

In [12]:
class Fire(nn.Module):
    def __init__(self, inplanes, squeeze_planes, expand1x1_planes, expand3x3_planes, diliation=1):
        super(Fire, self).__init__()
        self.inplanes = inplanes
        self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
        self.squeeze_activation = nn.ReLU(inplace=True)
        self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes, kernel_size=1)
        self.expand1x1_activation = nn.ReLU(inplace=True)
        self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes, kernel_size=3, padding=diliation, diliation=diliation)
        self.expand3x3_activation = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.squeeze_activation(self.squeeze(x))
        return torch.cat([
            self.expand1x1_activation(self.expand1x1(x)),
            self.expand3x3_activation(self.expand3x3(x))
        ], 1)

In [13]:
class SqueezeNet(nn.Module):
    def __init__(self, pretrained=False):
        super(SqueezeNet, self).__init__()
        
        self.feat_1 = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=7, stride=2),
            nn.ReLU(inplace=True),
        )
        self.feat_2 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
            Fire(96, 16, 64, 64),
            Fire(128, 16, 64, 64),
        )
        self.feat_3 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
            Fire(128, 32, 128, 128),
            Fire(256, 32, 128, 128),
        )
        self.feat_4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
            Fire(256, 48, 192, 192),
            Fire(384, 48, 192, 192),
            Fire(384, 64, 256, 256),
            Fire(512, 64, 256, 256),
        )
        
        if pretrained:
            weights = squeezenet1_1(pretrained=True).state_dict()
            load_weight_sequential(self, weights)
            
    def forward(self, x):
        f1 = self.feat_1(x)
        f2 = self.feat_2(f1)
        f3 = self.feat_3(f2)
        f4 = self.feat_4(f3)
        
        return f4, f3

In [None]:
def squeezenet(pretrained=True):
    model = SqueezeNet(pretrained=pretrained)
    return model

def densenet(pretrained=True):
    model = DenseNet(pretrained=pretrained)
    return model

def resnet18(pretrained=True):
    model = ResNet(BasicBlock, [2, 2, 2, 2])
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
    return model

def resnet34(pretrained=True):
    model = ResNet(BasicBlock, [3, 4, 6, 3])
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
    return model

def resnet50(pretrained=True):
    model = ResNet(BottleNeck, [3, 4, 6, 3])
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
    return model

def resnet101(pretrained=True):
    model = ResNet(BottleNeck, [3, 4, 23, 3])
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model

def resnet152(pretrained=True):
    model = ResNet(BottleNeck, [3, 8, 36, 3])
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
    return model

In [14]:
class PSPModule(nn.Module):
    def __init__(self, features, out_features=1024, sizes=(1, 2, 3, 6)):
        super().__init__()
        self.stages = []
        self.stages = nn.ModuleList([self._make_stage(features, size) for size in sizes])
        self.bottleneck = nn.Conv2d(features * (len(sizes) + 1), out_features, kernel_size=1)
        self.relu = nn.ReLU()
        
    def _make_stage(self, features, size):
        prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
        conv = nn.Conv2d(features, features, kernel_size=1, bias=False)
        return nn.Sequential(prior, conv)
    
    def forward(self, feats):
        h, w = feats.size(2), feats.size(3)
        priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear') for stage in self.stages] + [feats]
        bottle = self.bottleneck(torch.cat(priors, 1))
        return self.relu(bottle)

In [15]:
class PSPUpsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
        
    def forward(self, x):
        h, w = 2 * x.size(2), 2 * x.size(3)
        p = F.upsample(input=x, size=(h, w), mode='bilinear')
        return self.conv(p)

In [16]:
class PSPNet(nn.Module):
    def __init__(self, n_classes=18, size=(1, 2, 3, 6), psp_size=2048, deep_feature_size=1024,
                backend='resnet34', pretrained=True):
        super().__init__()
        self.feats = getattr(self, backend)(pretrained)
        self.psp = PSPModule(psp_size, 1024, size)
        self.drop_1 = nn.Dropout2d(p=0.3)
        
        self.up_1 = PSPUpsample(1024, 256)
        self.up_2 = PSPUpsample(256, 64)
        self.up_3 = PSPUpsample(64, 64)
        
        self.drop_2 = nn.Dropout2d(p=0.15)
        self.final = nn.Sequential(
            nn.Conv2d(64, n_classes, 1),
            nn.LogSoftmax(),
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(deep_feature_size, 256),
            nn.ReLU(),
            nn.Linear(256, n_classes),
        )
        
    def forward(self, x):
        f, class_f = self.feats(x)
        p = self.psp(f)
        p = self.drop_1(p)
        
        p = self.up_1(p)
        p = self.drop_2(p)
        
        p = self.up_2(p)
        p = self.drop_2(p)
        
        p = self.up_3(p)
        p = self.drop_2(p)
        
        aux = F.adaptive_max_pool2d(input=class_f, output_size=(1,1).view(-1, class_f.size(1)))
        
        return self.final(p), self.classifier(aux)

In [17]:
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.autograd import Variable
from torch.utils.data import DataLoader

from tqdm import tqdm
import numpy as np

In [19]:
models = {
    'squeezenet' : lambda: PSPNet(sizes=(1,2,3,6), psp_size=512, deep_feature_size=256, backend='squeezenet'),
    'densenet' : lambda: PSPNet(sizes=(1,2,3,6), psp_size=1024, deep_feature_size=512, backend='densenet'),
    'resnet18' : lambda: PSPNet(sizes=(1,2,3,6), psp_size=512, deep_feature_size=256, backend='resnet18'),
    'resnet34' : lambda: PSPNet(sizes=(1,2,3,6), psp_size=512, deep_feature_size=256, backend='resnet34'),
    'resnet50' : lambda: PSPNet(sizes=(1,2,3,6), psp_size=2048, deep_feature_size=1024, backend='resnet50'),
    'resnet101' : lambda: PSPNet(sizes=(1,2,3,6), psp_size=2048, deep_feature_size=1024, backend='resnet101'),
    'resnet152' : lambda: PSPNet(sizes=(1,2,3,6), psp_size=2048, deep_feature_size=1024, backend='resnet152'),
}

In [23]:
params = {
    'data_path' : './data',
    'model_path' : './models',
    'backend' : 'resnet50',
    'snapshot': './models/pspnet50_ADE20K.pth',
    'crop_x' : 256,
    'crop_y' : 256,
    'batch_size' : 16,
    'alpha': 1.0,
    'epochs' : 30,
    'gpu' : 'mps',
    'start_lr' : 0.001,
    'milestones' : '10,20',
}

In [24]:
def train(params):
    os.environ['CUDA_VISIBLE_DEVICES'] = params['gpu']
    net, starting_epoch = build_network(params['snapshot', 'backend'])
    data_path = params['data_path']
    model_path = params['model_path']
    os.makedirs(model_path, exist_ok=True)
    
    train_loader, class_weights, n_images = None, None, None
    
    optimizer = optim.Adam(net.parameters(), lr=params['start_lr'])
    scheduler = MultiStepLR(optimizer, milestones=[int(e) for e in params['milestones'].split(',')])
    
    for epoch in range(starting_epoch, params['epochs']):
        net.train()
        for i, data in enumerate(tqdm(train_loader)):
            inputs, labels = data
            inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
            
            optimizer.zero_grad()
            outputs, aux = net(inputs)
            loss = cross_entropy2d(input=outputs, target=labels, weight=class_weights)
            loss += params['alpha'] * cross_entropy2d(input=aux, target=labels, weight=class_weights)
            loss.backward()
            optimizer.step()
            
        scheduler.step()
        torch.save(net.state_dict(), os.path.join(model_path, 'pspnet{}_{}.pth'.format(params['backend'], epoch)))