In [1]:
### Imports ###
import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import pickle
from HumanimalClassifier import HumanimalClassifier
import config
import importlib
import numpy as np
from UtilFunctions import process_landmarks

In [3]:
### Configs ###


# Change Configs below in config.py
### Load Configs ###
importlib.reload(config)
LandModelType = config.LandModelType

## Stuff with the configs ##
# Load Landmarks from file ##
if LandModelType == LandModelType.Holistic:
    landmarkFilename = f'./Data/landmarks_holistic_default.pkl'
    landmarkFilenameVal = f'./Data/landmarks_holistic_val.pkl'
else:
    landmarkFilename = './Data/landmarks_'
    InFeatures = 0
    if LandModelType == LandModelType.HandAndPose or LandModelType == LandModelType.HandOnly:
        landmarkFilename += 'hand_'
    if LandModelType == LandModelType.HandAndPose or LandModelType == LandModelType.PoseOnly:
        landmarkFilename += 'pose_'
    landmarkFilenameVal = landmarkFilename + 'val.pkl'
    landmarkFilename += 'default.pkl'

with open(landmarkFilename, 'rb') as f:
    landmarks_dataset = pickle.load(f)
with open(landmarkFilenameVal, 'rb') as f:
    landmarks_dataset_val = pickle.load(f)

Labels = config.Labels
label_map = config.GetLabelMap()
inv_label_map = config.GetInversLableMap()

InFeatures = config.GetInFetures()
OutClasses = len(Labels)

## Print relevant Config ##
print(LandModelType)
print("Out Classes: " + str(OutClasses))
print("In Features: " + str(InFeatures))
print("Hidden layer: " + str(config.Hiddenlayer))
print("Batch Size: " + str(config.batch_size))
print("Learning Rate: " + str(config.LearningRate))
print("Num epochs: " + str(config.num_epochs))

LandmarkModelEnum.HandAndPose
Out Classes: 33
In Features: 225
Hidden layer: 450
Batch Size: 16
Learning Rate: 0.00025
Num epochs: 40


In [4]:
## Train Data
landmarks_list = [item[0] for item in landmarks_dataset]
labels_list = [item[1] for item in landmarks_dataset]

landmarks_tensor = torch.stack(landmarks_list)
labels_list = [label_map[item[1]] for item in landmarks_dataset]
labels_tensor = torch.tensor(labels_list)

tensor_dataset = TensorDataset(landmarks_tensor, labels_tensor)
data_loader = DataLoader(tensor_dataset, batch_size=config.batch_size, shuffle=True)

## Validation Data
landmarks_list_val = [item[0] for item in landmarks_dataset_val]
labels_list_val = [item[1] for item in landmarks_dataset_val]

landmarks_tensor_val = torch.stack(landmarks_list_val)
labels_list_val = [label_map[item[1]] for item in landmarks_dataset_val]
labels_tensor_val = torch.tensor(labels_list_val)

tensor_dataset_val = TensorDataset(landmarks_tensor_val, labels_tensor_val)
data_loader_val = DataLoader(tensor_dataset_val, batch_size=config.batch_size, shuffle=True)

# Setup the model, criterion, and optimizer
model = HumanimalClassifier(in_feat=InFeatures, hiddenlayer=config.Hiddenlayer, num_classes=OutClasses)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=config.LearningRate)

# Start training loop
for epoch in range(config.num_epochs):
    loss_total = 0
    correct_total = 0
    for i, (inputs, labels) in enumerate(data_loader):
        inputs = inputs.float()
        labels = labels.float()  # change it back to float
        labels_one_hot = F.one_hot(labels.to(torch.int64), num_classes=OutClasses)  # use labels as int64 for one-hot function
        outputs = model(inputs)
        loss = criterion(outputs, labels_one_hot.float())

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update total loss
        loss_total += loss.item()

        # Calculate accuracy
        _, predicted = torch.max(outputs.data, 1)
        correct = (predicted == labels).sum().item()
        correct_total += correct

    # calculate average loss and accuracy
    avg_loss = loss_total / len(data_loader)
    avg_acc = correct_total / len(data_loader.dataset)
    print(f'Epoch [{epoch + 1}/{config.num_epochs}], Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.2f}')
            
    #test after each epoch
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in data_loader_val:
            inputs = inputs.float()
            labels = labels.float()
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        valid_accuracy = correct / total
        print(f'Epoch [{epoch + 1}/{config.num_epochs}], Validation Accuracy: {valid_accuracy:.2f}')

    model.train()


print('Finished Training')


if LandModelType == LandModelType.Holistic:
    ModelFilename = f'./Data/Model_classifier_holistic.pth'
else:
    ModelFilename = './Data/Model_classifier'
    if LandModelType == LandModelType.HandAndPose or LandModelType == LandModelType.HandOnly:
        ModelFilename += '_hand'
        InFeatures += config.InFeaturesHand
    if LandModelType == LandModelType.HandAndPose or LandModelType == LandModelType.PoseOnly:
        ModelFilename += '_pose'
        InFeatures += config.InFeaturesPose
    ModelFilename += '.pth'
    
for name, param in model.named_parameters():
    print(f'{name}: {param.shape}')
torch.save(model.state_dict(), ModelFilename)



Epoch [1/40], Loss: 0.0168, Accuracy: 0.68
Epoch [1/40], Validation Accuracy: 0.81
Epoch [2/40], Loss: 0.0089, Accuracy: 0.88
Epoch [2/40], Validation Accuracy: 0.84
Epoch [3/40], Loss: 0.0063, Accuracy: 0.93
Epoch [3/40], Validation Accuracy: 0.86
Epoch [4/40], Loss: 0.0051, Accuracy: 0.94
Epoch [4/40], Validation Accuracy: 0.89
Epoch [5/40], Loss: 0.0044, Accuracy: 0.95
Epoch [5/40], Validation Accuracy: 0.90
Epoch [6/40], Loss: 0.0038, Accuracy: 0.96
Epoch [6/40], Validation Accuracy: 0.88
Epoch [7/40], Loss: 0.0035, Accuracy: 0.97
Epoch [7/40], Validation Accuracy: 0.89
Epoch [8/40], Loss: 0.0032, Accuracy: 0.97
Epoch [8/40], Validation Accuracy: 0.90
Epoch [9/40], Loss: 0.0029, Accuracy: 0.97
Epoch [9/40], Validation Accuracy: 0.89
Epoch [10/40], Loss: 0.0027, Accuracy: 0.97
Epoch [10/40], Validation Accuracy: 0.91
Epoch [11/40], Loss: 0.0026, Accuracy: 0.98
Epoch [11/40], Validation Accuracy: 0.90
Epoch [12/40], Loss: 0.0024, Accuracy: 0.98
Epoch [12/40], Validation Accuracy: 0.9