In [1]:
import torchvision.transforms as transforms

data_transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])

In [2]:
from torchvision import datasets

train_data = datasets.VOCSegmentation(root="./data",
                                      year="2012",
                                      image_set="train",
                                      transform=data_transform,
                                      target_transform=data_transform)

test_data = datasets.VOCSegmentation(root="./data",
                                     year="2012",
                                     image_set="val",
                                     transform=data_transform,
                                     target_transform=data_transform)

viz_data = datasets.VOCSegmentation(root="./data", year="2012", image_set="train")

In [3]:
from torch.utils.data import DataLoader

BATCH_SIZE = 16

train_dataloader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=False)

print(f"train_dataloader: {len(train_dataloader)} * {BATCH_SIZE} images")
print(f"test_dataloader: {len(test_dataloader)} * {BATCH_SIZE} images")

train_dataloader: 92 * 16 images
test_dataloader: 91 * 16 images


In [4]:
import torch
from torch import nn
import torch.nn.functional as F

class conv2DBatchNorm(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, padding, dilation, bias, activation=False):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels,
                              kernel_size, stride,
                              padding, dilation, bias=bias)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.activation = activation
        if self.activation:
            self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.batchnorm(x)
        if self.activation:
            outputs = self.relu(x)
        else:
            outputs = x

        return outputs

In [5]:
import torchvision
from torchvision import models

class CustomResNet101(nn.Module):
    def __init__(self):
        super().__init__()
        self.resnet = models.resnet101(weights=torchvision.models.ResNet101_Weights.DEFAULT)

    def forward(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        auxiliary_x = self.resnet.layer3(x)
        x = self.resnet.layer4(auxiliary_x)

        return x, auxiliary_x

In [6]:
class PyramidPooling(nn.Module):
    def __init__(self, in_channels, pool_sizes, height, width):
        super().__init__()

        self.height = height
        self.width = width

        out_channels = int(in_channels / len(pool_sizes))
        
        conv = nn.Conv2d(in_channels=in_channels, out_channels=int(in_channels/len(pool_sizes)), kernel_size=1)

        self.pooling_layers = [
            nn.Sequential(
                nn.AdaptiveAvgPool2d(output_size=p), # Pool
                nn.Conv2d(in_channels=in_channels, out_channels=int(in_channels/len(pool_sizes)), kernel_size=1) # Conv
            )
            for p in pool_sizes
        ]
        
        ## pool__sizes: [6, 3, 2, 1]
        self.avpool_1 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[0])
        self.cbr_1 = conv

        self.avpool_2 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[1])
        self.cbr_2 = conv

        self.avpool_3 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[2])
        self.cbr_3 = conv

        self.avpool_4 = nn.AdaptiveAvgPool2d(output_size=pool_sizes[3])
        self.cbr_4 = conv

    def forward(self, x):
        out1 = self.cbr_1(self.avpool_1(x))
        out1 = F.interpolate(out1, size=(
            self.height, self.width), mode='bilinear', align_corners=True)

        out2 = self.cbr_2(self.avpool_2(x))
        out2 = F.interpolate(out2, size=(
            self.height, self.width), mode='bilinear', align_corners=True)

        out3 = self.cbr_3(self.avpool_3(x))
        out3 = F.interpolate(out3, size=(
            self.height, self.width), mode='bilinear', align_corners=True)

        out4 = self.cbr_4(self.avpool_4(x))
        out4 = F.interpolate(out4, size=(
            self.height, self.width), mode='bilinear', align_corners=True)

        output = torch.cat([x, out1, out2, out3, out4], dim=1)

        return output

In [7]:
from torch import nn
import torch.nn.functional as F

class PyramidPoolingModule(nn.Module):
    
    def __init__(self, pools, in_channels, input_shape):
        super().__init__()
        self.input_shape = input_shape
        self.pooling_layers = [
            nn.Sequential(
                nn.AdaptiveAvgPool2d(output_size=p), # Pool
                nn.Conv2d(in_channels=in_channels, out_channels=int(in_channels/len(pools)), kernel_size=1) # Conv
            ) 
            for p in pools
        ]
        
    def forward(self, x):
        outputs = [x]
        for pooling_layer in self.pooling_layers:
            layer_output = pooling_layer(x)
            outputs.append(F.interpolate(layer_output, size=self.input_shape, mode="bilinear", align_corners=True))

        return torch.cat(outputs, dim=1)
            

