In [1]:
import datetime
import os
from math import sqrt
from torch import nn

import numpy as np
import torchvision.transforms as standard_transforms
import torchvision.utils as vutils
from tensorboard import SummaryWriter
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
import sys
sys.path.append('../')
import utils.joint_transforms as joint_transforms
import utils.transforms as extended_transforms
from datasets import LIP
from models import *
from utils import check_mkdir, evaluate, AverageMeter, CrossEntropyLoss2d

In [2]:
ckpt_path = './checkpoints/'
exp_name = 'lip-psp_net'
writer = SummaryWriter(os.path.join(ckpt_path, 'exp', exp_name))

In [4]:
args = {
    'train_batch_size': 8,
    'lr': 1e-2 / sqrt(16. / 4),
    'lr_decay': 0.9,
    'max_iter': 3e4,
    'longer_size': 512,
    'crop_size': 473,
    'stride_rate': 2 / 3.,
    'weight_decay': 1e-4,
    'momentum': 0.9,
    'snapshot': '',
    'print_freq': 10,
    'val_save_to_img_file': True,
    'val_img_sample_rate': 0.01,  # randomly sample some validation results to display,
    'val_img_display_size': 384,
}

In [9]:
mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

train_joint_transform = joint_transforms.Compose([
    joint_transforms.RandomSized(args['crop_size']),
    # joint_transforms.Scale(args['longer_size']),
    joint_transforms.RandomRotate(10),
    joint_transforms.RandomHorizontallyFlip()
])
sliding_crop = joint_transforms.SlidingCrop(args['crop_size'], args['stride_rate'], LIP.ignore_label)
train_input_transform = standard_transforms.Compose([
    standard_transforms.ToTensor(),
    standard_transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_input_transform = standard_transforms.Compose([
    standard_transforms.ToTensor(),
    standard_transforms.Normalize(*mean_std)
])
target_transform = extended_transforms.MaskToTensor()
visualize = standard_transforms.Compose([
    standard_transforms.Scale(args['val_img_display_size']),
    standard_transforms.ToTensor()
])

In [11]:
train_set = LIP.LIP('train', joint_transform=train_joint_transform,
                                      transform=train_input_transform, target_transform=target_transform)
train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=8, shuffle=False)
val_set = LIP.LIP('val', transform=val_input_transform,
                                target_transform=target_transform)
val_loader = DataLoader(val_set, batch_size=1, num_workers=8, shuffle=False)

In [None]:
class _PyramidPoolingModule(nn.Module):
    def __init__(self, in_dim, reduction_dim, setting):
        super(_PyramidPoolingModule, self).__init__()
        self.features = []
        for s in setting:
            self.features.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(s),
                nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
                nn.BatchNorm2d(reduction_dim, momentum=.95),
                nn.ReLU(inplace=True)
            ))
        self.features = nn.ModuleList(self.features)

    def forward(self, x):
        x_size = x.size()
        out = [x]
        for f in self.features:
            out.append(F.upsample(f(x), x_size[2:], mode='bilinear'))
        out = torch.cat(out, 1)
        return out


class PSPNet(nn.Module):
    def __init__(self, num_classes, pretrained=True, use_aux=True):
        super(PSPNet, self).__init__()
        self.use_aux = use_aux
        resnet = models.resnet101()
        if pretrained:
            resnet.load_state_dict(torch.load(res101_path))
        self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
        self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)

        self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6))
        self.final = nn.Sequential(
            nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=.95),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(512, num_classes, kernel_size=1)
        )

        if use_aux:
            self.aux_logits = nn.Conv2d(1024, num_classes, kernel_size=1)
            initialize_weights(self.aux_logits)

        initialize_weights(self.ppm, self.final)

    def forward(self, x):
        x_size = x.size()
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        if self.training and self.use_aux:
            aux = self.aux_logits(x)
        x = self.layer4(x)
        x = self.ppm(x)
        x = self.final(x)
        if self.training and self.use_aux:
            return F.upsample(x, x_size[2:], mode='bilinear'), F.upsample(aux, x_size[2:], mode='bilinear')
        return F.upsample(x, x_size[2:], mode='bilinear')


# just a try, not recommend to use
class PSPNetDeform(nn.Module):
    def __init__(self, num_classes, input_size, pretrained=True, use_aux=True):
        super(PSPNetDeform, self).__init__()
        self.input_size = input_size
        self.use_aux = use_aux
        resnet = models.resnet101()
        if pretrained:
            resnet.load_state_dict(torch.load(res101_path))
        self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4

        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.padding = (1, 1)
                m.stride = (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.padding = (1, 1)
                m.stride = (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)
        for idx in range(len(self.layer3)):
            self.layer3[idx].conv2 = Conv2dDeformable(self.layer3[idx].conv2)
        for idx in range(len(self.layer4)):
            self.layer4[idx].conv2 = Conv2dDeformable(self.layer4[idx].conv2)

        self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6))
        self.final = nn.Sequential(
            nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512, momentum=.95),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Conv2d(512, num_classes, kernel_size=1)
        )

        if use_aux:
            self.aux_logits = nn.Conv2d(1024, num_classes, kernel_size=1)
            initialize_weights(self.aux_logits)

        initialize_weights(self.ppm, self.final)

    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        if self.training and self.use_aux:
            aux = self.aux_logits(x)
        x = self.layer4(x)
        x = self.ppm(x)
        x = self.final(x)
        if self.training and self.use_aux:
            return F.upsample(x, self.input_size, mode='bilinear'), F.upsample(aux, self.input_size, mode='bilinear')
        return F.upsample(x, self.input_size, mode='bilinear')

In [None]:
dataiter = iter(train_loader)
data=dataiter.next()
img,gts = dataiter.next()
print(img.size(),gts.size())

In [None]:
net=PSPNet(num_classes=20)

In [12]:
for epoch in range(3):
    for i,data in enumerate(train_loader):
        inputs,gts = data
        print(inputs.size(),gts.size())

(torch.Size([8, 3, 473, 473]), torch.Size([8, 473, 473]))
(torch.Size([8, 3, 473, 473]), torch.Size([8, 473, 473]))
(torch.Size([8, 3, 473, 473]), torch.Size([8, 473, 473]))
(torch.Size([8, 3, 473, 473]), torch.Size([8, 473, 473]))
(torch.Size([8, 3, 473, 473]), torch.Size([8, 473, 473]))
(torch.Size([8, 3, 473, 473]), torch.Size([8, 473, 473]))
(torch.Size([8, 3, 473, 473]), torch.Size([8, 473, 473]))
(torch.Size([8, 3, 473, 473]), torch.Size([8, 473, 473]))
(torch.Size([8, 3, 473, 473]), torch.Size([8, 473, 473]))
(torch.Size([8, 3, 473, 473]), torch.Size([8, 473, 473]))
(torch.Size([8, 3, 473, 473]), torch.Size([8, 473, 473]))
(torch.Size([8, 3, 473, 473]), torch.Size([8, 473, 473]))
(torch.Size([8, 3, 473, 473]), torch.Size([8, 473, 473]))
(torch.Size([8, 3, 473, 473]), torch.Size([8, 473, 473]))
(torch.Size([8, 3, 473, 473]), torch.Size([8, 473, 473]))
