In [28]:
import torchvision.transforms as transforms

data_transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])

In [29]:
from torchvision import datasets

train_data = datasets.VOCSegmentation(root="./data",
                                      year="2012",
                                      image_set="train",
                                      download="true",
                                      transform=data_transform,
                                      target_transform=data_transform)

test_data = datasets.VOCSegmentation(root="./data",
                                     year="2012",
                                     image_set="val",
                                     download="true",
                                     transform=data_transform,
                                     target_transform=data_transform)

viz_data = datasets.VOCSegmentation(root="./data", year="2012", image_set="train")

Using downloaded and verified file: ./data\VOCtrainval_11-May-2012.tar
Extracting ./data\VOCtrainval_11-May-2012.tar to ./data
Using downloaded and verified file: ./data\VOCtrainval_11-May-2012.tar
Extracting ./data\VOCtrainval_11-May-2012.tar to ./data


In [30]:
from torch.utils.data import DataLoader

BATCH_SIZE = 16

train_dataloader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=False)

print(f"train_dataloader: {len(train_dataloader)} * {BATCH_SIZE} images")
print(f"test_dataloader: {len(test_dataloader)} * {BATCH_SIZE} images")

train_dataloader: 92 * 16 images
test_dataloader: 91 * 16 images


In [110]:
## Structure of Resnet18
from torchvision import models
from torchinfo import summary

resnet = models.resnet18(pretrained=True)
batch_size = 16

summary(
    resnet,
    input_size=(batch_size, 3, 256, 256),
    col_names=["output_size", "num_params"],
)

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [16, 1000]                --
├─Conv2d: 1-1                            [16, 64, 128, 128]        9,408
├─BatchNorm2d: 1-2                       [16, 64, 128, 128]        128
├─ReLU: 1-3                              [16, 64, 128, 128]        --
├─MaxPool2d: 1-4                         [16, 64, 64, 64]          --
├─Sequential: 1-5                        [16, 64, 64, 64]          --
│    └─BasicBlock: 2-1                   [16, 64, 64, 64]          --
│    │    └─Conv2d: 3-1                  [16, 64, 64, 64]          36,864
│    │    └─BatchNorm2d: 3-2             [16, 64, 64, 64]          128
│    │    └─ReLU: 3-3                    [16, 64, 64, 64]          --
│    │    └─Conv2d: 3-4                  [16, 64, 64, 64]          36,864
│    │    └─BatchNorm2d: 3-5             [16, 64, 64, 64]          128
│    │    └─ReLU: 3-6                    [16, 64, 64, 64]          --
│

In [5]:
import torch
from torch import nn
import torch.nn.functional as F

class conv2DBatchNorm(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, padding, dilation, bias, activation=False):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels,
                              kernel_size, stride,
                              padding, dilation, bias=bias)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.activation = activation
        if self.activation:
            self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        if self.activation:
            outputs = self.relu(x)
        else:
            outputs = x

        return outputs

In [158]:
import torchvision
from torchvision import models

class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet = models.resnet101(weights=torchvision.models.ResNet101_Weights.DEFAULT)

    def forward(self, x):
        for i in range(7):
            x = list(self.resnet.children())[i](x)
        encoder_outputs = x
        return encoder_outputs

In [172]:
import torchvision
from torchvision import models

class CustomResNet101(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet = models.resnet101(weights=torchvision.models.ResNet101_Weights.DEFAULT)

    def forward(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        auxiliary_x = self.resnet.layer3(x)
        x = self.resnet.layer4(auxiliary_x)

        return x, auxiliary_x

In [7]:
class PyramidPooling(nn.Module):
    def __init__(self, in_channels, pool_sizes, height, width):
        super().__init__()

        self.height = height
        self.width = width

        out_channels = int(in_channels / len(pool_sizes))

        ## pool__sizes: [6, 3, 2, 1]
        self.avpool_1 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[0])
        self.cbr_1 = conv2DBatchNorm(
            in_channels, out_channels, kernel_size=1, stride=1,
            padding=0, dilation=1, bias=False, activation=True)

        self.avpool_2 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[1])
        self.cbr_2 = conv2DBatchNorm(
            in_channels, out_channels, kernel_size=1, stride=1,
            padding=0, dilation=1, bias=False, activation=True)

        self.avpool_3 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[2])
        self.cbr_3 = conv2DBatchNorm(
            in_channels, out_channels, kernel_size=1, stride=1,
            padding=0, dilation=1, bias=False, activation=True)

        self.avpool_4 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[3])
        self.cbr_4 = conv2DBatchNorm(
            in_channels, out_channels, kernel_size=1, stride=1,
            padding=0, dilation=1, bias=False, activation=True)

    def forward(self, x):
        out1 = self.cbr_1(self.avpool_1(x))
        out1 = F.interpolate(out1, size=(
            self.height, self.width), mode='bilinear', align_corners=True)

        out2 = self.cbr_2(self.avpool_2(x))
        out2 = F.interpolate(out2, size=(
            self.height, self.width), mode='bilinear', align_corners=True)

        out3 = self.cbr_3(self.avpool_3(x))
        out3 = F.interpolate(out3, size=(
            self.height, self.width), mode='bilinear', align_corners=True)

        out4 = self.cbr_4(self.avpool_4(x))
        out4 = F.interpolate(out4, size=(
            self.height, self.width), mode='bilinear', align_corners=True)

        output = torch.cat([x, out1, out2, out3, out4], dim=1)

        return output