In [8]:
class DecodePSPFeature(nn.Module):
    def __init__(self, height, width, n_classes):
        super().__init__()

        self.height = height
        self.width = width

        self.cbr = conv2DBatchNorm(
            in_channels=4096, out_channels=64, kernel_size=3, stride=1,
            padding=1, dilation=1,  bias=False, activation=True)
        self.dropout = nn.Dropout2d(p=0.1)
        self.classification = nn.Conv2d(
            in_channels=64, out_channels=n_classes, kernel_size=1,
            stride=1, padding=0)

    def forward(self, x):
        x = self.cbr(x)
        x = self.dropout(x)
        x = self.classification(x)
        output = F.interpolate(
            x, size=(self.height, self.width),
            mode='bilinear', align_corners=True)

        return output

In [9]:
class PSPNet(nn.Module):
    def __init__(self, n_classes):
        super().__init__()

        full_img_size = 256
        feature_map_size = 8

        self.feature_extractor = CustomResNet101()
        self.pyramid_pooling = PyramidPooling(in_channels=2048, pool_sizes=[6, 3, 2, 1],height=feature_map_size, width=feature_map_size)
        #self.pyramid_pooling = PyramidPoolingModule(input_shape=(8, 8), in_channels=2048, pools = (1, 2, 3, 6))
        self.decode_feature = DecodePSPFeature(
            height=full_img_size, width=full_img_size,
            n_classes=n_classes)

    def forward(self, x):
        encoder_outputs, encoder_auxiliary = self.feature_extractor(x)
        pyramid_outputs = self.pyramid_pooling(encoder_outputs)
        docoder_outputs = self.decode_feature(pyramid_outputs)
        return docoder_outputs

In [10]:
torch.manual_seed(32)
model = PSPNet(n_classes=20)
input = next(iter(train_dataloader))[0]
pred = model(input)
pred

