In [None]:
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold

In [None]:
TRAIN_COUNT = 20399

class OpFlowDataset(Dataset):

    def __init__(self, data_path, list_ids):
        super(OpFlowDataset, self).__init__()
        self.file = h5py.File(data_path, 'r')
        self.list_ids = list_ids
        self.transform = transforms.ToTensor()

    def __len__(self):
        return len(self.list_ids)

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        i = self.list_ids[index]
        X = self.file['OPF'][i]
        y = self.file['LBL'][i]
        return self.transform(X), y

In [None]:
class Conv2dReLU(nn.Module):

    def __init__(
        self, in_channels, out_channels, kernel_size, stride=1, padding=0
    ):
        super(Conv2dReLU, self).__init__()
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding
        )

    def forward(self, x):
        return F.relu(self.conv(x))

class SpeedNet(nn.Module):

    def __init__(self):
        super(SpeedNet, self).__init__()
        self.layers = nn.Sequential(
            Conv2dReLU(2, 8, 3),
            nn.MaxPool2d((1, 2)),
            Conv2dReLU(8, 16, 3),
            nn.MaxPool2d((1, 2)),
            Conv2dReLU(16, 32, 3),
            nn.MaxPool2d((2, 2)),
            Conv2dReLU(32, 64, 3),
            nn.MaxPool2d((2, 2)),
            Conv2dReLU(64, 128, 3),
            nn.MaxPool2d((2, 2)),
            Conv2dReLU(128, 256, 3),
            Conv2dReLU(256, 512, 3, padding=1),
            nn.MaxPool2d((2, 2)),
            nn.AvgPool2d((7, 2)),
            nn.Dropout(.2),
            nn.Flatten(),
            nn.Linear(512, 500),
            nn.ReLU(),
            nn.Linear(500, 250),
            nn.ReLU(),
            nn.Linear(250, 1)
        )

    def forward(self, x):
        return self.layers(x)

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if use_cuda else 'cpu')

kf = KFold(n_splits=5)

params = {
    'batch_size': 32,
    'shuffle': True
}
max_epochs = 100
patience = 3
criterion = nn.MSELoss()

for train_index, val_index in kf.split(np.arange(TRAIN_COUNT)):
    # train fold generator
    train_set = OpFlowDataset('./data/train.h5', train_index)
    train_loader = DataLoader(train_set, **params)

    # validation fold generator
    valid_set = OpFlowDataset('./data/train.h5', val_index)
    valid_loader = DataLoader(valid_set, **params)

    net = SpeedNet().to(device)
    optimizer = optim.Adam(net.parameters(), lr=0.001, eps=1e-7)

    best_val_loss = np.inf
    val_count = 0

    for epoch in range(max_epochs):
        running_loss = 0.
        for i, data in enumerate(train_loader):
            X, y = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            output = net(X).squeeze()
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if i % 100 == 99:
                print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
                running_loss = 0.

        net.eval()
        with torch.no_grad():
            running_loss = 0.
            for i, data in enumerate(valid_loader):
                X, y = data[0].to(device), data[1].to(device)
                output = net(X).squeeze()
                loss = criterion(output, y)
                running_loss += loss.item()
            val_loss = running_loss / (i+1)
            print('Valid. loss: ' + str(val_loss))
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                val_count = 0
            else:
                val_count += 1
                if val_count > patience:
                    break