In [1]:
%load_ext autoreload
%autoreload 2

In [33]:
import torch
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import ToTensor, Normalize, Compose, Resize
from dataset import SegmentationDataset
from models import VCN32, VCN16, VCN8
from train import VCNTrainer

# Semantic Segmentation with Fully Convolutional Networks
Semantic Segmentation is a task where you look at every pixel of an image and classify it in one of C categories. It's like coloring images in a coloring book. Green for grass. Blue for sky. Yellow for sun. So on and so forth. Training a neural net for semantic segmentation is a much more fancy way of saying "I'm going to teach an AI to be an expert at coloring within the lines."

To do this task we'll use a fully convolutional network, a neural net that consists only of convolutional layers with ReLU activations and pooling layers mixed in.

## Config

## Dataset

In [16]:
torch.manual_seed(5)

<torch._C.Generator at 0x7f60f17c50d0>

In [17]:
transform = Compose([
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    Resize((224, 224), antialias=True)
])

target_transform = Compose([
    Resize((224, 224), antialias=True),
    lambda x: x[0] # first dimension is unnecessary
])

In [18]:
train_val_set = SegmentationDataset(root="data/stage1_train", train=True, transform=transform, target_transform=target_transform)
train_set, val_set = random_split(train_val_set, lengths=[0.8, 0.2]) # 80/20 split
test_set = SegmentationDataset(root="data/stage1_test", train=False)

In [19]:
print(f"train_set size: {len(train_set)}, val_set size: {len(val_set)}, test_set size: {len(test_set)}")

train_set size: 536, val_set size: 134, test_set size: 65


In [20]:
BATCH_SIZE = 32

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)

In [21]:
for (X, Y) in train_loader:
    print(X.shape, Y.shape)
    break

torch.Size([32, 3, 224, 224]) torch.Size([32, 224, 224])


## Device

In [23]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## Models
VCN32, VCN16, and VCN8

In [25]:
NUM_CLASSES = 2
vcn32 = VCN32(pretrained=True, num_classes=NUM_CLASSES)

Downloading: "https://download.pytorch.org/models/vgg11-8a719046.pth" to /home/studio-lab-user/.cache/torch/hub/checkpoints/vgg11-8a719046.pth
100%|██████████| 507M/507M [00:01<00:00, 367MB/s] 


## Training

In [55]:
torch.cuda.empty_cache()

In [56]:
EPOCHS = 3
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
LR = 1e-3

In [57]:
cross_entropy = torch.nn.CrossEntropyLoss()

In [58]:
vcn32_optimizer = torch.optim.SGD(vcn32.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

In [59]:
vcn32_trainer = VCNTrainer(
    model=vcn32,
    criterion=cross_entropy,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=vcn32_optimizer,
    epochs=EPOCHS,
    device=device,
    checkpoint="vcn32.pth"
)

In [60]:
vcn32_trainer.train()

Epoch [1/3]:   0%|          | 0/17 [00:11<?, ?it/s, summary=Loss: 0.5870959162712097, Val Loss: 0.5986226081848145]


Epoch [1/3]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


{'model': VCN32(
   (down1): DownSample(
     (conv_pool): Sequential(
       (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (1): ReLU()
       (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
     )
   )
   (down2): DownSample(
     (conv_pool): Sequential(
       (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (1): ReLU()
       (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
     )
   )
   (down3): DownSample(
     (conv_pool): Sequential(
       (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (1): ReLU()
       (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
       (3): ReLU()
       (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
     )
   )
   (down4): DownSample(
     (conv_pool): Sequential(
       (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)