In [44]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import torch
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import ToTensor, Normalize, Compose, Resize
from dataset import SegmentationDataset
from models import VCN32, VCN16, VCN8
from train import VCNTrainer

# Semantic Segmentation with Fully Convolutional Networks
Semantic Segmentation is a task where you look at every pixel of an image and classify it in one of C categories. It's like coloring images in a coloring book. Green for grass. Blue for sky. Yellow for sun. So on and so forth. Training a neural net for semantic segmentation is a much more fancy way of saying "I'm going to teach an AI to be an expert at coloring within the lines."

To do this task we'll use a fully convolutional network, a neural net that consists only of convolutional layers with ReLU activations and pooling layers mixed in.

## Config

In [46]:
BATCH_SIZE = 32
NUM_CLASSES = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 70
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
LR = 1e-3

## Dataset

In [47]:
torch.manual_seed(5)

<torch._C.Generator at 0x7f98261e8ed0>

In [48]:
transform = Compose([
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    Resize((224, 224), antialias=True)
])

target_transform = Compose([
    Resize((224, 224), antialias=True),
    lambda x: x[0] # first dimension is unnecessary
])

In [49]:
train_val_set = SegmentationDataset(root="data/stage1_train", train=True, transform=transform, target_transform=target_transform)
train_set, val_set = random_split(train_val_set, lengths=[0.8, 0.2]) # 80/20 split
test_set = SegmentationDataset(root="data/stage1_test", train=False)

In [50]:
print(f"train_set size: {len(train_set)}, val_set size: {len(val_set)}, test_set size: {len(test_set)}")

train_set size: 536, val_set size: 134, test_set size: 65


In [51]:
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)

In [52]:
for (X, Y) in train_loader:
    print(X.shape, Y.shape)
    break

torch.Size([32, 3, 224, 224]) torch.Size([32, 224, 224])


## Models
VCN32, VCN16, and VCN8

In [53]:
vcn32 = VCN32(pretrained=True, num_classes=NUM_CLASSES)
vcn16 = VCN16(pretrained=True, num_classes=NUM_CLASSES)
vcn8 = VCN8(pretrained=True, num_classes=NUM_CLASSES)

## Training

In [54]:
torch.cuda.empty_cache()

In [55]:
cross_entropy = torch.nn.CrossEntropyLoss()

