In [28]:
import torchvision.transforms as transforms

data_transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])

In [29]:
from torchvision import datasets

train_data = datasets.VOCSegmentation(root="./data",
                                      year="2012",
                                      image_set="train",
                                      download="true",
                                      transform=data_transform,
                                      target_transform=data_transform)

test_data = datasets.VOCSegmentation(root="./data",
                                     year="2012",
                                     image_set="val",
                                     download="true",
                                     transform=data_transform,
                                     target_transform=data_transform)

viz_data = datasets.VOCSegmentation(root="./data", year="2012", image_set="train")

Using downloaded and verified file: ./data\VOCtrainval_11-May-2012.tar
Extracting ./data\VOCtrainval_11-May-2012.tar to ./data
Using downloaded and verified file: ./data\VOCtrainval_11-May-2012.tar
Extracting ./data\VOCtrainval_11-May-2012.tar to ./data


In [30]:
from torch.utils.data import DataLoader

BATCH_SIZE = 16

train_dataloader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=False)

print(f"train_dataloader: {len(train_dataloader)} * {BATCH_SIZE} images")
print(f"test_dataloader: {len(test_dataloader)} * {BATCH_SIZE} images")

train_dataloader: 92 * 16 images
test_dataloader: 91 * 16 images


In [207]:
import torchvision
from torch import nn
from torchvision import models

class CustomResNet101(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet = models.resnet101(weights=torchvision.models.ResNet101_Weights.DEFAULT)

    def forward(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        auxiliary_x = self.resnet.layer3(x) # res4b22
        x = self.resnet.layer4(auxiliary_x) # res5c

        return x, auxiliary_x

In [206]:
import torch
import torch.nn.functional as F

class PyramidPooling(nn.Module):
    def __init__(self, in_channels, pool_sizes, input_size):
        super().__init__()
        self.input_size = input_size
        self.pooling_layers = [
            nn.Sequential(
                nn.AdaptiveAvgPool2d(output_size=p),
                nn.Conv2d(in_channels=in_channels, out_channels=int(in_channels/len(pool_sizes)), kernel_size=1)
            )
            for p in pool_sizes
        ]

    def forward(self, x):
        outputs = [x]
        for pool_layer in self.pooling_layers:
            layer_output = pool_layer(x)
            outputs.append(F.interpolate(layer_output, size=self.input_size, mode='bilinear', align_corners=True))

        return torch.cat(outputs, dim=1)

In [205]:
class PSPNet(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.encoder = CustomResNet101()
        self.pyramid_pooling = PyramidPooling(in_channels=2048, pool_sizes=[6, 3, 2, 1], input_size=(8, 8))

    def forward(self, x):
        encoder_output, encoder_auxiliary_output = self.encoder(x)
        pyramid_output = self.pyramid_pooling(encoder_output)
        return pyramid_output

In [208]:
model = PSPNet(n_classes=4)
input = next(iter(train_dataloader))[0]
pred = model(input)
input.shape

torch.Size([16, 3, 256, 256])

In [169]:
model = PSPNet(n_classes=4)
batch_size = 16

summary(
    model,
    input_size=(batch_size, 3, 256, 256),
    col_names=["output_size", "num_params"],
)

Layer (type:depth-idx)                             Output Shape              Param #
PSPNet                                             [16, 4096, 8, 8]          185,096
├─FeatureExtractor2: 1-1                           [16, 2048, 8, 8]          --
│    └─ResNet: 2-1                                 --                        2,049,000
│    │    └─Conv2d: 3-1                            [16, 64, 128, 128]        9,408
│    │    └─BatchNorm2d: 3-2                       [16, 64, 128, 128]        128
│    │    └─ReLU: 3-3                              [16, 64, 128, 128]        --
│    │    └─MaxPool2d: 3-4                         [16, 64, 64, 64]          --
│    │    └─Sequential: 3-5                        [16, 256, 64, 64]         215,808
│    │    └─Sequential: 3-6                        [16, 512, 32, 32]         1,219,584
│    │    └─Sequential: 3-7                        [16, 1024, 16, 16]        26,090,496
│    │    └─Sequential: 3-8                        [16, 2048, 8, 8]          14