In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
from utils.tensorboard_helper import write_loss, write_model_pred_figures
from datasets.media_gesture import MediaGestureDataset
from datasets.transformations import BasicTransform, NormalizeMaxSpan, WristAsOrigin
from nets.hand_ges_rec_net import HandGesRecNet
import copy

## Init and start training
to monitor training we use tensorboard, command line: `tensorboard --logdir=runs`

In [3]:
batch_size = 32
lr = 1e-3
epochs = 20_000

# log setting
num_fig_per_epoch = 1 #  number of figure to log for each epoch

In [4]:
dataset = MediaGestureDataset(
    transform=nn.Sequential(
        BasicTransform(),
        NormalizeMaxSpan(1),
        WristAsOrigin()
    )
)
train_dataset, valid_dataset = random_split(dataset, [0.8, 0.2])
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
writer = SummaryWriter()
model = HandGesRecNet(dataset.feature_cnt, dataset.class_cnt, dataset.label_idx_2_name)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

In [5]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size, epoch_loss = len(dataloader.dataset), 0
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * len(X)
    epoch_loss /= size
    return epoch_loss

def val_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item() * len(X)
            correct += (pred.argmax(1) == y.argmax(1)).type(torch.float).sum().item()

    test_loss /= size
    accuracy = correct / size
    return test_loss, accuracy

In [6]:
#
best_model = copy.deepcopy(model)
best_test_loss = float("inf")
# write model
writer.add_graph(model, input_to_model=next(iter(train_dataloader))[0])
# training
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loss = train_loop(train_dataloader, model, loss_fn, optimizer)
    print(f"train loss: {train_loss:>6f}  [{t+1:>5d}/{epochs:>5d}]")
    test_loss, accuracy = val_loop(valid_dataloader, model, loss_fn)
    print(f"Test Accuracy: {accuracy*100:>0.1f}% Test loss: {test_loss:>7f}\n")
    # save best model
    if test_loss <= best_test_loss and t > int(epochs * 0.9):
        best_model, best_test_loss = copy.deepcopy(model), test_loss
    # write to board
    write_loss(writer, train_loss, test_loss, t+1)
    if t % 500 == 0:
        write_model_pred_figures(writer, model, dataset, 4, t + 1, figsize=6)

print("Done!")

Epoch 1
-------------------------------
train loss: 0.701651  [    1/20000]
Test Accuracy: 20.0% Test loss: 0.715459

Epoch 2
-------------------------------
train loss: 0.701400  [    2/20000]
Test Accuracy: 20.0% Test loss: 0.715232

Epoch 3
-------------------------------
train loss: 0.701161  [    3/20000]
Test Accuracy: 20.0% Test loss: 0.714985

Epoch 4
-------------------------------
train loss: 0.700931  [    4/20000]
Test Accuracy: 23.3% Test loss: 0.714749

Epoch 5
-------------------------------
train loss: 0.700681  [    5/20000]
Test Accuracy: 23.3% Test loss: 0.714490

Epoch 6
-------------------------------
train loss: 0.700499  [    6/20000]
Test Accuracy: 26.7% Test loss: 0.714265

Epoch 7
-------------------------------
train loss: 0.700245  [    7/20000]
Test Accuracy: 26.7% Test loss: 0.713996

Epoch 8
-------------------------------
train loss: 0.700006  [    8/20000]
Test Accuracy: 26.7% Test loss: 0.713745

Epoch 9
-------------------------------
train loss: 0.69

## save model

In [8]:
save_path = "./trained_nets/hand_ges_rec_net"
best_model.label_idx_to_name = dataset.label_idx_2_name
best_model.transform = dataset.transform
torch.save(best_model, save_path)