In [56]:
vcn32_optimizer = torch.optim.SGD(vcn32.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
vcn16_optimizer = torch.optim.SGD(vcn16.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
vcn8_optimizer = torch.optim.SGD(vcn8.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)

In [57]:
vcn32_trainer = VCNTrainer(
    model=vcn32,
    criterion=cross_entropy,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=vcn32_optimizer,
    epochs=EPOCHS,
    device=DEVICE,
    checkpoint="vcn32.pth"
)

In [None]:
vcn32_trainer.train()

Epoch [1/70]: 100%|██████████| 17/17 [02:45<00:00,  9.72s/it, summary=Loss: 0.6675023436546326, Val Loss: 0.6632475137710572]


Epoch [1/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [2/70]: 100%|██████████| 17/17 [02:45<00:00,  9.74s/it, summary=Loss: 0.6367778182029724, Val Loss: 0.6294708013534546]


Epoch [2/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [3/70]: 100%|██████████| 17/17 [02:44<00:00,  9.70s/it, summary=Loss: 0.5951410531997681, Val Loss: 0.5991144061088562]


Epoch [3/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [4/70]: 100%|██████████| 17/17 [02:44<00:00,  9.69s/it, summary=Loss: 0.5899320840835571, Val Loss: 0.575039541721344] 


Epoch [4/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [5/70]: 100%|██████████| 17/17 [02:45<00:00,  9.71s/it, summary=Loss: 0.5609776377677917, Val Loss: 0.5460091829299927]


Epoch [5/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [6/70]: 100%|██████████| 17/17 [02:47<00:00,  9.84s/it, summary=Loss: 0.5249054431915283, Val Loss: 0.5273350596427917]


Epoch [6/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [7/70]: 100%|██████████| 17/17 [02:45<00:00,  9.71s/it, summary=Loss: 0.5158845782279968, Val Loss: 0.5048510372638703]


Epoch [7/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [8/70]: 100%|██████████| 17/17 [02:45<00:00,  9.71s/it, summary=Loss: 0.5035630464553833, Val Loss: 0.5026364862918854]  


Epoch [8/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [9/70]: 100%|██████████| 17/17 [02:45<00:00,  9.76s/it, summary=Loss: 0.4650995135307312, Val Loss: 0.4851538360118866] 


Epoch [9/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [10/70]: 100%|██████████| 17/17 [02:46<00:00,  9.78s/it, summary=Loss: 0.5126964449882507, Val Loss: 0.4748647093772888]  


Epoch [10/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [11/70]: 100%|██████████| 17/17 [02:46<00:00,  9.78s/it, summary=Loss: 0.4747842252254486, Val Loss: 0.47506238222122193]


Epoch [11/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [12/70]: 100%|██████████| 17/17 [02:46<00:00,  9.77s/it, summary=Loss: 0.4492447078227997, Val Loss: 0.46348450779914857] 


Epoch [12/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [13/70]: 100%|██████████| 17/17 [02:45<00:00,  9.73s/it, summary=Loss: 0.4532831013202667, Val Loss: 0.45325044393539426] 


Epoch [13/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [14/70]: 100%|██████████| 17/17 [02:45<00:00,  9.73s/it, summary=Loss: 0.44953399896621704, Val Loss: 0.45444504022598264]


Epoch [14/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [15/70]: 100%|██████████| 17/17 [02:45<00:00,  9.72s/it, summary=Loss: 0.4875869154930115, Val Loss: 0.44419227838516234] 


Epoch [15/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [16/70]: 100%|██████████| 17/17 [02:45<00:00,  9.71s/it, summary=Loss: 0.45591309666633606, Val Loss: 0.43889243006706236]


Epoch [16/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [17/70]: 100%|██████████| 17/17 [02:45<00:00,  9.72s/it, summary=Loss: 0.40672600269317627, Val Loss: 0.4345195770263672] 


Epoch [17/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [18/70]: 100%|██████████| 17/17 [02:44<00:00,  9.69s/it, summary=Loss: 0.4496973752975464, Val Loss: 0.4420386254787445]  


Epoch [18/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [19/70]: 100%|██████████| 17/17 [02:45<00:00,  9.73s/it, summary=Loss: 0.40854164958000183, Val Loss: 0.42373030781745913]


Epoch [19/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [20/70]: 100%|██████████| 17/17 [02:45<00:00,  9.71s/it, summary=Loss: 0.4180986285209656, Val Loss: 0.41624556183815004]


Epoch [20/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [21/70]: 100%|██████████| 17/17 [02:44<00:00,  9.68s/it, summary=Loss: 0.3883826732635498, Val Loss: 0.41332321166992186] 


Epoch [21/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [22/70]: 100%|██████████| 17/17 [02:44<00:00,  9.70s/it, summary=Loss: 0.35051316022872925, Val Loss: 0.42052493095397947]


Epoch [22/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [23/70]: 100%|██████████| 17/17 [02:44<00:00,  9.71s/it, summary=Loss: 0.43299517035484314, Val Loss: 0.41227291226387025]


Epoch [23/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [24/70]: 100%|██████████| 17/17 [02:45<00:00,  9.71s/it, summary=Loss: 0.47888296842575073, Val Loss: 0.4202948033809662] 


Epoch [24/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [25/70]: 100%|██████████| 17/17 [02:44<00:00,  9.67s/it, summary=Loss: 0.4623141586780548, Val Loss: 0.41584287881851195]


Epoch [25/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [26/70]: 100%|██████████| 17/17 [02:44<00:00,  9.69s/it, summary=Loss: 0.4053952097892761, Val Loss: 0.4201237797737122]  


Epoch [26/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [27/70]: 100%|██████████| 17/17 [02:45<00:00,  9.76s/it, summary=Loss: 0.37936580181121826, Val Loss: 0.40087261199951174]


Epoch [27/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [28/70]: 100%|██████████| 17/17 [02:44<00:00,  9.69s/it, summary=Loss: 0.4453032612800598, Val Loss: 0.4193118572235107] 


Epoch [28/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [42/70]: 100%|██████████| 17/17 [02:44<00:00,  9.70s/it, summary=Loss: 0.4450494647026062, Val Loss: 0.3933979690074921] 


Epoch [42/70]: Train Pixel Acc: 27.142089257578032, Val Pixel Acc: 23.122253667091837


Epoch [43/70]:  59%|█████▉    | 10/17 [01:36<01:07,  9.59s/it, summary=Loss: 0.40890181064605713, Val Loss: 0.40001668930053713]