In [73]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import h5py
from sklearn.model_selection import KFold
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from networks.correlation_package.correlation import Correlation

In [74]:
COUNT      = 20399
WIDTH      = 300
HEIGHT     = 150
FLOW_CHAN  = 2
FRAME_CHAN = 6

In [75]:
PATH_TRAIN    = './data/train.h5'
OP_FLOW_ID    = 'OPF'
FRAME_PAIR_ID = 'FRM'
LABEL_ID      = 'LBL'

In [76]:
class CommaData(Dataset):
    """
    PyTorch dataset for Farneback optical flow / consecutive
    frame pairs + speed labels
    """
    def __init__(self, data_path, list_ids=None, category=OP_FLOW_ID):
        super(CommaData, self).__init__()
        self.file = h5py.File(data_path, 'r')
        if list_ids is None:
            self.list_ids = np.arange(len(self.file[LABEL_ID]))
        else:
            self.list_ids = list_ids
        self.category = category
        self.transform = transforms.ToTensor()

    def __len__(self):
        return len(self.list_ids)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        i = self.list_ids[idx]
        X = self.file[self.category][i]
        y = self.file[LABEL_ID][i]
        return self.transform(X), y

    def __del__(self):
        self.file.close()

In [77]:
class Conv(nn.Module):
    """
    Package Conv2d with batch normalization and activation
    """
    def __init__(
        self, in_channels, out_channels, kernel_size, stride=1, padding=0, batch_norm=False
    ):
        super(Conv, self).__init__()
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding
        )
        self.act = nn.ReLU(inplace=True)
        if batch_norm:
            self.norm = nn.BatchNorm2d(out_channels)
        else:
            self.norm = None

    def forward(self, x):
        x = self.conv(x)
        x = self.act(x)
        if self.norm is not None:
            x = self.norm(x)
        return x

class OpFlowNet(nn.Module):
    """
    Basic ConvNet for speed estimation with optical flow input
    """
    def __init__(self):
        super(OpFlowNet, self).__init__()

        # conv layers
        self.conv0 = Conv(FRAME_CHAN, 8, 3)
        self.conv1 = Conv(8, 16, 3)
        self.conv2 = Conv(16, 32, 3)
        self.conv3 = Conv(32, 64, 3)
        self.conv4 = Conv(64, 128, 3)
        self.conv5 = Conv(128, 256, 3)
        self.conv6 = Conv(256, 512, 3, padding=1)

        # pooling layers
        self.pool0 = nn.MaxPool2d((1, 2))
        self.pool1 = nn.MaxPool2d((2, 2))

        # linear layers
        self.linear0 = nn.Linear(512, 256)
        self.linear1 = nn.Linear(256, 64)
        self.linear2 = nn.Linear(64, 1)

        self.drop = nn.Dropout(.2)
        self.flat = nn.Flatten()

    def forward(self, x):
        x = self.conv0(x)
        x = self.pool0(x)
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool1(x)
        x = self.conv3(x)
        x = self.pool1(x)
        x = self.conv4(x)
        x = self.pool1(x)
        x = self.conv5(x)
        x = self.pool1(x)
        x = self.conv6(x)
        x = self.drop(x)
        x = self.flat(x)
        x = F.relu(self.linear0(x))
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x

class FrameNet(nn.Module):
    """
    ConvNet with correlation layer from FlowNetC
    https://arxiv.org/pdf/1504.06852.pdf
    """
    def __init__(self, batch_norm=True):
        super(FrameNet, self).__init__()

        # top input conv layers
        self.conv0a = Conv(3, 32, 3, stride=2, batch_norm=batch_norm)

        # bottom input conv layers
        self.conv0b = Conv(3, 32, 3, stride=2, batch_norm=batch_norm)

        # correlation layer
        self.corr = Correlation(
            pad_size=10, kernel_size=1, max_displacement=10, stride1=1, stride2=2, corr_multiply=1
        )
        self.corr_act = nn.LeakyReLU(0.1, inplace=True)

        # merged conv stream
        self.conv4 = Conv(121, 64, 7, stride=2)
        self.conv5 = Conv(64, 64, 5, stride=2)

        # pooling layers
        self.pool0 = nn.MaxPool2d((1, 2))
        self.pool1 = nn.MaxPool2d((2, 2))
        self.pool2 = nn.AvgPool2d((3, 3))

        # linear layers
        self.linear0 = nn.Linear(64, 32)
        self.linear2 = nn.Linear(32, 1)

        self.drop = nn.Dropout(.2)
        self.flat = nn.Flatten()

    def forward(self, x):
        # split input
        x0 = x[:, 0:3, :, :]
        x1 = x[:, 3:, :, :]

        # top input stream
        x0 = self.conv0a(x0)
        x0 = self.pool0(x0)

        # bottom input stream
        x1 = self.conv0b(x1)
        x1 = self.pool0(x1)

        # merge streams
        x = self.corr(x0, x1)
        x = self.corr_act(x)

        # merged conv stream
        x = self.conv4(x)
        x = self.pool1(x)
        x = self.conv5(x)
        x = self.pool1(x)
        x = self.pool2(x)
        x = self.drop(x)
        x = self.flat(x)

        # feed forward
        x = F.relu(self.linear0(x))
        x = self.linear2(x)

        return x

In [78]:

device = 'cuda:0'
max_epochs = 100
batch_size = 32
patience = 1

net = OpFlowNet().to(device)
optimizer = optim.Adam(net.parameters(), lr=0.001, eps=1e-7)
criterion = nn.MSELoss()

kf = KFold(n_splits=5, shuffle=True)

for train_index, val_index in kf.split(np.arange(COUNT)):
    # train fold generator
    train_set = CommaData('./data/train.h5', train_index, FRAME_PAIR_ID)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)

    # validation fold generator
    valid_set = CommaData('./data/train.h5', val_index, FRAME_PAIR_ID)
    valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=True)

    net = FrameNet().to(device)
    optimizer = optim.Adam(net.parameters(), lr=0.001, eps=1e-7)

    best_val_loss = np.inf
    val_count = 0

    for epoch in range(max_epochs):
        running_loss = 0.

        # train network on train folds
        net.train()
        for i, data in enumerate(train_loader):
            X, y = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            output = net(X).squeeze()
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if i % 50 == 49:
                print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 50))
                running_loss = 0.

        # evaluate network on eval fold
        net.eval()
        with torch.no_grad():
            running_loss = 0.
            for i, data in enumerate(valid_loader):
                X, y = data[0].to(device), data[1].to(device)
                output = net(X).squeeze()
                loss = criterion(output, y)
                running_loss += loss.item()
            val_loss = running_loss / (i+1)
            print('Valid. loss: ' + str(val_loss))
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                val_count = 0
            else:
                val_count += 1
                if val_count >= patience:
                    break

[1,    50] loss: 63.088


KeyboardInterrupt: 