In [14]:
import torch 
import numpy as np
from PIL import Image
from torchvision import transforms
import h5py
import matplotlib.pyplot as plt
import random
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
import os
from datasets import CustomCOCOADataset

In [None]:
phase = 'val'
root_dict = {'train': "../data/COCOA/train2014", 'val': "../data/COCOA/val2014"}

img_root = root_dict[phase]
annot_path = "../data/COCOA/annotations/COCO_amodal_{}2014.json".format(phase)

data_reader = CustomCOCOADataset(annot_path)

In [None]:
from torch.utils.data import DataLoader

val_dataloader = DataLoader(data_reader, batch_size=64, shuffle=True)

In [19]:
from collections import OrderedDict

import torch
import torch.nn as nn


class UNet(nn.Module):

    def __init__(self, in_channels=3, out_channels=9, init_features=32):
        super(UNet, self).__init__()

        features = init_features
        self.encoder1 = UNet._block(in_channels, features, name="enc1")
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = UNet._block(features, features * 2, name="enc2")
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")

        self.upconv4 = nn.ConvTranspose2d(
            features * 16, features * 8, kernel_size=2, stride=2
        )
        self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
        self.upconv3 = nn.ConvTranspose2d(
            features * 8, features * 4, kernel_size=2, stride=2
        )
        self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
        self.upconv2 = nn.ConvTranspose2d(
            features * 4, features * 2, kernel_size=2, stride=2
        )
        self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
        self.upconv1 = nn.ConvTranspose2d(
            features * 2, features, kernel_size=2, stride=2
        )
        self.decoder1 = UNet._block(features * 2, features, name="dec1")

        self.conv = nn.Conv2d(
            in_channels=features, out_channels=out_channels, kernel_size=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.pool1(enc1))
        enc3 = self.encoder3(self.pool2(enc2))
        enc4 = self.encoder4(self.pool3(enc3))

        bottleneck = self.bottleneck(self.pool4(enc4))

        dec4 = self.upconv4(bottleneck)
        dec4 = torch.cat((dec4, enc4), dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = torch.cat((dec3, enc3), dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        
        dec0 = self.conv(dec1)
        return torch.sigmoid(dec0[0]), torch.tanh(dec0[1:])

    @staticmethod
    def _block(in_channels, features, name):
        return nn.Sequential(
            OrderedDict(
                [
                    (
                        name + "conv1",
                        nn.Conv2d(
                            in_channels=in_channels,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm1", nn.BatchNorm2d(num_features=features)),
                    (name + "relu1", nn.ReLU(inplace=True)),
                    (
                        name + "conv2",
                        nn.Conv2d(
                            in_channels=features,
                            out_channels=features,
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                    ),
                    (name + "norm2", nn.BatchNorm2d(num_features=features)),
                    (name + "relu2", nn.ReLU(inplace=True)),
                ]
            )
        )

In [24]:
from tqdm import tqdm

def train(epochs, device, net, losses, optimizer, train_loader, val_loader,
          in_key=0, target_key=1, mask_key=2, loss_alpha=0.5, scheduler=None, 
          checkpoint=False, checkpoint_dir='./models/', exp_name='unet'):
    net = net.to(device)
    loss_seg, loss_graph = losses
    
    # Create checkpoint dir
    logging_dir_name = exp_name + '_' + str(time.time()) + '/'
    checkpoint_dir = checkpoint_dir + logging_dir_name
    os.mkdir(checkpoint_dir)
    
    # Phases and Logging
    phases = { 'train': train_loader, 
               'val': val_loader }
    start_time = time.time()
    train_log = []

    # Training
    for i in range(epochs):
        epoch_data = { 'train_mean_loss_seg': 0.0, 'train_mean_loss_graph': 0.0,
                       'val_mean_loss_seg': 0.0, 'val_mean_loss_graph': 0.0 }
        for phase, loader in phases.items():
            if phase == 'train':
                net.train()
            else:
                net.eval()
            
            running_losses = np.zeros(2)
            for batch in tqdm(loader):
                _in, _out, _mask = batch[in_key].to(device), batch[target_key].to(device), batch[mask_key].to(device)
                _out_seg, _out_graph = _out
                
                # Forward
                optimizer.zero_grad()
                output_seg, output_graph = net(_in)
                
                # Apply graph loss to masked outputs
                output_graph, _out_graph = output_graph[_mask != 0], _out_graph[_mask != 0]
                loss0, loss1 = loss_seg(output_seg, _out_seg), loss_graph(output_graph, _out_graph)
                loss = alpha * loss0 + (1 - alpha) * loss1
                
                # Optimize
                if phase == 'train':
                    loss.backward()
                    self.optimizer.step()
                    
                # Log batch results
                running_losses += [loss0.item(), loss1.item()]
                torch.cuda.empty_cache()
                
            # Log phase results
            running_loss_seg, running_loss_graph = running_losses
            epoch_data[phase + '_mean_loss_seg'] = running_loss_seg / len(loader)
            epoch_data[phase + '_mean_loss_graph'] = running_loss_graph / len(loader)

        # Display Progress
        duration_elapsed = time.time() - start_time
        print('\n-- Finished Epoch {}/{} --'.format(i, epochs - 1))
        print('Training Loss (Segmentation): {}'.format(epoch_data['train_mean_loss_seg']))
        print('Training Loss (Graph): {}'.format(epoch_data['train_mean_loss_graph']))
        print('Validation Loss (Segmentation): {}'.format(epoch_data['val_mean_loss_seg']))
        print('Validation Loss (Graph): {}'.format(epoch_data['val_mean_loss_graph']))
        print('Time since start: {}'.format(duration_elapsed))
        epoch_data['time_elapsed'] = duration_elapsed
        train_log.append(epoch_data)

        # Scheduler
        if scheduler:
            scheduler.step()

        # Checkpoint
        if checkpoint:
            path = checkpoint_dir + 'checkpoint_' + str(i) + '_' + str(time.time())
            torch.save(net.state_dict(), path)

        # Save train_log
        path = checkpoint_dir + 'train_log_' + str(time.time()) 
        with open(path, 'wb') as fp:
            pickle.dump(train_log, fp)

    return train_log

In [25]:
net = UNet()
cuda_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = net.to(cuda_device)

In [26]:
crit1 = nn.CrossEntropyLoss()
crit2 = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-05)

# TODO: add scheduler
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[epochs to drop], gamma=0.1)

In [None]:
# Expects loaders to return (input image, (target segmentation, target graph), mask)

# train_loader = ...
# val_loader = ...
train(10, cuda_device, net, [crit1, crit2], optimizer, train_loader, val_loader, 
      loss_alpha=0.5, scheduler=None, checkpoint=True)