# Facial Keypoints - Model Training

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

%load_ext autoreload
%autoreload 2

### Prepare the Dataset

In [None]:
# Split the training images into a Training (80%) and Validation (20%)
# Use the images in the test directory for testing on unseen data.

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.utils.data.sampler import SubsetRandomSampler

from facepoint.data_load import FacialKeypointsDataset
from facepoint.data_load import Rescale, RandomCrop, Normalize, ToTensor

BATCH_SIZE = 32
TRAIN_IMG_DIR = "data/training/"
TRAIN_IMG_KEYPTS = "data/training_frames_keypoints.csv"
TEST_IMG_DIR = "data/test/"
TEST_IMG_KEYPTS = "data/test_frames_keypoints.csv"


# Define a Data Transform to apply to images in the dataset.
data_transform = transforms.Compose([Rescale(250), RandomCrop(224), Normalize(), ToTensor()])


# Create training and validation data loaders.
transformed_dataset = FacialKeypointsDataset(csv_file=TRAIN_IMG_KEYPTS, 
                                             root_dir=TRAIN_IMG_DIR, 
                                             transform=data_transform)
train_sampler = SubsetRandomSampler(range(2770)) # about 80% of the training images 
validation_sampler = SubsetRandomSampler(range(2771,len(transformed_dataset)))

train_loader = DataLoader(transformed_dataset, 
                          batch_size=BATCH_SIZE, 
                          sampler=train_sampler)
validation_loader = DataLoader(transformed_dataset, 
                               batch_size=BATCH_SIZE, 
                               sampler=validation_sampler)

# Create the test dataset loader
test_dataset = FacialKeypointsDataset(csv_file=TEST_IMG_KEYPTS, 
                                      root_dir=TEST_IMG_DIR, 
                                      transform=data_transform)
test_loader = DataLoader(test_dataset, 
                         batch_size=BATCH_SIZE, 
                         shuffle=True)

### Train the Model

In [None]:
# Get the Model

from facepoint.models import Net

net = Net()
print(net)

In [None]:
# Train the Model

import torch.optim as optim
import torch.nn as nn
from facepoint.model_trainer import ModelTrainer

criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

trainer = ModelTrainer(net, optimizer, criterion, train_loader, validation_loader)
trainer.init_weights()
trainer.train(300)

In [None]:
# Test on some sample data
from facepoint.utils import visualize_output, net_sample_output

test_images, test_outputs, gt_pts = net_sample_output(test_loader, net)
visualize_output(test_images, test_outputs, gt_pts, 4)

In [None]:
# Save model
trainer.save_weights()