In [28]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [62]:
import torch
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import ToTensor, Normalize, Compose, Resize
from dataset import SegmentationDataset

# Semantic Segmentation with Fully Convolutional Networks
Semantic Segmentation is a task where you look at every pixel of an image and classify it in one of C categories. It's like coloring images in a coloring book. Green for grass. Blue for sky. Yellow for sun. So on and so forth. Training a neural net for semantic segmentation is a much more fancy way of saying "I'm going to teach an AI to be an expert at coloring within the lines."

To do this task we'll use a fully convolutional network, a neural net that consists only of convolutional layers with ReLU activations and pooling layers mixed in.

## Dataset

In [63]:
torch.manual_seed(5)

<torch._C.Generator at 0x105d50bb0>

In [69]:
transform = Compose([
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    Resize((224, 224), antialias=True)
])

target_transform = Compose([
    Resize((224, 224)),
    lambda x: x[0] # first dimension is unnecessary
])

In [70]:
train_val_set = SegmentationDataset(root="data/stage1_train", train=True, transform=transform, target_transform=target_transform)
train_set, val_set = random_split(train_val_set, lengths=[0.8, 0.2]) # 80/20 split
test_set = SegmentationDataset(root="data/stage1_test", train=False)

In [71]:
print(f"train_set size: {len(train_set)}, val_set size: {len(val_set)}, test_set size: {len(test_set)}")

train_set size: 536, val_set size: 134, test_set size: 65


In [72]:
BATCH_SIZE = 20

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)

In [74]:
for (X, Y) in train_loader:
    print(Y.shape)
    print(Y)
    break

torch.Size([20, 224, 224])
tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        ...,

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  