In [1]:
%matplotlib inline
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
import d2l
from tqdm import tqdm

import os

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]

VOC_CLASSES = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
               'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
               'diningtable', 'dog', 'horse', 'motorbike', 'person',
               'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']

def read_voc_images(voc_dir, is_train=True):
    """读取所有VOC图像并标注"""
    txt_fname = os.path.join(voc_dir, 'ImageSets', 'Segmentation',
                             'train.txt' if is_train else 'val.txt')
    mode = torchvision.io.image.ImageReadMode.RGB
    with open(txt_fname, 'r') as f:
        images = f.read().split()
    features, labels = [], []
    for i, fname in enumerate(images):
        features.append(torchvision.io.read_image(os.path.join(
            voc_dir, 'JPEGImages', f'{fname}.jpg')))
        labels.append(torchvision.io.read_image(os.path.join(
            voc_dir, 'SegmentationClass' ,f'{fname}.png'), mode))
    return features, labels

def voc_colormap2label():
    """构建从RGB到VOC类别索引的映射"""
    colormap2label = torch.zeros(256 ** 3, dtype=torch.long)
    for i, colormap in enumerate(VOC_COLORMAP):
        colormap2label[
            (colormap[0] * 256 + colormap[1]) * 256 + colormap[2]] = i
    return colormap2label

def voc_label_indices(colormap, colormap2label):
    """将VOC标签中的RGB值映射到它们的类别索引"""
    colormap = colormap.permute(1, 2, 0).numpy().astype('int32')
    idx = ((colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256
           + colormap[:, :, 2])
    return colormap2label[idx]

def voc_rand_crop(feature, label, height, width):
    """随机裁剪特征和标签图像"""
    rect = torchvision.transforms.RandomCrop.get_params(
        feature, (height, width))
    feature = torchvision.transforms.functional.crop(feature, *rect)
    label = torchvision.transforms.functional.crop(label, *rect)
    return feature, label

class VOCSegDataset(torch.utils.data.Dataset):
    """一个用于加载VOC数据集的自定义数据集"""

    def __init__(self, is_train, crop_size, voc_dir):
        self.transform = torchvision.transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.crop_size = crop_size
        features, labels = read_voc_images(voc_dir, is_train=is_train)
        self.features = [self.normalize_image(feature)
                         for feature in self.filter(features)]
        self.labels = self.filter(labels)
        self.colormap2label = voc_colormap2label()
        print('read ' + str(len(self.features)) + ' examples')

    def normalize_image(self, img):
        return self.transform(img.float() / 255)

    def filter(self, imgs):
        return [img for img in imgs if (
            img.shape[1] >= self.crop_size[0] and
            img.shape[2] >= self.crop_size[1])]

    def __getitem__(self, idx):
        feature, label = voc_rand_crop(self.features[idx], self.labels[idx],
                                       *self.crop_size)
        return (feature, voc_label_indices(label, self.colormap2label))

    def __len__(self):
        return len(self.features)

In [4]:
pretrained_net = torchvision.models.resnet18(pretrained=True)
net = nn.Sequential(*list(pretrained_net.children())[:-2])

num_classes = 21
net.add_module('final_conv', nn.Conv2d(512, num_classes, kernel_size=1))
net.add_module('transpose_conv', nn.ConvTranspose2d(num_classes, num_classes,
                                    kernel_size=64, padding=16, stride=32))



In [5]:
def bilinear_kernel(in_channels, out_channels, kernel_size):
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = (torch.arange(kernel_size).reshape(-1, 1),
          torch.arange(kernel_size).reshape(1, -1))
    filt = (1 - torch.abs(og[0] - center) / factor) * \
           (1 - torch.abs(og[1] - center) / factor)
    weight = torch.zeros((in_channels, out_channels,
                          kernel_size, kernel_size))
    weight[range(in_channels), range(out_channels), :, :] = filt
    return weight

conv_trans = nn.ConvTranspose2d(3, 3, kernel_size=4, padding=1, stride=2,
                                bias=False)
conv_trans.weight.data.copy_(bilinear_kernel(3, 3, 4));

W = bilinear_kernel(num_classes, num_classes, 64)
net.transpose_conv.weight.data.copy_(W);

net = net.to(device)


In [6]:
trans = transforms.Compose([
    transforms.Resize((320, 480)),
    transforms.ToTensor()])

# train_set = datasets.VOCSegmentation(root='./data', year='2012', image_set='train',
#                                         download=True, transform=trans,
#                                         target_transform=transforms.ToTensor())
# test_set = datasets.VOCSegmentation(root='./data', year='2012', image_set='val',
#                                         download=True, transform=trans,
#                                         target_transform=transforms.ToTensor())
train_set = VOCSegDataset(True, (320, 480), './data/VOCdevkit/VOC2012')
test_set = VOCSegDataset(False, (320, 480), './data/VOCdevkit/VOC2012')

read 1114 examples


: 

In [None]:
train_size = int(0.8 * len(train_set))
valid_size = len(train_set) - train_size
train_set, valid_set = torch.utils.data.random_split(train_set, [train_size, valid_size])

train_iter = DataLoader(train_set, batch_size=64, shuffle=True)
valid_iter = DataLoader(valid_set, batch_size=64, shuffle=True)
test_iter = DataLoader(test_set, batch_size=64, shuffle=True)

len(train_iter), len(valid_iter), len(test_iter)

(19, 5, 23)

In [None]:
# def loss(inputs, targets):
#     return F.cross_entropy(inputs, targets, reduction='none').mean(1).mean(1)
loss = nn.CrossEntropyLoss(reduction='none')
lr, wd =0.001, 1e-3
trainer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd)

In [None]:
num_epochs = 5

def train(net, train_iter, valid_iter, num_epochs, trainer, loss, device):
    train_loss, valid_loss = [], [] 
    train_acc, valid_acc = [], []
    for epoch in range(num_epochs):
        net.train()
        running_loss = 0.0
        acc_sum, n = 0.0, 0
        for X, y in tqdm(train_iter):
            trainer.zero_grad()
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y).mean()
            l.backward()
            trainer.step()
            running_loss += l.item()
            acc_sum += (y_hat.argmax(axis=1) == y).sum().item()
            n += y.numel()
        train_loss.append(running_loss / len(train_iter))
        train_acc.append(acc_sum / n)
        net.eval()
        acc_sum, n = 0.0, 0
        for X, y in valid_iter:
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y).mean()
            acc_sum += (y_hat.argmax(axis=1) == y).sum().item()
            n += y.numel()
        valid_loss.append(l.item())
        valid_acc.append(acc_sum / n)
        
        print(f'epoch: {epoch}, train_loss: {train_loss[-1]:.4f}, valid_loss: {valid_loss[-1]:.4f}, train_acc: {train_acc[-1]:.4f}, valid_acc: {valid_acc[-1]:.4f}')
        
    plt.subplot(1, 2, 1)
    plt.plot(train_loss, label='train')
    plt.plot(valid_loss, label='valid')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(train_acc, label='train')
    plt.plot(valid_acc, label='valid')
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.legend()
    plt.show()
    
train(net, train_iter, valid_iter, num_epochs, trainer, loss, device)

  0%|          | 0/19 [00:00<?, ?it/s]

  0%|          | 0/19 [00:12<?, ?it/s]


RuntimeError: stack expects each tensor to be equal size, but got [1, 375, 500] at entry 0 and [1, 500, 333] at entry 3