tensor([[[[ 1.8625e-01,  1.8634e-01,  1.8643e-01,  ...,  2.0709e-01,
            2.1605e-01,  2.2501e-01],
          [ 1.8610e-01,  1.8629e-01,  1.8647e-01,  ...,  2.2033e-01,
            2.2973e-01,  2.3913e-01],
          [ 1.8595e-01,  1.8623e-01,  1.8652e-01,  ...,  2.3357e-01,
            2.4341e-01,  2.5325e-01],
          ...,
          [ 3.0366e-01,  2.9213e-01,  2.8060e-01,  ..., -1.0854e-01,
           -1.2055e-01, -1.3257e-01],
          [ 3.0569e-01,  2.9403e-01,  2.8238e-01,  ..., -1.1351e-01,
           -1.2563e-01, -1.3775e-01],
          [ 3.0772e-01,  2.9594e-01,  2.8416e-01,  ..., -1.1848e-01,
           -1.3070e-01, -1.4292e-01]],

         [[-7.4000e-02, -7.0625e-02, -6.7251e-02,  ...,  3.4862e-01,
            3.5892e-01,  3.6922e-01],
          [-7.3670e-02, -7.0410e-02, -6.7150e-02,  ...,  3.3603e-01,
            3.4577e-01,  3.5552e-01],
          [-7.3340e-02, -7.0195e-02, -6.7049e-02,  ...,  3.2344e-01,
            3.3263e-01,  3.4181e-01],
          ...,
     

In [11]:
from torchinfo import summary

model = PSPNet(n_classes=20)
batch_size = 16

summary(
    model,
    input_size=(batch_size, 3, 256, 256),
    col_names=["output_size", "num_params"],
)

Layer (type:depth-idx)                             Output Shape              Param #
PSPNet                                             [16, 20, 256, 256]        --
├─CustomResNet101: 1-1                             [16, 2048, 8, 8]          --
│    └─ResNet: 2-1                                 --                        2,049,000
│    │    └─Conv2d: 3-1                            [16, 64, 128, 128]        9,408
│    │    └─BatchNorm2d: 3-2                       [16, 64, 128, 128]        128
│    │    └─ReLU: 3-3                              [16, 64, 128, 128]        --
│    │    └─MaxPool2d: 3-4                         [16, 64, 64, 64]          --
│    │    └─Sequential: 3-5                        [16, 256, 64, 64]         215,808
│    │    └─Sequential: 3-6                        [16, 512, 32, 32]         1,219,584
│    │    └─Sequential: 3-7                        [16, 1024, 16, 16]        26,090,496
│    │    └─Sequential: 3-8                        [16, 2048, 8, 8]          14,964,

In [12]:
class PSPLoss(nn.Module):
    def __init__(self, aux_weight=0.4):
        super().__init__()
        self.aux_weight = aux_weight

    def forward(self, outputs, targets):
        loss = F.cross_entropy(outputs, targets, reduction='mean')
        #loss_aux = F.cross_entropy(outputs[1], targets, reduction='mean')

        return loss #+ self.aux_weight * loss_aux

criterion = PSPLoss(aux_weight=0.4)

In [13]:
loss_fn = PSPLoss()
LEARNING_RATE = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [14]:
from tqdm import tqdm

## For the model training loop.
if torch.cuda.is_available():
    DEVICE = 'cpu'
else: DEVICE = 'cpu'

def train_fn(loader, model, optimizer, loss_fn, device=DEVICE):
    model.train()
    train_loss = 0.
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=device)
        targets = targets.to(device=device)

        predictions = model(data)
        targets = torch.argmax(targets, dim=1)
        loss = loss_fn(predictions, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loop.set_postfix(loss=loss.item())
        train_loss += loss.detach().cpu().numpy() * BATCH_SIZE

    train_loss = train_loss / (BATCH_SIZE * len(train_dataloader))
    return train_loss

## For the model validation loop.
def valid_fn(loader, model, loss_fn, device=DEVICE):
    model.eval()
    valid_loss = 0.
    loop = tqdm(loader)

    with torch.no_grad():
        for batch_idx, (data, targets) in enumerate(loop):
            data = data.to(device=device)
            targets = targets.to(device=device)

            predictions = model(data)
            targets = torch.argmax(targets, dim=1)
            loss = loss_fn(predictions, targets)
            valid_loss += loss * BATCH_SIZE

            loop.set_postfix(loss=loss.item())

        valid_loss = valid_loss / (BATCH_SIZE * len(test_dataloader))
    return valid_loss

In [18]:
## For the train & validation loop.
NUM_EPOCHS = 1

## DeepLabv3 model
model.to(device=DEVICE)

best_loss = 100
for epoch in range(NUM_EPOCHS):
    print('-------------')
    print('Epoch {}/{}'.format(epoch+1, NUM_EPOCHS))
    print('-------------')

    train_loss = train_fn(train_dataloader, model, optimizer, loss_fn, DEVICE)
    valid_loss = valid_fn(test_dataloader, model, loss_fn, DEVICE)

    if valid_loss < best_loss:
        checkpoint = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        torch.save(checkpoint, "./checkpoint.pth")
        print('best model saved!')
        best_loss = valid_loss

    print(f'Train Loss: {train_loss},  Valid Loss: {valid_loss}')

100%|██████████| 92/92 [18:08<00:00, 11.84s/it, loss=1.07]
100%|██████████| 91/91 [05:52<00:00,  3.87s/it, loss=1.1] 


best model saved!
Train Loss: 1.2260723995125813,  Valid Loss: 1.0915895700454712
-------------
Epoch 1/1
-------------


100%|██████████| 92/92 [17:39<00:00, 11.52s/it, loss=0.895]
100%|██████████| 91/91 [05:43<00:00,  3.77s/it, loss=0.859]


best model saved!
Train Loss: 0.9901133944158969,  Valid Loss: 0.8448228240013123