In [8]:
class DecodePSPFeature(nn.Module):
    def __init__(self, height, width, n_classes):
        super().__init__()

        self.height = height
        self.width = width

        self.cbr = conv2DBatchNorm(
            in_channels=256, out_channels=64, kernel_size=3, stride=1,
            padding=1, dilation=1,  bias=False, activation=True)
        self.dropout = nn.Dropout2d(p=0.1)
        self.classification = nn.Conv2d(
            in_channels=64, out_channels=n_classes, kernel_size=1,
            stride=1, padding=0)

    def forward(self, x):
        x = self.cbr(x)
        x = self.dropout(x)
        x = self.classification(x)
        output = F.interpolate(
            x, size=(self.height, self.width),
            mode='bilinear', align_corners=True)

        return output

In [9]:
class AuxiliaryPSPlayers(nn.Module):
    def __init__(self, in_channels, height, width, n_classes):
        super().__init__()

        self.height = height
        self.width = width

        self.cbr = conv2DBatchNorm(
            in_channels=in_channels, out_channels=64,
            kernel_size=3, stride=1, padding=1,
            dilation=1, bias=False, activation=True)
        self.dropout = nn.Dropout2d(p=0.1)
        self.classification = nn.Conv2d(
            in_channels=64, out_channels=n_classes,
            kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        x = self.cbr(x)
        x = self.dropout(x)
        x = self.classification(x)
        output = F.interpolate(
            x, size=(self.height, self.width),
            mode='bilinear', align_corners=True)

        return output

In [170]:
class PSPNet(nn.Module):
    def __init__(self, n_classes):
        super().__init__()

        full_img_size = 256
        feature_map_size = 8

        self.feature_extractor = CustomResNet101()
        self.pyramid_pooling = PyramidPooling(
            in_channels=2048, pool_sizes=[6, 3, 2, 1],
            height=feature_map_size, width=feature_map_size)
        self.decode_feature = DecodePSPFeature(
            height=full_img_size, width=full_img_size,
            n_classes=n_classes)
        self.aux = AuxiliaryPSPlayers(
            in_channels=64, n_classes=n_classes,
            height=full_img_size, width=full_img_size)

    def forward(self, x):
        encoder_outputs = self.feature_extractor(x)
        pyramid_outputs = self.pyramid_pooling(encoder_outputs)
        #docoder_outputs = self.decode_feature(pyramid_outputs)
        return pyramid_outputs

In [171]:
model = PSPNet(n_classes=4)
input = next(iter(train_dataloader))[0]
pred = model(input)
input.shape

NameError: name 'CustomResNet101' is not defined

In [169]:
model = PSPNet(n_classes=4)
batch_size = 16

summary(
    model,
    input_size=(batch_size, 3, 256, 256),
    col_names=["output_size", "num_params"],
)

Layer (type:depth-idx)                             Output Shape              Param #
PSPNet                                             [16, 4096, 8, 8]          185,096
├─FeatureExtractor2: 1-1                           [16, 2048, 8, 8]          --
│    └─ResNet: 2-1                                 --                        2,049,000
│    │    └─Conv2d: 3-1                            [16, 64, 128, 128]        9,408
│    │    └─BatchNorm2d: 3-2                       [16, 64, 128, 128]        128
│    │    └─ReLU: 3-3                              [16, 64, 128, 128]        --
│    │    └─MaxPool2d: 3-4                         [16, 64, 64, 64]          --
│    │    └─Sequential: 3-5                        [16, 256, 64, 64]         215,808
│    │    └─Sequential: 3-6                        [16, 512, 32, 32]         1,219,584
│    │    └─Sequential: 3-7                        [16, 1024, 16, 16]        26,090,496
│    │    └─Sequential: 3-8                        [16, 2048, 8, 8]          14